File size: 1,918 Bytes
cf4be8f
96b9bee
cf4be8f
96b9bee
ff782aa
3412694
 
cf4be8f
ff782aa
 
 
 
 
 
f2ed08b
ff782aa
 
 
 
 
f2ed08b
96b9bee
f2ed08b
96b9bee
 
 
 
 
cf4be8f
f2ed08b
96b9bee
f2ed08b
ff782aa
 
f2ed08b
 
 
 
 
 
 
 
 
 
 
 
 
cf4be8f
f2ed08b
cf4be8f
f2ed08b
 
cf4be8f
96b9bee
f2ed08b
96b9bee
 
 
 
f2ed08b
96b9bee
 
 
f2ed08b
96b9bee
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
from transformers import LlavaForConditionalGeneration, LlavaProcessor
from PIL import Image
import torch
import io
import warnings
warnings.filterwarnings("ignore")


def safe_convert_image(img):
    if isinstance(img, Image.Image):
        return img
    elif isinstance(img, bytes):
        return Image.open(io.BytesIO(img))
    elif hasattr(img, "read"):
        return Image.open(img)
    else:
        raise ValueError("Format d'image non pris en charge.")


# Chargement du modèle et du processeur
model_id = "llava-hf/llava-1.5-7b-hf"
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Fonction de VQA
def vqa_llava(image, question):
    try:
        image = safe_convert_image(image)

        # ✅ Prompt spécifique à LLaVA
        prompt = f"<image>\nUSER: {question.strip()}\nASSISTANT:"

        # Préparation des entrées
        inputs = processor(
            text=prompt,
            images=image,
            return_tensors="pt"
        ).to(model.device)

        # Génération de la réponse
        generate_ids = model.generate(**inputs, max_new_tokens=100)
        response = processor.batch_decode(generate_ids, skip_special_tokens=True)[0]

        return response.replace(prompt, "").strip()

    except Exception as e:
        return f"❌ Erreur : {str(e)}"


# Interface Gradio
interface = gr.Interface(
    fn=vqa_llava,
    inputs=[
        gr.Image(type="pil", label="🖼️ Image"),
        gr.Textbox(lines=2, label="❓ Question (en anglais)")
    ],
    outputs=gr.Textbox(label="💬 Réponse"),
    title="🔎 Visual Question Answering avec LLaVA",
    description="Téléverse une image et pose une question visuelle (en anglais). Le modèle LLaVA-1.5-7B y répondra."
)

if __name__ == "__main__":
    interface.launch()