pdfassistant / app.py
priyanshu23456's picture
Update app.py
0460eee verified
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
from flask_cors import CORS
import os
import torch
import fitz # PyMuPDF
import pytesseract
from pdf2image import convert_from_path
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import tempfile
from PIL import Image
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Fix caching issue on Hugging Face Spaces
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
os.environ["HF_HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp"
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
UPLOAD_FOLDER = "/tmp/uploads"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
# Global model variables
embedder = None
qa_pipeline = None
tokenizer = None
model = None
# Initialize models once on startup
def initialize_models():
global embedder, qa_pipeline, tokenizer, model
try:
logger.info("Loading SentenceTransformer model...")
embedder = SentenceTransformer("all-MiniLM-L6-v2")
logger.info("Loading QA pipeline...")
qa_pipeline = pipeline(
"question-answering",
model="distilbert-base-cased-distilled-squad",
tokenizer="distilbert-base-cased",
device=-1 # Force CPU
)
logger.info("Loading language model...")
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # Use float16 for lower memory on CPU
device_map="cpu", # Explicitly set to CPU
low_cpu_mem_usage=True # Optimize memory loading
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
logger.info("Models initialized successfully")
except Exception as e:
logger.error(f"Error initializing models: {str(e)}")
raise
# Generation-based answering
def answer_with_generation(index, embeddings, chunks, question):
try:
logger.info(f"Answering with generation model: '{question}'")
global tokenizer, model
if tokenizer is None or model is None:
logger.info("Generation models not initialized, creating now...")
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="cpu",
low_cpu_mem_usage=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
# Get embeddings for question
q_embedding = embedder.encode([question])
# Find relevant chunks
_, top_k_indices = index.search(q_embedding, k=3)
relevant_chunks = [chunks[i] for i in top_k_indices[0]]
context = " ".join(relevant_chunks)
# Limit context size
if len(context) > 2000:
context = context[:2000]
# Create prompt
prompt = f"""<|im_start|>system
You are a helpful assistant answering questions based on provided PDF content. Use the information below to give a clear, concise, and accurate answer. Avoid speculation and focus on the context.
<|im_end|>
<|im_start|>user
**Context**: {context}
**Question**: {question}
**Instruction**: Provide a detailed and accurate answer based on the context. If the context doesn't contain enough information, say so clearly. <|im_end|>"""
# Handle inputs
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
# Move inputs to CPU
inputs = {k: v.to('cpu') for k, v in inputs.items()}
# Generate answer
output = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
do_sample=True,
num_beams=2,
no_repeat_ngram_size=2
)
# Decode and format answer
answer = tokenizer.decode(output[0], skip_special_tokens=True)
if "<|im_end|>" in answer:
answer = answer.split("<|im_end|>")[1].strip()
elif "Instruction" in answer:
answer = answer.split("Instruction")[1].strip()
logger.info(f"Generation answer: '{answer[:50]}...' (length: {len(answer)})")
return answer.strip()
except Exception as e:
logger.error(f"Generation error: {str(e)}")
return "I couldn't generate a good answer based on the PDF content."
# Cleanup function for temporary files
def cleanup_temp_files(filepath):
try:
if os.path.exists(filepath):
os.remove(filepath)
logger.info(f"Removed temporary file: {filepath}")
except Exception as e:
logger.warning(f"Failed to clean up file {filepath}: {str(e)}")
# Improved OCR function
def ocr_pdf(pdf_path):
try:
logger.info(f"Starting OCR for {pdf_path}")
# Use a higher DPI for better quality
images = convert_from_path(
pdf_path,
dpi=300, # Higher DPI for better quality
grayscale=False, # Color might help with some PDFs
thread_count=2, # Use multiple threads
use_pdftocairo=True # pdftocairo often gives better results
)
text = ""
for i, img in enumerate(images):
logger.info(f"Processing page {i+1} of {len(images)}")
# Preprocess the image for better OCR results
preprocessed = preprocess_image_for_ocr(img)
# Use tesseract with more options
page_text = pytesseract.image_to_string(
preprocessed,
config='--psm 1 --oem 3 -l eng' # Page segmentation mode 1 (auto), OCR Engine mode 3 (default)
)
text += page_text
logger.info(f"OCR completed with {len(text)} characters extracted")
return text
except Exception as e:
logger.error(f"OCR error: {str(e)}")
return ""
# Image preprocessing function for better OCR
def preprocess_image_for_ocr(img):
# Convert to grayscale
gray = img.convert('L')
# Optional: You could add more preprocessing here like:
# - Thresholding
# - Noise removal
# - Contrast enhancement
return gray
# Improved extract_text function with better text detection
def extract_text(pdf_path):
try:
logger.info(f"Extracting text from {pdf_path}")
doc = fitz.open(pdf_path)
text = ""
for page_num, page in enumerate(doc):
page_text = page.get_text()
text += page_text
logger.info(f"Extracted {len(page_text)} characters from page {page_num+1}")
# Check if the text is meaningful (more sophisticated check)
words = text.split()
unique_words = set(word.lower() for word in words if len(word) > 2)
logger.info(f"PDF text extraction: {len(text)} chars, {len(words)} words, {len(unique_words)} unique words")
# If we don't have enough meaningful text, try OCR
if len(unique_words) < 20 or len(text.strip()) < 100:
logger.info("Text extraction yielded insufficient results, trying OCR...")
ocr_text = ocr_pdf(pdf_path)
# If OCR gave us more text, use it
if len(ocr_text.strip()) > len(text.strip()):
logger.info(f"Using OCR result: {len(ocr_text)} chars (better than {len(text)} chars)")
text = ocr_text
return text
except Exception as e:
logger.error(f"Text extraction error: {str(e)}")
return ""
# Split into chunks
def split_into_chunks(text, max_tokens=300, overlap=50):
logger.info(f"Splitting text into chunks (max_tokens={max_tokens}, overlap={overlap})")
sentences = text.split('.')
chunks, current = [], ''
for sentence in sentences:
sentence = sentence.strip() + '.'
if len(current) + len(sentence) < max_tokens:
current += sentence
else:
chunks.append(current.strip())
words = current.split()
if len(words) > overlap:
current = ' '.join(words[-overlap:]) + ' ' + sentence
else:
current = sentence
if current:
chunks.append(current.strip())
logger.info(f"Split text into {len(chunks)} chunks")
return chunks
# Setup FAISS
def setup_faiss(chunks):
try:
logger.info("Setting up FAISS index")
global embedder
if embedder is None:
embedder = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedder.encode(chunks)
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embeddings)
logger.info(f"FAISS index created with {len(chunks)} chunks and dimension {dim}")
return index, embeddings, chunks
except Exception as e:
logger.error(f"FAISS setup error: {str(e)}")
raise
# QA pipeline
def answer_with_qa_pipeline(chunks, question):
try:
logger.info(f"Answering with QA pipeline: '{question}'")
global qa_pipeline
if qa_pipeline is None:
logger.info("QA pipeline not initialized, creating now...")
qa_pipeline = pipeline(
"question-answering",
model="distilbert-base-cased-distilled-squad",
tokenizer="distilbert-base-cased",
device=0 if device == "cuda" else -1
)
# Limit context size to avoid token length issues
context = " ".join(chunks[:5])
if len(context) > 5000: # Approx token limit
context = context[:5000]
result = qa_pipeline(question=question, context=context)
logger.info(f"QA pipeline answer: '{result['answer']}' (score: {result['score']})")
return result["answer"]
except Exception as e:
logger.error(f"QA pipeline error: {str(e)}")
return ""
# Generation-based answering
def answer_with_generation(index, embeddings, chunks, question):
try:
logger.info(f"Answering with generation model: '{question}'")
global tokenizer, model
if tokenizer is None or model is None:
logger.info("Generation models not initialized, creating now...")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained(
"distilgpt2",
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
# Get embeddings for question
q_embedding = embedder.encode([question])
# Find relevant chunks
_, top_k_indices = index.search(q_embedding, k=3)
relevant_chunks = [chunks[i] for i in top_k_indices[0]]
context = " ".join(relevant_chunks)
# Limit context size to avoid token length issues
if len(context) > 4000:
context = context[:4000]
# Create prompt
prompt = f"Answer the following question based on this information:\n\nInformation: {context}\n\nQuestion: {question}\n\nDetailed answer:"
# Handle inputs
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
# Move inputs to the right device if needed
if torch.cuda.is_available():
inputs = {k: v.to('cuda') for k, v in inputs.items()}
# Generate answer
output = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
do_sample=True,
num_beams=3,
no_repeat_ngram_size=2
)
# Decode and format answer
answer = tokenizer.decode(output[0], skip_special_tokens=True)
if "Detailed answer:" in answer:
answer = answer.split("Detailed answer:")[-1].strip()
logger.info(f"Generation answer: '{answer[:50]}...' (length: {len(answer)})")
return answer.strip()
except Exception as e:
logger.error(f"Generation error: {str(e)}")
return "I couldn't generate a good answer based on the PDF content."
# API route
@app.route('/')
def home():
return jsonify({"message": "PDF QA API is running!"})
@app.route('/ask', methods=['POST'])
def ask():
file = request.files.get("pdf")
question = request.form.get("question", "")
filepath = None
if not file or not question:
return jsonify({"error": "Both PDF file and question are required"}), 400
try:
filename = secure_filename(file.filename)
filepath = os.path.join(UPLOAD_FOLDER, filename)
file.save(filepath)
logger.info(f"Processing file: {filename}, Question: '{question}'")
# Process PDF and generate answer
text = extract_text(filepath)
if not text.strip():
return jsonify({"error": "Could not extract text from the PDF"}), 400
chunks = split_into_chunks(text)
if not chunks:
return jsonify({"error": "PDF content couldn't be processed"}), 400
try:
answer = answer_with_qa_pipeline(chunks, question)
except Exception as e:
logger.warning(f"QA pipeline failed: {str(e)}")
answer = ""
# If QA pipeline didn't give a good answer, try generation
if not answer or len(answer.strip()) < 20:
try:
logger.info("QA pipeline answer insufficient, trying generation...")
index, embeddings, chunks = setup_faiss(chunks)
answer = answer_with_generation(index, embeddings, chunks, question)
except Exception as e:
logger.error(f"Generation fallback failed: {str(e)}")
return jsonify({"error": "Failed to generate answer from PDF content"}), 500
return jsonify({"answer": answer})
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
return jsonify({"error": f"An error occurred processing your request: {str(e)}"}), 500
finally:
# Always clean up, even if errors occur
if filepath:
cleanup_temp_files(filepath)
if __name__ == "__main__":
try:
# Initialize models at startup
initialize_models()
logger.info("Starting Flask application")
app.run(host="0.0.0.0", port=7860)
except Exception as e:
logger.critical(f"Failed to start application: {str(e)}")