import torch import gradio as gr from langchain.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import Chroma import openai import time import logging from datasets import load_dataset from nltk.tokenize import sent_tokenize import nltk from langchain.docstore.document import Document from tqdm import tqdm import os # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Download NLTK data nltk.download('punkt') nltk.download('punkt_tab') nltk.download('averaged_perceptron_tagger') nltk.download('stopwords') # Initialize OpenAI API key openai.api_key = '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' # Load selected datasets logger.info("Starting dataset loading...") ragbench = {} datasets_to_load = ['covidqa', 'hotpotqa', 'pubmedqa'] for dataset in datasets_to_load: try: ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset, split='train') logger.info(f"Successfully loaded {dataset}") except Exception as e: logger.error(f"Failed to load {dataset}: {e}") continue print(f"Loaded {len(ragbench)} datasets successfully") # Initialize embedding model model_name = 'sentence-transformers/all-mpnet-base-v2' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") embedding_model = HuggingFaceEmbeddings(model_name=model_name) embedding_model.client.to(device) def chunk_documents_semantic(documents, max_chunk_size=500): chunks = [] for doc in documents: if isinstance(doc, list): for passage in doc: sentences = sent_tokenize(passage) current_chunk = "" for sentence in sentences: if len(current_chunk) + len(sentence) <= max_chunk_size: current_chunk += sentence + " " else: chunks.append(current_chunk.strip()) current_chunk = sentence + " " if current_chunk: chunks.append(current_chunk.strip()) else: sentences = sent_tokenize(doc) current_chunk = "" for sentence in sentences: if len(current_chunk) + len(sentence) <= max_chunk_size: current_chunk += sentence + " " else: chunks.append(current_chunk.strip()) current_chunk = sentence + " " if current_chunk: chunks.append(current_chunk.strip()) return chunks # Process documents documents = [] for dataset_name, dataset in ragbench.items(): logger.info(f"Processing {dataset_name}") original_documents = dataset['documents'] chunked_documents = chunk_documents_semantic(original_documents) documents.extend([Document(page_content=chunk) for chunk in chunked_documents]) logger.info(f"Processed {len(chunked_documents)} chunks from {dataset_name}") # Initialize vectordb vectordb = Chroma.from_documents( documents=documents, embedding=embedding_model, persist_directory='./docs/chroma/' ) vectordb.persist() def process_query(query, dataset_choice): try: logger.info(f"Processing query for {dataset_choice}: {query}") relevant_docs = vectordb.max_marginal_relevance_search( query, k=5, fetch_k=10 ) context = " ".join([doc.page_content for doc in relevant_docs]) response = openai.chat.completions.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are a specialized assistant for the RagBench dataset. Provide precise answers based solely on the given context."}, {"role": "user", "content": f"Dataset: {dataset_choice}\nContext: {context}\nQuestion: {query}\n\nProvide a detailed answer using only the information from the context above."} ], max_tokens=300, temperature=0.7, ) return response.choices[0].message.content.strip() except Exception as e: logger.error(f"Error processing query: {str(e)}") return f"Error: {str(e)}" # Create Gradio interface demo = gr.Interface( fn=process_query, inputs=[ gr.Textbox(label="Question", placeholder="Type your question here...", lines=2), gr.Dropdown( choices=list(ragbench.keys()), label="Select Dataset", value="hotpotqa" ) ], outputs=gr.Textbox(label="Answer", lines=5), title="RagBench Question Answering System", description="Ask questions across different RagBench datasets", examples=[ ["What role does T-cell count play in severe human adenovirus type 55 (HAdV-55) infection?", "covidqa"], ["In what school district is Governor John R. Rogers High School located?", "hotpotqa"], ["Is there a functional neural correlate of individual differences in cardiovascular reactivity?", "pubmedqa"] ] ) if __name__ == "__main__": demo.launch(debug=True)