The documentation is not available anymore as the PR was closed or merged.
Looks good!
Works fine for me, thanks a lot @entrpn!
@patil-suraj can you maybe take a quick look?
thanks @entrpn ❤ am I doing right ?
real_seed = random.randint(0, 2147483647)
prng_seed = jax.random.PRNGKey(real_seed)
prng_seed = jax.random.split(prng_seed, jax.device_count())
num_samples = jax.device_count()
prompt_n = num_samples * [prompt]
prompt_ids = pipe.prepare_inputs(prompt_n)
prompt_ids = shard(prompt_ids)
negative_prompt_n = num_samples * [negative_prompt]
negative_prompt_ids = pipe.prepare_inputs(negative_prompt_n)
negative_prompt_ids = shard(negative_prompt_ids)
images = pipe(prompt_ids, params, prng_seed, neg_prompt_ids=negative_prompt_ids, num_inference_steps=num_inference_steps, height=height, width=width, guidance_scale=guidance_scale, jit=True).images
@camenduru that looks right.
Login to write a write a comment.
Added negative prompt support to jax pipeline.
For example.
prompt :
a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart
With negative prompt:
fog, grainy, purple