Use shard_as in scan to ensure that inputs and their gradients have the same sharding #8879
tengyifei
marked this pull request as ready for review 1 year ago
tengyifei
changed the base branch from
master
to
yifeit/call-jax-cache
1 year ago
tengyifei
force pushed
from
3823d66f
to
705814f4
1 year ago
tengyifei
force pushed
from
705814f4
to
b2140fa9
1 year ago
tengyifei
changed the base branch from
yifeit/call-jax-cache
to
master
1 year ago
tengyifei
force pushed
from
b2140fa9
to
ba4b6ab1
1 year ago
qihqi
approved these changes
on 2025-03-28
Use shard_as in scan to ensure that inputs and their gradients have t…
0a85f09e
Add back removed API
903119f9
Simplify
f625d707
Simplify
4637fb09
Add test
5c15061e
yapf
0bbfd1b8
Fix tests
4337f9e7
yapf
9dcfe1c3
tengyifei
force pushed
from
5f84ac1d
to
9dcfe1c3
1 year ago
tengyifei
merged
6d88c089
into master 1 year ago
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub