Flourish commited on
Commit
ff3266f
·
verified ·
1 Parent(s): 2ec78b1

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ imgs/car.png filter=lfs diff=lfs merge=lfs -text
37
+ imgs/chair.png filter=lfs diff=lfs merge=lfs -text
38
+ imgs/count.png filter=lfs diff=lfs merge=lfs -text
39
+ imgs/foot.webp filter=lfs diff=lfs merge=lfs -text
40
+ imgs/table.webp filter=lfs diff=lfs merge=lfs -text
41
+ imgs/train.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
  title: Ovis U1 3B
3
- emoji: 🦀
4
- colorFrom: gray
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.35.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Ovis U1 3B
3
+ emoji: 🎨
4
+ colorFrom: green
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.35.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: Demo for multimodal understanding and generation
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ subprocess.run('pip install flash-attn==2.6.3 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
4
+ import random
5
+ import spaces
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+ import gradio as gr
10
+ from transformers import AutoModelForCausalLM
11
+ from test_img_edit import pipe_img_edit
12
+ from test_img_to_txt import pipe_txt_gen
13
+ from test_txt_to_img import pipe_t2i
14
+
15
+
16
+ # Constants
17
+ MAX_SEED = 10000
18
+
19
+ hf_token = os.getenv("HF_TOKEN")
20
+
21
+ HUB_MODEL_ID = "AIDC-AI/Ovis-U1-3B"
22
+ model, loading_info = AutoModelForCausalLM.from_pretrained(
23
+ HUB_MODEL_ID,
24
+ torch_dtype=torch.bfloat16,
25
+ output_loading_info=True,
26
+ token=hf_token,
27
+ trust_remote_code=True
28
+ )
29
+ print(f'Loading info of Ovis-U1:\n{loading_info}')
30
+
31
+ model = model.eval().to("cuda")
32
+ model = model.to(torch.bfloat16)
33
+
34
+ def set_global_seed(seed: int = 42):
35
+ random.seed(seed)
36
+ np.random.seed(seed)
37
+ torch.manual_seed(seed)
38
+ torch.cuda.manual_seed_all(seed)
39
+
40
+ def randomize_seed_fn(seed: int, randomize: bool) -> int:
41
+ return random.randint(0, MAX_SEED) if randomize else seed
42
+
43
+ @spaces.GPU
44
+ def process_txt_to_img(prompt: str, height: int, width: int, steps: int, final_seed: int, guidance_scale: float, progress: gr.Progress = gr.Progress(track_tqdm=True)) -> list[Image.Image]:
45
+ set_global_seed(final_seed)
46
+ images = pipe_t2i(model, prompt, height, width, steps, cfg=guidance_scale, seed=final_seed)
47
+ return images
48
+
49
+ @spaces.GPU
50
+ def process_img_to_txt(prompt: str, img: Image.Image, progress: gr.Progress = gr.Progress(track_tqdm=True)) -> str:
51
+ output_text = pipe_txt_gen(model, img, prompt)
52
+ return output_text
53
+
54
+ @spaces.GPU
55
+ def process_img_txt_to_img(prompt: str, img: Image.Image, steps: int, final_seed: int, txt_cfg: float, img_cfg: float, progress: gr.Progress = gr.Progress(track_tqdm=True)) -> list[Image.Image]:
56
+ set_global_seed(final_seed)
57
+ images = pipe_img_edit(model, img, prompt, steps, txt_cfg, img_cfg, seed=final_seed)
58
+ return images
59
+
60
+ # Gradio UI
61
+ with gr.Blocks(title="Ovis-U1-3B") as demo:
62
+ gr.Markdown('''# Ovis-U1-3B
63
+ ''')
64
+
65
+ with gr.Row():
66
+ with gr.Column():
67
+ with gr.Tabs():
68
+ with gr.TabItem("Image + Text → Image"):
69
+ edit_image_input = gr.Image(label="Input Image", type="pil")
70
+ with gr.Row():
71
+ edit_prompt_input = gr.Textbox(
72
+ label="Prompt",
73
+ show_label=False,
74
+ placeholder="Describe the editing instruction...",
75
+ container=False,
76
+ lines=1
77
+ )
78
+ run_edit_image_btn = gr.Button("Run", scale=0)
79
+
80
+ with gr.Accordion("Advanced Settings", open=False):
81
+
82
+ with gr.Row():
83
+
84
+ edit_img_guidance_slider = gr.Slider(
85
+ label="Image Guidance Scale",
86
+ minimum=1.0, maximum=10.0,
87
+ step=0.1, value=1.5
88
+ )
89
+
90
+ edit_txt_guidance_slider = gr.Slider(
91
+ label="Text Guidance Scale",
92
+ minimum=1.0, maximum=30.0,
93
+ step=0.5, value=6.0
94
+ )
95
+
96
+ edit_num_steps_slider = gr.Slider(
97
+ label='Steps',
98
+ minimum=40, maximum=100,
99
+ value=50, step=1
100
+ )
101
+ edit_seed_slider = gr.Slider(
102
+ label="Seed",
103
+ minimum=0, maximum=int(MAX_SEED),
104
+ step=1, value=42
105
+ )
106
+ edit_randomize_checkbox = gr.Checkbox(
107
+ label="Randomize seed", value=False
108
+ )
109
+
110
+ img_edit_examples_data = [
111
+ ["imgs/train.png", "Modify this image in a Ghibli style. "],
112
+ ["imgs/chair.png", "Transfer the image into a faceted low-poly 3-D render style."],
113
+ ["imgs/car.png", "Replace the tiny house on wheels in the image with a vintage car."],
114
+ ]
115
+ gr.Examples(
116
+ examples=img_edit_examples_data,
117
+ inputs=[edit_image_input, edit_prompt_input],
118
+ cache_examples=False,
119
+ label="Image Editing Examples"
120
+ )
121
+
122
+ with gr.TabItem("Text → Image"):
123
+ with gr.Row():
124
+ prompt_gen_input = gr.Textbox(
125
+ label="Prompt",
126
+ show_label=False,
127
+ placeholder="Describe the image you want...",
128
+ container=False,
129
+ lines=1
130
+ )
131
+ run_image_gen_btn = gr.Button("Run", scale=0)
132
+
133
+ with gr.Accordion("Advanced Settings", open=False):
134
+ with gr.Row():
135
+ height_slider = gr.Slider(
136
+ label='height',
137
+ minimum=256, maximum=1536,
138
+ value=1024, step=32
139
+ )
140
+ width_slider = gr.Slider(
141
+ label='width',
142
+ minimum=256, maximum=1536,
143
+ value=1024, step=32
144
+ )
145
+
146
+ guidance_slider = gr.Slider(
147
+ label="Guidance Scale",
148
+ minimum=1.0, maximum=30.0,
149
+ step=0.5, value=5.0
150
+ )
151
+
152
+ num_steps_slider = gr.Slider(
153
+ label='Steps',
154
+ minimum=40, maximum=100,
155
+ value=50, step=1
156
+ )
157
+ seed_slider = gr.Slider(
158
+ label="Seed",
159
+ minimum=0, maximum=int(MAX_SEED),
160
+ step=1, value=42
161
+ )
162
+ randomize_checkbox = gr.Checkbox(
163
+ label="Randomize seed", value=False
164
+ )
165
+
166
+ text_gen_examples_data = [
167
+ ["A breathtaking fairy with teal wings sits gracefully on a lotus flower in a serene pond, exuding elegance."],
168
+ ["A winter mountain landscape at deep night with snowy terrain and colorful flowers, under beautiful clouds and no people, portrayed as an anime background illustration with intricate detail and sharp focus."],
169
+ ["A photo of a pug wearing a cowboy hat and bandana, sitting on a hay bale."]
170
+ ]
171
+ gr.Examples(
172
+ examples=text_gen_examples_data,
173
+ inputs=[prompt_gen_input],
174
+ cache_examples=False,
175
+ label="Image Generation Examples"
176
+ )
177
+
178
+ with gr.TabItem("Image → Text"):
179
+ image_understand_input = gr.Image(label="Input Image", type="pil")
180
+ with gr.Row():
181
+ prompt_understand_input = gr.Textbox(
182
+ label="Prompt",
183
+ show_label=False,
184
+ placeholder="Describe the question about image...",
185
+ container=False,
186
+ lines=1
187
+ )
188
+ run_image_understand_btn = gr.Button("Run", scale=0)
189
+
190
+ image_understanding_examples_data = [
191
+ ["imgs/table.webp", "In what scenario does this picture take place?"],
192
+ ["imgs/count.png", "How many broccoli are there in the picture?"],
193
+ ["imgs/foot.webp", "Where is this picture located?"],
194
+ ]
195
+ gr.Examples(
196
+ examples=image_understanding_examples_data,
197
+ inputs=[image_understand_input, prompt_understand_input],
198
+ cache_examples=False,
199
+ label="Image Understanding Examples"
200
+ )
201
+
202
+ clean_btn = gr.Button("Clear All Inputs/Outputs")
203
+
204
+ with gr.Column():
205
+ output_gallery = gr.Gallery(label="Generated Images", columns=2, visible=True) # Default to visible, content will control
206
+ output_text = gr.Textbox(label="Generated Text", visible=False, lines=5, interactive=False)
207
+
208
+ @spaces.GPU
209
+ def run_img_txt_to_img_tab(prompt, img, steps, seed, txt_cfg, img_cfg, progress=gr.Progress(track_tqdm=True)):
210
+ if img is None:
211
+ return (
212
+ gr.update(value=[], visible=False),
213
+ gr.update(value="Please upload an image for editing.", visible=True)
214
+ )
215
+ # Seed is already finalized by the randomize_seed_fn in the click chain
216
+ imgs = process_img_txt_to_img(prompt, img, steps, seed, txt_cfg, img_cfg, progress=progress)
217
+ return (
218
+ gr.update(value=imgs, visible=True),
219
+ gr.update(value="", visible=False)
220
+ )
221
+
222
+ @spaces.GPU
223
+ def run_txt_to_img_tab(prompt, height, width, steps, seed, guidance, progress=gr.Progress(track_tqdm=True)):
224
+ # Seed is already finalized by the randomize_seed_fn in the click chain
225
+ imgs = process_txt_to_img(prompt, height, width, steps, seed, guidance, progress=progress)
226
+ return (
227
+ gr.update(value=imgs, visible=True),
228
+ gr.update(value="", visible=False)
229
+ )
230
+
231
+ @spaces.GPU
232
+ def run_img_to_txt_tab(img, prompt, progress=gr.Progress(track_tqdm=True)):
233
+ if img is None:
234
+ return (
235
+ gr.update(value=[], visible=False),
236
+ gr.update(value="Please upload an image for understanding.", visible=True)
237
+ )
238
+ txt = process_img_to_txt(prompt, img, progress=progress)
239
+ return (
240
+ gr.update(value=[], visible=False),
241
+ gr.update(value=txt, visible=True)
242
+ )
243
+
244
+ def clean_all_fn():
245
+ return (
246
+ # Tab 1 inputs
247
+ gr.update(value=None),
248
+ gr.update(value=""),
249
+ gr.update(value=1.5),
250
+ gr.update(value=6.0),
251
+ gr.update(value=50),
252
+ gr.update(value=42),
253
+ gr.update(value=False),
254
+ # Tab 2 inputs
255
+ gr.update(value=""), # prompt_gen_input
256
+ gr.update(value=1024),
257
+ gr.update(value=1024),
258
+ gr.update(value=5.0),
259
+ gr.update(value=50),
260
+ gr.update(value=42), # seed_slider
261
+ gr.update(value=False), # randomize_checkbox
262
+ # Tab 3 inputs
263
+ gr.update(value=None), # image_understand_input
264
+ gr.update(value=""), # prompt_understand_input
265
+ # Outputs
266
+ gr.update(value=[], visible=True), # output_gallery (reset and keep visible for next gen)
267
+ gr.update(value="", visible=False) # output_text (reset and hide)
268
+ )
269
+
270
+ # Event listeners for Image + Text -> Image
271
+ edit_inputs = [edit_prompt_input, edit_image_input, edit_num_steps_slider, edit_seed_slider, edit_txt_guidance_slider, edit_img_guidance_slider]
272
+
273
+ run_edit_image_btn.click(
274
+ fn=randomize_seed_fn,
275
+ inputs=[edit_seed_slider, edit_randomize_checkbox],
276
+ outputs=[edit_seed_slider]
277
+ ).then(
278
+ fn=run_img_txt_to_img_tab,
279
+ inputs=edit_inputs,
280
+ outputs=[output_gallery, output_text]
281
+ )
282
+
283
+ edit_prompt_input.submit(
284
+ fn=randomize_seed_fn,
285
+ inputs=[edit_seed_slider, edit_randomize_checkbox],
286
+ outputs=[edit_seed_slider]
287
+ ).then(
288
+ fn=run_img_txt_to_img_tab,
289
+ inputs=edit_inputs,
290
+ outputs=[output_gallery, output_text]
291
+ )
292
+
293
+ # Event listeners for Text -> Image
294
+ gen_inputs = [prompt_gen_input, height_slider, width_slider, num_steps_slider, seed_slider, guidance_slider]
295
+
296
+ run_image_gen_btn.click(
297
+ fn=randomize_seed_fn,
298
+ inputs=[seed_slider, randomize_checkbox],
299
+ outputs=[seed_slider]
300
+ ).then(
301
+ fn=run_txt_to_img_tab,
302
+ inputs=gen_inputs,
303
+ outputs=[output_gallery, output_text]
304
+ )
305
+
306
+ prompt_gen_input.submit(
307
+ fn=randomize_seed_fn,
308
+ inputs=[seed_slider, randomize_checkbox],
309
+ outputs=[seed_slider]
310
+ ).then(
311
+ fn=run_txt_to_img_tab,
312
+ inputs=gen_inputs,
313
+ outputs=[output_gallery, output_text]
314
+ )
315
+
316
+ # Event listeners for Image -> Text
317
+ understand_inputs = [image_understand_input, prompt_understand_input]
318
+
319
+ run_image_understand_btn.click(
320
+ fn=run_img_to_txt_tab,
321
+ inputs=understand_inputs,
322
+ outputs=[output_gallery, output_text]
323
+ )
324
+
325
+ prompt_understand_input.submit(
326
+ fn=run_img_to_txt_tab,
327
+ inputs=understand_inputs,
328
+ outputs=[output_gallery, output_text]
329
+ )
330
+
331
+ clean_btn.click(
332
+ fn=clean_all_fn,
333
+ inputs=[],
334
+ outputs=[
335
+ edit_image_input, edit_prompt_input, edit_img_guidance_slider, edit_txt_guidance_slider,
336
+ edit_num_steps_slider, edit_seed_slider, edit_randomize_checkbox,
337
+ prompt_gen_input, height_slider, width_slider, guidance_slider, num_steps_slider, seed_slider, randomize_checkbox,
338
+ image_understand_input, prompt_understand_input,
339
+ output_gallery, output_text
340
+ ]
341
+ )
342
+
343
+ if __name__ == "__main__":
344
+ demo.launch(share=True)
imgs/car.png ADDED

Git LFS Details

  • SHA256: dc363061b5b227fb8da0906dcad0e59620ef68f4e118a1579f4289c3609c2e8a
  • Pointer size: 131 Bytes
  • Size of remote file: 598 kB
imgs/chair.png ADDED

Git LFS Details

  • SHA256: 58575e9530fa8ffbbea71afa46a4681af453cf4628ae61876bcd8a45092a2eeb
  • Pointer size: 131 Bytes
  • Size of remote file: 316 kB
imgs/count.png ADDED

Git LFS Details

  • SHA256: 9a3b0eb4ef918255b16707f5c298fdef23e3dc793eb1a2f22f3889b752929d0c
  • Pointer size: 131 Bytes
  • Size of remote file: 646 kB
imgs/foot.webp ADDED

Git LFS Details

  • SHA256: a386ff4face1f463fba9fc273a4e21f2943e646f585dcdb27f25ffcdc1c58ef1
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
imgs/table.webp ADDED

Git LFS Details

  • SHA256: 63ec6557a9e6cda427539bbb8ccd3723744bd109da616b877e9bb12c6d322d4f
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
imgs/train.png ADDED

Git LFS Details

  • SHA256: 927c6853f4059ca1b61fc9077f209a451ba0247de8591ce42a253002174cc192
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.0
2
+ transformers==4.51.3
3
+ tokenizers==0.21.1
4
+ sentencepiece==0.1.99
5
+ pyarrow==18.0.0
6
+ accelerate==1.1.0
7
+ pydantic==2.8.2
8
+ markdown2[all]
9
+ numpy==1.24.3
10
+ scikit-learn==1.2.2
11
+ requests
12
+ httpx
13
+ uvicorn
14
+ fastapi==0.112.4
15
+ einops==0.6.1
16
+ einops-exts==0.0.4
17
+ timm==1.0.11
18
+ tiktoken
19
+ transformers_stream_generator==0.0.4
20
+ scipy
21
+ pandas
22
+ torchaudio
23
+ xformers
24
+ pillow==10.3.0
25
+ pysubs2==1.7.2
26
+ trl==0.12.1
27
+ moviepy==1.0.3
28
+ diffusers==0.31.0
29
+ gradio
test_img_edit.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import AutoModelForCausalLM
8
+
9
+
10
+ def parse_args():
11
+ parser = argparse.ArgumentParser(description="Test Image Editing")
12
+ parser.add_argument(
13
+ "--model_path",
14
+ type=str,
15
+ default="AIDC-AI/Ovis-U1-3B",
16
+ )
17
+ parser.add_argument(
18
+ "--steps", type=int, default=50,
19
+ )
20
+ parser.add_argument(
21
+ "--img_cfg", type=float, default=1.5,
22
+ )
23
+ parser.add_argument(
24
+ "--txt_cfg", type=float, default=6,
25
+ )
26
+ args = parser.parse_args()
27
+ return args
28
+
29
+ def load_blank_image(width, height):
30
+ pil_image = Image.new("RGB", (width, height), (255, 255, 255)).convert('RGB')
31
+ return pil_image
32
+
33
+ def build_inputs(model, text_tokenizer, visual_tokenizer, prompt, pil_image, target_width, target_height):
34
+ if pil_image is not None:
35
+ target_size = (int(target_width), int(target_height))
36
+ pil_image, vae_pixel_values, cond_img_ids = model.visual_generator.process_image_aspectratio(pil_image, target_size)
37
+ cond_img_ids[..., 0] = 1.0
38
+ vae_pixel_values = vae_pixel_values.unsqueeze(0).to(device=model.device)
39
+ width = pil_image.width
40
+ height = pil_image.height
41
+ resized_height, resized_width = visual_tokenizer.smart_resize(height, width, max_pixels=visual_tokenizer.image_processor.min_pixels)
42
+ pil_image = pil_image.resize((resized_width, resized_height))
43
+ else:
44
+ vae_pixel_values = None
45
+ cond_img_ids = None
46
+
47
+ prompt, input_ids, pixel_values, grid_thws = model.preprocess_inputs(
48
+ prompt,
49
+ [pil_image],
50
+ generation_preface=None,
51
+ return_labels=False,
52
+ propagate_exception=False,
53
+ multimodal_type='single_image',
54
+ fix_sample_overall_length_navit=False
55
+ )
56
+ attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
57
+ input_ids = input_ids.unsqueeze(0).to(device=model.device)
58
+ attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
59
+ if pixel_values is not None:
60
+ pixel_values = torch.cat([
61
+ pixel_values.to(device=visual_tokenizer.device, dtype=torch.bfloat16) if pixel_values is not None else None
62
+ ],dim=0)
63
+ if grid_thws is not None:
64
+ grid_thws = torch.cat([
65
+ grid_thws.to(device=visual_tokenizer.device) if grid_thws is not None else None
66
+ ],dim=0)
67
+ return input_ids, pixel_values, attention_mask, grid_thws, vae_pixel_values
68
+
69
+ def pipe_img_edit(model, input_img, prompt, steps, txt_cfg, img_cfg, seed=42):
70
+ text_tokenizer = model.get_text_tokenizer()
71
+ visual_tokenizer = model.get_visual_tokenizer()
72
+
73
+ width, height = input_img.size
74
+ height, width = visual_tokenizer.smart_resize(height, width, factor=32)
75
+
76
+ gen_kwargs = dict(
77
+ max_new_tokens=1024,
78
+ do_sample=False,
79
+ top_p=None,
80
+ top_k=None,
81
+ temperature=None,
82
+ repetition_penalty=None,
83
+ eos_token_id=text_tokenizer.eos_token_id,
84
+ pad_token_id=text_tokenizer.pad_token_id,
85
+ use_cache=True,
86
+ height=height,
87
+ width=width,
88
+ num_steps=steps,
89
+ seed=seed,
90
+ img_cfg=img_cfg,
91
+ txt_cfg=txt_cfg,
92
+ )
93
+ uncond_image = load_blank_image(width, height)
94
+ uncond_prompt = "<image>\nGenerate an image."
95
+ input_ids, pixel_values, attention_mask, grid_thws, _ = build_inputs(model, text_tokenizer, visual_tokenizer, uncond_prompt, uncond_image, width, height)
96
+ with torch.inference_mode():
97
+ no_both_cond = model.generate_condition(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, grid_thws=grid_thws, **gen_kwargs)
98
+
99
+ input_img = input_img.resize((width, height))
100
+ prompt = "<image>\n" + prompt.strip()
101
+ with torch.inference_mode():
102
+ input_ids, pixel_values, attention_mask, grid_thws, _ = build_inputs(model, text_tokenizer, visual_tokenizer, uncond_prompt, input_img, width, height)
103
+ no_txt_cond = model.generate_condition(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, grid_thws=grid_thws, **gen_kwargs)
104
+
105
+ input_ids, pixel_values, attention_mask, grid_thws, vae_pixel_values = build_inputs(model, text_tokenizer, visual_tokenizer, prompt, input_img, width, height)
106
+ with torch.inference_mode():
107
+ cond = model.generate_condition(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, grid_thws=grid_thws, **gen_kwargs)
108
+ cond["vae_pixel_values"] = vae_pixel_values
109
+ images = model.generate_img(cond=cond, no_both_cond=no_both_cond, no_txt_cond=no_txt_cond, **gen_kwargs)
110
+ return images
111
+
112
+ def main():
113
+ args = parse_args()
114
+ model, loading_info = AutoModelForCausalLM.from_pretrained(args.model_path,
115
+ torch_dtype=torch.bfloat16,
116
+ output_loading_info=True,
117
+ trust_remote_code=True
118
+ )
119
+ print(f'Loading info of Ovis-U1:\n{loading_info}')
120
+
121
+ model = model.eval().to("cuda")
122
+ model = model.to(torch.bfloat16)
123
+ image_path = os.path.join(os.path.dirname(__file__), "docs", "imgs", "cat.png")
124
+ pil_img = Image.open(image_path).convert('RGB')
125
+ prompt = "add a hat to this cat."
126
+ image = pipe_img_edit(model, pil_img, prompt,
127
+ args.steps, args.txt_cfg, args.img_cfg)[0]
128
+ image.save("test_image_edit.png")
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
test_img_to_txt.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import AutoModelForCausalLM
6
+
7
+ def parse_args():
8
+ parser = argparse.ArgumentParser(description="Test Text Generation")
9
+ parser.add_argument(
10
+ "--model_path",
11
+ type=str,
12
+ default="AIDC-AI/Ovis-U1-3B",
13
+ )
14
+ args = parser.parse_args()
15
+ return args
16
+
17
+
18
+ def build_inputs(model, text_tokenizer, visual_tokenizer, prompt, pil_image):
19
+ prompt, input_ids, pixel_values, grid_thws = model.preprocess_inputs(
20
+ prompt,
21
+ [pil_image],
22
+ generation_preface=None,
23
+ return_labels=False,
24
+ propagate_exception=False,
25
+ multimodal_type='single_image',
26
+ fix_sample_overall_length_navit=False
27
+ )
28
+ attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
29
+ input_ids = input_ids.unsqueeze(0).to(device=model.device)
30
+ attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
31
+ if pixel_values is not None:
32
+ pixel_values = torch.cat([
33
+ pixel_values.to(device=visual_tokenizer.device, dtype=torch.bfloat16) if pixel_values is not None else None
34
+ ],dim=0)
35
+ if grid_thws is not None:
36
+ grid_thws = torch.cat([
37
+ grid_thws.to(device=visual_tokenizer.device) if grid_thws is not None else None
38
+ ],dim=0)
39
+ return input_ids, pixel_values, attention_mask, grid_thws
40
+
41
+
42
+ def pipe_txt_gen(model, pil_image, prompt):
43
+ text_tokenizer = model.get_text_tokenizer()
44
+ visual_tokenizer = model.get_visual_tokenizer()
45
+ gen_kwargs = dict(
46
+ max_new_tokens=4096,
47
+ do_sample=False,
48
+ top_p=None,
49
+ top_k=None,
50
+ temperature=None,
51
+ repetition_penalty=None,
52
+ eos_token_id=text_tokenizer.eos_token_id,
53
+ pad_token_id=text_tokenizer.pad_token_id,
54
+ use_cache=True,
55
+ )
56
+ prompt = "<image>\n" + prompt
57
+ input_ids, pixel_values, attention_mask, grid_thws = build_inputs(model, text_tokenizer, visual_tokenizer, prompt, pil_image)
58
+ with torch.inference_mode():
59
+ output_ids = model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, grid_thws=grid_thws, **gen_kwargs)[0]
60
+ gen_text = text_tokenizer.decode(output_ids, skip_special_tokens=True)
61
+ return gen_text
62
+
63
+
64
+ def main():
65
+ # load model
66
+ args = parse_args()
67
+ model, loading_info = AutoModelForCausalLM.from_pretrained(args.model_path,
68
+ torch_dtype=torch.bfloat16,
69
+ output_loading_info=True,
70
+ trust_remote_code=True
71
+ )
72
+ print(f'Loading info of Ovis-U1:\n{loading_info}')
73
+
74
+ model = model.eval().to("cuda")
75
+ model = model.to(torch.bfloat16)
76
+ image_path = os.path.join(os.path.dirname(__file__), "docs", "imgs", "cat.png")
77
+ pil_img = Image.open(image_path).convert('RGB')
78
+ prompt = "What is it?"
79
+ gen_txt = pipe_txt_gen(model, pil_img, prompt)
80
+ print(gen_txt)
81
+
82
+
83
+ if __name__ == "__main__":
84
+ main()
test_txt_to_img.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import math
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import AutoModelForCausalLM
7
+
8
+
9
+ def parse_args():
10
+ parser = argparse.ArgumentParser(description="Test Text-to-Image")
11
+ parser.add_argument(
12
+ "--model_path",
13
+ type=str,
14
+ default="AIDC-AI/Ovis-U1-3B",
15
+ )
16
+ parser.add_argument(
17
+ "--height",
18
+ type=int,
19
+ default=1024,
20
+ )
21
+ parser.add_argument(
22
+ "--width",
23
+ type=int,
24
+ default=1024,
25
+ )
26
+ parser.add_argument(
27
+ "--seed", type=int, default=42,
28
+ )
29
+ parser.add_argument(
30
+ "--steps", type=int, default=50,
31
+ )
32
+ parser.add_argument(
33
+ "--txt_cfg", type=float, default=5,
34
+ )
35
+ args = parser.parse_args()
36
+ return args
37
+
38
+
39
+ def load_blank_image(width, height):
40
+ pil_image = Image.new("RGB", (width, height), (255, 255, 255)).convert('RGB')
41
+ return pil_image
42
+
43
+ def build_inputs(model, text_tokenizer, visual_tokenizer, prompt, pil_image, target_width, target_height):
44
+ if pil_image is not None:
45
+ target_size = (int(target_width), int(target_height))
46
+ pil_image, vae_pixel_values, cond_img_ids = model.visual_generator.process_image_aspectratio(pil_image, target_size)
47
+ cond_img_ids[..., 0] = 1.0
48
+ vae_pixel_values = vae_pixel_values.unsqueeze(0).to(device=model.device)
49
+ width = pil_image.width
50
+ height = pil_image.height
51
+ resized_height, resized_width = visual_tokenizer.smart_resize(height, width, max_pixels=visual_tokenizer.image_processor.min_pixels)
52
+ pil_image = pil_image.resize((resized_width, resized_height))
53
+ else:
54
+ vae_pixel_values = None
55
+ cond_img_ids = None
56
+
57
+ prompt, input_ids, pixel_values, grid_thws = model.preprocess_inputs(
58
+ prompt,
59
+ [pil_image],
60
+ generation_preface=None,
61
+ return_labels=False,
62
+ propagate_exception=False,
63
+ multimodal_type='single_image',
64
+ fix_sample_overall_length_navit=False
65
+ )
66
+ attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
67
+ input_ids = input_ids.unsqueeze(0).to(device=model.device)
68
+ attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
69
+ if pixel_values is not None:
70
+ pixel_values = torch.cat([
71
+ pixel_values.to(device=visual_tokenizer.device, dtype=torch.bfloat16) if pixel_values is not None else None
72
+ ],dim=0)
73
+ if grid_thws is not None:
74
+ grid_thws = torch.cat([
75
+ grid_thws.to(device=visual_tokenizer.device) if grid_thws is not None else None
76
+ ],dim=0)
77
+ return input_ids, pixel_values, attention_mask, grid_thws, vae_pixel_values
78
+
79
+
80
+ def pipe_t2i(model, prompt, height, width, steps, cfg, seed=42):
81
+ text_tokenizer = model.get_text_tokenizer()
82
+ visual_tokenizer = model.get_visual_tokenizer()
83
+ gen_kwargs = dict(
84
+ max_new_tokens=1024,
85
+ do_sample=False,
86
+ top_p=None,
87
+ top_k=None,
88
+ temperature=None,
89
+ repetition_penalty=None,
90
+ eos_token_id=text_tokenizer.eos_token_id,
91
+ pad_token_id=text_tokenizer.pad_token_id,
92
+ use_cache=True,
93
+ height=height,
94
+ width=width,
95
+ num_steps=steps,
96
+ seed=seed,
97
+ img_cfg=0,
98
+ txt_cfg=cfg,
99
+ )
100
+ uncond_image = load_blank_image(width, height)
101
+ uncond_prompt = "<image>\nGenerate an image."
102
+ input_ids, pixel_values, attention_mask, grid_thws, _ = build_inputs(model, text_tokenizer, visual_tokenizer, uncond_prompt, uncond_image, width, height)
103
+ with torch.inference_mode():
104
+ no_both_cond = model.generate_condition(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, grid_thws=grid_thws, **gen_kwargs)
105
+ prompt = "<image>\nDescribe the image by detailing the color, shape, size, texture, quantity, text, and spatial relationships of the objects:" + prompt
106
+ no_txt_cond = None
107
+ input_ids, pixel_values, attention_mask, grid_thws, vae_pixel_values = build_inputs(model, text_tokenizer, visual_tokenizer, prompt, uncond_image, width, height)
108
+ with torch.inference_mode():
109
+ cond = model.generate_condition(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, grid_thws=grid_thws, **gen_kwargs)
110
+ cond["vae_pixel_values"] = vae_pixel_values
111
+ images = model.generate_img(cond=cond, no_both_cond=no_both_cond, no_txt_cond=no_txt_cond, **gen_kwargs)
112
+ return images
113
+
114
+
115
+ def main():
116
+ args = parse_args()
117
+ model, loading_info = AutoModelForCausalLM.from_pretrained(args.model_path,
118
+ torch_dtype=torch.bfloat16,
119
+ output_loading_info=True,
120
+ trust_remote_code=True
121
+ )
122
+ print(f'Loading info of Ovis-U1:\n{loading_info}')
123
+
124
+ model = model.eval().to("cuda")
125
+ model = model.to(torch.bfloat16)
126
+ prompt = "a cute cat"
127
+ image = pipe_t2i(model, prompt, args.height, args.width, args.steps, args.txt_cfg)[0]
128
+ image.save("test_t2i.png")
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()