import gradio as gr from peft import PeftModel, PeftConfig from transformers import PaliGemmaForConditionalGeneration import torch from transformers import PaliGemmaProcessor import PIL from utils import parse_bbox_and_labels,display_boxes import os from huggingface_hub import login # Login with token (from Space secret) hf_token = os.environ["HF_TOKEN"] login(token=hf_token) def get_response( image: PIL.Image.Image, prompt: str, max_new_tokens: str ) -> str: if any([task in prompt for task in tasks]): prompt = f"{prompt} \n" if "handwritten" in prompt: prompt = prompt + " \n" else: prompt=f"answer en {prompt} \n" raw_image = image.convert("RGB") width, height = raw_image.size inputs = processor(raw_image, prompt, return_tensors="pt").to(device) with torch.inference_mode(): output = peft_model.generate(**inputs, max_new_tokens=int(max_new_tokens)) input_len = inputs["input_ids"].shape[-1] output = processor.decode(output[0][input_len:], skip_special_tokens=True) print(prompt) print(output) if "loc" in output: boxes, labels = parse_bbox_and_labels(output) raw_image=display_boxes(raw_image, boxes, labels, target_size=(width, height)) return output,raw_image if __name__ == "__main__": tasks=["detect","extract handwritten_text","ocr","segment"] device = torch.device("cpu") # bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) #for gpu peft_model_id = "vk888/paligemma2_vqav2_hw" model_id = "google/paligemma2-3b-pt-448" config = PeftConfig.from_pretrained(peft_model_id) base_model = PaliGemmaForConditionalGeneration.from_pretrained(config.base_model_name_or_path, device_map=device) # , quantization_config=bnb_config) peft_model = PeftModel.from_pretrained(base_model, peft_model_id) processor = PaliGemmaProcessor.from_pretrained(model_id) examples = [ ["example/elron.png","what is the total amount ?", 80], ["example/invoice360.jpg","extract handwritten_text",128], ["example/invoice1.png","detect signature", 80], ["example/invoice1.png", "ocr",200], ] iface = gr.Interface( cache_examples=False, fn=get_response, inputs=[gr.Image(type="pil"),gr.Textbox(placeholder="what is the balance due ?"),gr.Textbox(placeholder="200")], examples=examples, outputs=[gr.Textbox(), gr.Image(type="pil")], title="DocVQA with Paligemma2 VLM", description="DocVQA with Paligemma2 VLM. Running on CPU .Each prompt can take 4-5 mins, better to clone & run locally. Thanks for your patience :) " ) iface.launch(share=True)