flax
feat(nnx): add out_sharding to recurrent cells
#5255
Open

feat(nnx): add out_sharding to recurrent cells #5255

aarushisingh04
gemini-code-assist
gemini-code-assist
gemini-code-assist commented on 2026-02-15
aarushisingh04 feat(nnx): Add out_sharding to recurrent cell __call__ methods
fcca8e94
aarushisingh04 aarushisingh04 force pushed from 6e465017 to fcca8e94 129 days ago
aarushisingh04
gemini-code-assist
gemini-code-assist commented on 2026-02-15
aarushisingh04 Update flax/nnx/nn/recurrent.py
331ff8e5
samanklesaria
samanklesaria approved these changes on 2026-04-21
IvyZX
IvyZX approved these changes on 2026-04-23
IvyZX IvyZX added pull ready
IvyZX
IvyZX requested changes on 2026-04-23
IvyZX IvyZX removed pull ready
aarushisingh04 test(nnx): verify out_sharding applies correct sharding to recurrent …
261ddfb4
IvyZX
IvyZX requested changes on 2026-04-28
aarushisingh04 fix: use meaningful sharding spec and device-count batch size in out_…
0a18acc4

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone