Spaces:
Running
Running
from flask import Flask, request, jsonify | |
from werkzeug.utils import secure_filename | |
from flask_cors import CORS | |
import os | |
import torch | |
import fitz # PyMuPDF | |
import pytesseract | |
from pdf2image import convert_from_path | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import tempfile | |
from PIL import Image | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Fix caching issue on Hugging Face Spaces | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp" | |
os.environ["HF_HOME"] = "/tmp" | |
os.environ["XDG_CACHE_HOME"] = "/tmp" | |
app = Flask(__name__) | |
CORS(app) # Enable CORS for all routes | |
UPLOAD_FOLDER = "/tmp/uploads" | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
# Global model variables | |
embedder = None | |
qa_pipeline = None | |
tokenizer = None | |
model = None | |
# Initialize models once on startup | |
def initialize_models(): | |
global embedder, qa_pipeline, tokenizer, model | |
try: | |
logger.info("Loading SentenceTransformer model...") | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
logger.info("Loading QA pipeline...") | |
qa_pipeline = pipeline( | |
"question-answering", | |
model="distilbert-base-cased-distilled-squad", | |
tokenizer="distilbert-base-cased", | |
device=-1 # Force CPU | |
) | |
logger.info("Loading language model...") | |
model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, # Use float16 for lower memory on CPU | |
device_map="cpu", # Explicitly set to CPU | |
low_cpu_mem_usage=True # Optimize memory loading | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
logger.info("Models initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing models: {str(e)}") | |
raise | |
# Generation-based answering | |
def answer_with_generation(index, embeddings, chunks, question): | |
try: | |
logger.info(f"Answering with generation model: '{question}'") | |
global tokenizer, model | |
if tokenizer is None or model is None: | |
logger.info("Generation models not initialized, creating now...") | |
model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
device_map="cpu", | |
low_cpu_mem_usage=True | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
# Get embeddings for question | |
q_embedding = embedder.encode([question]) | |
# Find relevant chunks | |
_, top_k_indices = index.search(q_embedding, k=3) | |
relevant_chunks = [chunks[i] for i in top_k_indices[0]] | |
context = " ".join(relevant_chunks) | |
# Limit context size | |
if len(context) > 2000: | |
context = context[:2000] | |
# Create prompt | |
prompt = f"""<|im_start|>system | |
You are a helpful assistant answering questions based on provided PDF content. Use the information below to give a clear, concise, and accurate answer. Avoid speculation and focus on the context. | |
<|im_end|> | |
<|im_start|>user | |
**Context**: {context} | |
**Question**: {question} | |
**Instruction**: Provide a detailed and accurate answer based on the context. If the context doesn't contain enough information, say so clearly. <|im_end|>""" | |
# Handle inputs | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) | |
# Move inputs to CPU | |
inputs = {k: v.to('cpu') for k, v in inputs.items()} | |
# Generate answer | |
output = model.generate( | |
**inputs, | |
max_new_tokens=300, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
num_beams=2, | |
no_repeat_ngram_size=2 | |
) | |
# Decode and format answer | |
answer = tokenizer.decode(output[0], skip_special_tokens=True) | |
if "<|im_end|>" in answer: | |
answer = answer.split("<|im_end|>")[1].strip() | |
elif "Instruction" in answer: | |
answer = answer.split("Instruction")[1].strip() | |
logger.info(f"Generation answer: '{answer[:50]}...' (length: {len(answer)})") | |
return answer.strip() | |
except Exception as e: | |
logger.error(f"Generation error: {str(e)}") | |
return "I couldn't generate a good answer based on the PDF content." | |
# Cleanup function for temporary files | |
def cleanup_temp_files(filepath): | |
try: | |
if os.path.exists(filepath): | |
os.remove(filepath) | |
logger.info(f"Removed temporary file: {filepath}") | |
except Exception as e: | |
logger.warning(f"Failed to clean up file {filepath}: {str(e)}") | |
# Improved OCR function | |
def ocr_pdf(pdf_path): | |
try: | |
logger.info(f"Starting OCR for {pdf_path}") | |
# Use a higher DPI for better quality | |
images = convert_from_path( | |
pdf_path, | |
dpi=300, # Higher DPI for better quality | |
grayscale=False, # Color might help with some PDFs | |
thread_count=2, # Use multiple threads | |
use_pdftocairo=True # pdftocairo often gives better results | |
) | |
text = "" | |
for i, img in enumerate(images): | |
logger.info(f"Processing page {i+1} of {len(images)}") | |
# Preprocess the image for better OCR results | |
preprocessed = preprocess_image_for_ocr(img) | |
# Use tesseract with more options | |
page_text = pytesseract.image_to_string( | |
preprocessed, | |
config='--psm 1 --oem 3 -l eng' # Page segmentation mode 1 (auto), OCR Engine mode 3 (default) | |
) | |
text += page_text | |
logger.info(f"OCR completed with {len(text)} characters extracted") | |
return text | |
except Exception as e: | |
logger.error(f"OCR error: {str(e)}") | |
return "" | |
# Image preprocessing function for better OCR | |
def preprocess_image_for_ocr(img): | |
# Convert to grayscale | |
gray = img.convert('L') | |
# Optional: You could add more preprocessing here like: | |
# - Thresholding | |
# - Noise removal | |
# - Contrast enhancement | |
return gray | |
# Improved extract_text function with better text detection | |
def extract_text(pdf_path): | |
try: | |
logger.info(f"Extracting text from {pdf_path}") | |
doc = fitz.open(pdf_path) | |
text = "" | |
for page_num, page in enumerate(doc): | |
page_text = page.get_text() | |
text += page_text | |
logger.info(f"Extracted {len(page_text)} characters from page {page_num+1}") | |
# Check if the text is meaningful (more sophisticated check) | |
words = text.split() | |
unique_words = set(word.lower() for word in words if len(word) > 2) | |
logger.info(f"PDF text extraction: {len(text)} chars, {len(words)} words, {len(unique_words)} unique words") | |
# If we don't have enough meaningful text, try OCR | |
if len(unique_words) < 20 or len(text.strip()) < 100: | |
logger.info("Text extraction yielded insufficient results, trying OCR...") | |
ocr_text = ocr_pdf(pdf_path) | |
# If OCR gave us more text, use it | |
if len(ocr_text.strip()) > len(text.strip()): | |
logger.info(f"Using OCR result: {len(ocr_text)} chars (better than {len(text)} chars)") | |
text = ocr_text | |
return text | |
except Exception as e: | |
logger.error(f"Text extraction error: {str(e)}") | |
return "" | |
# Split into chunks | |
def split_into_chunks(text, max_tokens=300, overlap=50): | |
logger.info(f"Splitting text into chunks (max_tokens={max_tokens}, overlap={overlap})") | |
sentences = text.split('.') | |
chunks, current = [], '' | |
for sentence in sentences: | |
sentence = sentence.strip() + '.' | |
if len(current) + len(sentence) < max_tokens: | |
current += sentence | |
else: | |
chunks.append(current.strip()) | |
words = current.split() | |
if len(words) > overlap: | |
current = ' '.join(words[-overlap:]) + ' ' + sentence | |
else: | |
current = sentence | |
if current: | |
chunks.append(current.strip()) | |
logger.info(f"Split text into {len(chunks)} chunks") | |
return chunks | |
# Setup FAISS | |
def setup_faiss(chunks): | |
try: | |
logger.info("Setting up FAISS index") | |
global embedder | |
if embedder is None: | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
embeddings = embedder.encode(chunks) | |
dim = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dim) | |
index.add(embeddings) | |
logger.info(f"FAISS index created with {len(chunks)} chunks and dimension {dim}") | |
return index, embeddings, chunks | |
except Exception as e: | |
logger.error(f"FAISS setup error: {str(e)}") | |
raise | |
# QA pipeline | |
def answer_with_qa_pipeline(chunks, question): | |
try: | |
logger.info(f"Answering with QA pipeline: '{question}'") | |
global qa_pipeline | |
if qa_pipeline is None: | |
logger.info("QA pipeline not initialized, creating now...") | |
qa_pipeline = pipeline( | |
"question-answering", | |
model="distilbert-base-cased-distilled-squad", | |
tokenizer="distilbert-base-cased", | |
device=0 if device == "cuda" else -1 | |
) | |
# Limit context size to avoid token length issues | |
context = " ".join(chunks[:5]) | |
if len(context) > 5000: # Approx token limit | |
context = context[:5000] | |
result = qa_pipeline(question=question, context=context) | |
logger.info(f"QA pipeline answer: '{result['answer']}' (score: {result['score']})") | |
return result["answer"] | |
except Exception as e: | |
logger.error(f"QA pipeline error: {str(e)}") | |
return "" | |
# Generation-based answering | |
def answer_with_generation(index, embeddings, chunks, question): | |
try: | |
logger.info(f"Answering with generation model: '{question}'") | |
global tokenizer, model | |
if tokenizer is None or model is None: | |
logger.info("Generation models not initialized, creating now...") | |
tokenizer = AutoTokenizer.from_pretrained("distilgpt2") | |
model = AutoModelForCausalLM.from_pretrained( | |
"distilgpt2", | |
device_map="auto", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
# Get embeddings for question | |
q_embedding = embedder.encode([question]) | |
# Find relevant chunks | |
_, top_k_indices = index.search(q_embedding, k=3) | |
relevant_chunks = [chunks[i] for i in top_k_indices[0]] | |
context = " ".join(relevant_chunks) | |
# Limit context size to avoid token length issues | |
if len(context) > 4000: | |
context = context[:4000] | |
# Create prompt | |
prompt = f"Answer the following question based on this information:\n\nInformation: {context}\n\nQuestion: {question}\n\nDetailed answer:" | |
# Handle inputs | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
# Move inputs to the right device if needed | |
if torch.cuda.is_available(): | |
inputs = {k: v.to('cuda') for k, v in inputs.items()} | |
# Generate answer | |
output = model.generate( | |
**inputs, | |
max_new_tokens=300, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
num_beams=3, | |
no_repeat_ngram_size=2 | |
) | |
# Decode and format answer | |
answer = tokenizer.decode(output[0], skip_special_tokens=True) | |
if "Detailed answer:" in answer: | |
answer = answer.split("Detailed answer:")[-1].strip() | |
logger.info(f"Generation answer: '{answer[:50]}...' (length: {len(answer)})") | |
return answer.strip() | |
except Exception as e: | |
logger.error(f"Generation error: {str(e)}") | |
return "I couldn't generate a good answer based on the PDF content." | |
# API route | |
def home(): | |
return jsonify({"message": "PDF QA API is running!"}) | |
def ask(): | |
file = request.files.get("pdf") | |
question = request.form.get("question", "") | |
filepath = None | |
if not file or not question: | |
return jsonify({"error": "Both PDF file and question are required"}), 400 | |
try: | |
filename = secure_filename(file.filename) | |
filepath = os.path.join(UPLOAD_FOLDER, filename) | |
file.save(filepath) | |
logger.info(f"Processing file: {filename}, Question: '{question}'") | |
# Process PDF and generate answer | |
text = extract_text(filepath) | |
if not text.strip(): | |
return jsonify({"error": "Could not extract text from the PDF"}), 400 | |
chunks = split_into_chunks(text) | |
if not chunks: | |
return jsonify({"error": "PDF content couldn't be processed"}), 400 | |
try: | |
answer = answer_with_qa_pipeline(chunks, question) | |
except Exception as e: | |
logger.warning(f"QA pipeline failed: {str(e)}") | |
answer = "" | |
# If QA pipeline didn't give a good answer, try generation | |
if not answer or len(answer.strip()) < 20: | |
try: | |
logger.info("QA pipeline answer insufficient, trying generation...") | |
index, embeddings, chunks = setup_faiss(chunks) | |
answer = answer_with_generation(index, embeddings, chunks, question) | |
except Exception as e: | |
logger.error(f"Generation fallback failed: {str(e)}") | |
return jsonify({"error": "Failed to generate answer from PDF content"}), 500 | |
return jsonify({"answer": answer}) | |
except Exception as e: | |
logger.error(f"Error processing request: {str(e)}") | |
return jsonify({"error": f"An error occurred processing your request: {str(e)}"}), 500 | |
finally: | |
# Always clean up, even if errors occur | |
if filepath: | |
cleanup_temp_files(filepath) | |
if __name__ == "__main__": | |
try: | |
# Initialize models at startup | |
initialize_models() | |
logger.info("Starting Flask application") | |
app.run(host="0.0.0.0", port=7860) | |
except Exception as e: | |
logger.critical(f"Failed to start application: {str(e)}") |