diffusers
[Core] introduce _no_split_modules to `ModelMixin`
#6396
Merged

[Core] introduce _no_split_modules to `ModelMixin` #6396

sayakpaul merged 38 commits into main from feat/device-map-auto
sayakpaul
sayakpaul1 year ago (edited 1 year ago)🚀 2

What does this PR do?

Adds utilities to support _no_split_modules to the ModelMixin. Closely follows what's done in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py.

Part of #6240.

I think it's better to tackle the introduction of device_map="auto" to pipelines in multiple PRs. @SunMarc laid out a very nice plan here (internal Slack link).

TODO

  • Get initial reviews from an accelerate core maintainer
  • Propagate to other important models inheriting ModelMixin
  • Add tests
  • Docs (if needed)
sayakpaul introduce _no_split_modules.
0843c467
sayakpaul sayakpaul requested a review from patrickvonplaten patrickvonplaten 1 year ago
sayakpaul sayakpaul requested a review from SunMarc SunMarc 1 year ago
sayakpaul sayakpaul marked this pull request as draft 1 year ago
sayakpaul
sayakpaul commented on 2023-12-30
Conversation is marked as resolved
Show resolved
src/diffusers/models/unet_2d_condition.py
163163 """
164164
165165 _supports_gradient_checkpointing = True
166
_no_split_modules = ["Block"]
sayakpaul1 year ago

I'm a little confused as to what should go here. Seems like all the blocks having "Block" in their names (such as DownBlock2D, UNetMidBlock2DCrossAttn, etc.) should go here as they have residual connections inside of them (recommendation comes from here).

SunMarc1 year ago

Yeah, that's right. We need to add all potential blocks name that have residual connections inside of them. What is tricky with diffusers models it that it allows different types of transformers block (e.g. DownBlock2D, CrossAttnDownBlock2D) which is not the case in transformers. This makes the testing harder.

sayakpaul1 year ago

That indeed sounds tricky.

So, let's think this through.

If you take at the config.json of the UNet of SDXL here, you would notice the distinct block types there. Now these blocks are composed of lower-level blocks that have residual connections. Would adding these higher-level blocks cut it for us?

SunMarc1 year ago

We can definitely add these higher level-blocks instead of the lower-level blocks if they are not occupying too much memory space. I think that if we can have blocks of < ~1GB, it should be good enough.

sayakpaul1 year ago👍 1

I think that if we can have blocks of < ~1GB, it should be good enough.

Could you elaborate this a bit? Additionally, it would be helpful if you maybe locally test this PR and let me know your findings, that would be great!

SunMarc1 year ago

I probably got confused by the term higher level-block. Can you explain what do you mean by that ? I was saying that it is okay to not split high level modules if they are small enough.
For example, in this example where the submodules have residual connections, the best case would be to set _no_split_modules = [SubModule 1, SubModule 2, SubModule 3]. But we can decide to set _no_split_modules = [Module 1] if Module 1 is not too big.
(Module 1)
----(SubModule 1)
----(SubModule 2)
----(SubModule 3)

sayakpaul1 year ago👍 1

Asked a question over Slack.

sayakpaul unnecessary spaces.
e60abd3c
sayakpaul remove unnecessary kwargs and style
c3745125
sayakpaul
sayakpaul commented on 2023-12-30
src/diffusers/models/modeling_utils.py
sayakpaul1 year ago👍 1

I am not sure why we used to pass this but these are NOT used in configuration_utils.py anywhere. Given that, I think they are best removed:

  • No unwanted cognitive burden in thinking about what these are doing for configuration parsing.
  • Reduces LoC (albeit small)
sayakpaul fix: accelerate imports.
6f5ae67b
SunMarc
SunMarc commented on 2024-01-08
SunMarc1 year ago

Thanks for working on this @sayakpaul ! This is exactly what I was thinking ! Let's first make it work on diffusers models and extending it to pipeline should be straightforward !

Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
829909 else: # else let accelerate handle loading and dispatching.
830910 # Load weights and dispatch according to the device_map
831911 # by default the device_map is None and the weights are loaded on the CPU
912
device_map = _determine_device_map_from_string(model, device_map, max_memory, torch_dtype)
SunMarc1 year ago👍 1

maybe change the naming since we are passing the device_map even when it is not a string

sayakpaul1 year ago

Done in 726df08.

sayakpaul merge main and resolve conflicts.
6c401563
sayakpaul change to _determine_device_map
726df088
HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev1 year ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

sayakpaul add the blocks that have residual connections.
f8d8afae
sayakpaul
sayakpaul1 year ago

@SunMarc so, I incorporated the changes and tested with:

from diffusers import UNet2DConditionModel

unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto")
print(unet.hf_device_map)

It prints:

{'': 0}

I tested this on a single GPU. Does this look correct?

@patrickvonplaten I have gone through the structures but would appreciate a confirmation if BasicTransformerBlock and ResnetBlock2D are indeed the only blocks that contain a residual path in their forward() method (consider the base model is an SDXL UNet).

SunMarc
SunMarc1 year ago

I tested this on a single GPU. Does this look correct?

Yes, it looks correct. Try to play with multiple gpu and if you are able to run the model correctly since users uses device_map to split the model on multiple gpus.

sayakpaul
sayakpaul1 year ago

Try to play with multiple gpu and if you are able to run the model correctly since users uses device_map to split the model on multiple gpus.

Do you mean using the same code example but on multiple GPUs? How should the inputs be constructed, then? How should we handle device placement for them?

sayakpaul
sayakpaul1 year ago (edited 1 year ago)

@SunMarc I tried on two GPUs. Here are some findings.

Test code
from diffusers import UNet2DConditionModel
import torch 

unet = UNet2DConditionModel.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="sequential"
)
print(unet.hf_device_map)

# Inputs
sample = torch.randn(1, 4, 128, 128).to("cuda")
t = torch.randint(1, 1000, size=(1, )).to("cuda")
encoder_hidden_states = torch.randn(1, 77, 2048).to("cuda")
add_text_embeds = torch.randn(1, 1280).to("cuda")
add_time_ids = torch.randn(1, 6).to("cuda")
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}

# Forward
with torch.no_grad():
    outputs = unet(
        sample=sample,
        timestep=t,
        encoder_hidden_states=encoder_hidden_states,
        added_cond_kwargs=added_cond_kwargs
    ).sample
    print(outputs.shape)

With ["BasicTransformerBlock", "ResnetBlock2D"] specified in _no_split_modules of UNet2DConditionModel, it leads to the following device map:

{'conv_in': 0, 'time_proj': 0, 'time_embedding': 0, 'add_time_proj': 0, 'add_embedding': 0, 'down_blocks': 0, 'up_blocks.0.attentions.0': 0, 'up_blocks.0.attentions.1.norm': 0, 'up_blocks.0.attentions.1.proj_in': 0, 'up_blocks.0.attentions.1.transformer_blocks.0': 0, 'up_blocks.0.attentions.1.transformer_blocks.1': 1, 'up_blocks.0.attentions.1.transformer_blocks.2': 1, 'up_blocks.0.attentions.1.transformer_blocks.3': 1, 'up_blocks.0.attentions.1.transformer_blocks.4': 1, 'up_blocks.0.attentions.1.transformer_blocks.5': 1, 'up_blocks.0.attentions.1.transformer_blocks.6': 1, 'up_blocks.0.attentions.1.transformer_blocks.7': 1, 'up_blocks.0.attentions.1.transformer_blocks.8': 1, 'up_blocks.0.attentions.1.transformer_blocks.9': 1, 'up_blocks.0.attentions.1.proj_out': 1, 'up_blocks.0.attentions.2': 1, 'up_blocks.0.resnets': 1, 'up_blocks.0.upsamplers': 1, 'up_blocks.1': 1, 'up_blocks.2': 1, 'mid_block': 1, 'conv_norm_out': 1, 'conv_act': 1, 'conv_out': 1}

However, it leads to the following error:

Traceback (most recent call last):
  File "/home/sayak/diffusers/test_single_file.py", line 19, in <module>
    outputs = unet(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/unet_2d_condition.py", line 1197, in forward
    sample = upsample_block(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/unet_2d_blocks.py", line 2324, in forward
    hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument tensors in method wrapper_CUDA_cat)

"CrossAttnUpBlock2D" is the block that causes this and when added to _no_split_modules alongside ["BasicTransformerBlock", "ResnetBlock2D"], the error went away and I was able to obtain the output. The device map prints as follows:

{'': 0}

Seems like nothing is being split, which I think is the expected result here?

SunMarc
SunMarc1 year ago

Do you mean using the same code example but on multiple GPUs? How should the inputs be constructed, then? How should we handle device placement for them?

The inputs will be automatically dispatched to the right device because accelerate adds hooks for that to the modules.

"CrossAttnUpBlock2D" is the block that causes this and when added to _no_split_modules alongside ["BasicTransformerBlock", "ResnetBlock2D"], the error went away and I was able to obtain the output. The device map prints as follows:

{'': 0}
Seems like nothing is being split, which I think is the expected result here?

No that should be the case since we want the model to be split. The results we should get is something like:

{'conv_in': 0, 'time_proj': 0, 'time_embedding': 0, 'add_time_proj': 0, 'add_embedding': 0, 'down_blocks': 0, 'up_blocks.0.attentions.0': 0, 'up_blocks.0.attentions.1': 1, 'up_blocks.0.attentions.2': 1, 'up_blocks.0.resnets': 1, 'up_blocks.0.upsamplers': 1, 'up_blocks.1': 1, 'up_blocks.2': 1, 'mid_block': 1, 'conv_norm_out': 1, 'conv_act': 1, 'conv_out': 1}

In the previous example, the inference failed since the CrossAttnUpBlock2D is concatenating hidden_states that are coming from different devices. I suspect the problem comes from this mapping which splits the attention block. So indeed, we should add CrossAttnUpBlock2D inside _no_split_modules. Another way would be to make that that hidden_states, res_hidden_states are on the same device but I prefer not to add anything in the modeling code :

{'up_blocks.0.attentions.1.norm': 0, 'up_blocks.0.attentions.1.proj_in': 0, 'up_blocks.0.attentions.1.transformer_blocks.0': 0, 'up_blocks.0.attentions.1.transformer_blocks.1': 1, 'up_blocks.0.attentions.1.transformer_blocks.2': 1, 'up_blocks.0.attentions.1.transformer_blocks.3': 1, 'up_blocks.0.attentions.1.transformer_blocks.4': 1, 'up_blocks.0.attentions.1.transformer_blocks.5': 1, 'up_blocks.0.attentions.1.transformer_blocks.6': 1, 'up_blocks.0.attentions.1.transformer_blocks.7': 1, 'up_blocks.0.attentions.1.transformer_blocks.8': 1, 'up_blocks.0.attentions.1.transformer_blocks.9': 1, 'up_blocks.0.attentions.1.proj_out': 1}

sayakpaul
sayakpaul1 year ago👍 1

Thanks for providing your inputs.

Another way would be to make that that hidden_states, res_hidden_states are on the same device but I prefer not to add anything in the modeling code :

Indeed this should be preferred. We don't want to touch the forward call until and unless absolutely necessary.

I suspect the problem comes from this mapping which splits the attention block. So indeed, we should add CrossAttnUpBlock2D inside _no_split_modules.

But when I did that the model doesn't seem to split though. What are we missing here? Would you be able to take deeper look or provide me pointers to see this through further?

SunMarc
SunMarc1 year ago👍 2

I've traced back to the issue. It is an issue on accelerate where the memory allocation + module placement is not very good when we have models where the largest non splittable layer is very big compared to the whole model. In our case, by specifying CrossAttnUpBlock2D , the module up_blocks.0 become non splittable and the fact that it represent half of the memory (5GB out of 10GB) and we get a bad module placement. This is why I was recommending to have smaller non splittable blocks. Nevertheless, this is what needs to be added into _no_split_modules if we don't want to modify the modeling file.
I can try to fix it in accelerate but I might require quite some time since it can impacting all models on transformers depending on the fix. This model is pretty small, so it will fit in one gpu. To continue with the PR, can you try other model by adding the _no_split_modules ? This way, we can try to see if this is a recurrent issue or not.

I forgot to mention but you can also put your own device_map to check if the inference works for a specific placement since the generated device_map is not optimal. For example, this device map works with the UNet2DConditionModel .
It shows that you indeed need to have the up_blocks non split.

device_map = {
    "conv_in": 0,
    "time_proj": 0,
    "time_embedding": 0,
    "add_time_proj": 0,
    "add_embedding": 0,
    "down_blocks": 0,
    "up_blocks.0": 0, 
    "up_blocks.1": 1,
    "up_blocks.2": 1,
    "mid_block": 1,
    "conv_norm_out": 1,
    "conv_act": 1,
    "conv_out": 1,
}
sayakpaul
sayakpaul1 year ago👍 1

Nevertheless, this is what needs to be added into _no_split_modules if we don't want to modify the modeling file.

I think we definitely don't want to change the modeling code following what transformers does.

I will try on other models and maybe even on a smaller GPU. The smallest I have access to is 16GB.

sayakpaul add: CrossAttnUpBlock2D
9215e357
sayakpaul add: testin
706d96e7
sayakpaul style
149ba913
sayakpaul line-spaces
4c88038b
sayakpaul
sayakpaul1 year ago (edited 1 year ago)🚀 1

@SunMarc seems like a good progress now.

Since I am trying on a machine having two 4090s, tried the following to restrict the memory so that device_map takes effect:

unet = UNet2DConditionModel.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    subfolder="unet", 
    device_map="auto",
    max_memory={0: "6GiB", 1: "10GiB"},
)
print(unet.hf_device_map)

Worked like a charm!

The device map:

{'conv_in': 0, 'time_proj': 0, 'time_embedding': 0, 'add_time_proj': 0, 'add_embedding': 0, 'down_blocks.0': 0, 'down_blocks.1': 0, 'down_blocks.2.attentions.0.norm': 0, 'down_blocks.2.attentions.0.proj_in': 0, 'down_blocks.2.attentions.0.transformer_blocks.0': 0, 'down_blocks.2.attentions.0.transformer_blocks.1': 0, 'down_blocks.2.attentions.0.transformer_blocks.2': 0, 'down_blocks.2.attentions.0.transformer_blocks.3': 0, 'down_blocks.2.attentions.0.transformer_blocks.4': 0, 'down_blocks.2.attentions.0.transformer_blocks.5': 0, 'down_blocks.2.attentions.0.transformer_blocks.6': 0, 'down_blocks.2.attentions.0.transformer_blocks.7': 0, 'down_blocks.2.attentions.0.transformer_blocks.8': 0, 'down_blocks.2.attentions.0.transformer_blocks.9': 1, 'down_blocks.2.attentions.0.proj_out': 1, 'down_blocks.2.attentions.1': 1, 'down_blocks.2.resnets': 1, 'up_blocks': 1, 'mid_block': 1, 'conv_norm_out': 1, 'conv_act': 1, 'conv_out': 1}

I have also added two tests closely following this and this. Have tested it too with the following:

RUN_SLOW=1 pytest tests/models/test_models_unet_2d_condition.py -k "offload"

I think we can add docs after we ship this feature to pipelines because that provides a fuller context.

Meanwhile, could you go through the PR once in detail and let me know your thoughts? Once that's done, will add _no_split_modules to other models and mark it ready for review.

Cc: @DN6 for awareness.

sayakpaul quality
796ee054
sayakpaul Merge branch 'main' into feat/device-map-auto
6801c22e
SunMarc
SunMarc commented on 2024-01-18
SunMarc1 year ago

Thanks for adding the cpu/disk offload tests ! Can you also add the multi-gpu test ?

Conversation is marked as resolved
Show resolved
tests/models/test_modeling_common.py
708 self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
709
710 @require_torch_gpu
711
def test_disk_offload(self):
SunMarc1 year ago

For disk offload, you need to distinguish two cases: with or without safetensors. See related tests in transformers. This one this the safetensors case.

sayakpaul1 year ago

Done in e9290c9.

sayakpaul
sayakpaul1 year ago

@SunMarc we don't run multi-gpu tests yet because this hasn't been a strong case for us.

SunMarc
SunMarc1 year ago

@SunMarc we don't run multi-gpu tests yet because this hasn't been a strong case for us.

Makes sense. But still can we have them and skip them in the CI ? They are useful to check that we did the splitting correctly and are able to run them (_no_split_modules )

sayakpaul
sayakpaul1 year ago

@SunMarc we don't run multi-gpu tests yet because this hasn't been a strong case for us.

Makes sense. But still can we have them and skip them in the CI ? They are useful to check that we did the splitting correctly and are able to run them (_no_split_modules )

Good idea. I will add it. Apart from that changes you requested, is there anything else on you would like me to change as far as the core design goes?

SunMarc
SunMarc1 year ago❤ 1

No, the core design looks very good ! It is similar to transformers and device_map is working well there.

sayakpaul Merge branch 'main' into feat/device-map-auto
7ecb597a
sayakpaul add disk offload test without safetensors.
e9290c9c
sayakpaul checking disk offloading percentages.
73a99dc5
sayakpaul change model split
b4e9f2f1
sayakpaul add: utility for checking multi-gpu requirement.
26fd4af2
sayakpaul model parallelism test
1055ba7d
sayakpaul splits.
6d336dca
sayakpaul splits.
8575ad6a
sayakpaul splits
92da3bb5
sayakpaul splits.
57bc6f71
sayakpaul splits.
945458aa
sayakpaul splits.
7e8f602f
sayakpaul offload folder to test_disk_offload_with_safetensors
a48d25cc
sayakpaul
sayakpaul1 year ago👍 1

@SunMarc I added the multi-GPU parallelism test and also test_disk_offload_without_safetensors. Some notes:

  • model_split_percents = [0.5, 0.3, 0.4] is the one that seems to work for both multi-GPU and single-GPU environments for the UNet under consideration. The size of the UNet is definitely small.
  • I had to pass an offload_folder to test_disk_offload_with_safetensors to make it work with the model_split_percents for the given UNet.

Let me know your thoughts.

SunMarc
SunMarc1 year ago

model_split_percents = [0.5, 0.3, 0.4] is the one that seems to work for both multi-GPU and single-GPU environments for the UNet under consideration. The size of the UNet is definitely small.

Makes sense, the model is small + the non splittable modules are big.

I had to pass an offload_folder to test_disk_offload_with_safetensors to make it work with the model_split_percents for the given UNet.

You can use disk offload without having to pass offload_folder when using safentesors format. Check this PR in transformers. This can be implemented in a follow up PR since it is not essential. LMK if you want my help on that.

sayakpaul Merge branch 'main' into feat/device-map-auto
cd9c29eb
sayakpaul
sayakpaul1 year ago👍 1

You can use disk offload without having to pass offload_folder when using safentesors format. Check this PR in transformers. This can be implemented in a follow up PR since it is not essential. LMK if you want my help on that.

@SunMarc thanks! I think it would be better to have it in a follow-up PR. Would appreciate your help in that :)

sayakpaul add _no_split_modules
fcd277c7
sayakpaul sayakpaul marked this pull request as ready for review 1 year ago
sayakpaul
sayakpaul1 year ago

@patrickvonplaten this is ready for a review now. I propose to add the docs after we ship device map support to pipelines to have more context. Let me know what you think.

@SunMarc in case you want to take another look (which I would appreciate since this is an important PR). But, really thank you for all your help thus far!

sayakpaul sayakpaul changed the title [WIP] introduce _no_split_modules to `ModelMixin` [Core] introduce _no_split_modules to `ModelMixin` 1 year ago
patrickvonplaten
patrickvonplaten commented on 2024-01-23
patrickvonplaten1 year ago

How can we use device_map="auto" for inference here? Can it be used when loading a pipeline?

Also cc @yiyixuxu

sayakpaul
sayakpaul1 year ago

How can we use device_map="auto" for inference here? Can it be used when loading a pipeline?

As stated in the PR description and the internal message, we need to be able to add support to models first. Once this is merged, we need to add support for pipelines accordingly. Passing device_map="auto" to models should work, given its _no_split_modules has been set just like transformers. Let me know if that's clear.

sayakpaul merge main and resolve conflicts.
9dff2c1d
sayakpaul fix-copies
9e5a5e94
zhangvia
zhangvia1 year ago (edited 1 year ago)

@sayakpaul

How can we use device_map="auto" for inference here? Can it be used when loading a pipeline?

As stated in the PR description and the internal message, we need to be able to add support to models first. Once this is merged, we need to add support for pipelines accordingly. Passing device_map="auto" to models should work, given its _no_split_modules has been set just like transformers. Let me know if that's clear.

hey, i noticed that you guys are working on something i interested. i'm seeking some elegant solution to execute pipeline on multiple gpus. but is this feature only can be used through device_map=‘auto’? i think the auto device_map is analyzed according to the model parameters, but what about the model input? like the sd pipeline, generate resolution also significantly impacts memory usage。so if the gpu that every model use can be set manually will be a better solution. because i can test every different gpu setting

SunMarc
SunMarc1 year ago

so if the gpu that every model use can be set manually will be a better solution. because i can test every different gpu setting

You will be able set the max memory usage for each gpu (e.g. max_memory={0: "6GiB", 1: "10GiB"}). This way you can make sure to leave enough space for the model input.

SunMarc
SunMarc approved these changes on 2024-01-24
SunMarc1 year ago

Thanks for your great work @sayakpaul !

tests/models/test_modeling_common.py
664681 " from `_deprecated_kwargs = [<deprecated_argument>]`"
665682 )
666683
SunMarc1 year ago (edited 1 year ago)

Maybe add a reference to transformers code for the tests

sayakpaul1 year ago

It was there, though in this place. Could you elaborate a bit more?

zhangvia
zhangvia1 year ago (edited 1 year ago)

so if the gpu that every model use can be set manually will be a better solution. because i can test every different gpu setting

You will be able set the max memory usage for each gpu (e.g. max_memory={0: "6GiB", 1: "10GiB"}). This way you can make sure to leave enough space for the model input.

i'm not just talking about model input but the increase of the memory it brings. for example, when i generate 512*512 images using text2img pipeline, the memory cost will be much lower than generate 1024 * 1024. and the memory cost of single model like vae, unet ,controlnet are different. i test this case: i load two controlnet, a full img2img pipeline. i put the unet,vae to gpu0, the rest models to gpu1. i use two 2080ti (11g). it will get oom on gpu0 when generate 1024 * 1024. but if i put unet and controlnets to gpu0, the rest models to gpu1, i can generate 1024 * 1024

sayakpaul
sayakpaul1 year ago

But the ControlNet model doesn't yet have _no_split_modules yet. Let's maybe revisit your usecase once we add support for device_map to the pipelines.

sayakpaul
sayakpaul1 year ago👍 1

The memory specification also varies a bit from how it's done in the language modeling world. For example, the memory specification for generating 512x512 resolution images will be different from that of generating 1024x1024 images, naturally. So, you will need to take that into consideration. @SunMarc am I thinking in the right direction?

sayakpaul Merge branch 'main' into feat/device-map-auto
c9b81f6f
zhangvia
zhangvia1 year ago

The memory specification also varies a bit from how it's done in the language modeling world. For example, the memory specification for generating 512x512 resolution images will be different from that of generating 1024x1024 images, naturally. So, you will need to take that into consideration. @SunMarc am I thinking in the right direction?

that is what i'm thinking about. in my use case, you actually can find a model placement policy to generate 1024*1024 images when using two 2080ti(12g). but the device_map=auto may get oom

sayakpaul
sayakpaul1 year ago

That could be because it's not input-aware. In those cases, handcrafting the memory map is better.

zhangvia
zhangvia1 year ago

what do u mean memory map? how can i ensure my use case won't get oom through memory map?

sayakpaul
sayakpaul1 year ago

The device map where you can specify which device should get what ratios for splitting.

There is no one single answer to the other question, as it requires analysing the memory consumption w.r.t the inputs (as with resolution scaling, it can grow more drastically than how it is with language models) and then crafting a device map that works for you.

As mentioned, the support for device maps in pipelines is not there. So, I cannot give you more concrete guidelines yet. But we will be sure consider these things and clearly document them.

zhangvia
zhangvia1 year ago

The device map where you can specify which device should get what ratios for splitting.

There is no one single answer to the other question, as it requires analysing the memory consumption w.r.t the inputs (as with resolution scaling, it can grow more drastically than how it is with language models) and then crafting a device map that works for you.

As mentioned, the support for device maps in pipelines is not there. So, I cannot give you more concrete guidelines yet. But we will be sure consider these things and clearly document them.

Thank you for your patient explanation. i will definitely try it when it's done

patrickvonplaten
patrickvonplaten1 year ago (edited 1 year ago)

I'm still not sure whether the way we support device_map here is the right way to do so. Instead of splitting the unet over multiple devices, it would be much better to move each component to one device - e.g. text_encoder is on device_0, unet in device_1, vae on device_0 again etc... IMO we first should try to do map different components to different devices before splitting one component over multiple devices.

What is the use case exactly for splitting the unet over multiple devices (and how should the text_encoder and text_encoder_2 then be split?)

sayakpaul
sayakpaul1 year ago (edited 1 year ago)

I was hoping that would be possible if we start from models itself, i.e., add support for device map the way it's supported in transformers. And then we implement th e device mapping support like the way you described.

@SunMarc WDYT?

patrickvonplaten
patrickvonplaten1 year ago

My 2 cents here:

  • A big difference between Transformes and Diffusers is that Transformer models always only have a single model file where as Diffusers are a chain of models (text_encoder -> text_encoder_2 -> unet -> vae)
  • Also Transformer models can be much much bigger. It's rare / non-existing that a single UNet model has more than 6 billion parameters
  • While the UNet model makes up for most parameters of the pipeline it's not like it's 99% of the parameters. In all T5-based pipelines the text encoder is actually bigger and in SDXL, both text encoders combined are roughly 30% of the parameters => so it makes a lot of sense IMO to first move different components to different GPUs before splitting the unet over multiple GPUs
  • Diffusers pipelines already define a graph that explains how tensors from through each component and in which order (see here) => Can't we leverage this?

Also cc @yiyixuxu here

sayakpaul
sayakpaul1 year ago

I will let @SunMarc comment further here.

sayakpaul merge main and resolve conflicts
089b6bd5
sayakpaul
sayakpaul1 year ago

Let's try to think about the API design we're aiming to achieve here.

Maybe it's better to think from the pipelines first.

For a single GPU, I think it's ALWAYS better to ask users to rely on enable_model_cpu_offload and enable_sequential_cpu_offload. This is already heavily used and quite well-tested. It works well in practice as well.

Now, we want to have better support to execute pipelines where users may want to use multiple GPUs. Note that, we do support distributed inference (see here).

However, we want to provide more flexibility to the users by providing them with an easy way to define how they want to place the nn.Modules of a pipeline on different devices.

Questions to think about:

  • How should we define the behavior of each acceptable device_map ("balanced", "sequential", "auto", and "balanced_low_0") value on the pipeline-level?
    • As mentioned by @patrickvonplaten the sequential graph defined in each pipeline (example) should influence this definition.
  • Do we want to allow the user to provide a custom device_map. E.g. - device_map = {"unet": 0, "text_encoder": 1, ...}. I think this has the potential to quickly mess up the code. So, I won't prefer this.
yiyixuxu
yiyixuxu1 year ago (edited 1 year ago)

@patrickvonplaten
what happens if not all components can fit in separate devices like you described here? How should we decide which model to split in that case?

it would be much better to move each component to one device - e.g. text_encoder is on device_0, unet in device_1, vae on device_0 again etc... IMO we first should try to do map different components to different devices before splitting one component over multiple devices.

sayakpaul
sayakpaul1 year ago (edited 1 year ago)

Then they are split using the available GPU, CPU, and disk. See some of the comments above and tests added in the PR.

yiyixuxu
yiyixuxu1 year ago (edited 1 year ago)👍 1

@sayakpaul
I think the PR is about how to split a model. But the question is how to support device_map = auto for the pipeline as the next step.

From what I understand, if we were to support device_map the same way it is supported in transformers, we would pretty much always have to split the unet. e.g., for a pipeline with sequential graphtext_encoder -> unet -> vae, currently, we would likely do

                 text_encoder ( device_0) ->  unet (device_0 and device_1) -> vae(device_1) 

Patrick suggests that we should try to see if we can move each model to the different devices first, e.g. we might be able to just do:

                text_encoder ( device_0) ->  unet (device_1) -> vae(device_0) 

This makes sense, and it seems to me it's something we can do inside diffusers. We didn't even have to support _no_split_modules for that. I'm just trying to understand what should be the expected behavior if we can't get away with that and have to split something.

sayakpaul
sayakpaul1 year ago

But the question is how to support device_map = auto for the pipeline as the next step.

You asked about the following in #6396 (comment).

How should we decide which model to split in that case?

Usually, accelerate utilities take care of those things i.e., deciding which model to split in case a user didn't provide anything. This PR already takes care of the single-model splitting behaviour, following what's done in transformers. Hopefully, this also answers:

I'm just trying to understand what should be the expected behavior if we can't get away with that and have to split something.

From what I understand, if we were to support device_map the same way it is supported in transformers, we would pretty much always have to split the unet

But that completely depends on what the user wants and the available GPU memory. As @SunMarc mentioned in #6396 (comment), it also depends on memory allocation and the largest non-splittable layer is very big compared to the whole model.

Coming to:

Patrick suggests that we should try to see if we can move each model to the different devices first, e.g. we might be able to just do:

This makes sense, and it seems to me it's something we can do inside diffusers. We didn't even have to support _no_split_modules for that.

Yeah I am with Patrick and that is why I elaborated the design my level best in #6396 (comment). I am play going this route but I would like to take @SunMarc's inputs here and how to go about this in light of offloading because this proposition concerns offloading too.

Let me know if anything is still unclear.

zhangvia
zhangvia1 year ago

i got a new use case. when support the device_map=auto for pipeline, if we split the single model, we cannot use some pipeline like onnx or tensorrt pipeline,because those pipeline don't have any torch.nn.module. but actually those pipeline can be executed on multi gpu

patrickvonplaten
patrickvonplaten1 year ago

I'd recommend the following logic for device_map="auto" for diffusers:

    1. Retrieve parameter count of all components
    1. Pick the largest component and put in on device 0, pick the second largest component and put in on device 1 , put the third largest component on device 2, ... If the last device is hit, I'd go in reverse again and start with the last device. So for 4 components and device 0 and 1, I'd do:
    • device 0: 1st compontent, 4th component
    • device 1: 2nd component, 3rd component

At the moment, there is no diffusion model really that is larger than 10GB so I don't think there is a need to split individually components over mulitple devices

sayakpaul
sayakpaul1 year ago (edited 1 year ago)

I'd recommend the following logic for device_map="auto" for diffusers:

This is for pipelines yeah?

And the execution of the modules will be lazy, right, much akin to how it's done with offloading? This is because a single module may very well fit in a single device but two modules might not.

@SunMarc a gentle ping here as relly interested to hear your thoughts too.

SunMarc
SunMarc1 year ago (edited 1 year ago)❤ 2

It makes sense to support the logic proposed by Patrick since the models are < 10GB. In transformers, the no_split_modules attribute is not defined for most small model since it can fit in most GPU. However, we can still keep this PR open since it could be useful when bigger models gets added to the library. We would only have to set device_map = "sequential" at the pipeline level to make everything work and get for example:

text_encoder ( device_0) ->  unet (device_0 and device_1) -> vae(device_1) 

Do we want to allow the user to provide a custom device_map. E.g. - device_map = {"unet": 0, "text_encoder": 1, ...}. I think this has the potential to quickly mess up the code. So, I won't prefer this.

I also think that we should not let the user provide a custom device_map since the device_map was designed at the level of the model, not at the pipeline level. It will be confusing.

As for the logic for device_map="auto" at the pipeline-level:

  1. Retrieve parameter count of all components
  1. Pick the largest component and put in on device 0, pick the second largest component and put in on device 1 , put the third largest component on device 2, ... If the last device is hit, I'd go in reverse again and start with the last device.
  • To do that, we need to load all component on meta device and calculate their size. Then, using get_max_memory to get the memory of each device, we can allocate each component. We can do a sequential allocation to limit data transfert between gpus (no need to load all models on meta). Or we can try to maximise gpu allocation with the option proposed by Patrick to or any other methods.
  • Concerning the offload, we can't offload the whole component to the gpu since there is no space(enable_model_cpu_offload case). We need to use cpu_offload instead (used in enable_sequential_cpu_offload).
  • Since each model are loaded on only one device, we don't add hooks by default. You need to set force_hook=True in load_checkpoint_and_dispatch. By doing that, we will add hooks that will move the data to the correct device when performing inference.
github-actions
github-actions1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions github-actions added stale
sayakpaul
sayakpaul1 year ago

Not stale. This PR serves as a valuable reference for models that would need splitting.

yiyixuxu yiyixuxu removed stale
github-actions
github-actions1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions github-actions added stale
sayakpaul sayakpaul removed stale
pcuenca pcuenca added wip
sayakpaul merge main and resolve conflicts.
968e857c
sayakpaul
sayakpaul1 year ago👀 1

@yiyixuxu I think this PR is ready for a review now.

To test its somewhat extremes, I did the following:

Created a ~8B Transformer variant:

import torch 
from accelerate import init_empty_weights


with init_empty_weights():
    pixart_transformer = Transformer2DModel.from_config("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer")

actual_bigger_transformer = Transformer2DModel.from_config(
    pixart_transformer.config, num_layers=72, num_attention_heads=36, cross_attention_dim=2592
)
actual_bigger_transformer.save_pretrained("/raid/.cache/actual_bigger_transformer")

Has about 7.8B parameters and takes ~29.135GB to store.

I then ran the model like so:

from diffusers import Transformer2DModel
import tempfile
import torch
import os

def get_inputs():
    sample = torch.randn(1, 4, 128, 128)
    timestep = torch.randint(0, 1000, size=(1, ))
    encoder_hidden_states = torch.randn(1, 120, 4096)

    resolution = torch.tensor([1024, 1024]).repeat(1, 1)
    aspect_ratio = torch.tensor([1.]).repeat(1, 1)
    added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
    return sample, timestep, encoder_hidden_states, added_cond_kwargs

with torch.no_grad():
    max_memory = {0: "15GB"} # reasonable estimate for a consumer-gpu.
    with tempfile.TemporaryDirectory() as tmp_dir:
        new_model = Transformer2DModel.from_pretrained(
            "/raid/.cache/actual_bigger_transformer", 
            device_map="auto",
            max_memory=max_memory, 
            offload_folder=os.path.join("/raid/.cache/huggingface", tmp_dir)
        )

        sample, timestep, encoder_hidden_states, added_cond_kwargs = get_inputs()
        out = new_model(
            hidden_states=sample,
            encoder_hidden_states=encoder_hidden_states,
            timestep=timestep, 
            added_cond_kwargs=added_cond_kwargs
        ).sample
        print(out.shape)

It successfully runs.

Happy to add it to the test suite if you want.

Would be quite nice to also support:

  • Sharding of big checkpoints
  • Loading of shared checkpoints

Since we will have bigger models, I think it makes sense to at least support these because downloading a single big checkpoint is quite messy. Both of the above can be easily done with accelerate.

Ccing @SunMarc for visibility and awareness. If you have any comments, feel free to do so.

sayakpaul Merge branch 'main' into feat/device-map-auto
2fd13938
sayakpaul Merge branch 'main' into feat/device-map-auto
99585706
sayakpaul Merge branch 'main' into feat/device-map-auto
4761b5cf
sayakpaul
sayakpaul1 year ago

@yiyixuxu a gentle ping here.

SunMarc
SunMarc approved these changes on 2024-04-29
SunMarc1 year ago (edited 1 year ago)

LGTM ! Thanks for testing this PR on a real example. As for a sharding, it makes sense to support it but maybe in a follow up PR. For the sharding, you can take inspiration from transformers save_pretrained function or save_model from accelerate where I tried to mimic save_pretrained but removing specific code to transformers. I think it will be better to add your own logic for flexibility. As for the loading, sharded checkpoint is supported in load_checkpoint_in_model.

yiyixuxu
yiyixuxu1 year ago

so we only support this on the model level, not the pipeline level, right?

sayakpaul Merge branch 'main' into feat/device-map-auto
2e79e566
sayakpaul
sayakpaul1 year ago👍 2

Yes. This is because pipeline-level device mapping strategy and model-level device mapping strategy are conceptually different from one another.

Once this feature is in, I will work on the pipeline-level device-mapping strategy to facilitate the change.

yiyixuxu
yiyixuxu approved these changes on 2024-04-30
sayakpaul Merge branch 'main' into feat/device-map-auto
b870eb73
sayakpaul
sayakpaul1 year ago

I am going to merge this without docs because in a follow-up PR, I am going to add support for serializing shared checkpoints and tests to ensure we can load them.

sayakpaul sayakpaul merged 3fd31eef into main 1 year ago
sayakpaul sayakpaul deleted the feat/device-map-auto branch 1 year ago

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone