Replace NumPy Operations with JAX NumPy Equivalents for JIT Compilation Compatibility #23356
Replace numpy operations with jax.numpy for JIT compatibility
96e9dd42
rm numpy import
848f8df1
rm numpy import and fix np->jnp
175ab5dd
fixed slices bug
d110c34d
fixed decoder_start_tokens -> decoder_start_token_id
f25d5881
fixed jnp in modleing mt5
c74abb81
doc fix
33b5904e
rm numpy import
b6b5e67b
make
c90229e0
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub