diffusers
[do not merge] refactor `Attention` class for HunYuan DIT
#8265
Closed

[do not merge] refactor `Attention` class for HunYuan DIT #8265

yiyixuxu wants to merge 2 commits into main from hunyuan-yiyi
yiyixuxu
yiyixuxu1 year ago
# integration test (hunyuan dit)
import torch
from diffusers import HunyuanDiTPipeline

import torch
from huggingface_hub import hf_hub_download
from diffusers import HunyuanDiT2DModel
import safetensors.torch

device = "cuda"
model_config = HunyuanDiT2DModel.load_config("XCLiu/HunyuanDiT-0523", subfolder="transformer")
model = HunyuanDiT2DModel.from_config(model_config).to(device)

ckpt_path = hf_hub_download(
    "XCLiu/HunyuanDiT-0523",
    filename ="diffusion_pytorch_model.safetensors",
    subfolder="transformer",
)
state_dict = safetensors.torch.load_file(ckpt_path)

num_layers = 40
for i in range(num_layers):
    
    # attn1
    # Wkqv -> to_q, to_k, to_v
    q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
    q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
    state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
    state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
    state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
    state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
    state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
    state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
    state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
    state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
    
    # q_norm, k_norm -> norm_q, norm_k
    state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
    state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
    state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
    state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]

    state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
    state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
    state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
    state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")

    # out_proj -> to_out
    state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
    state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
    state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
    state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")

    # attn2
    # kq_proj -> to_k, to_v
    k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
    k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
    state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
    state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
    state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
    state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
    state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
    state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
    
    # q_proj -> to_q
    state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
    state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
    state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
    state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
    
    # q_norm, k_norm -> norm_q, norm_k
    state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
    state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
    state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
    state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]

    state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
    state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
    state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
    state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")

    # out_proj -> to_out
    state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
    state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
    state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
    state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")

model.load_state_dict(state_dict)

pipe = HunyuanDiTPipeline.from_pretrained("XCLiu/HunyuanDiT-0523", transformer=model, torch_dtype=torch.float32)
pipe.to('cuda')

### NOTE: HunyuanDiT supports both Chinese and English inputs
prompt = "一个宇航员在骑马"
#prompt = "An astronaut riding a horse"
image = pipe(height=1024, width=1024, prompt=prompt).images[0]

image.save("yiyi_test_5_out.png")

yiyi_test_5_out

Hunyuan Team: add HunyuanDiT related updates
154e1502
yiyixuxu refactor hunyun attentions
3f85b1d2
yiyixuxu
yiyixuxu commented on 2024-05-24
src/diffusers/models/attention_processor.py
26982810 LoRAAttnAddedKVProcessor,
26992811]
2812
2813
from typing import Tuple
yiyixuxu1 year ago

put it here for now (I don't think they belong here)

gnobitab1 year ago

Hi I tested this branch and it gives the same outputs as the old pipeline with the same generator. The code looks good to me.

yiyixuxu yiyixuxu closed this 1 year ago
yiyixuxu yiyixuxu deleted the hunyuan-yiyi branch 1 year ago

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone