diffusers
add OnnxStableDiffusionUpscalePipeline pipeline
#2158
Merged

add OnnxStableDiffusionUpscalePipeline pipeline #2158

ssube
ssube2 years ago (edited 2 years ago)

I think I have a working implemention of an OnnxStableDiffusionUpscalePipeline, which extends StableDiffusionUpscalePipeline to be compatible with OnnxRuntimeModel. I'm hoping to get some feedback on whether this is the right approach, and if so, what else I need to do before this can be merged besides writing tests. There are a few spots in the code that I have questions about, marked with # TODOs and noted at the bottom here.

Motivation

Running the current StableDiffusionUpscalePipeline on a machine without CUDA acceleration can be pretty slow, even with relatively small 128x128 input images. I am writing a web UI for running ONNX pipelines that allows you to run a series of upscaling models (or one model repeatedly), but running StableDiffusionUpscalePipeline on a 1024px square input (split into 128px tiles) can easily take 60+ minutes on a 16 core CPU. Using the ONNX runtime is much faster, but that combination was not available, so I wrote this pipeline.

  • Per 128x128 tile:
    • Using StableDiffusionUpscalePipeline: 2.98s/it or 02:28 per tile
    • Using OnnxStableDiffusionUpscalePipeline w/ ROCmExecutionProvider: 6.46it/s or 00:07 per tile
    • Using OnnxStableDiffusionUpscalePipeline w/ DMLExecutionProvider: 1.17it/s or 00:42 per tile
  • Upscaling 512x512 -> 2048x2048, 16 runs with 50 inference steps each:
    • Using StableDiffusionUpscalePipeline: finished pipeline in 0:41:00.270845
    • Using OnnxStableDiffusionUpscalePipeline w/ ROCmExecutionProvider: finished pipeline in 0:02:10.359478
  • Upscaling 1024x1024 -> 4096x4096, 64 runs with 50 inference steps each:
    • Using StableDiffusionUpscalePipeline: still running
    • Using OnnxStableDiffusionUpscalePipeline w/ ROCmExecutionProvider: finished pipeline in 0:05:53.323918

I have only tested this using the CPUExecutionProvider and ROCmExecutionProvider so far, but I have machines set up for testing the CUDAExecutionProvider and DMLExecutionProviders and will check on them as well.

I tried to make the least-necessary changes and ended up only overriding a few methods. It looks like the preference in some of the other pipelines is to copy methods, which I can also do, but I wanted to find the minimum viable diff. Most of the changes are around passing named parameters to the models and replacing .sample with [0], but there are a few ndarray.int() calls that I'm not sure about, and the StableDiffusionUpscalePipeline code used some config values that do not appear to exist on OnnxRuntimeModel.

Example

prompt = "an astronaut eating a hamburger"
steps = 50

txt2img = StableDiffusionOnnxPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    revision="onnx",
    provider="CUDAExecutionProvider",
)
small_image = txt2img(
    prompt,
    num_inference_steps=steps,
).images[0]

generator = torch.manual_seed(0)
upscale = OnnxStableDiffusionUpscalePipeline.from_pretrained(
    "ssube/stable-diffusion-x4-upscaler-onnx",
    provider="CUDAExecutionProvider",
)
large_image = upscale(
    prompt,
    small_image,
    generator=generator,
    num_inference_steps=steps,
).images[0]

TODOs

HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev2 years ago (edited 2 years ago)

The documentation is not available anymore as the PR was closed or merged.

ssube
ssube2 years ago

I added a basic test, which is passing locally (13 passed, 10 skipped in 67.31s (0:01:07)), but relies on an ONNX revision of stabilityai/stable-diffusion-x4-upscaler that does not exist in the https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/tree/main repo.

ForserX
ForserX2 years ago (edited 2 years ago)

@ssube
And how did you translate this model into ONYX format? I catch a bunch of errors. (AMD GPU & Windows)
i wanna check into DML mode

ssube
ssube2 years ago

@ForserX I'm using this script: https://github.com/ssube/onnx-web/blob/main/api/onnx_web/convert.py#L206
It's very close to https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py, but the single_vae branches are new for upscaling. ssube/onnx-web@bacce0a#diff-2b8422f2625f7e1cd0ca3fa3e9975deed7d4962823108c2fc29f14c53e2c0cc6 is the bulk of the changes. I got it to work by switching between class_labels and return_dict on the UNet inputs and export a single VAE rather than splitting the encoder/decoder. No idea if that's right. 😄

ForserX
ForserX2 years ago

How difficult everything is... I'll try, if it doesn't work out, I'll ask for a ready-made model))

ssube
ssube2 years ago

Using that convert.py script, I was able to convert the model on Windows 10 and run it using the DirectMLExecutionProvider on an AMD GPU. The output looks about right, nothing unusual showing up. I've added the iteration and 128px tile times to the description. It's not as fast as ROCm, from initial testing, but still much faster than CPU (roughly 5x).

Some logs from that:

Fetching 17 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [02:49<00:00,  9.95s/it]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensur
e that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team a
nd Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior 
or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
C:\Users\ssube\stabdiff\onnx-try-2\onnx-web\api\onnx_env\lib\site-packages\transformers\models\clip\modeling_clip.py:754: TracerWarning: torch.tensor results are registered as c
onstants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this fu
nction. In any other case, this might cause the trace to be incorrect.
  mask.fill_(torch.tensor(torch.finfo(dtype).min))
C:\Users\ssube\stabdiff\onnx-try-2\onnx-web\api\onnx_env\lib\site-packages\torch\onnx\symbolic_opset9.py:5408: UserWarning: Exporting aten::index operator of advanced indexing i
n opset 14 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will p
roduce incorrect results.
  warnings.warn(
[2023-01-30 21:03:37,446] INFO: __main__: UNET config: FrozenDict([('sample_size', 128), ('in_channels', 7), ('out_channels', 4), ('center_input_sample', False), ('flip_sin_to_c
os', True), ('freq_shift', 0), ('down_block_types', ['DownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D']), ('mid_block_type', 'UNetMidBlock2DC
rossAttn'), ('up_block_types', ['CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'UpBlock2D']), ('only_cross_attention', [True, True, True, False]), ('block_out
_channels', [256, 512, 512, 1024]), ('layers_per_block', 2), ('downsample_padding', 1), ('mid_block_scale_factor', 1), ('act_fn', 'silu'), ('norm_num_groups', 32), ('norm_eps', 
1e-05), ('cross_attention_dim', 1024), ('attention_head_dim', 8), ('dual_cross_attention', False), ('use_linear_projection', True), ('class_embed_type', None), ('num_class_embed
s', 1000), ('upcast_attention', False), ('resnet_time_scale_shift', 'default'), ('_class_name', 'UNet2DConditionModel'), ('_diffusers_version', '0.9.0.dev0'), ('_name_or_path', 
'C:\\Users\\ssube/.cache\\huggingface\\diffusers\\models--stabilityai--stable-diffusion-x4-upscaler\\snapshots\\19b610c68ca7572defb6e09e64d1063f32b4db83\\unet')])
[2023-01-30 21:04:33,172] INFO: __main__: VAE config: FrozenDict([('in_channels', 3), ('out_channels', 3), ('down_block_types', ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'Dow
nEncoderBlock2D']), ('up_block_types', ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D']), ('block_out_channels', [128, 256, 512]), ('layers_per_block', 2), ('act_fn'
, 'silu'), ('latent_channels', 4), ('norm_num_groups', 32), ('sample_size', 256), ('_class_name', 'AutoencoderKL'), ('_diffusers_version', '0.9.0.dev0'), ('_name_or_path', 'C:\\
Users\\ssube/.cache\\huggingface\\diffusers\\models--stabilityai--stable-diffusion-x4-upscaler\\snapshots\\19b610c68ca7572defb6e09e64d1063f32b4db83\\vae')])
[2023-01-30 21:04:43,174] INFO: __main__: exporting ONNX model
[2023-01-30 21:04:43,225] INFO: __main__: ONNX pipeline saved to ..\models\upscaling-stable-diffusion-x4
[2023-01-30 21:04:47,267] INFO: __main__: ONNX pipeline is loadable

and

[2023-01-30 21:29:02,983] INFO: onnx_web.chain.upscale_outpaint: final output image size: 1024x1024
[2023-01-30 21:29:02,984] INFO: onnx_web.chain.base: finished stage expand, result size: 1024x1024
[2023-01-30 21:29:02,984] INFO: onnx_web.chain.base: running stage upscale on image with dimensions 1024x1024, dict_keys(['output', 'size', 'prompt', 'scale', 'outscale', 'tile_
size', 'upscale'])
[2023-01-30 21:29:02,984] INFO: onnx_web.chain.base: image larger than tile size of SizeChart.mini, tiling stage
[2023-01-30 21:29:02,992] INFO: onnx_web.chain.utils: processing tile 1 of 64, 0.0
[2023-01-30 21:29:02,993] INFO: onnx_web.chain.upscale_stable_diffusion: upscaling with Stable Diffusion, 50 steps
2023-01-30 21:29:03.0777243 [W:onnxruntime:, inference_session.cc:493 onnxruntime::InferenceSession::RegisterExecutionProvider] Having memory pattern enabled is not supported while using the DML Execution Provider. So disabling it for this session since it uses the DML Execution Provider.
2023-01-30 21:29:04.0214862 [W:onnxruntime:, session_state.cc:1030 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2023-01-30 21:29:04.0253514 [W:onnxruntime:, session_state.cc:1032 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
2023-01-30 21:29:05.8783892 [W:onnxruntime:, inference_session.cc:493 onnxruntime::InferenceSession::RegisterExecutionProvider] Having memory pattern enabled is not supported while using the DML Execution Provider. So disabling it for this session since it uses the DML Execution Provider.
2023-01-30 21:29:05.9192614 [W:onnxruntime:, session_state.cc:1030 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2023-01-30 21:29:05.9231882 [W:onnxruntime:, session_state.cc:1032 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
2023-01-30 21:29:06.5044530 [W:onnxruntime:, inference_session.cc:493 onnxruntime::InferenceSession::RegisterExecutionProvider] Having memory pattern enabled is not supported while using the DML Execution Provider. So disabling it for this session since it uses the DML Execution Provider.
2023-01-30 21:29:06.7290335 [W:onnxruntime:, session_state.cc:1030 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2023-01-30 21:29:06.7331632 [W:onnxruntime:, session_state.cc:1032 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
  8%|███████████▏                                                                                                                                | 4/50 [00:09<01:53,  2.47s/it]
ForserX
ForserX2 years ago (edited 2 years ago)

(roughly 5x)

The Vulkan variation of ESRGAN works even faster

Check mail, Please

ssube
ssube2 years ago👍 1

I pushed a copy of the model that I have been using to https://huggingface.co/ssube/stable-diffusion-x4-upscaler-onnx and updated the tests accordingly 🤞

patrickvonplaten
patrickvonplaten2 years ago

Cool, cc @anton-l @echarlaix for review

patrickvonplaten patrickvonplaten requested a review from anton-l anton-l 2 years ago
ForserX
ForserX2 years ago

It remains to wait custiom VAE and LoRA for ONNX))

ssube
ssube2 years ago

I added another, longer test and fixed up a few of the TODOs. The remaining ones are all related to hard-coded channel counts and the text_embeddings dtype, and I'm not sure where to look those up, they don't seem to be present on the OnnxRuntimeModel.

I also tried adding attention_mask back to the text encoder, but I don't see it being used in the other ONNX pipelines, and attempting to add it causes an 2 : INVALID_ARGUMENT : Invalid Feed Input Name:attention_mask error.

ssube ssube changed the title [WIP] add OnnxStableDiffusionUpscalePipeline pipeline add OnnxStableDiffusionUpscalePipeline pipeline 2 years ago
patrickvonplaten
patrickvonplaten2 years ago
patrickvonplaten patrickvonplaten assigned anton-l anton-l 2 years ago
ssube
ssube2 years ago

Is there anything else I can/should add to this? I'm not sure where to look up the vae.config/unet.config equivalents, or how important that is.

patrickvonplaten
patrickvonplaten2 years ago

@anton-l can you take a look here?

ssube ssube force pushed from 39bdc34a to 295a96d6 2 years ago
ssube
ssube2 years ago👍 1

I've been using and testing this pipeline more, with more schedulers, and fixed a couple of issues related to the mix of numpy and torch types. There was an unsupported operand type(s) for *: 'numpy.ndarray' and 'Tensor' error with some (but not all) schedulers, which I fixed based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py#L437. I've added tests for all of the same schedulers that are tested in https://github.com/huggingface/diffusers/blob/main/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py except for a fast test for LMS discrete, which was timing out.

There were a few .config lookups that I wasn't sure about, but it looks like the other ONNX pipelines declare them as constants, so I did the same: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py#L34

The last issue I'm aware of is a slight difference between the parameter types to the scheduler.step() call: many of the other ONNX pipelines use something like torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs, but I think due to the restrictions in ORT, converting latents to a tensor causes an TypeError: expected np.ndarray (got Tensor) error and does not seem right here. torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs does appear to work.

I did run into one issue with int32 vs int64 types, but that appears to be related to how the model is trained or serialized, and exporting it again with the 4th input as a torch.long solved that:

     # UNET
     if single_vae:
         unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
-        unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int)
+        unet_scale = torch.tensor(4).to(
+            device=ctx.training_device, dtype=torch.long
+        )
anton-l
anton-l approved these changes on 2023-02-16
anton-l2 years ago

Very impressive work @ssube, thank you so much for contributing!
Overall your implementation looks good to me, just left a couple of minor comments :)

For the int32 vs int64 issue: maybe it would be possible to infer the type at runtime, similar to

?

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
14logger = getLogger(__name__)
15
16
17
NUM_LATENT_CHANNELS = 4
18
NUM_UNET_INPUT_CHANNELS = 7
anton-l2 years ago

Yes, this works 👍

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
18NUM_UNET_INPUT_CHANNELS = 7
19
20# TODO: should this be a lookup? it needs to match the conversion script
21
class_labels_dtype = np.int64
anton-l2 years ago👍 1

The integer types stay the same even in fp16 mode, so you can safely move it inline

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
108
109 # 5. Add noise to image
110 noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
111
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype)
anton-l2 years ago

text_embeddings_dtype can be inferred from text_embeddings (fp32 or fp16), so this shouldn't be a constant

ssube2 years ago

I thought so, let me fix that up

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
23# TODO: should this be a lookup or converted? can it vary on ONNX?
24text_embeddings_dtype = torch.float32
25
26
###
27
# This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline:
28
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
29
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
30
###
anton-l2 years ago👍 1

Probably ok to remove this disclaimer now 😄

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
8888 ):
8989 super().__init__()
9090
91 # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
92 is_vae_scaling_factor_set_to_0_08333 = (
93 hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
94 )
95 if not is_vae_scaling_factor_set_to_0_08333:
96 deprecation_message = (
97 "The configuration file of the vae does not contain `scaling_factor` or it is set to"
98 f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned"
99 " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to 0.08333"
100 " Please make sure to update the config accordingly, as not doing so might lead to incorrect results"
101 " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be"
102 " very nice if you could open a Pull Request for the `vae/config.json` file"
91
if hasattr(vae, "config"):
92
# check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
93
is_vae_scaling_factor_set_to_0_08333 = (
94
hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
anton-l2 years ago

cc @patrickvonplaten @patil-suraj for this change

ssube2 years ago

I wasn't sure about this part, but if the VAE doesn't have .config, the current implement will throw without logging much.

patrickvonplaten2 years ago

Ok for me!

ssube
ssube2 years ago

I inlined the integer type and put in lookups for the other two. One of them needed to go from numpy to the torch dtype since that's what the StableDiffusionUpscalePipeline expects, so I put in a little lookup table for that, hopefully that is ok: 75cadf2#diff-3815a0888bb607ca69fe4022fa3b4a809687fe2b3ae4d0ea0397288fac3c920bR20-R23

For the int32/64 issue that I mentioned, I tested that a little bit more, and everything seems to work as long as the type in the convert/export code and the pipeline match. Is there any reason not to use int64 there? For more context, this is my convert script and the relevant part is:

    # UNET
    if single_vae: # upscale pipeline
        unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
        unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.long) # <- this is the type that needs to match
    else:
        unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
        unet_scale = torch.tensor(False).to(
            device=ctx.training_device, dtype=torch.bool
        )

    unet_in_channels = pipeline.unet.config.in_channels
    unet_sample_size = pipeline.unet.config.sample_size
    unet_path = output_path / "unet" / "model.onnx"
    onnx_export(
        pipeline.unet,
        model_args=(
            torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
                device=ctx.training_device, dtype=dtype
            ),
            torch.randn(2).to(device=ctx.training_device, dtype=dtype),
            torch.randn(2, num_tokens, text_hidden_size).to(
                device=ctx.training_device, dtype=dtype
            ),
            unet_scale,
        ),
        output_path=unet_path,
        ordered_input_names=unet_inputs,
        # has to be different from "sample" for correct tracing
        output_names=["out_sample"],
        dynamic_axes={
            "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
            "timestep": {0: "batch"},
            "encoder_hidden_states": {0: "batch", 1: "sequence"},
        },
        opset=ctx.opset,
        use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
    )
ssube [Onnx] add Stable Diffusion Upscale pipeline
538f37f8
ssube add a test for the OnnxStableDiffusionUpscalePipeline
8a3fbc4c
ssube check for VAE config before adjusting scaling factor
3b55710f
ssube update test assertions, lint fixes
a9a255bc
ssube run fix-copies target
26f762c0
ssube switch test checkpoint to one hosted on huggingface
e7d433ab
ssube partially restore attention mask
a771002b
ssube reshape embeddings after running text encoder
72f26cea
ssube add longer nightly test for ONNX upscale pipeline
b975750e
ssube use package import to fix tests
81a6b7a2
ssube fix scheduler compatibility and class labels dtype
4df84473
ssube use more precise type
13731d14
ssube remove LMS from fast tests
bab362f8
ssube lookup latent and timestamp types
760321aa
ssube add docs for ONNX upscaling, rename lookup table
9b7810fe
ssube ssube force pushed from 3d102b0a to 9b7810fe 2 years ago
ssube replace deprecated pipeline names in ONNX docs
9b2c347c
ssube ssube force pushed from c2b748e9 to 9b2c347c 2 years ago
patrickvonplaten patrickvonplaten requested a review from williamberman williamberman 2 years ago
patrickvonplaten
patrickvonplaten2 years ago

Looks good to me - thanks for checking the PR @anton-l :-)

cc @williamberman could you also take a quick look?

patrickvonplaten
patrickvonplaten2 years ago

Merging to not block the community contributor here

patrickvonplaten patrickvonplaten merged 9920c333 into main 2 years ago
zetyquickly
zetyquickly1 year ago

Hello. On version diffusers > 0.16.0 this pipeline throws exception due to vae.config attribute check is removed.

File "/opt/conda/envs/lora/lib/python3.9/site-packages/diffusers/pipelines/pipeline_utils.py", line 1101, in from_pretrained
    model = pipeline_class(**init_kwargs)
  File "/opt/conda/envs/lora/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py", line 59, in __init__
    super().__init__(
  File "/opt/conda/envs/lora/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py", line 134, in __init__
    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
AttributeError: 'OnnxRuntimeModel' object has no attribute 'config'
patrickvonplaten
patrickvonplaten1 year ago

Thanks for the ping @zetyquickly ! Would you like to open an issue to fix it?

Login to write a write a comment.

Login via GitHub

Assignees
Labels
Milestone