manelhalima commited on
Commit
f2ed08b
·
verified ·
1 Parent(s): ff782aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -17
app.py CHANGED
@@ -12,49 +12,56 @@ def safe_convert_image(img):
12
  return img
13
  elif isinstance(img, bytes):
14
  return Image.open(io.BytesIO(img))
15
- elif hasattr(img, "read"): # Cas fichier Gradio
16
  return Image.open(img)
17
  else:
18
  raise ValueError("Format d'image non pris en charge.")
19
 
20
 
21
- #Charger le modèle et le processeur
22
  model_id = "llava-hf/llava-1.5-7b-hf"
23
- #processor = LlavaProcessor.from_pretrained(model_id)
24
- processor = LlavaProcessor.from_pretrained(model_id, use_fast=False)
25
-
26
  model = LlavaForConditionalGeneration.from_pretrained(
27
  model_id,
28
  torch_dtype=torch.float16,
29
  device_map="auto"
30
  )
31
 
32
- #Fonction de réponse à une question visuelle
33
  def vqa_llava(image, question):
34
- if not isinstance(image, Image.Image):
35
- image = Image.open(image)
36
  image = safe_convert_image(image)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- prompt = f"[INST] {question} [/INST]"
40
- inputs = processor(prompt, image, return_tensors="pt").to(model.device)
41
 
42
- generate_ids = model.generate(**inputs, max_new_tokens=100)
43
- response = processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
44
 
45
- # Nettoyage de la réponse
46
- return response.strip().replace(prompt, "").strip()
47
 
48
- #Interface Gradio
49
  interface = gr.Interface(
50
  fn=vqa_llava,
51
  inputs=[
52
  gr.Image(type="pil", label="🖼️ Image"),
53
- gr.Textbox(lines=2, label="❓ Question")
54
  ],
55
  outputs=gr.Textbox(label="💬 Réponse"),
56
  title="🔎 Visual Question Answering avec LLaVA",
57
- description="Pose une question sur une image et LLaVA répondra."
58
  )
59
 
60
  if __name__ == "__main__":
 
12
  return img
13
  elif isinstance(img, bytes):
14
  return Image.open(io.BytesIO(img))
15
+ elif hasattr(img, "read"):
16
  return Image.open(img)
17
  else:
18
  raise ValueError("Format d'image non pris en charge.")
19
 
20
 
21
+ # Chargement du modèle et du processeur
22
  model_id = "llava-hf/llava-1.5-7b-hf"
23
+ processor = LlavaProcessor.from_pretrained(model_id)
 
 
24
  model = LlavaForConditionalGeneration.from_pretrained(
25
  model_id,
26
  torch_dtype=torch.float16,
27
  device_map="auto"
28
  )
29
 
30
+ # Fonction de VQA
31
  def vqa_llava(image, question):
32
+ try:
 
33
  image = safe_convert_image(image)
34
 
35
+ # ✅ Prompt spécifique à LLaVA
36
+ prompt = f"<image>\nUSER: {question.strip()}\nASSISTANT:"
37
+
38
+ # Préparation des entrées
39
+ inputs = processor(
40
+ text=prompt,
41
+ images=image,
42
+ return_tensors="pt"
43
+ ).to(model.device)
44
+
45
+ # Génération de la réponse
46
+ generate_ids = model.generate(**inputs, max_new_tokens=100)
47
+ response = processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
48
 
49
+ return response.replace(prompt, "").strip()
 
50
 
51
+ except Exception as e:
52
+ return f"❌ Erreur : {str(e)}"
53
 
 
 
54
 
55
+ # Interface Gradio
56
  interface = gr.Interface(
57
  fn=vqa_llava,
58
  inputs=[
59
  gr.Image(type="pil", label="🖼️ Image"),
60
+ gr.Textbox(lines=2, label="❓ Question (en anglais)")
61
  ],
62
  outputs=gr.Textbox(label="💬 Réponse"),
63
  title="🔎 Visual Question Answering avec LLaVA",
64
+ description="Téléverse une image et pose une question visuelle (en anglais). Le modèle LLaVA-1.5-7B y répondra."
65
  )
66
 
67
  if __name__ == "__main__":