Add out_sharding argument to call methods for layers with jax calls that support it #5102
samanklesaria
force pushed
from
67bafc18
to
11950870
110 days ago
samanklesaria
force pushed
from
11950870
to
96f0a646
110 days ago
Add out_sharding argument to call methods for standard layers
1f61d6b7
samanklesaria
force pushed
from
96f0a646
to
1f61d6b7
110 days ago
samanklesaria
changed the title Add out_sharding argument to call methods for standard layers Add out_sharding argument to call methods for layers with jax calls that support it 110 days ago
cgarciae
approved these changes
on 2025-11-25
Fix nits
66544200
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub