abiabidali commited on
Commit
9404648
·
verified ·
1 Parent(s): b5d2078

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -61
app.py CHANGED
@@ -4,30 +4,19 @@ import torch
4
  from PIL import Image
5
  from transformers import AutoProcessor, AutoModelForCausalLM
6
 
7
- # Install necessary dependencies (for local testing; skip if already installed)
8
- subprocess.run(
9
- "pip install flash-attn --no-build-isolation --global-option='--skip-cuda-build'",
10
- shell=True
11
- )
12
 
13
  # Initialize Florence model
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
- florence_model = AutoModelForCausalLM.from_pretrained(
16
- "microsoft/Florence-2-base", trust_remote_code=True
17
- ).to(device).eval()
18
- florence_processor = AutoProcessor.from_pretrained(
19
- "microsoft/Florence-2-base", trust_remote_code=True
20
- )
21
 
22
- # Define the caption generation function
23
  def generate_caption(image):
24
  if not isinstance(image, Image.Image):
25
  image = Image.fromarray(image)
26
-
27
- inputs = florence_processor(
28
- text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt"
29
- ).to(device)
30
 
 
31
  generated_ids = florence_model.generate(
32
  input_ids=inputs["input_ids"],
33
  pixel_values=inputs["pixel_values"],
@@ -46,48 +35,10 @@ def generate_caption(image):
46
  print("\n\nGeneration completed!:" + prompt)
47
  return prompt
48
 
49
- # Gradio Interface
50
- def save_to_csv(images, captions):
51
- import csv
52
- from io import StringIO
53
-
54
- # Create CSV content
55
- output = StringIO()
56
- writer = csv.writer(output)
57
- writer.writerow(["Filename", "Title", "Keywords"])
58
-
59
- for img, caption in zip(images, captions):
60
- filename = img.name if hasattr(img, "name") else "uploaded_image"
61
- title = caption[:50]
62
- keywords = caption.split(" ") # Simple keyword generation (replace with a better method)
63
- writer.writerow([filename, title, ", ".join(keywords)])
64
-
65
- output.seek(0)
66
- return output
67
-
68
- with gr.Blocks() as demo:
69
- with gr.Row():
70
- with gr.Column():
71
- input_images = gr.Image(
72
- label="Upload Images", type="pil", multiple=True
73
- )
74
- generate_button = gr.Button("Generate Captions")
75
- with gr.Column():
76
- output_texts = gr.Textbox(
77
- label="Generated Captions", lines=5, interactive=False
78
- )
79
- csv_output = gr.File(label="Download CSV")
80
-
81
- # Define event logic
82
- def process(images):
83
- captions = [generate_caption(img) for img in images]
84
- csv_file = save_to_csv(images, captions)
85
- return captions, csv_file
86
-
87
- generate_button.click(
88
- fn=process,
89
- inputs=[input_images],
90
- outputs=[output_texts, csv_output]
91
- )
92
-
93
- demo.launch(debug=True)
 
4
  from PIL import Image
5
  from transformers import AutoProcessor, AutoModelForCausalLM
6
 
7
+ # Install flash-attn library
8
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
 
 
9
 
10
  # Initialize Florence model
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
13
+ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
 
 
 
 
14
 
 
15
  def generate_caption(image):
16
  if not isinstance(image, Image.Image):
17
  image = Image.fromarray(image)
 
 
 
 
18
 
19
+ inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
20
  generated_ids = florence_model.generate(
21
  input_ids=inputs["input_ids"],
22
  pixel_values=inputs["pixel_values"],
 
35
  print("\n\nGeneration completed!:" + prompt)
36
  return prompt
37
 
38
+ # Gradio interface
39
+ io = gr.Interface(
40
+ generate_caption,
41
+ inputs=[gr.Image(label="Input Image")],
42
+ outputs=[gr.Textbox(label="Output Prompt", lines=2, show_copy_button=True)]
43
+ )
44
+ io.launch(debug=True)