gokaygokay commited on
Commit
87966a5
1 Parent(s): 0c2f001
Files changed (2) hide show
  1. app.py +58 -5
  2. requirements.txt +9 -2
app.py CHANGED
@@ -9,8 +9,9 @@ from huggingface_hub import InferenceClient
9
  import subprocess
10
  import torch
11
  from PIL import Image
12
- from transformers import AutoProcessor, AutoModelForCausalLM
13
- import random
 
14
 
15
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
16
 
@@ -21,6 +22,10 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
21
  florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval()
22
  florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True)
23
 
 
 
 
 
24
  # Florence caption function
25
  @spaces.GPU
26
  def florence_caption(image):
@@ -44,6 +49,50 @@ def florence_caption(image):
44
  )
45
  return parsed_answer["<MORE_DETAILED_CAPTION>"]
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Load JSON files
48
  def load_json_file(file_name):
49
  file_path = os.path.join("data", file_name)
@@ -469,6 +518,7 @@ def create_interface():
469
  with gr.Accordion("Image and Caption", open=False):
470
  input_image = gr.Image(label="Input Image (optional)")
471
  caption_output = gr.Textbox(label="Generated Caption", lines=3)
 
472
  create_caption_button = gr.Button("Create Caption")
473
  add_caption_button = gr.Button("Add Caption to Prompt")
474
 
@@ -488,14 +538,17 @@ def create_interface():
488
  generate_text_button = gr.Button("Generate Prompt with LLM (Llama 3.1 70B)")
489
  text_output = gr.Textbox(label="Generated Text", lines=10)
490
 
491
- def create_caption(image):
492
  if image is not None:
493
- return florence_caption(image)
 
 
 
494
  return ""
495
 
496
  create_caption_button.click(
497
  create_caption,
498
- inputs=[input_image],
499
  outputs=[caption_output]
500
  )
501
 
 
9
  import subprocess
10
  import torch
11
  from PIL import Image
12
+ from transformers import AutoProcessor, AutoModelForCausalLM, Qwen2VLForConditionalGeneration
13
+ from qwen_vl_utils import process_vision_info
14
+ import numpy as np
15
 
16
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
17
 
 
22
  florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval()
23
  florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True)
24
 
25
+ # Initialize Qwen2-VL-2B model
26
+ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype="auto").to(device).eval()
27
+ qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
28
+
29
  # Florence caption function
30
  @spaces.GPU
31
  def florence_caption(image):
 
49
  )
50
  return parsed_answer["<MORE_DETAILED_CAPTION>"]
51
 
52
+ # Qwen2-VL-2B caption function
53
+ @spaces.GPU
54
+ def qwen_caption(image):
55
+ if not isinstance(image, Image.Image):
56
+ image = Image.fromarray(image)
57
+
58
+ image_path = array_to_image_path(np.array(image))
59
+
60
+ messages = [
61
+ {
62
+ "role": "user",
63
+ "content": [
64
+ {
65
+ "type": "image",
66
+ "image": image_path,
67
+ },
68
+ {"type": "text", "text": "Describe this image in detail."},
69
+ ],
70
+ }
71
+ ]
72
+
73
+ text = qwen_processor.apply_chat_template(
74
+ messages, tokenize=False, add_generation_prompt=True
75
+ )
76
+ image_inputs, video_inputs = process_vision_info(messages)
77
+ inputs = qwen_processor(
78
+ text=[text],
79
+ images=image_inputs,
80
+ videos=video_inputs,
81
+ padding=True,
82
+ return_tensors="pt",
83
+ )
84
+ inputs = inputs.to(device)
85
+
86
+ generated_ids = qwen_model.generate(**inputs, max_new_tokens=256)
87
+ generated_ids_trimmed = [
88
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
89
+ ]
90
+ output_text = qwen_processor.batch_decode(
91
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
92
+ )
93
+
94
+ return output_text[0]
95
+
96
  # Load JSON files
97
  def load_json_file(file_name):
98
  file_path = os.path.join("data", file_name)
 
518
  with gr.Accordion("Image and Caption", open=False):
519
  input_image = gr.Image(label="Input Image (optional)")
520
  caption_output = gr.Textbox(label="Generated Caption", lines=3)
521
+ caption_model = gr.Radio(["Florence", "Qwen"], label="Caption Model", value="Florence")
522
  create_caption_button = gr.Button("Create Caption")
523
  add_caption_button = gr.Button("Add Caption to Prompt")
524
 
 
538
  generate_text_button = gr.Button("Generate Prompt with LLM (Llama 3.1 70B)")
539
  text_output = gr.Textbox(label="Generated Text", lines=10)
540
 
541
+ def create_caption(image, model):
542
  if image is not None:
543
+ if model == "Florence":
544
+ return florence_caption(image)
545
+ elif model == "Qwen":
546
+ return qwen_caption(image)
547
  return ""
548
 
549
  create_caption_button.click(
550
  create_caption,
551
+ inputs=[input_image, caption_model],
552
  outputs=[caption_output]
553
  )
554
 
requirements.txt CHANGED
@@ -1,4 +1,11 @@
1
  spaces
2
- transformers
3
  timm
4
- openai==1.37.0
 
 
 
 
 
 
 
 
 
1
  spaces
 
2
  timm
3
+ openai==1.37.0
4
+ numpy==1.24.4
5
+ Pillow==10.3.0
6
+ Requests==2.31.0
7
+ torch
8
+ torchvision
9
+ git+https://github.com/huggingface/transformers.git
10
+ accelerate
11
+ qwen-vl-utils