Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +23 -57
  3. src/transformer.py +7 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🌍
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.23.2
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.6.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -97,65 +97,31 @@ header = """
97
 
98
 
99
  def create_app():
100
- # with gr.Blocks() as app:
101
- # gr.Markdown(header, elem_id="header")
102
- # # with gr.Tabs():
103
- # # with gr.Tab("Subject-driven"):
104
- # gr.Interface(
105
- # fn=process_image_and_text,
106
- # inputs=[
107
- # gr.Image(type="pil", label="Condition Image", width=300, elem_id="input"),
108
- # gr.Radio(
109
- # [("512", 512), ("1024(beta)", 1024)],
110
- # label="Resolution",
111
- # value=512,
112
- # elem_id="resolution",
113
- # ),
114
- # # gr.Slider(4, 16, 4, step=4, label="Inference Steps"),
115
- # gr.Textbox(lines=2, label="Text Prompt", elem_id="text"),
116
- # ],
117
- # outputs=gr.Image(type="pil", elem_id="output"),
118
- # examples=get_samples(),
119
- # )
120
- # # with gr.Tab("Fill"):
121
- # # gr.Markdown("Coming soon")
122
- # # with gr.Tab("Canny"):
123
- # # gr.Markdown("Coming soon")
124
- # # with gr.Tab("Depth"):
125
- # # gr.Markdown("Coming soon")
126
-
127
  with gr.Blocks() as app:
128
- gr.Markdown(header, elem_id="header")
129
- with gr.Row(equal_height=False):
130
- with gr.Column(variant="panel", elem_classes="inputPanel"):
131
- original_image = gr.Image(
132
- type="pil", label="Condition Image", width=300, elem_id="input"
 
 
 
 
 
 
 
 
 
 
 
 
133
  )
134
- resolution = gr.Radio(
135
- [("512", 512), ("1024(beta)", 1024)],
136
- label="Resolution",
137
- value=512,
138
- elem_id="resolution",
139
- )
140
- text = gr.Textbox(lines=2, label="Text Prompt", elem_id="text")
141
- submit_btn = gr.Button("Run", elem_id="submit_btn")
142
-
143
- with gr.Column(variant="panel", elem_classes="outputPanel"):
144
- output_image = gr.Image(type="pil", elem_id="output")
145
-
146
- with gr.Row():
147
- examples = gr.Examples(
148
- examples=get_samples(),
149
- inputs=[original_image, resolution, text],
150
- label="Examples",
151
- )
152
-
153
- submit_btn.click(
154
- fn=process_image_and_text,
155
- inputs=[original_image, resolution, text],
156
- outputs=output_image,
157
- )
158
-
159
  return app
160
 
161
 
 
97
 
98
 
99
  def create_app():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  with gr.Blocks() as app:
101
+ gr.Markdown(header)
102
+ with gr.Tabs():
103
+ with gr.Tab("Subject-driven"):
104
+ gr.Interface(
105
+ fn=process_image_and_text,
106
+ inputs=[
107
+ gr.Image(type="pil", label="Condition Image", width=300),
108
+ gr.Radio(
109
+ [("512", 512), ("1024(beta)", 1024)],
110
+ label="Resolution",
111
+ value=512,
112
+ ),
113
+ # gr.Slider(4, 16, 4, step=4, label="Inference Steps"),
114
+ gr.Textbox(lines=2, label="Text Prompt"),
115
+ ],
116
+ outputs=gr.Image(type="pil"),
117
+ examples=get_samples(),
118
  )
119
+ with gr.Tab("Fill"):
120
+ gr.Markdown("Coming soon")
121
+ with gr.Tab("Canny"):
122
+ gr.Markdown("Coming soon")
123
+ with gr.Tab("Depth"):
124
+ gr.Markdown("Coming soon")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  return app
126
 
127
 
src/transformer.py CHANGED
@@ -7,6 +7,7 @@ from diffusers.models.transformers.transformer_flux import (
7
  FluxTransformer2DModel,
8
  Transformer2DModelOutput,
9
  USE_PEFT_BACKEND,
 
10
  scale_lora_layers,
11
  unscale_lora_layers,
12
  logger,
@@ -154,6 +155,9 @@ def tranformer_forward(
154
 
155
  return custom_forward
156
 
 
 
 
157
  encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
158
  create_custom_forward(block),
159
  hidden_states,
@@ -200,6 +204,9 @@ def tranformer_forward(
200
 
201
  return custom_forward
202
 
 
 
 
203
  hidden_states = torch.utils.checkpoint.checkpoint(
204
  create_custom_forward(block),
205
  hidden_states,
 
7
  FluxTransformer2DModel,
8
  Transformer2DModelOutput,
9
  USE_PEFT_BACKEND,
10
+ is_torch_version,
11
  scale_lora_layers,
12
  unscale_lora_layers,
13
  logger,
 
155
 
156
  return custom_forward
157
 
158
+ ckpt_kwargs: Dict[str, Any] = (
159
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
160
+ )
161
  encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
162
  create_custom_forward(block),
163
  hidden_states,
 
204
 
205
  return custom_forward
206
 
207
+ ckpt_kwargs: Dict[str, Any] = (
208
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
209
+ )
210
  hidden_states = torch.utils.checkpoint.checkpoint(
211
  create_custom_forward(block),
212
  hidden_states,