Thanks! Please fix the lint with
yapf -i -r *.py test/ scripts/ torch_xla/ benchmarks/ torchax/
Thanks
how do I move forward? I am unfamiliar with pytorch/XLA CI and resources
Hi @zmelumian972 can you rebase? The GPU CIs are not merge blocking anymore on the new HEAD. I also enabled auto-merge so if CI passes it should merge automatically. Thanks!
Done :)
Login to write a write a comment.
torchax aims to improve seamless interoperability between torch and jax
one of the parts in torch training pipeline revolves around storing and loading statedict (checkpoints)
most of the objects revolving torch checkpoints expect a (non nested) dict containing weight name and it's value (in either CPU or GPU device)
since torchax tensors are held in jax container, torch checkpointers cannot easily handle it
this changes forces JittableModule to convert state_dict() functions (both load and get), making it seamless to the user when he wants to extract the statdict prior to saving it as a checkpoint