transformers
ba6815e8 - Replace NumPy Operations with JAX NumPy Equivalents for JIT Compilation Compatibility (#23356)

Commit
2 years ago
Replace NumPy Operations with JAX NumPy Equivalents for JIT Compilation Compatibility (#23356) * Replace numpy operations with jax.numpy for JIT compatibility Replaced numpy operations with their jax.numpy equivalents in the transformer library. This change was necessary to prevent errors during JIT compilation. Specifically, the modifications involve changing numpy's in-place assignments to jax.numpy's immutable update methods. * rm numpy import * rm numpy import and fix np->jnp * fixed slices bug * fixed decoder_start_tokens -> decoder_start_token_id * fixed jnp in modleing mt5 * doc fix * rm numpy import * make
Author
Parents
Loading