[torch_xla2] DistributedDataParallel wrapper with minGPT example #7847
Try shard_map inside forward
a9583207
pull out step function
f444fa2f
no explicit collectives, only sharding
5da90a01
Remove mapping and explicit collectives
87fae37c
multiprocess TPU equivalence check
976ffaf4
mingpt (single process only, no CPU)
53a33f39
fix multiprocess config
18e8636c
shard input over addressable devices
059ee1fd
scale rate tracker to global batch size
720fc660
jit training step
df0de5f2
correctness check
85e302ca
tweak LR
62c9df5f
disable JIT and show example input/output
9b9f1d8a
small fixes and tweaks
f539e686
add `jit_step` decorator
c5789820
clip cpu gradients
5542b6f5
clean up
451cc35e
formatting and cleanup
9ab0219a
comments and docstrings
5ed1b2f2
remove redundant div implementation
93928ffb
typo
b7ffa770
will-cromar
marked this pull request as ready for review 1 year ago
fix arange dtype
fd75e304
cast arange inputs with astype
a0d12d45
cast using numpy
c3b77b61
add spawn back to unbreak tests
5bf4cd13
ManfeiBai
approved these changes
on 2024-08-14
fix recompilations
f4dbb756
qihqi
approved these changes
on 2024-08-15
Merge branch 'master' of github.com:pytorch/xla into wcromar/ddp-wrapper
40126459
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub