replicate params once
Browse files
app.py
CHANGED
|
@@ -24,10 +24,10 @@ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
|
| 24 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
| 25 |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
|
| 26 |
)
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def infer(prompts, negative_prompts, image):
|
| 29 |
-
params["controlnet"] = controlnet_params
|
| 30 |
-
|
| 31 |
num_samples = 1 #jax.device_count()
|
| 32 |
rng = create_key(0)
|
| 33 |
rng = jax.random.split(rng, jax.device_count())
|
|
@@ -38,7 +38,6 @@ def infer(prompts, negative_prompts, image):
|
|
| 38 |
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
|
| 39 |
processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
|
| 40 |
|
| 41 |
-
p_params = replicate(params)
|
| 42 |
prompt_ids = shard(prompt_ids)
|
| 43 |
negative_prompt_ids = shard(negative_prompt_ids)
|
| 44 |
processed_image = shard(processed_image)
|
|
|
|
| 24 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
| 25 |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
|
| 26 |
)
|
| 27 |
+
params["controlnet"] = controlnet_params
|
| 28 |
+
p_params = replicate(params)
|
| 29 |
|
| 30 |
def infer(prompts, negative_prompts, image):
|
|
|
|
|
|
|
| 31 |
num_samples = 1 #jax.device_count()
|
| 32 |
rng = create_key(0)
|
| 33 |
rng = jax.random.split(rng, jax.device_count())
|
|
|
|
| 38 |
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
|
| 39 |
processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
|
| 40 |
|
|
|
|
| 41 |
prompt_ids = shard(prompt_ids)
|
| 42 |
negative_prompt_ids = shard(negative_prompt_ids)
|
| 43 |
processed_image = shard(processed_image)
|