Spaces:
Runtime error
Runtime error
replicate params once
Browse files
app.py
CHANGED
@@ -24,10 +24,10 @@ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
|
24 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
25 |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
|
26 |
)
|
|
|
|
|
27 |
|
28 |
def infer(prompts, negative_prompts, image):
|
29 |
-
params["controlnet"] = controlnet_params
|
30 |
-
|
31 |
num_samples = 1 #jax.device_count()
|
32 |
rng = create_key(0)
|
33 |
rng = jax.random.split(rng, jax.device_count())
|
@@ -38,7 +38,6 @@ def infer(prompts, negative_prompts, image):
|
|
38 |
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
|
39 |
processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
|
40 |
|
41 |
-
p_params = replicate(params)
|
42 |
prompt_ids = shard(prompt_ids)
|
43 |
negative_prompt_ids = shard(negative_prompt_ids)
|
44 |
processed_image = shard(processed_image)
|
|
|
24 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
25 |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
|
26 |
)
|
27 |
+
params["controlnet"] = controlnet_params
|
28 |
+
p_params = replicate(params)
|
29 |
|
30 |
def infer(prompts, negative_prompts, image):
|
|
|
|
|
31 |
num_samples = 1 #jax.device_count()
|
32 |
rng = create_key(0)
|
33 |
rng = jax.random.split(rng, jax.device_count())
|
|
|
38 |
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
|
39 |
processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
|
40 |
|
|
|
41 |
prompt_ids = shard(prompt_ids)
|
42 |
negative_prompt_ids = shard(negative_prompt_ids)
|
43 |
processed_image = shard(processed_image)
|