Roberta2024's picture
Update app.py
2c46588 verified
raw
history blame
5.74 kB
import os
import gradio as gr
from langchain_core.prompts import PromptTemplate
from langchain_community.document_loaders import PyPDFLoader
from langchain_google_genai import ChatGoogleGenerativeAI
import google.generativeai as genai
from langchain.chains.question_answering import load_qa_chain
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Configure Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
# Load Mistral model
model_path = "nvidia/Mistral-NeMo-Minitron-8B-Base"
mistral_tokenizer = AutoTokenizer.from_pretrained(model_path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16
# Improved model loading with error handling
try:
mistral_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
device_map=device
)
print(f"Mistral model loaded successfully on {device}")
except Exception as e:
print(f"Error loading Mistral model: {str(e)}")
mistral_model = None
def initialize(file_path, question):
try:
# Check if API key is set
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
return "Error: GOOGLE_API_KEY environment variable is not set."
model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3)
prompt_template = """Answer the question as precise as possible using the provided context. If the answer is
not contained in the context, say "answer not available in context" \n\n
Context: \n {context}?\n
Question: \n {question} \n
Answer:
"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
if os.path.exists(file_path):
# Load and process PDF
pdf_loader = PyPDFLoader(file_path)
pages = pdf_loader.load_and_split()
if not pages:
return "Error: The PDF file appears to be empty or could not be processed."
context = "\n".join(str(page.page_content) for page in pages[:30])
# Generate Gemini answer
stuff_chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
stuff_answer = stuff_chain(
{"input_documents": pages, "question": question, "context": context},
return_only_outputs=True
)
gemini_answer = stuff_answer['output_text']
# Use Mistral model for additional text generation
if mistral_model is not None:
mistral_prompt = f"Based on this answer: {gemini_answer}\nGenerate a follow-up question:"
mistral_inputs = mistral_tokenizer.encode(mistral_prompt, return_tensors='pt').to(device)
with torch.no_grad():
mistral_outputs = mistral_model.generate(
mistral_inputs,
max_length=200, # Increased max length
min_length=20, # Set min length
do_sample=True, # Enable sampling
top_p=0.95, # Top-p sampling
temperature=0.7 # Temperature for creativity
)
mistral_output = mistral_tokenizer.decode(mistral_outputs[0], skip_special_tokens=True)
# Clean up the output to get just the follow-up question
if "Generate a follow-up question:" in mistral_output:
mistral_output = mistral_output.split("Generate a follow-up question:")[1].strip()
combined_output = f"Gemini Answer: {gemini_answer}\n\nMistral Follow-up: {mistral_output}"
else:
combined_output = f"Gemini Answer: {gemini_answer}\n\n(Mistral model unavailable)"
return combined_output
else:
return f"Error: File not found at path '{file_path}'. Please ensure the PDF file is valid."
except Exception as e:
import traceback
error_details = traceback.format_exc()
return f"An error occurred: {str(e)}\n\nDetails: {error_details}"
# Define Gradio Interface with improved error handling
def pdf_qa(file, question):
if file is None:
return "Please upload a PDF file first."
if not question or question.strip() == "":
return "Please enter a question about the document."
try:
return initialize(file.name, question)
except Exception as e:
import traceback
error_details = traceback.format_exc()
return f"Error processing request: {str(e)}\n\nDetails: {error_details}"
# Create Gradio Interface with additional options
demo = gr.Interface(
fn=pdf_qa,
inputs=[
gr.File(label="Upload PDF File", file_types=[".pdf"]),
gr.Textbox(label="Ask about the document", placeholder="What is the main topic of this document?")
],
outputs=gr.Textbox(label="Answer - Combined Gemini and Mistral"),
title="RAG Knowledge Retrieval using Gemini API and Mistral Model",
description="Upload a PDF file and ask questions about the content. The system uses Gemini for answering and Mistral for generating follow-up questions.",
examples=[
[None, "What are the main findings in this document?"],
[None, "Summarize the key points discussed in this paper."]
],
allow_flagging="never"
)
# Launch the app with additional parameters
if __name__ == "__main__":
demo.launch(share=True, debug=True)