Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import json | |
| import random | |
| import pprint | |
| import os | |
| from io import BytesIO | |
| import glob | |
| from pathlib import Path | |
| from typing import Optional, cast | |
| import numpy as np | |
| #from datasets import load_dataset | |
| import json | |
| import boto3 | |
| from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth | |
| from requests.auth import HTTPBasicAuth | |
| from requests_aws4auth import AWS4Auth | |
| import matplotlib.pyplot as plt | |
| import requests | |
| import boto3 | |
| import streamlit as st | |
| import base64 | |
| from colpali_engine.interpretability import ( | |
| get_similarity_maps_from_embeddings, | |
| plot_all_similarity_maps, | |
| plot_similarity_map, | |
| ) | |
| import torch | |
| # from colpali_engine.models import ColPali, ColPaliProcessor | |
| # from colpali_engine.utils.torch_utils import get_torch_device | |
| from PIL import Image | |
| import utilities.invoke_models as invoke_models | |
| model_name = ( | |
| "vidore/colpali-v1.3" | |
| ) | |
| # colpali_model = ColPali.from_pretrained( | |
| # model_name, | |
| # torch_dtype=torch.bfloat16, | |
| # device_map="cuda:0", # Use "cuda:0" for GPU, "cpu" for CPU, or "mps" for Apple Silicon | |
| # ).eval() | |
| # colpali_processor = ColPaliProcessor.from_pretrained( | |
| # model_name | |
| # ) | |
| awsauth = HTTPBasicAuth('master',st.secrets['ml_search_demo_api_access']) | |
| headers = {"Content-Type": "application/json"} | |
| aos_client = OpenSearch( | |
| hosts = [{'host': 'search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com', 'port': 443}], | |
| http_auth = awsauth, | |
| use_ssl = True, | |
| verify_certs = True, | |
| connection_class = RequestsHttpConnection, | |
| pool_maxsize = 20 | |
| ) | |
| region_endpoint = "us-east-1" | |
| # Your SageMaker endpoint name | |
| endpoint_name = "colpali-endpoint" | |
| # Create a SageMaker runtime client | |
| runtime = boto3.client("sagemaker-runtime",aws_access_key_id=st.secrets['user_access_key'], | |
| aws_secret_access_key=st.secrets['user_secret_key'], region_name=region_endpoint) | |
| # Prepare your payload (e.g., text-only input) | |
| if 'top_img' not in st.session_state: | |
| st.session_state['top_img'] = "" | |
| if 'query_token_vectors' not in st.session_state: | |
| st.session_state['query_token_vectors'] = "" | |
| if 'query_tokens' not in st.session_state: | |
| st.session_state['query_tokens'] = "" | |
| def call_nova( | |
| model, | |
| messages, | |
| system_message="", | |
| streaming=False, | |
| max_tokens=512, | |
| temp=0.0001, | |
| top_p=0.99, | |
| top_k=20, | |
| tools=None, | |
| verbose=False, | |
| ): | |
| client = boto3.client('bedrock-runtime', | |
| aws_access_key_id=st.secrets['user_access_key'], | |
| aws_secret_access_key=st.secrets['user_secret_key'], region_name = region_endpoint) | |
| system_list = [{"text": system_message}] | |
| inf_params = { | |
| "max_new_tokens": max_tokens, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "temperature": temp, | |
| } | |
| request_body = { | |
| "messages": messages, | |
| "system": system_list, | |
| "inferenceConfig": inf_params, | |
| } | |
| if tools is not None: | |
| tool_config = [] | |
| for tool in tools: | |
| tool_config.append({"toolSpec": tool}) | |
| request_body["toolConfig"] = {"tools": tool_config} | |
| if verbose: | |
| print("Request Body", request_body) | |
| if not streaming: | |
| response = client.invoke_model(modelId=model, body=json.dumps(request_body)) | |
| model_response = json.loads(response["body"].read()) | |
| return model_response, model_response["output"]["message"]["content"][0]["text"] | |
| else: | |
| response = client.invoke_model_with_response_stream( | |
| modelId=model, body=json.dumps(request_body) | |
| ) | |
| return response["body"] | |
| def get_base64_encoded_value(media_path): | |
| with open(media_path, "rb") as media_file: | |
| binary_data = media_file.read() | |
| base_64_encoded_data = base64.b64encode(binary_data) | |
| base64_string = base_64_encoded_data.decode("utf-8") | |
| return base64_string | |
| def generate_ans(top_result,query): | |
| print(query) | |
| system_message = "given an image of a PDF page, answer the question. Be accurate to the question. If you don't find the answer in the page, please say, I don't know" | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "image": { | |
| "format": "jpeg", | |
| "source": { | |
| "bytes": get_base64_encoded_value( | |
| top_result | |
| ) | |
| }, | |
| } | |
| }, | |
| { | |
| "text": query#"what is the proportion of female new hires 2021-2023?" | |
| }, | |
| ], | |
| } | |
| ] | |
| model_response, content_text = call_nova( | |
| "amazon.nova-pro-v1:0", messages, system_message=system_message, max_tokens=300 | |
| ) | |
| print(content_text) | |
| return content_text | |
| def img_highlight(img,batch_queries,query_tokens): | |
| img_name = os.path.basename(img) # e.g., "my_image.png" | |
| # Construct the search pattern | |
| search_pattern = f"/home/user/app/similarity_maps/similarity_map_{img_name}_token_*" | |
| # Search for matching files | |
| matching_files = glob.glob(search_pattern) | |
| # Check if any match exists | |
| map_images = [] | |
| if matching_files: | |
| print("✅ Matching similarity map exists:") | |
| for file_path in matching_files: | |
| print(f" - {file_path}") | |
| map_images.append({'file':file_path}) | |
| return map_images | |
| # Reference from : https://github.com/tonywu71/colpali-cookbooks/blob/main/examples/gen_colpali_similarity_maps.ipynb | |
| with open(img, "rb") as f: | |
| img_b64 = base64.b64encode(f.read()).decode("utf-8") | |
| # Construct payload with only the image | |
| payload = { | |
| "images": [img_b64] | |
| } | |
| # Send to endpoint | |
| response = runtime.invoke_endpoint( | |
| EndpointName=endpoint_name, # your endpoint name | |
| ContentType="application/json", | |
| Body=json.dumps(payload) | |
| ) | |
| # Read response | |
| img_colpali_res = (json.loads(response["Body"].read().decode())) | |
| # Convert outputs to tensors | |
| image_embeddings = torch.tensor(img_colpali_res["image_embeddings"][0]) # shape: [B, T, D] or [T, D] | |
| query_embeddings = torch.tensor(batch_queries) # shape: [B, D] | |
| # Ensure you're accessing the full 1D mask vector, not a single value | |
| image_mask_list = img_colpali_res["image_mask"] | |
| if isinstance(image_mask_list[0], list): | |
| # Correct: list of lists | |
| image_mask = torch.tensor(image_mask_list[0]).bool() | |
| else: | |
| # Edge case: already flattened | |
| image_mask = torch.tensor(image_mask_list).bool() | |
| print("Valid patch count:", image_mask.sum().item()) # shape: [B, T] or [T] | |
| # Ensure 2D query_embeddings | |
| if query_embeddings.dim() == 2: | |
| query_embeddings = query_embeddings.unsqueeze(0) | |
| # Ensure image_embeddings and image_mask are batched | |
| if image_embeddings.dim() == 2: | |
| image_embeddings = image_embeddings.unsqueeze(0) # [1, T, D] | |
| if image_mask.dim() == 1: | |
| image_mask = image_mask.unsqueeze(0) | |
| print("query_embeddings shape:", query_embeddings.shape) | |
| print("image_embeddings shape:", image_embeddings.shape) | |
| print("image_mask shape:", image_mask.shape) | |
| # Get the number of image patches | |
| image = Image.open(img) | |
| n_patches = (img_colpali_res["patch_shape"]['height'],img_colpali_res["patch_shape"]['width']) | |
| print(f"Number of image patches: {n_patches}") | |
| # # Generate the similarity maps | |
| batched_similarity_maps = get_similarity_maps_from_embeddings( | |
| image_embeddings=image_embeddings, | |
| query_embeddings=query_embeddings, | |
| n_patches=n_patches, | |
| image_mask = image_mask | |
| ) | |
| # # Get the similarity map for our (only) input image | |
| similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y) | |
| query_tokens_from_model = query_tokens[0]['tokens'] | |
| plots = plot_all_similarity_maps( | |
| image=image, | |
| query_tokens=query_tokens_from_model, | |
| similarity_maps=similarity_maps, | |
| figsize=(8, 8), | |
| show_colorbar=False, | |
| add_title=True, | |
| ) | |
| map_images = [] | |
| for idx, (fig, ax) in enumerate(plots): | |
| if(idx<3): | |
| continue | |
| savepath = "/home/user/app/similarity_maps/similarity_map_"+(img.split("/"))[-1]+"_token_"+str(idx)+"_"+query_tokens_from_model[idx]+".png" | |
| fig.savefig(savepath, bbox_inches="tight") | |
| map_images.append({'file':savepath}) | |
| print(f"Similarity map for token `{query_tokens_from_model[idx]}` saved at `{savepath}`") | |
| plt.close("all") | |
| return map_images | |
| def colpali_search_rerank(query): | |
| if(st.session_state.show_columns == True): | |
| print("show columns activated------------------------") | |
| st.session_state.maxSimImages = img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens) | |
| st.session_state.show_columns = False | |
| return_val = {'text':st.session_state.answers_[0]['answer'],'source':st.session_state.answers_[0]['source'],'image':st.session_state.maxSimImages,'table':[]} | |
| st.session_state.input_query = st.session_state.questions_[-1]["question"] | |
| st.session_state.answers_.pop() | |
| st.session_state.questions_.pop() | |
| return return_val | |
| # Convert to JSON string | |
| payload = { | |
| "queries": [query] | |
| } | |
| body = json.dumps(payload) | |
| # Call the endpoint | |
| response = runtime.invoke_endpoint( | |
| EndpointName=endpoint_name, | |
| ContentType="application/json", | |
| Body=body | |
| ) | |
| # Read and print the response | |
| result = json.loads(response["Body"].read().decode()) | |
| #print(len(result['query_embeddings'][0])) | |
| final_docs_sorted_20 = [] | |
| for i in result['query_embeddings']: | |
| batch_embeddings = i | |
| a = np.array(batch_embeddings) | |
| vec = a.mean(axis=0) | |
| #print(vec) | |
| hits = [] | |
| #for v in batch_embeddings: | |
| query_ = { | |
| "size": 200, | |
| "query": { | |
| "nested": { | |
| "path": "page_sub_vectors", | |
| "query": { | |
| "knn": { | |
| "page_sub_vectors.page_sub_vector": { | |
| "vector": vec.tolist(), | |
| "k": 200 | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| response = aos_client.search( | |
| body = query_, | |
| index = 'colpali-vs-reindex' | |
| ) | |
| #print(response) | |
| query_token_vectors = batch_embeddings | |
| final_docs = [] | |
| hits += response['hits']['hits'] | |
| #print(len(hits)) | |
| for ind,j in enumerate(hits): | |
| max_score_dict_list = [] | |
| doc={"id":j["_id"],"score":j["_score"],"image":j["_source"]["image"]} | |
| with_s = j['_source']['page_sub_vectors'] | |
| add_score = 0 | |
| for index,i in enumerate(query_token_vectors): | |
| query_token_vector = np.array(i) | |
| scores = [] | |
| for m in with_s: | |
| doc_token_vector = np.array(m['page_sub_vector']) | |
| score = np.dot(query_token_vector,doc_token_vector) | |
| scores.append(score) | |
| scores.sort(reverse=True) | |
| max_score = scores[0] | |
| add_score+=max_score | |
| doc["total_score"] = add_score | |
| final_docs.append(doc) | |
| final_docs_sorted = sorted(final_docs, key=lambda d: d['total_score'], reverse=True) | |
| final_docs_sorted_20.append(final_docs_sorted[:20]) | |
| img = "/home/user/app/vs/"+final_docs_sorted_20[0][0]['image'] | |
| ans = generate_ans(img,query) | |
| images_highlighted = [{'file':img}] | |
| st.session_state.top_img = img | |
| st.session_state.query_token_vectors = query_token_vectors | |
| st.session_state.query_tokens = result['query_tokens'] | |
| return {'text':ans,'source':img,'image':images_highlighted,'table':[]}#[{'file':img}] | |