The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Very cool API ! I like the design and how easily it is to use. I left a few comments around the split_points
mainly.
90 | def forward(*args, **kwargs): | ||
91 | return model_forward(*args, **kwargs) | ||
92 | |||
93 | # To act like a decorator so that it can be popped when doing `extract_model_from_parallel` | ||
94 | forward.__wrapped__ = model_forward |
nice !
21 | """ | ||
22 | Calculates the device map for `model` with an offset for PiPPy | ||
23 | """ | ||
24 | no_split_module_classes = getattr(model, "_no_split_modules", []) | ||
25 | if num_processes == 1: | ||
26 | return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False) | ||
27 | model_size, shared = calculate_maximum_sizes(model) | ||
28 | |||
29 | # Split into `n` chunks for each GPU | ||
30 | memory = (model_size + shared[0]) / num_processes | ||
31 | memory = convert_bytes(memory) | ||
32 | value, ending = memory.split(" ") | ||
33 | |||
34 | # Add a chunk to deal with potential extra shared memory instances | ||
35 | memory = math.ceil(float(value)) * 1.1 | ||
36 | memory = f"{memory} {ending}" | ||
37 | device_map = infer_auto_device_map( | ||
38 | model, | ||
39 | max_memory={i: memory for i in range(num_processes)}, | ||
40 | no_split_module_classes=no_split_module_classes, | ||
41 | clean_result=False, | ||
42 | ) |
We can definitely generate a balanced device_map
for pippy exclusively "device_map = "balanced_pippy"
if the current balanced option is not the best for that. However, I think it would be great if the user can use other options like "sequential". I didn't try but what happens when we only fill 2 gpus out of the 4 available (possible sequential
case) ?
78 | """ | ||
79 | example_args = send_to_device(example_args, "cpu") | ||
80 | example_kwargs = send_to_device(example_kwargs, "cpu") | ||
81 | if device_map == "auto": | ||
82 | device_map = generate_device_map(model, PartialState().num_processes) | ||
83 | stage = build_pipeline(model, device_map, example_args, example_kwargs) |
Just a thought about how to handle the split points.
device_map
with predefined options ("sequential", "balanced_pippy")device_map
. For the custom case, it can be complicated since the user needs to be careful about the order (OrderedDict()) and he needs to attribute the gpu in a sequential manner because of that split_points.append(next(k for k, v in device_map.items() if v == i))
. So that can be quite complicated.List[str]
.Agreed to do 1 and 3
The API is in good shape ! Let's document the main functions a bit and we can merge it. I left a few comments but nothing blocking.
77 | state = PartialState() | ||
78 | example_args = send_to_device(example_args, "cpu") | ||
79 | example_kwargs = send_to_device(example_kwargs, "cpu") | ||
80 | if split_points == "auto": | ||
81 | device_map = generate_device_map(model, state.num_processes, no_split_module_classes=no_split_module_classes) | ||
82 | split_points = [] | ||
83 | for i in range(1, state.num_processes): | ||
84 | split_points.append(next(k for k, v in device_map.items() if v == i)) |
it would be great to have a sanity check, to make sure that we indeed have self.num_processes
split points when we are generating the split_points
+ when the user manually pass them
Thanks a lot for the integration effort!
LGTM!
Thx for iterating ! LGTM
Thanks for writing the doc so quick! Looks good to me!
cc @MKhalusova for the docs!
Nice work! I left a few comments to polish things in the docs a bit.
Final comment before merging, things that still need to be done in a latter PR at some point (but okay not being in the first iteration of this joint effort):
balanced_pippy
device map and allow a sequential
device_map when making the pipeline via prepare_pippy
model.generate()
through an alternative hook into the model forward if possiblepippy-device-map-playground
examples over to here as part of our examples
folder(I'll be doing 3& 4 this week as a follow-up prior to release)
Login to write a write a comment.
Example use:
Speed up:
Using 2x4090's in full precision
Bert
GPT2
T5