hysts HF Staff commited on
Commit
e02a8c2
·
1 Parent(s): d2d62f5
Files changed (1) hide show
  1. app.py +29 -14
app.py CHANGED
@@ -1,5 +1,3 @@
1
- import random
2
-
3
  import gradio as gr
4
  import numpy as np
5
  import PIL.Image
@@ -16,16 +14,33 @@ MAX_SEED = np.iinfo(np.int32).max
16
  MAX_IMAGE_SIZE = 2048
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @spaces.GPU
20
  def infer(
21
  prompt: str,
22
  seed: int,
23
- randomize_seed: bool,
24
  width: int = 1024,
25
  height: int = 1024,
26
  num_inference_steps: int = 4,
27
  progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
28
- ) -> tuple[PIL.Image.Image, int]:
29
  """Generate an image from a text prompt using the FLUX.1 [schnell] model.
30
 
31
  Note:
@@ -35,19 +50,16 @@ def infer(
35
  Args:
36
  prompt: A text prompt in English used to guide the image generation. Limited to 77 tokens.
37
  seed: The seed used for deterministic random number generation.
38
- randomize_seed: If True, a new random seed will be used instead of the one provided.
39
  width: Width of the generated image in pixels. Defaults to 1024.
40
  height: Height of the generated image in pixels. Defaults to 1024.
41
  num_inference_steps: Number of inference steps to perform. A higher value may improve image quality. Defaults to 4.
42
  progress: (Internal) Used to display progress in the UI; should not be modified by the user.
43
 
44
  Returns:
45
- A tuple containing the generated image and the seed used.
46
  """
47
- if randomize_seed:
48
- seed = random.randint(0, MAX_SEED) # noqa: S311
49
  generator = torch.Generator().manual_seed(seed)
50
- image = pipe(
51
  prompt=prompt,
52
  width=width,
53
  height=height,
@@ -55,11 +67,10 @@ def infer(
55
  generator=generator,
56
  guidance_scale=0.0,
57
  ).images[0]
58
- return image, seed
59
 
60
 
61
  def run_example(prompt: str) -> tuple[PIL.Image.Image, int]:
62
- return infer(prompt, seed=42, randomize_seed=False)
63
 
64
 
65
  examples = [
@@ -132,13 +143,17 @@ with gr.Blocks(css=css) as demo:
132
  examples=examples,
133
  fn=run_example,
134
  inputs=prompt,
135
- outputs=[result, seed],
136
  )
137
 
138
  prompt.submit(
 
 
 
 
139
  fn=infer,
140
- inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps],
141
- outputs=[result, seed],
142
  )
143
 
144
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import PIL.Image
 
14
  MAX_IMAGE_SIZE = 2048
15
 
16
 
17
+ def get_seed(randomize_seed: bool, seed: int) -> int:
18
+ """Determine and return the random seed to use for model generation.
19
+
20
+ Args:
21
+ randomize_seed (bool): If True, a random seed (an integer in [0, MAX_SEED)) is generated using NumPy's default random number generator. If False, the provided seed argument is returned as-is.
22
+ seed (int): The seed value to use if randomize_seed is False.
23
+
24
+ Returns:
25
+ int: The selected seed value. If randomize_seed is True, a randomly generated integer; otherwise, the value of the seed argument.
26
+
27
+ Notes:
28
+ - MAX_SEED is the maximum value for a 32-bit integer (np.iinfo(np.int32).max).
29
+ - This function is typically used to ensure reproducibility or to introduce randomness in model generation.
30
+ """
31
+ rng = np.random.default_rng()
32
+ return int(rng.integers(0, MAX_SEED)) if randomize_seed else seed
33
+
34
+
35
  @spaces.GPU
36
  def infer(
37
  prompt: str,
38
  seed: int,
 
39
  width: int = 1024,
40
  height: int = 1024,
41
  num_inference_steps: int = 4,
42
  progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
43
+ ) -> PIL.Image.Image:
44
  """Generate an image from a text prompt using the FLUX.1 [schnell] model.
45
 
46
  Note:
 
50
  Args:
51
  prompt: A text prompt in English used to guide the image generation. Limited to 77 tokens.
52
  seed: The seed used for deterministic random number generation.
 
53
  width: Width of the generated image in pixels. Defaults to 1024.
54
  height: Height of the generated image in pixels. Defaults to 1024.
55
  num_inference_steps: Number of inference steps to perform. A higher value may improve image quality. Defaults to 4.
56
  progress: (Internal) Used to display progress in the UI; should not be modified by the user.
57
 
58
  Returns:
59
+ A PIL.Image.Image object representing the generated image.
60
  """
 
 
61
  generator = torch.Generator().manual_seed(seed)
62
+ return pipe(
63
  prompt=prompt,
64
  width=width,
65
  height=height,
 
67
  generator=generator,
68
  guidance_scale=0.0,
69
  ).images[0]
 
70
 
71
 
72
  def run_example(prompt: str) -> tuple[PIL.Image.Image, int]:
73
+ return infer(prompt, seed=42)
74
 
75
 
76
  examples = [
 
143
  examples=examples,
144
  fn=run_example,
145
  inputs=prompt,
146
+ outputs=result,
147
  )
148
 
149
  prompt.submit(
150
+ fn=get_seed,
151
+ inputs=[randomize_seed, seed],
152
+ outputs=seed,
153
+ ).then(
154
  fn=infer,
155
+ inputs=[prompt, seed, width, height, num_inference_steps],
156
+ outputs=result,
157
  )
158
 
159