The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
cc @sayakpaul here
Thanks for adding this pipeline. Could you also share some comparative results with and without text encoder LoRA fine-tuning? Just trying to gauge the effectiveness.
Thanks for your hard work!
Maybe let's wait for #3778 to get merged, as there are some refactoring-related changes that will make this simpler.
Okay for you?
@sayakpaul No problems.
Gentle ping @sayakpaul - do you think we could merge this?
@patrickvonplaten Should we fix some codes based on this PR?
@sayakpaul I updated codes. But it caused following errors.
It looks like this issues.
Do you have any ideas to solve it?
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮█████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 4.69it/s]
│ /home/ubuntu/Desktop/cartoon-name-generation/tools/train_text_to_image_lora_new.py:1054 in │
│ <module> │
│ │
│ 1051 │
│ 1052 │
│ 1053 if __name__ == "__main__": │
│ ❱ 1054 │ main() │
│ 1055 │
│ │
│ /home/ubuntu/Desktop/cartoon-name-generation/tools/train_text_to_image_lora_new.py:961 in main │
│ │
│ 958 │ │ │ │ images = [] │
│ 959 │ │ │ │ for _ in range(args.num_validation_images): │
│ 960 │ │ │ │ │ #with torch.cuda.amp.autocast(enabled=False): #dtype=torch.float32): │
│ ❱ 961 │ │ │ │ │ image = pipeline(args.validation_prompt, num_inference_steps=30, gen │
│ 962 │ │ │ │ │ images.append(image) │
│ 963 │ │ │ │ │
│ 964 │ │ │ │ for tracker in accelerator.trackers: │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27 in │
│ decorate_context │
│ │
│ 24 │ │ @functools.wraps(func) │
│ 25 │ │ def decorate_context(*args, **kwargs): │
│ 26 │ │ │ with self.clone(): │
│ ❱ 27 │ │ │ │ return func(*args, **kwargs) │
│ 28 │ │ return cast(F, decorate_context) │
│ 29 │ │
│ 30 │ def _wrap_generator(self, func): │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_st │
│ able_diffusion.py:728 in __call__ │
│ │
│ 725 │ │ │ │ latent_model_input = self.scheduler.scale_model_input(latent_model_input │
│ 726 │ │ │ │ │
│ 727 │ │ │ │ # predict the noise residual │
│ ❱ 728 │ │ │ │ noise_pred = self.unet( │
│ 729 │ │ │ │ │ latent_model_input, │
│ 730 │ │ │ │ │ t, │
│ 731 │ │ │ │ │ encoder_hidden_states=prompt_embeds, │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl │
│ │
│ 1191 │ │ # this function, and just call forward. │
│ 1192 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │
│ 1193 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1194 │ │ │ return forward_call(*input, **kwargs) │
│ 1195 │ │ # Do not call functions when jit is used │
│ 1196 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1197 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/accelerate/utils/operations.py:521 in forward │
│ │
│ 518 │ model_forward = ConvertOutputsToFp32(model_forward) │
│ 519 │ │
│ 520 │ def forward(*args, **kwargs): │
│ ❱ 521 │ │ return model_forward(*args, **kwargs) │
│ 522 │ │
│ 523 │ # To act like a decorator so that it can be popped when doing `extract_model_from_pa │
│ 524 │ forward.__wrapped__ = model_forward │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/accelerate/utils/operations.py:509 in __call__ │
│ │
│ 506 │ │ update_wrapper(self, model_forward) │
│ 507 │ │
│ 508 │ def __call__(self, *args, **kwargs): │
│ ❱ 509 │ │ return convert_to_fp32(self.model_forward(*args, **kwargs)) │
│ 510 │ │
│ 511 │ def __getstate__(self): │
│ 512 │ │ raise pickle.PicklingError( │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/amp/autocast_mode.py:14 in │
│ decorate_autocast │
│ │
│ 11 │ @functools.wraps(func) │
│ 12 │ def decorate_autocast(*args, **kwargs): │
│ 13 │ │ with autocast_instance: │
│ ❱ 14 │ │ │ return func(*args, **kwargs) │
│ 15 │ decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in │
│ 16 │ return decorate_autocast │
│ 17 │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py:905 in │
│ forward │
│ │
│ 902 │ │ down_block_res_samples = (sample,) │
│ 903 │ │ for downsample_block in self.down_blocks: │
│ 904 │ │ │ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has │
│ ❱ 905 │ │ │ │ sample, res_samples = downsample_block( │
│ 906 │ │ │ │ │ hidden_states=sample, │
│ 907 │ │ │ │ │ temb=emb, │
│ 908 │ │ │ │ │ encoder_hidden_states=encoder_hidden_states, │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl │
│ │
│ 1191 │ │ # this function, and just call forward. │
│ 1192 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │
│ 1193 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1194 │ │ │ return forward_call(*input, **kwargs) │
│ 1195 │ │ # Do not call functions when jit is used │
│ 1196 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1197 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py:993 in │
│ forward │
│ │
│ 990 │ │ │ │ )[0] │
│ 991 │ │ │ else: │
│ 992 │ │ │ │ hidden_states = resnet(hidden_states, temb) │
│ ❱ 993 │ │ │ │ hidden_states = attn( │
│ 994 │ │ │ │ │ hidden_states, │
│ 995 │ │ │ │ │ encoder_hidden_states=encoder_hidden_states, │
│ 996 │ │ │ │ │ cross_attention_kwargs=cross_attention_kwargs, │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl │
│ │
│ 1191 │ │ # this function, and just call forward. │
│ 1192 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │
│ 1193 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1194 │ │ │ return forward_call(*input, **kwargs) │
│ 1195 │ │ # Do not call functions when jit is used │
│ 1196 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1197 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/transformer_2d.py:291 in │
│ forward │
│ │
│ 288 │ │ │
│ 289 │ │ # 2. Blocks │
│ 290 │ │ for block in self.transformer_blocks: │
│ ❱ 291 │ │ │ hidden_states = block( │
│ 292 │ │ │ │ hidden_states, │
│ 293 │ │ │ │ attention_mask=attention_mask, │
│ 294 │ │ │ │ encoder_hidden_states=encoder_hidden_states, │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl │
│ │
│ 1191 │ │ # this function, and just call forward. │
│ 1192 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │
│ 1193 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1194 │ │ │ return forward_call(*input, **kwargs) │
│ 1195 │ │ # Do not call functions when jit is used │
│ 1196 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1197 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/attention.py:170 in forward │
│ │
│ 167 │ │ │ │ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self │
│ 168 │ │ │ ) │
│ 169 │ │ │ │
│ ❱ 170 │ │ │ attn_output = self.attn2( │
│ 171 │ │ │ │ norm_hidden_states, │
│ 172 │ │ │ │ encoder_hidden_states=encoder_hidden_states, │
│ 173 │ │ │ │ attention_mask=encoder_attention_mask, │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl │
│ │
│ 1191 │ │ # this function, and just call forward. │
│ 1192 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │
│ 1193 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1194 │ │ │ return forward_call(*input, **kwargs) │
│ 1195 │ │ # Do not call functions when jit is used │
│ 1196 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1197 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/attention_processor.py:321 in │
│ forward │
│ │
│ 318 │ │ # The `Attention` class can call different attention processors / attention func │
│ 319 │ │ # here we simply pass along all tensors to the selected processor class │
│ 320 │ │ # For standard processors that are defined here, `**cross_attention_kwargs` is e │
│ ❱ 321 │ │ return self.processor( │
│ 322 │ │ │ self, │
│ 323 │ │ │ hidden_states, │
│ 324 │ │ │ encoder_hidden_states=encoder_hidden_states, │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/attention_processor.py:1224 in │
│ __call__ │
│ │
│ 1221 │ │ key = attn.head_to_batch_dim(key).contiguous() │
│ 1222 │ │ value = attn.head_to_batch_dim(value).contiguous() │
│ 1223 │ │ │
│ ❱ 1224 │ │ hidden_states = xformers.ops.memory_efficient_attention( │
│ 1225 │ │ │ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=att │
│ 1226 │ │ ) │
│ 1227 │ │ hidden_states = attn.batch_to_head_dim(hidden_states) │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py:192 in │
│ memory_efficient_attention │
│ │
│ 189 │ │ and options. │
│ 190 │ :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]`` │
│ 191 │ """ │
│ ❱ 192 │ return _memory_efficient_attention( │
│ 193 │ │ Inputs( │
│ 194 │ │ │ query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale │
│ 195 │ │ ), │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py:290 in │
│ _memory_efficient_attention │
│ │
│ 287 ) -> torch.Tensor: │
│ 288 │ # fast-path that doesn't require computing the logsumexp for backward computation │
│ 289 │ if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]): │
│ ❱ 290 │ │ return _memory_efficient_attention_forward( │
│ 291 │ │ │ inp, op=op[0] if op is not None else None │
│ 292 │ │ ) │
│ 293 │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py:303 in │
│ _memory_efficient_attention_forward │
│ │
│ 300 def _memory_efficient_attention_forward( │
│ 301 │ inp: Inputs, op: Optional[Type[AttentionFwOpBase]] │
│ 302 ) -> torch.Tensor: │
│ ❱ 303 │ inp.validate_inputs() │
│ 304 │ output_shape = inp.normalize_bmhk() │
│ 305 │ if op is None: │
│ 306 │ │ op = _dispatch_fw(inp) │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/xformers/ops/fmha/common.py:73 in │
│ validate_inputs │
│ │
│ 70 │ │ if any(x.device != self.query.device for x in qkv): │
│ 71 │ │ │ raise ValueError("Query/Key/Value should all be on the same device") │
│ 72 │ │ if any(x.dtype != self.query.dtype for x in qkv): │
│ ❱ 73 │ │ │ raise ValueError( │
│ 74 │ │ │ │ "Query/Key/Value should all have the same dtype\n" │
│ 75 │ │ │ │ f" query.dtype: {self.query.dtype}\n" │
│ 76 │ │ │ │ f" key.dtype : {self.key.dtype}\n" │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Query/Key/Value should all have the same dtype
query.dtype: torch.float32
key.dtype : torch.float16
value.dtype: torch.float16
Hi,
Could you also try out the solutions provided in that thread to see if the errors persist?
Also, what happens if we do this in Torch 2.0 taking advantage of SDPA disabling xformers?
When disabling xformers,
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮███████████████████████████████████████████▎ | 6/7 [00:01<00:00, 5.61it/s]
│ /home/ubuntu/Desktop/cartoon-name-generation/tools/train_text_to_image_lora_new.py:1054 in │
│ <module> │
│ │
│ 1051 │
│ 1052 │
│ 1053 if __name__ == "__main__": │
│ ❱ 1054 │ main() │
│ 1055 │
│ │
│ /home/ubuntu/Desktop/cartoon-name-generation/tools/train_text_to_image_lora_new.py:961 in main │
│ │
│ 958 │ │ │ │ images = [] │
│ 959 │ │ │ │ for _ in range(args.num_validation_images): │
│ 960 │ │ │ │ │ #with torch.cuda.amp.autocast(enabled=False): #dtype=torch.float32): │
│ ❱ 961 │ │ │ │ │ image = pipeline(args.validation_prompt, num_inference_steps=30, gen │
│ 962 │ │ │ │ │ images.append(image) │
│ 963 │ │ │ │ │
│ 964 │ │ │ │ for tracker in accelerator.trackers: │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27 in │
│ decorate_context │
│ │
│ 24 │ │ @functools.wraps(func) │
│ 25 │ │ def decorate_context(*args, **kwargs): │
│ 26 │ │ │ with self.clone(): │
│ ❱ 27 │ │ │ │ return func(*args, **kwargs) │
│ 28 │ │ return cast(F, decorate_context) │
│ 29 │ │
│ 30 │ def _wrap_generator(self, func): │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_st │
│ able_diffusion.py:755 in __call__ │
│ │
│ 752 │ │ │ │ │ │ callback(i, t, latents) │
│ 753 │ │ │
│ 754 │ │ if not output_type == "latent": │
│ ❱ 755 │ │ │ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dic │
│ 756 │ │ │ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embe │
│ 757 │ │ else: │
│ 758 │ │ │ image = latents │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/utils/accelerate_utils.py:46 in │
│ wrapper │
│ │
│ 43 │ def wrapper(self, *args, **kwargs): │
│ 44 │ │ if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): │
│ 45 │ │ │ self._hf_hook.pre_forward(self) │
│ ❱ 46 │ │ return method(self, *args, **kwargs) │
│ 47 │ │
│ 48 │ return wrapper │
│ 49 │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py:264 in decode │
│ │
│ 261 │ │ │ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] │
│ 262 │ │ │ decoded = torch.cat(decoded_slices) │
│ 263 │ │ else: │
│ ❱ 264 │ │ │ decoded = self._decode(z).sample │
│ 265 │ │ │
│ 266 │ │ if not return_dict: │
│ 267 │ │ │ return (decoded,) │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py:250 in │
│ _decode │
│ │
│ 247 │ │ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > │
│ 248 │ │ │ return self.tiled_decode(z, return_dict=return_dict) │
│ 249 │ │ │
│ ❱ 250 │ │ z = self.post_quant_conv(z) │
│ 251 │ │ dec = self.decoder(z) │
│ 252 │ │ │
│ 253 │ │ if not return_dict: │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl │
│ │
│ 1191 │ │ # this function, and just call forward. │
│ 1192 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │
│ 1193 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1194 │ │ │ return forward_call(*input, **kwargs) │
│ 1195 │ │ # Do not call functions when jit is used │
│ 1196 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1197 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:463 in forward │
│ │
│ 460 │ │ │ │ │ │ self.padding, self.dilation, self.groups) │
│ 461 │ │
│ 462 │ def forward(self, input: Tensor) -> Tensor: │
│ ❱ 463 │ │ return self._conv_forward(input, self.weight, self.bias) │
│ 464 │
│ 465 class Conv3d(_ConvNd): │
│ 466 │ __doc__ = r"""Applies a 3D convolution over an input signal composed of several inpu │
│ │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:459 in _conv_forward │
│ │
│ 456 │ │ │ return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=sel │
│ 457 │ │ │ │ │ │ │ weight, bias, self.stride, │
│ 458 │ │ │ │ │ │ │ _pair(0), self.dilation, self.groups) │
│ ❱ 459 │ │ return F.conv2d(input, weight, bias, self.stride, │
│ 460 │ │ │ │ │ │ self.padding, self.dilation, self.groups) │
│ 461 │ │
│ 462 │ def forward(self, input: Tensor) -> Tensor: │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Input type (float) and bias type (c10::Half) should be the same
I tried some pattens, but all failed.
images = []
for _ in range(args.num_validation_images):
with torch.cuda.amp.autocast():
image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
images.append(image)
images = []
for _ in range(args.num_validation_images):
with torch.autocast('cuda'):
image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
images.append(image)
images = []
for _ in range(args.num_validation_images):
with torch.cuda.amp.autocast(dtype=weight_dtype):
image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
images.append(image)
images = []
for _ in range(args.num_validation_images):
image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
images.append(image)
@sayakpaul pytorch2.0 and LoRAAttnProcessor2_0
works well.
This PR is ready for review.
Thanks for letting me know and for your efforts.
Could we maybe try to dive a bit deeper to see what is the issue? For example does the issue persist when not training the text encoder?
For example does the issue persist when not training the text encoder?
It also persist when training only unet.
Okay. Can we verify if the version in the main branch works with the following settings?
Just trying to double down the issue. If you have other ideas to try out please let me know.
PyTorch 2.0 and SDPA works well.
PyTorch 1.13.1 and xformers fails.
Thanks for reporting.
This is indeed weird as the main
version should support both. Did you try it out on the main
version of the script or on your version?
The error raised when I used my branch of the script. The script of main
branch works well.
Oh okay. Let's try to investigate the differences then :-)
The error solved when changing the place of applying xformers.
I don't know the reason :)
@sayakpaul This PR is ready.
471 | if is_xformers_available(): | ||
472 | import xformers | ||
473 | |||
474 | xformers_version = version.parse(xformers.__version__) | ||
475 | if xformers_version == version.parse("0.0.16"): | ||
476 | logger.warn( | ||
477 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | ||
478 | ) | ||
479 | unet.enable_xformers_memory_efficient_attention() |
❤️
Excellent work here. Thanks so much for all your experiments and iterations!
Let's also update the README of the example so that readers are aware that text encoder training is supported? We also need to add a test case for for the text encoder training to test_examples.py
. Then this PR is good to be shipped 🚀
Also, could we do a full run with the changes to ensure the results you're getting are as expected?
Also, could we do a full run with the changes to ensure the results you're getting are as expected?
https://civitai.com/models/25613/classic-anime-expressions
You can download images from this page'sTraining Images
tab.
CSV file is here.
Params are followings,
accelerate launch tools/train_text_to_image_lora.py \
--pretrained_model_name_or_path=models/anythingv5 \
--train_data_dir=data/ExpressionTraining \
--image_column=image \
--center_crop --random_flip \
--output_dir=work_dirs/ExpressionTraining \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--num_train_epochs=100 \
--learning_rate=1e-4 \
--lr_scheduler="cosine_with_restarts" --lr_warmup_steps=150 \
--validation_prompt="1girl, >_<" \
--checkpointing_steps=500 \
--rank=128 \
--snr_gamma=5 \
--train_text_encoder \
--seed="0"
182 | weights. For example: | ||
183 | |||
184 | ```python | ||
185 | from huggingface_hub.repocard import RepoCard | ||
186 | from diffusers import StableDiffusionPipeline | ||
187 | import torch | ||
188 | |||
189 | lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4" | ||
190 | card = RepoCard.load(lora_model_id) | ||
191 | base_model_id = card.data.to_dict()["base_model"] | ||
192 | |||
193 | pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) | ||
194 | pipe = pipe.to("cuda") | ||
195 | pipe.load_lora_weights(lora_model_id) | ||
196 | image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0] |
Let's use a LoRA you trained with this script? We'd also need to update the example prompt :)
@williamberman could you also take a look here?
444 | 467 | vae.to(accelerator.device, dtype=weight_dtype) | |
445 | 468 | text_encoder.to(accelerator.device, dtype=weight_dtype) | |
446 | 469 | ||
470 | if args.enable_xformers_memory_efficient_attention: | ||
471 | if is_xformers_available(): | ||
472 | import xformers | ||
473 | |||
474 | xformers_version = version.parse(xformers.__version__) | ||
475 | if xformers_version == version.parse("0.0.16"): | ||
476 | logger.warn( | ||
477 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | ||
478 | ) | ||
479 | unet.enable_xformers_memory_efficient_attention() | ||
480 | else: | ||
481 | raise ValueError("xformers is not available. Make sure it is installed correctly") | ||
482 |
unless I'm missing something, In the future it'd be helpful to not move a code block unrelated to the PR as it makes the diff harder to read :)
#3912 (comment)
#3912 (comment)
By moving the code block, I avoid this error.
866 | 960 | generator = generator.manual_seed(args.seed) | |
867 | 961 | images = [] | |
868 | 962 | for _ in range(args.num_validation_images): | |
869 | images.append( | ||
870 | pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] | ||
871 | ) | ||
963 | with torch.cuda.amp.autocast(): |
did something change that required autocast to be added?
It also related to this error.
You can check this thread.
Looks basically good, a few small questions :)
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Login to write a write a comment.
What does this PR do?
Fixes #3418
Add train_text_encoder args in train_text_to_image_lora.py. We can finetune text encoder.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.