Spaces:
Sleeping
Sleeping
| from datasets import load_from_disk | |
| from transformers import AutoTokenizer, AutoModel | |
| import faiss | |
| import numpy as np | |
| import torch | |
| from datasets import load_from_disk | |
| import faiss | |
| import numpy as np | |
| import os | |
| from datasets import load_dataset, Dataset, get_dataset_config_names | |
| from sentence_transformers import SentenceTransformer | |
| from groq import Groq | |
| from sentence_transformers import CrossEncoder | |
| import requests | |
| import uuid | |
| import re | |
| import json | |
| import gradio as gr | |
| import io | |
| import sys | |
| import traceback | |
| embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| # Preload datasets and indices | |
| hf_dataset_cs = load_from_disk("cs_dataset") | |
| faiss_index_cs = faiss.read_index("cs_index/faiss.index") | |
| hf_dataset_med = load_from_disk("med_dataset") | |
| faiss_index_med = faiss.read_index("med_index/faiss.index") | |
| hf_dataset_gk = load_from_disk("gk_dataset") | |
| faiss_index_gk = faiss.read_index("gk_index/faiss.index") | |
| hf_dataset_fin = load_from_disk("fin_dataset") | |
| faiss_index_fin = faiss.read_index("fin_index/faiss.index") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(device) | |
| legal_dataset = load_dataset("rungalileo/ragbench", "cuad", split="test") | |
| med_dataset = load_dataset("rungalileo/ragbench", "pubmedqa", split="test") | |
| gk_dataset = load_dataset("rungalileo/ragbench", "hotpotqa", split="test") | |
| cs_dataset = load_dataset("rungalileo/ragbench", "emanual", split="test") | |
| fin_dataset = load_dataset("rungalileo/ragbench", "finqa", split="test") | |
| # Load BGE reranker | |
| reranker = CrossEncoder("BAAI/bge-reranker-base", max_length=512) | |
| embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| model_name = "nlpaueb/legal-bert-base-uncased" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name).to(device) | |
| model.eval() | |
| def retrieve_top_k(query,domain='legal', model_name='nlpaueb/legal-bert-base-uncased', k=8): | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name).to(device) | |
| model.eval() | |
| #print(f"In retrive_top_k Query:{query}") | |
| # Tokenize and embed query using mean pooling | |
| inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
| # Load FAISS index and dataset | |
| index_path = f"legal_index/faiss.index" | |
| dataset_path = f"legal_dataset" | |
| faiss_index = faiss.read_index(index_path) | |
| dataset = load_from_disk(dataset_path) | |
| # Perform FAISS search | |
| D, I = faiss_index.search(query_embedding.astype('float32'), k) | |
| # Retrieve top-k matching chunks | |
| top_chunks = [dataset[int(idx)]['text'] for idx in I[0]] | |
| return top_chunks | |
| # Retrieval function using preloaded objects | |
| def retrieve_top_c(query, domain, embedder, k=5): | |
| if domain == "CS": | |
| hf_dataset = hf_dataset_cs | |
| faiss_index = faiss_index_cs | |
| elif domain == "Medical": | |
| hf_dataset = hf_dataset_med | |
| faiss_index = faiss_index_med | |
| elif domain == "GK": | |
| hf_dataset = hf_dataset_gk | |
| faiss_index = faiss_index_gk | |
| elif domain == "Finance": | |
| hf_dataset = hf_dataset_fin | |
| faiss_index = faiss_index_fin | |
| else: | |
| raise ValueError(f"Unknown domain: {domain}") | |
| # Encode query and search | |
| query_embedding = embedder.encode([query]).astype('float32') | |
| #query_embedding = embedder.encode([query], convert_to_numpy=True).astype('float32') | |
| distances, indices = faiss_index.search(query_embedding, k) | |
| return [hf_dataset[int(i)]["text"] for i in indices[0]] | |
| client = Groq( | |
| api_key= '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',#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', | |
| ) | |
| def rerank_documents_bge(query, documents, top_n=5, return_scores=False): | |
| """ | |
| Rerank documents using BAAI/bge-reranker-base CrossEncoder. | |
| Args: | |
| query (str): The query string. | |
| documents (List[str]): List of candidate documents. | |
| top_n (int): Number of top results to return. | |
| return_scores (bool): Whether to return scores along with documents. | |
| Returns: | |
| List[str] or List[Tuple[str, float]] | |
| """ | |
| if not documents: | |
| return [] | |
| # Prepare (query, doc) pairs | |
| pairs = [(query, doc) for doc in documents] | |
| # Predict relevance scores | |
| scores = reranker.predict(pairs, batch_size=16) | |
| # Sort by score descending | |
| reranked = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True) | |
| if return_scores: | |
| return reranked[:top_n] | |
| else: | |
| return [doc for doc, _ in reranked[:top_n]] | |
| def generate_response_rag(query,domain): | |
| # Step 1: Retrieve top-k context chunks using your FAISS setup | |
| if domain == "Legal": | |
| top_chunks = retrieve_top_k(query,'Legal', model_name) | |
| else: | |
| top_chunks = retrieve_top_c(query, domain,embedder) | |
| # Step 2: Rerank retrieved documents using cross-encoder | |
| #reranked_chunks = rerank_documents(query, top_chunks, top_n=15) | |
| #rerank_and_filter_chunks = filter_by_faithfulness(query, reranked_chunks) | |
| #print("Retrieved Top chunks",top_chunks) | |
| #reranked_chunks = rerank_and_filter_chunks | |
| reranked_chunks_bge = rerank_documents_bge(query, top_chunks, top_n=5) | |
| #sum_context = summarize_context("\n\n".join(reranked_chunks_bge)) | |
| final_context = reranked_chunks_bge | |
| # Step 2: Prepare context and RAG-style prompt | |
| context = "\n\n".join(final_context) | |
| #print(f"Context:{context}") | |
| prompt = f"""You are a helpful legal assistant. | |
| Use the following context to answer the question. | |
| Using only the information from the retrieved context, answer the following question. If the answer cannot be derived, say "I don't know." Always have answer with prefix **Answer:** | |
| Context:{context} | |
| Question: {query} | |
| Answer:""" | |
| # Step 3: Call the LLM (LLaMA3 or any chat model) | |
| chat_completion = client.chat.completions.create( | |
| messages=[ | |
| {"role": "user", "content": prompt} | |
| ], | |
| model="llama3-70b-8192",#"gemma2-9b-it"#"qwen/qwen3-32b"#deepseek-r1-distill-llama-70b",#"llama3-70b-8192", # mistral-saba-24b | |
| temperature=0.0 | |
| ) | |
| return context,chat_completion.choices[0].message.content.strip() | |
| '''response = openai.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0.0, | |
| max_tokens=1024 | |
| ) | |
| return response.choices[0].message.content''' | |
| #JUDGE LLM | |
| def split_into_keyed_sentences(text, prefix): | |
| """Splits text into sentences with keys like '0a.', '0b.', or 'a.', 'b.', etc.""" | |
| # Basic sentence tokenizer with keys | |
| sentences = re.split(r'(?<=[.?!])\s+', text.strip()) | |
| keyed = {} | |
| for i, s in enumerate(sentences): | |
| key = f"{prefix}{chr(97 + i)}" # 'a', 'b', ... | |
| if s: | |
| keyed[key] = s.strip() | |
| return keyed | |
| def jugde_response_rag(query, domain): | |
| #top_chunks = retrieve_top_k(query) | |
| #top_chunks = [chunk[0] if isinstance(chunk, tuple) else chunk for chunk in top_chunks] | |
| # Step 2: Prepare context and RAG-style prompt | |
| #context = "\n\n".join(top_chunks) | |
| # Split context and dummy answer into keyed sentences | |
| #document_keys = split_into_keyed_sentences(context, "0") | |
| #print(f"Query:{query}\n====================================================================") | |
| context,response = generate_response_rag(query,domain) #deepseek-r1-distill-llama-70b llama3-70b-8192 | |
| # Split context and dummy answer into keyed sentences | |
| document_keys = split_into_keyed_sentences(context, "0") | |
| #print(f"\n====================================\Generator Response:{response}") | |
| #For deepseek | |
| #print("Before Curated:",response) | |
| response=response[response.find("**Answer"):].replace("**Answer",""); | |
| print(f"Response for Generator LLM:{response}") | |
| response_keys = split_into_keyed_sentences(response, "") | |
| # Rebuild sections for prompt | |
| documents_formatted = "\n".join([f"{k}. {v}" for k, v in document_keys.items()]) | |
| response_formatted = "\n".join([f"{k}. {v}" for k, v in response_keys.items()]) | |
| '''print(f"\n====================================================================") | |
| print(f"documents_formatted:{documents_formatted}") | |
| print(f"\n====================================================================") | |
| print(f"response_formatted:{response_formatted}") | |
| print(f"\n====================================================================")''' | |
| prompt = f"""I asked someone to answer a question based on one or more documents. | |
| Your task is to review their response and assess whether or not each sentence | |
| in that response is supported by text in the documents. And if so, which | |
| sentences in the documents provide that support. You will also tell me which | |
| of the documents contain useful information for answering the question, and | |
| which of the documents the answer was sourced from. | |
| Here are the documents, each of which is split into sentences. Alongside each | |
| sentence is associated key, such as ’0a.’ or ’0b.’ that you can use to refer | |
| to it: | |
| ''' | |
| {documents_formatted} | |
| ''' | |
| The question was: | |
| ''' | |
| {query} | |
| ''' | |
| Here is their response, split into sentences. Alongside each sentence is | |
| associated key, such as ’a.’ or ’b.’ that you can use to refer to it. Note | |
| that these keys are unique to the response, and are not related to the keys | |
| in the documents: | |
| ''' | |
| {response_formatted} | |
| ''' | |
| You must respond with a JSON object matching this schema: | |
| ''' | |
| {{ | |
| "relevance_explanation": string, | |
| "all_relevant_sentence_keys": [string], | |
| "overall_supported_explanation": string, | |
| "overall_supported": boolean, | |
| "sentence_support_information": [ | |
| {{ | |
| "response_sentence_key": string, | |
| "explanation": string, | |
| "supporting_sentence_keys": [string], | |
| "fully_supported": boolean | |
| }}, | |
| ], | |
| "all_utilized_sentence_keys": [string] | |
| }} | |
| ''' | |
| The relevance_explanation field is a string explaining which documents | |
| contain useful information for answering the question. Provide a step-by-step | |
| breakdown of information provided in the documents and how it is useful for | |
| answering the question. | |
| The all_relevant_sentence_keys field is a list of all document sentences keys | |
| (e.g. ’0a’) that are revant to the question. Include every sentence that is | |
| useful and relevant to the question, even if it was not used in the response, | |
| or if only parts of the sentence are useful. Ignore the provided response when | |
| making this judgement and base your judgement solely on the provided documents | |
| and question. Omit sentences that, if removed from the document, would not | |
| impact someone’s ability to answer the question. | |
| The overall_supported_explanation field is a string explaining why the response | |
| *as a whole* is or is not supported by the documents. In this field, provide a | |
| step-by-step breakdown of the claims made in the response and the support (or | |
| lack thereof) for those claims in the documents. Begin by assessing each claim | |
| separately, one by one; don’t make any remarks about the response as a whole | |
| until you have assessed all the claims in isolation. | |
| The overall_supported field is a boolean indicating whether the response as a | |
| whole is supported by the documents. This value should reflect the conclusion | |
| you drew at the end of your step-by-step breakdown in overall_supported_explanation. | |
| In the sentence_support_information field, provide information about the support | |
| *for each sentence* in the response. | |
| The sentence_support_information field is a list of objects, one for each sentence | |
| in the response. Each object MUST have the following fields: | |
| - response_sentence_key: a string identifying the sentence in the response. | |
| This key is the same as the one used in the response above. | |
| - explanation: a string explaining why the sentence is or is not supported by the | |
| documents. | |
| - supporting_sentence_keys: keys (e.g. ’0a’) of sentences from the documents that | |
| support the response sentence. If the sentence is not supported, this list MUST | |
| be empty. If the sentence is supported, this list MUST contain one or more keys. | |
| In special cases where the sentence is supported, but not by any specific sentence, | |
| you can use the string "supported_without_sentence" to indicate that the sentence | |
| is generally supported by the documents. Consider cases where the sentence is | |
| expressing inability to answer the question due to lack of relevant information in | |
| the provided contex as "supported_without_sentence". In cases where the sentence | |
| is making a general statement (e.g. outlining the steps to produce an answer, or | |
| summarizing previously stated sentences, or a transition sentence), use the | |
| sting "general".In cases where the sentence is correctly stating a well-known fact, | |
| like a mathematical formula, use the string "well_known_fact". In cases where the | |
| sentence is performing numerical reasoning (e.g. addition, multiplication), use | |
| the string "numerical_reasoning". | |
| - fully_supported: a boolean indicating whether the sentence is fully supported by | |
| the documents. | |
| - This value should reflect the conclusion you drew at the end of your step-by-step | |
| breakdown in explanation. | |
| - If supporting_sentence_keys is an empty list, then fully_supported must be false. | |
| 17 | |
| - Otherwise, use fully_supported to clarify whether everything in the response | |
| sentence is fully supported by the document text indicated in supporting_sentence_keys | |
| (fully_supported = true), or whether the sentence is only partially or incompletely | |
| supported by that document text (fully_supported = false). | |
| The all_utilized_sentence_keys field is a list of all sentences keys (e.g. ’0a’) that | |
| were used to construct the answer. Include every sentence that either directly supported | |
| the answer, or was implicitly used to construct the answer, even if it was not used | |
| in its entirety. Omit sentences that were not used, and could have been removed from | |
| the documents without affecting the answer. | |
| You must respond with a valid JSON string. Use escapes for quotes, e.g. ‘\\"‘, and | |
| newlines, e.g. ‘\\n‘. Do not write anything before or after the JSON string. Do not | |
| wrap the JSON string in backticks like ‘‘‘ or ‘‘‘json. | |
| As a reminder: your task is to review the response and assess which documents contain | |
| useful information pertaining to the question, and how each sentence in the response | |
| is supported by the text in the documents.\ | |
| """ | |
| # Step 3: Call the LLM | |
| chat_completion = client.chat.completions.create( | |
| messages=[ | |
| {"role": "user", "content": prompt} | |
| ], | |
| model="meta-llama/llama-4-maverick-17b-128e-instruct", #deepseek-r1-distill-llama-70b llama3-70b-8192 meta-llama/llama-4-maverick-17b-128e-instruct | |
| ) | |
| return documents_formatted,chat_completion.choices[0].message.content.strip() | |
| '''chat_completion = openai.chat.completions.create( | |
| messages=[ | |
| {"role":"user", | |
| "content":prompt} | |
| ], | |
| model="gpt-4o", | |
| max_tokens=1024, | |
| ) | |
| return documents_formatted,chat_completion.choices[0].message.content''' | |
| def extract_retrieved_sentence_keys(document_text: str) -> list[str]: | |
| """ | |
| Extracts sentence keys like '0a.', '0b.', etc. from a formatted document string. | |
| Parameters: | |
| - document_text (str): full text of document with sentence keys | |
| Returns: | |
| - List of unique sentence keys in the order they appear | |
| """ | |
| # Match pattern like 0a., 0b., 0z., 0{., 0|., etc. | |
| pattern = r'\b0[\w\{\|\}~]\.' | |
| matches = re.findall(pattern, document_text) | |
| return list(dict.fromkeys(matches)) # Removes duplicates while preserving order | |
| def compute_ragbench_metrics(judge_response: dict, retrieved_sentence_keys: list[str]) -> dict: | |
| """ | |
| Computes RAGBench-style metrics from Judge LLM response. | |
| Parameters: | |
| - judge_response (dict): JSON response from Judge LLM | |
| - retrieved_sentence_keys (list of str): all sentence keys from the retrieved documents | |
| Returns: | |
| - Dictionary with Context Relevance, Context Utilization, Completeness, and Adherence | |
| """ | |
| R = set(judge_response.get("all_relevant_sentence_keys", [])) # Relevant sentences | |
| U = set(judge_response.get("all_utilized_sentence_keys", [])) # Utilized sentences | |
| intersection_RU = R & U | |
| total_retrieved = len(retrieved_sentence_keys) | |
| len_R = len(R) | |
| len_U = len(U) | |
| len_intersection = len(intersection_RU) | |
| # Context Relevance: fraction of retrieved context that is relevant | |
| context_relevance = len_R / total_retrieved if total_retrieved else 0.0 | |
| # Context Utilization: fraction of retrieved context that was used | |
| context_utilization = len_U / total_retrieved if total_retrieved else 0.0 | |
| # Completeness: fraction of relevant content that was used | |
| completeness = len_intersection / len_R if len_R else 0.0 | |
| # Adherence: 1 if all response sentences are fully supported, else 0 | |
| is_fully_supported = all(s.get("fully_supported", False) | |
| for s in judge_response.get("sentence_support_information", [])) | |
| adherence = 1.0 if is_fully_supported and judge_response.get("overall_supported", False) else 0.0 | |
| return { | |
| "Context Relevance": round(context_relevance, 4), | |
| "Context Utilization": round(context_utilization, 4), | |
| "Completeness": round(completeness, 4), | |
| "Adherence": adherence | |
| } | |
| def evaluate_rag_pipeline(domain, q_indices): | |
| import torch | |
| import numpy as np | |
| from sklearn.metrics import mean_squared_error, roc_auc_score | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def safe_append(gt_list, pred_list, gt_val, pred_val): | |
| if gt_val is not None and pred_val is not None: | |
| gt_list.append(gt_val) | |
| pred_list.append(pred_val) | |
| def clean_and_parse_json_block(text): | |
| # Strip markdown-style code block if present | |
| #text = text.strip().strip("`").strip() | |
| code_block_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", text) | |
| if code_block_match: | |
| text = code_block_match.group(1).strip() | |
| # Remove invalid/control characters that break decoding | |
| text = re.sub(r"[^\x20-\x7E\n\t]", "", text) | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError as e: | |
| print("❌ JSON Decode Error:", e) | |
| print("⚠️ Cleaned text:\n", text) | |
| raise | |
| gt_relevance, pred_relevance = [], [] | |
| gt_utilization, pred_utilization = [], [] | |
| gt_completeness, pred_completeness = [], [] | |
| gt_adherence, pred_adherence = [], [] | |
| if(domain=="Legal"): | |
| dataset = legal_dataset | |
| elif(domain=="Medical"): | |
| dataset = med_dataset | |
| elif(domain=="GK"): | |
| dataset = gk_dataset | |
| elif(domain=="CS"): | |
| dataset = cs_dataset | |
| elif(domain=="Finance"): | |
| dataset = fin_dataset | |
| for i in q_indices: | |
| query = dataset[i]['question'] | |
| print(f"\n\n\nQuery:{i}.{query}\n====================================================================") | |
| #print(f"\ndomain:{domain}====================================================================") | |
| documents_formatted, response = jugde_response_rag(query, domain) | |
| judge_response = clean_and_parse_json_block(response) | |
| print(f"\ndocuments_formatted:{documents_formatted}") | |
| print(f"\n======================================================================\nResponse:{judge_response}") | |
| retrieved_sentences = extract_retrieved_sentence_keys(documents_formatted) | |
| predicted = compute_ragbench_metrics(judge_response, retrieved_sentences) | |
| # GT values | |
| gt_r = dataset[i].get('relevance_score') | |
| gt_u = dataset[i].get('utilization_score') | |
| gt_c = dataset[i].get('completeness_score') | |
| gt_a = dataset[i].get('gpt3_adherence') | |
| safe_append(gt_relevance, pred_relevance, gt_r, predicted['Context Relevance']) | |
| safe_append(gt_utilization, pred_utilization, gt_u, predicted['Context Utilization']) | |
| safe_append(gt_completeness, pred_completeness, gt_c, predicted['Completeness']) | |
| if gt_a is not None and predicted['Adherence'] is not None: | |
| safe_append(gt_adherence, pred_adherence, int(gt_a), int(predicted['Adherence'])) | |
| def compute_rmse(gt, pred): | |
| return round(np.sqrt(np.mean((np.array(gt) - np.array(pred)) ** 2)), 4) | |
| result = { | |
| "Context Relevance": compute_rmse(gt_relevance, pred_relevance), | |
| "Context Utilization": compute_rmse(gt_utilization, pred_utilization), | |
| "Completeness": compute_rmse(gt_completeness, pred_completeness), | |
| } | |
| if len(set(gt_adherence)) == 2: | |
| result["Adherence"] = compute_rmse(gt_adherence, pred_adherence) | |
| result["AUC-ROC (Adherence)"] = round(roc_auc_score(gt_adherence, pred_adherence), 4) | |
| else: | |
| result["Adherence"] = compute_rmse(gt_adherence, pred_adherence) | |
| result["AUC-ROC (Adherence)"] = "N/A - one class only" | |
| return result | |
| # Updated wrapper | |
| def evaluate_rag_gradio(domain, q_indices_str): | |
| # Capture logs | |
| log_stream = io.StringIO() | |
| sys.stdout = log_stream | |
| try: | |
| # Parse comma-separated indices | |
| q_indices = [int(x.strip()) for x in q_indices_str.split(",") if x.strip().isdigit()] | |
| results = evaluate_rag_pipeline(domain, q_indices) | |
| logs = log_stream.getvalue() | |
| return results, logs | |
| except Exception as e: | |
| traceback.print_exc() | |
| return {"error": str(e)}, log_stream.getvalue() | |
| finally: | |
| sys.stdout = sys.__stdout__ # Restore stdout | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=evaluate_rag_gradio, | |
| inputs=[ | |
| gr.Dropdown(choices=["Legal", "Medical", "GK", "CS", "Finance"], label="Domain"), | |
| gr.Textbox(label="Comma-separated Query Indices (e.g. 89,121,245)", lines=1), | |
| ], | |
| outputs=[ | |
| gr.JSON(label="Evaluation Metrics (RMSE & AUC-ROC)"), | |
| gr.Textbox(label="Execution Log", lines=10, interactive=True), | |
| ], | |
| title="RAG Evaluation Dashboard", | |
| description="Evaluate your RAG pipeline across selected queries using GPT-based generation and judgment." | |
| ) | |
| # Launch app | |
| iface.launch(server_name="0.0.0.0", server_port=7860, debug=True) | |