diffusers
[Feature] Finetune text encoder in train_text_to_image_lora
#3912
Closed

[Feature] Finetune text encoder in train_text_to_image_lora #3912

okotaku wants to merge 15 commits into huggingface:main from okotaku:feature/ft_textencoder
okotaku
okotaku1 year ago (edited 1 year ago)

What does this PR do?

Fixes #3418

Add train_text_encoder args in train_text_to_image_lora.py. We can finetune text encoder.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

okotaku ft text encoder
8e6553e5
HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev1 year ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

patrickvonplaten
patrickvonplaten1 year ago

cc @sayakpaul here

sayakpaul
sayakpaul1 year ago

Thanks for adding this pipeline. Could you also share some comparative results with and without text encoder LoRA fine-tuning? Just trying to gauge the effectiveness.

okotaku
okotaku1 year ago

@sayakpaul
I used this dataset for finetune.

only unet

test prompt = '1girl, X X'
tmp9

test prompt = '1girl, >_<'
tmp8

test prompt = '1girl, @_@'
tmp11

test prompt = '1girl, =_='
tmp12

with text encoder

test prompt = '1girl, X X'
tmp

test prompt = '1girl, >_<'
tmp2

test prompt = '1girl, @_@'
tmp3

test prompt = '1girl, =_='
tmp4

test prompt = '1girl, Jitome'
tmp5

test prompt = '1girl, :I'
tmp6

test prompt = '1girl, ._.'
tmp7

sayakpaul
sayakpaul1 year ago

Thanks for your hard work!

Maybe let's wait for #3778 to get merged, as there are some refactoring-related changes that will make this simpler.

Okay for you?

okotaku
okotaku1 year ago

@sayakpaul No problems.

patrickvonplaten
patrickvonplaten1 year ago

Gentle ping @sayakpaul - do you think we could merge this?

okotaku
okotaku1 year ago

@patrickvonplaten Should we fix some codes based on this PR?

sayakpaul
sayakpaul1 year ago👍 1

@okotaku my apologies for forgetting to ping you here.

Yes, let's refactor this PR based on the changes introduced in #3778. Happy to help you out in any way :)

okotaku Merge remote-tracking branch 'origin' into feature/ft_textencoder
3224665c
okotaku merge refactor
3b47815e
okotaku merge refactor
1f690805
okotaku
okotaku1 year ago

@sayakpaul I updated codes. But it caused following errors.
It looks like this issues.
Do you have any ideas to solve it?

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮█████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  4.69it/s]
│ /home/ubuntu/Desktop/cartoon-name-generation/tools/train_text_to_image_lora_new.py:1054 in       │
│ <module>                                                                                         │
│                                                                                                  │
│   1051                                                                                           │
│   1052                                                                                           │
│   1053 if __name__ == "__main__":                                                                │
│ ❱ 1054 │   main()                                                                                │
│   1055                                                                                           │
│                                                                                                  │
│ /home/ubuntu/Desktop/cartoon-name-generation/tools/train_text_to_image_lora_new.py:961 in main   │
│                                                                                                  │
│    958 │   │   │   │   images = []                                                               │
│    959 │   │   │   │   for _ in range(args.num_validation_images):                               │
│    960 │   │   │   │   │   #with torch.cuda.amp.autocast(enabled=False): #dtype=torch.float32):  │
│ ❱  961 │   │   │   │   │   image = pipeline(args.validation_prompt, num_inference_steps=30, gen  │
│    962 │   │   │   │   │   images.append(image)                                                  │
│    963 │   │   │   │                                                                             │
│    964 │   │   │   │   for tracker in accelerator.trackers:                                      │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27 in                │
│ decorate_context                                                                                 │
│                                                                                                  │
│    24 │   │   @functools.wraps(func)                                                             │
│    25 │   │   def decorate_context(*args, **kwargs):                                             │
│    26 │   │   │   with self.clone():                                                             │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                                               │
│    28 │   │   return cast(F, decorate_context)                                                   │
│    29 │                                                                                          │
│    30 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_st │
│ able_diffusion.py:728 in __call__                                                                │
│                                                                                                  │
│   725 │   │   │   │   latent_model_input = self.scheduler.scale_model_input(latent_model_input   │
│   726 │   │   │   │                                                                              │
│   727 │   │   │   │   # predict the noise residual                                               │
│ ❱ 728 │   │   │   │   noise_pred = self.unet(                                                    │
│   729 │   │   │   │   │   latent_model_input,                                                    │
│   730 │   │   │   │   │   t,                                                                     │
│   731 │   │   │   │   │   encoder_hidden_states=prompt_embeds,                                   │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl    │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/accelerate/utils/operations.py:521 in forward    │
│                                                                                                  │
│   518 │   model_forward = ConvertOutputsToFp32(model_forward)                                    │
│   519 │                                                                                          │
│   520 │   def forward(*args, **kwargs):                                                          │
│ ❱ 521 │   │   return model_forward(*args, **kwargs)                                              │
│   522 │                                                                                          │
│   523 │   # To act like a decorator so that it can be popped when doing `extract_model_from_pa   │
│   524 │   forward.__wrapped__ = model_forward                                                    │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/accelerate/utils/operations.py:509 in __call__   │
│                                                                                                  │
│   506 │   │   update_wrapper(self, model_forward)                                                │
│   507 │                                                                                          │
│   508 │   def __call__(self, *args, **kwargs):                                                   │
│ ❱ 509 │   │   return convert_to_fp32(self.model_forward(*args, **kwargs))                        │
│   510 │                                                                                          │
│   511 │   def __getstate__(self):                                                                │
│   512 │   │   raise pickle.PicklingError(                                                        │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/amp/autocast_mode.py:14 in                 │
│ decorate_autocast                                                                                │
│                                                                                                  │
│    11 │   @functools.wraps(func)                                                                 │
│    12 │   def decorate_autocast(*args, **kwargs):                                                │
│    13 │   │   with autocast_instance:                                                            │
│ ❱  14 │   │   │   return func(*args, **kwargs)                                                   │
│    15 │   decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in    │
│    16 │   return decorate_autocast                                                               │
│    17                                                                                            │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py:905 in     │
│ forward                                                                                          │
│                                                                                                  │
│   902 │   │   down_block_res_samples = (sample,)                                                 │
│   903 │   │   for downsample_block in self.down_blocks:                                          │
│   904 │   │   │   if hasattr(downsample_block, "has_cross_attention") and downsample_block.has   │
│ ❱ 905 │   │   │   │   sample, res_samples = downsample_block(                                    │
│   906 │   │   │   │   │   hidden_states=sample,                                                  │
│   907 │   │   │   │   │   temb=emb,                                                              │
│   908 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,                           │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl    │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py:993 in        │
│ forward                                                                                          │
│                                                                                                  │
│    990 │   │   │   │   )[0]                                                                      │
│    991 │   │   │   else:                                                                         │
│    992 │   │   │   │   hidden_states = resnet(hidden_states, temb)                               │
│ ❱  993 │   │   │   │   hidden_states = attn(                                                     │
│    994 │   │   │   │   │   hidden_states,                                                        │
│    995 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,                          │
│    996 │   │   │   │   │   cross_attention_kwargs=cross_attention_kwargs,                        │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl    │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/transformer_2d.py:291 in        │
│ forward                                                                                          │
│                                                                                                  │
│   288 │   │                                                                                      │
│   289 │   │   # 2. Blocks                                                                        │
│   290 │   │   for block in self.transformer_blocks:                                              │
│ ❱ 291 │   │   │   hidden_states = block(                                                         │
│   292 │   │   │   │   hidden_states,                                                             │
│   293 │   │   │   │   attention_mask=attention_mask,                                             │
│   294 │   │   │   │   encoder_hidden_states=encoder_hidden_states,                               │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl    │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/attention.py:170 in forward     │
│                                                                                                  │
│   167 │   │   │   │   self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self   │
│   168 │   │   │   )                                                                              │
│   169 │   │   │                                                                                  │
│ ❱ 170 │   │   │   attn_output = self.attn2(                                                      │
│   171 │   │   │   │   norm_hidden_states,                                                        │
│   172 │   │   │   │   encoder_hidden_states=encoder_hidden_states,                               │
│   173 │   │   │   │   attention_mask=encoder_attention_mask,                                     │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl    │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/attention_processor.py:321 in   │
│ forward                                                                                          │
│                                                                                                  │
│    318 │   │   # The `Attention` class can call different attention processors / attention func  │
│    319 │   │   # here we simply pass along all tensors to the selected processor class           │
│    320 │   │   # For standard processors that are defined here, `**cross_attention_kwargs` is e  │
│ ❱  321 │   │   return self.processor(                                                            │
│    322 │   │   │   self,                                                                         │
│    323 │   │   │   hidden_states,                                                                │
│    324 │   │   │   encoder_hidden_states=encoder_hidden_states,                                  │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/attention_processor.py:1224 in  │
│ __call__                                                                                         │
│                                                                                                  │
│   1221 │   │   key = attn.head_to_batch_dim(key).contiguous()                                    │
│   1222 │   │   value = attn.head_to_batch_dim(value).contiguous()                                │
│   1223 │   │                                                                                     │
│ ❱ 1224 │   │   hidden_states = xformers.ops.memory_efficient_attention(                          │
│   1225 │   │   │   query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=att  │
│   1226 │   │   )                                                                                 │
│   1227 │   │   hidden_states = attn.batch_to_head_dim(hidden_states)                             │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py:192 in             │
│ memory_efficient_attention                                                                       │
│                                                                                                  │
│   189 │   │   and options.                                                                       │
│   190 │   :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``                     │
│   191 │   """                                                                                    │
│ ❱ 192 │   return _memory_efficient_attention(                                                    │
│   193 │   │   Inputs(                                                                            │
│   194 │   │   │   query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale       │
│   195 │   │   ),                                                                                 │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py:290 in             │
│ _memory_efficient_attention                                                                      │
│                                                                                                  │
│   287 ) -> torch.Tensor:                                                                         │
│   288 │   # fast-path that doesn't require computing the logsumexp for backward computation      │
│   289 │   if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):             │
│ ❱ 290 │   │   return _memory_efficient_attention_forward(                                        │
│   291 │   │   │   inp, op=op[0] if op is not None else None                                      │
│   292 │   │   )                                                                                  │
│   293                                                                                            │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py:303 in             │
│ _memory_efficient_attention_forward                                                              │
│                                                                                                  │
│   300 def _memory_efficient_attention_forward(                                                   │
│   301 │   inp: Inputs, op: Optional[Type[AttentionFwOpBase]]                                     │
│   302 ) -> torch.Tensor:                                                                         │
│ ❱ 303 │   inp.validate_inputs()                                                                  │
│   304 │   output_shape = inp.normalize_bmhk()                                                    │
│   305 │   if op is None:                                                                         │
│   306 │   │   op = _dispatch_fw(inp)                                                             │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/xformers/ops/fmha/common.py:73 in                │
│ validate_inputs                                                                                  │
│                                                                                                  │
│    70 │   │   if any(x.device != self.query.device for x in qkv):                                │
│    71 │   │   │   raise ValueError("Query/Key/Value should all be on the same device")           │
│    72 │   │   if any(x.dtype != self.query.dtype for x in qkv):                                  │
│ ❱  73 │   │   │   raise ValueError(                                                              │
│    74 │   │   │   │   "Query/Key/Value should all have the same dtype\n"                         │
│    75 │   │   │   │   f"  query.dtype: {self.query.dtype}\n"                                     │
│    76 │   │   │   │   f"  key.dtype  : {self.key.dtype}\n"                                       │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Query/Key/Value should all have the same dtype
  query.dtype: torch.float32
  key.dtype  : torch.float16
  value.dtype: torch.float16
sayakpaul
sayakpaul1 year ago

Hi,

Could you also try out the solutions provided in that thread to see if the errors persist?

Also, what happens if we do this in Torch 2.0 taking advantage of SDPA disabling xformers?

okotaku
okotaku1 year ago

When disabling xformers,

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮███████████████████████████████████████████▎                 | 6/7 [00:01<00:00,  5.61it/s]
│ /home/ubuntu/Desktop/cartoon-name-generation/tools/train_text_to_image_lora_new.py:1054 in       │
│ <module>                                                                                         │
│                                                                                                  │
│   1051                                                                                           │
│   1052                                                                                           │
│   1053 if __name__ == "__main__":                                                                │
│ ❱ 1054 │   main()                                                                                │
│   1055                                                                                           │
│                                                                                                  │
│ /home/ubuntu/Desktop/cartoon-name-generation/tools/train_text_to_image_lora_new.py:961 in main   │
│                                                                                                  │
│    958 │   │   │   │   images = []                                                               │
│    959 │   │   │   │   for _ in range(args.num_validation_images):                               │
│    960 │   │   │   │   │   #with torch.cuda.amp.autocast(enabled=False): #dtype=torch.float32):  │
│ ❱  961 │   │   │   │   │   image = pipeline(args.validation_prompt, num_inference_steps=30, gen  │
│    962 │   │   │   │   │   images.append(image)                                                  │
│    963 │   │   │   │                                                                             │
│    964 │   │   │   │   for tracker in accelerator.trackers:                                      │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27 in                │
│ decorate_context                                                                                 │
│                                                                                                  │
│    24 │   │   @functools.wraps(func)                                                             │
│    25 │   │   def decorate_context(*args, **kwargs):                                             │
│    26 │   │   │   with self.clone():                                                             │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                                               │
│    28 │   │   return cast(F, decorate_context)                                                   │
│    29 │                                                                                          │
│    30 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_st │
│ able_diffusion.py:755 in __call__                                                                │
│                                                                                                  │
│   752 │   │   │   │   │   │   callback(i, t, latents)                                            │
│   753 │   │                                                                                      │
│   754 │   │   if not output_type == "latent":                                                    │
│ ❱ 755 │   │   │   image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dic   │
│   756 │   │   │   image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embe   │
│   757 │   │   else:                                                                              │
│   758 │   │   │   image = latents                                                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/utils/accelerate_utils.py:46 in        │
│ wrapper                                                                                          │
│                                                                                                  │
│   43 │   def wrapper(self, *args, **kwargs):                                                     │
│   44 │   │   if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"):             │
│   45 │   │   │   self._hf_hook.pre_forward(self)                                                 │
│ ❱ 46 │   │   return method(self, *args, **kwargs)                                                │
│   47 │                                                                                           │
│   48 │   return wrapper                                                                          │
│   49                                                                                             │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py:264 in decode │
│                                                                                                  │
│   261 │   │   │   decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]      │
│   262 │   │   │   decoded = torch.cat(decoded_slices)                                            │
│   263 │   │   else:                                                                              │
│ ❱ 264 │   │   │   decoded = self._decode(z).sample                                               │
│   265 │   │                                                                                      │
│   266 │   │   if not return_dict:                                                                │
│   267 │   │   │   return (decoded,)                                                              │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py:250 in        │
│ _decode                                                                                          │
│                                                                                                  │
│   247 │   │   if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] >   │
│   248 │   │   │   return self.tiled_decode(z, return_dict=return_dict)                           │
│   249 │   │                                                                                      │
│ ❱ 250 │   │   z = self.post_quant_conv(z)                                                        │
│   251 │   │   dec = self.decoder(z)                                                              │
│   252 │   │                                                                                      │
│   253 │   │   if not return_dict:                                                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl    │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:463 in forward          │
│                                                                                                  │
│    460 │   │   │   │   │   │   self.padding, self.dilation, self.groups)                         │
│    461 │                                                                                         │
│    462 │   def forward(self, input: Tensor) -> Tensor:                                           │
│ ❱  463 │   │   return self._conv_forward(input, self.weight, self.bias)                          │
│    464                                                                                           │
│    465 class Conv3d(_ConvNd):                                                                    │
│    466 │   __doc__ = r"""Applies a 3D convolution over an input signal composed of several inpu  │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:459 in _conv_forward    │
│                                                                                                  │
│    456 │   │   │   return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=sel  │
│    457 │   │   │   │   │   │   │   weight, bias, self.stride,                                    │
│    458 │   │   │   │   │   │   │   _pair(0), self.dilation, self.groups)                         │
│ ❱  459 │   │   return F.conv2d(input, weight, bias, self.stride,                                 │
│    460 │   │   │   │   │   │   self.padding, self.dilation, self.groups)                         │
│    461 │                                                                                         │
│    462 │   def forward(self, input: Tensor) -> Tensor:                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Input type (float) and bias type (c10::Half) should be the same
okotaku
okotaku1 year ago

I tried some pattens, but all failed.

images = []
for _ in range(args.num_validation_images):
     with torch.cuda.amp.autocast():
          image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
          images.append(image)
images = []
for _ in range(args.num_validation_images):
     with torch.autocast('cuda'):
          image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
          images.append(image)
images = []
for _ in range(args.num_validation_images):
     with torch.cuda.amp.autocast(dtype=weight_dtype):
          image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
          images.append(image)
images = []
for _ in range(args.num_validation_images):
      image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
      images.append(image)
okotaku merge refactor
fc2b7ec7
okotaku merge refactor
80e54c83
okotaku
okotaku1 year ago

@sayakpaul pytorch2.0 and LoRAAttnProcessor2_0 works well.
This PR is ready for review.

sayakpaul
sayakpaul1 year ago

Thanks for letting me know and for your efforts.

Could we maybe try to dive a bit deeper to see what is the issue? For example does the issue persist when not training the text encoder?

okotaku
okotaku1 year ago

For example does the issue persist when not training the text encoder?

It also persist when training only unet.

sayakpaul
sayakpaul1 year ago

Okay. Can we verify if the version in the main branch works with the following settings?

  • PyTorch 1.13.1 and xformers
  • PyTorch 2.0 and SDPA

Just trying to double down the issue. If you have other ideas to try out please let me know.

okotaku
okotaku1 year ago

PyTorch 2.0 and SDPA works well.
PyTorch 1.13.1 and xformers fails.

sayakpaul
sayakpaul1 year ago

Thanks for reporting.

This is indeed weird as the main version should support both. Did you try it out on the main version of the script or on your version?

okotaku
okotaku1 year ago

The error raised when I used my branch of the script. The script of main branch works well.

sayakpaul
sayakpaul1 year ago

Oh okay. Let's try to investigate the differences then :-)

okotaku fix xformers place
8905a492
okotaku fix text encoder rule
1350bd8c
okotaku
okotaku1 year ago

The error solved when changing the place of applying xformers.
I don't know the reason :)

okotaku
okotaku1 year ago👀 1

@sayakpaul This PR is ready.

sayakpaul
sayakpaul commented on 2023-07-17
examples/text_to_image/train_text_to_image_lora.py
471 if is_xformers_available():
472 import xformers
473
474
xformers_version = version.parse(xformers.__version__)
475
if xformers_version == version.parse("0.0.16"):
476
logger.warn(
477
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
478
)
479
unet.enable_xformers_memory_efficient_attention()
sayakpaul1 year ago

❤️

sayakpaul
sayakpaul commented on 2023-07-17
sayakpaul1 year ago

Excellent work here. Thanks so much for all your experiments and iterations!

Let's also update the README of the example so that readers are aware that text encoder training is supported? We also need to add a test case for for the text encoder training to test_examples.py. Then this PR is good to be shipped 🚀

Also, could we do a full run with the changes to ensure the results you're getting are as expected?

okotaku update docs
760ad207
okotaku
okotaku1 year ago (edited 1 year ago)

@sayakpaul

Also, could we do a full run with the changes to ensure the results you're getting are as expected?

https://civitai.com/models/25613/classic-anime-expressions

You can download images from this page'sTraining Images tab.
CSV file is here.

metadata.csv

Params are followings,

accelerate launch tools/train_text_to_image_lora.py \
  --pretrained_model_name_or_path=models/anythingv5  \
  --train_data_dir=data/ExpressionTraining \
  --image_column=image \
  --center_crop --random_flip \
  --output_dir=work_dirs/ExpressionTraining \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --num_train_epochs=100 \
  --learning_rate=1e-4 \
  --lr_scheduler="cosine_with_restarts" --lr_warmup_steps=150 \
  --validation_prompt="1girl, >_<" \
  --checkpointing_steps=500 \
  --rank=128 \
  --snr_gamma=5 \
  --train_text_encoder \
  --seed="0"
okotaku update test
c274dc7d
sayakpaul
sayakpaul commented on 2023-07-18
Conversation is marked as resolved
Show resolved
docs/source/en/training/lora.mdx
8989 --seed=1337
9090```
9191
92
## Finetuning the text encoder and UNet
sayakpaul1 year ago
Suggested change
## Finetuning the text encoder and UNet
### Finetuning the text encoder and UNet
sayakpaul
sayakpaul commented on 2023-07-18
docs/source/en/training/lora.mdx
182weights. For example:
183
184```python
185
from huggingface_hub.repocard import RepoCard
186
from diffusers import StableDiffusionPipeline
187
import torch
188
189
lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
190
card = RepoCard.load(lora_model_id)
191
base_model_id = card.data.to_dict()["base_model"]
192
193
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
194
pipe = pipe.to("cuda")
195
pipe.load_lora_weights(lora_model_id)
196
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
sayakpaul1 year ago

Let's use a LoRA you trained with this script? We'd also need to update the example prompt :)

sayakpaul
sayakpaul commented on 2023-07-18
Conversation is marked as resolved
Show resolved
examples/test_examples.py
853 pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
854
855 with tempfile.TemporaryDirectory() as tmpdir:
856
# Run training script with checkpointing
857
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
858
# Should create checkpoints at steps 2, 4, 6
859
# with checkpoint at step 2 deleted
sayakpaul1 year ago

I think we can safely ignore this no? We're not checking for checkpointing here.

sayakpaul1 year ago
Suggested change
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
sayakpaul
sayakpaul approved these changes on 2023-07-18
sayakpaul1 year ago

Looks amazing. Thanks for iterating.

Additionally, I'd also include a note about --train_text_encoder flag in the README here.

patrickvonplaten patrickvonplaten requested a review from williamberman williamberman 1 year ago
patrickvonplaten
patrickvonplaten1 year ago

@williamberman could you also take a look here?

okotaku Update docs/source/en/training/lora.mdx
00cff152
okotaku Update examples/test_examples.py
c0e80b0c
okotaku format
3e9d8696
okotaku update docs
caf921b5
williamberman
williamberman commented on 2023-07-21
examples/text_to_image/train_text_to_image_lora.py
444467 vae.to(accelerator.device, dtype=weight_dtype)
445468 text_encoder.to(accelerator.device, dtype=weight_dtype)
446469
470
if args.enable_xformers_memory_efficient_attention:
471
if is_xformers_available():
472
import xformers
473
474
xformers_version = version.parse(xformers.__version__)
475
if xformers_version == version.parse("0.0.16"):
476
logger.warn(
477
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
478
)
479
unet.enable_xformers_memory_efficient_attention()
480
else:
481
raise ValueError("xformers is not available. Make sure it is installed correctly")
482
williamberman1 year ago

unless I'm missing something, In the future it'd be helpful to not move a code block unrelated to the PR as it makes the diff harder to read :)

okotaku1 year ago

#3912 (comment)
#3912 (comment)

By moving the code block, I avoid this error.

williamberman
williamberman commented on 2023-07-21
examples/text_to_image/train_text_to_image_lora.py
866960 generator = generator.manual_seed(args.seed)
867961 images = []
868962 for _ in range(args.num_validation_images):
869 images.append(
870 pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
871 )
963
with torch.cuda.amp.autocast():
williamberman1 year ago

did something change that required autocast to be added?

okotaku1 year ago

#3912 (comment)

It also related to this error.
You can check this thread.

williamberman
williamberman commented on 2023-07-21
Conversation is marked as resolved
Show resolved
examples/text_to_image/train_text_to_image_lora.py
9171025
9181026 # load attention processors
919 pipeline.unet.load_attn_procs(args.output_dir)
1027
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin")
williamberman1 year ago

Do we have to specify weight_name?

okotaku1 year ago

For my environment, if weight_name is not specified, the error raised.

│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/loaders.py:828 in load_lora_weights    │
│                                                                                                  │
│    825 │   │   │   kwargs:                                                                       │
│    826 │   │   │   │   See [`~loaders.LoraLoaderMixin.lora_state_dict`].                         │
│    827 │   │   """                                                                               │
│ ❱  828 │   │   state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_o  │
│    829 │   │   self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet  │
│    830 │   │   self.load_lora_into_text_encoder(                                                 │
│    831 │   │   │   state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lor  │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/loaders.py:940 in lora_state_dict      │
│                                                                                                  │
│    937 │   │   │   │   │   │   user_agent=user_agent,                                            │
│    938 │   │   │   │   │   )                                                                     │
│    939 │   │   │   │   │   state_dict = safetensors.torch.load_file(model_file, device="cpu")    │
│ ❱  940 │   │   │   │   except (IOError, safetensors.SafetensorError) as e:                       │
│    941 │   │   │   │   │   if not allow_pickle:                                                  │
│    942 │   │   │   │   │   │   raise e                                                           │
│    943 │   │   │   │   │   # try loading non-safetensors weights                                 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: module 'safetensors' has no attribute 'SafetensorError'
sayakpaul1 year ago

We don't need to specify this attribute actually. Could you upgrade your local installation of safetensors?

okotaku1 year ago

It works. Thank you! I fixed it.

williamberman
williamberman1 year ago

Looks basically good, a few small questions :)

okotaku del weight_name
74b07495
sayakpaul sayakpaul requested a review from williamberman williamberman 1 year ago
github-actions
github-actions1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions github-actions added stale
github-actions github-actions closed this 1 year ago

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone