accelerate
396bb23e - feat: allow device_map auto to work with XLA devices

Commit
1 year ago
feat: allow device_map auto to work with XLA devices With this change it is possible to map a model to an XLA device (tested on a TPU v5e), given that the max_memory map is provided as parameter (because the TPU's XLA get_memory_info is not implemented). Usage example: ```python TPU_HBM_MEM = 16 * 1024 * 1024 * 1024 tpu_devices = xm.get_xla_supported_devices() tpu_device_map = { i : TPU_HBM_MEM for i in range(len(tpu_devices)) } cpu_device_map = {"cpu": psutil.virtual_memory().available * 0.8} max_memory= tpu_device_map | cpu_device_map model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", max_memory=max_memory, torch_dtype=torch_dtype) ```
Author
Committer
Parents
Loading