xla
[torchax]: JittableModule statedict handling
#9195
Merged

[torchax]: JittableModule statedict handling #9195

zmelumian972
zmelumian97299 days ago

torchax aims to improve seamless interoperability between torch and jax

one of the parts in torch training pipeline revolves around storing and loading statedict (checkpoints)

most of the objects revolving torch checkpoints expect a (non nested) dict containing weight name and it's value (in either CPU or GPU device)

since torchax tensors are held in jax container, torch checkpointers cannot easily handle it

this changes forces JittableModule to convert state_dict() functions (both load and get), making it seamless to the user when he wants to extract the statdict prior to saving it as a checkpoint

zmelumian972 zmelumian972 changed the title Torchax: JittableModule statedict handling [Torchax]: JittableModule statedict handling 99 days ago
zmelumian972 zmelumian972 changed the title [Torchax]: JittableModule statedict handling [torchax]: JittableModule statedict handling 99 days ago
qihqi qihqi requested a review from qihqi qihqi 99 days ago
qihqi
qihqi commented on 2025-05-20
Conversation is marked as resolved
Show resolved
torchax/torchax/interop.py
126127
127128 self._jitted[key] = call
129
130
def state_dict(self, *args, **kwargs):
qihqi99 days ago

I think this function should return torchax.Tensors (and leave the users who want to move it to CPU call the move).

Or maybe name it cpu_state_dict.

in regular torch.nn.Module's state_dict method returns references to weights inside of the module, so there is an implicit understanding that there is no copy.

zmelumian97298 days ago

I prefer cpu_state_dict and keep it torch native

keeping it in torchax.Tensors might make it harder for checkpointers who expect native torch tensors to work (because torch.save does not expect torchax.Tensors)

Conversation is marked as resolved
Show resolved
torchax/torchax/interop.py
148 """
149 Wrapper for load_state_dict
150
151
This function assumes torch CPU state dict and will transfer the parameters to the correct device
qihqi99 days ago

Maybe make this function work both with CPU tensors as well as torchax.Tensors?

zmelumian97298 days ago

Sure, that's a good idea

I will simply ignore converting torch tensor to torchax.tensor.Tensor if it's already there, meaning both torchax.tensor.Tensor is acceptable and torch.Tensor is acceptable

Conversation is marked as resolved
Show resolved
torchax/torchax/interop.py
qihqi99 days ago

Can probably make use of from jax.experimental.shard_alike import shard_alike.

Something like:

sharded_dict_jax = pytree.tree_map(lambda cpu_tensor, orig: shard_alike(tensor.t2j(cpu_tensor), orig), state_dict, current_state_dict)
state_dict = torch_view(sharded_dict_jax)
...
zmelumian97298 days ago

It's a bit unnatural here because state_dict and current_state_dict are not promised to have the same tree structure (user can specify strict=False and call with partial state dict)

iirc, pytrees do not support that natively

qihqi98 days ago

I see, makes sense.

zmelumian [torchax] Support for JittableModule::state_dict()
0c045cbf
zmelumian972 zmelumian972 force pushed from 43404e5d to 8b9bf46d 98 days ago
qihqi qihqi requested a review from qihqi qihqi 98 days ago
qihqi qihqi requested a review from qihqi qihqi 98 days ago
qihqi
qihqi approved these changes on 2025-05-20
qihqi98 days ago

Thanks! Please fix the lint with

yapf -i -r *.py test/ scripts/ torch_xla/ benchmarks/ torchax/

Thanks

zmelumian [torchax] Added JittableModule::load_state_dict mechanism
121cc008
zmelumian972 zmelumian972 force pushed from 8b9bf46d to 121cc008 98 days ago
zmelumian972
zmelumian97294 days ago

how do I move forward? I am unfamiliar with pytorch/XLA CI and resources

qihqi qihqi enabled auto-merge (squash) 85 days ago
disabled auto-merge 71 days ago
Manually disabled by user
qihqi qihqi enabled auto-merge (squash) 71 days ago
qihqi
qihqi71 days ago

Hi @zmelumian972 can you rebase? The GPU CIs are not merge blocking anymore on the new HEAD. I also enabled auto-merge so if CI passes it should merge automatically. Thanks!

qihqi
qihqi approved these changes on 2025-06-16
zmelumian972
zmelumian97266 days ago

Done :)

disabled auto-merge 65 days ago
Manually disabled by user
qihqi qihqi merged afe425e2 into master 65 days ago

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone