import gradio as gr from transformers import LlavaForConditionalGeneration, LlavaProcessor from PIL import Image import torch import io import warnings warnings.filterwarnings("ignore") def safe_convert_image(img): if isinstance(img, Image.Image): return img elif isinstance(img, bytes): return Image.open(io.BytesIO(img)) elif hasattr(img, "read"): return Image.open(img) else: raise ValueError("Format d'image non pris en charge.") # Chargement du modèle et du processeur model_id = "llava-hf/llava-1.5-7b-hf" processor = LlavaProcessor.from_pretrained(model_id) model = LlavaForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto" ) # Fonction de VQA def vqa_llava(image, question): try: image = safe_convert_image(image) # ✅ Prompt spécifique à LLaVA prompt = f"\nUSER: {question.strip()}\nASSISTANT:" # Préparation des entrées inputs = processor( text=prompt, images=image, return_tensors="pt" ).to(model.device) # Génération de la réponse generate_ids = model.generate(**inputs, max_new_tokens=100) response = processor.batch_decode(generate_ids, skip_special_tokens=True)[0] return response.replace(prompt, "").strip() except Exception as e: return f"❌ Erreur : {str(e)}" # Interface Gradio interface = gr.Interface( fn=vqa_llava, inputs=[ gr.Image(type="pil", label="🖼️ Image"), gr.Textbox(lines=2, label="❓ Question (en anglais)") ], outputs=gr.Textbox(label="💬 Réponse"), title="🔎 Visual Question Answering avec LLaVA", description="Téléverse une image et pose une question visuelle (en anglais). Le modèle LLaVA-1.5-7B y répondra." ) if __name__ == "__main__": interface.launch()