Spaces:
Sleeping
Sleeping
| import os | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| from pdf2image import convert_from_path | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from colpali_engine.models import ColQwen2, ColQwen2Processor | |
| model = ColQwen2.from_pretrained( | |
| "manu/colqwen2-v1.0-alpha", | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda:0", # or "mps" if on Apple Silicon | |
| attn_implementation="flash_attention_2", # should work on A100 | |
| ).eval() | |
| processor = ColQwen2Processor.from_pretrained("manu/colqwen2-v1.0-alpha") | |
| def search(query: str, ds, images, k): | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| if device != model.device: | |
| model.to(device) | |
| qs = [] | |
| with torch.no_grad(): | |
| batch_query = processor.process_queries([query]).to(model.device) | |
| embeddings_query = model(**batch_query) | |
| qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) | |
| scores = processor.score(qs, ds, device=device) | |
| top_k_indices = scores[0].topk(k).indices.tolist() | |
| results = [] | |
| for idx in top_k_indices: | |
| results.append((images[idx], f"Page {idx}")) | |
| return results | |
| def index(files, ds): | |
| print("Converting files") | |
| images = convert_files(files) | |
| print(f"Files converted with {len(images)} images.") | |
| return index_gpu(images, ds) | |
| def convert_files(files): | |
| images = [] | |
| for f in files: | |
| images.extend(convert_from_path(f, thread_count=4)) | |
| if len(images) >= 150: | |
| raise gr.Error("The number of images in the dataset should be less than 150.") | |
| return images | |
| def index_gpu(images, ds): | |
| """Example script to run inference with ColPali (ColQwen2)""" | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| if device != model.device: | |
| model.to(device) | |
| # run inference - docs | |
| dataloader = DataLoader( | |
| images, | |
| batch_size=4, | |
| shuffle=False, | |
| collate_fn=lambda x: processor.process_images(x).to(model.device), | |
| ) | |
| for batch_doc in tqdm(dataloader): | |
| with torch.no_grad(): | |
| batch_doc = {k: v.to(device) for k, v in batch_doc.items()} | |
| embeddings_doc = model(**batch_doc) | |
| ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
| return f"Uploaded and converted {len(images)} pages", ds, images | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) π") | |
| gr.Markdown("""Demo to test ColQwen2 (ColPali) on PDF documents. | |
| ColPali is model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449). | |
| This demo allows you to upload PDF files and search for the most relevant pages based on your query. | |
| Refresh the page if you change documents ! | |
| β οΈ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing english text. Performance is expected to drop for other page formats and languages. | |
| Other models will be released with better robustness towards different languages and document formats ! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("## 1οΈβ£ Upload PDFs") | |
| file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs") | |
| convert_button = gr.Button("π Index documents") | |
| message = gr.Textbox("Files not yet uploaded", label="Status") | |
| embeds = gr.State(value=[]) | |
| imgs = gr.State(value=[]) | |
| with gr.Column(scale=3): | |
| gr.Markdown("## 2οΈβ£ Search") | |
| query = gr.Textbox(placeholder="Enter your query here", label="Query") | |
| k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5) | |
| # Define the actions | |
| search_button = gr.Button("π Search", variant="primary") | |
| output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True) | |
| convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs]) | |
| search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery]) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10).launch(debug=True) |