from flask import Flask, request, jsonify from werkzeug.utils import secure_filename from flask_cors import CORS import os import torch import fitz # PyMuPDF import pytesseract from pdf2image import convert_from_path from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from sentence_transformers import SentenceTransformer import faiss import numpy as np import tempfile from PIL import Image import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Fix caching issue on Hugging Face Spaces os.environ["TRANSFORMERS_CACHE"] = "/tmp" os.environ["HF_HOME"] = "/tmp" os.environ["XDG_CACHE_HOME"] = "/tmp" app = Flask(__name__) CORS(app) # Enable CORS for all routes UPLOAD_FOLDER = "/tmp/uploads" os.makedirs(UPLOAD_FOLDER, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") # Global model variables embedder = None qa_pipeline = None tokenizer = None model = None # Initialize models once on startup def initialize_models(): global embedder, qa_pipeline, tokenizer, model try: logger.info("Loading SentenceTransformer model...") embedder = SentenceTransformer("all-MiniLM-L6-v2") logger.info("Loading QA pipeline...") qa_pipeline = pipeline( "question-answering", model="distilbert-base-cased-distilled-squad", tokenizer="distilbert-base-cased", device=-1 # Force CPU ) logger.info("Loading language model...") model_name = "Qwen/Qwen2.5-1.5B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, # Use float16 for lower memory on CPU device_map="cpu", # Explicitly set to CPU low_cpu_mem_usage=True # Optimize memory loading ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = model.config.eos_token_id logger.info("Models initialized successfully") except Exception as e: logger.error(f"Error initializing models: {str(e)}") raise # Generation-based answering def answer_with_generation(index, embeddings, chunks, question): try: logger.info(f"Answering with generation model: '{question}'") global tokenizer, model if tokenizer is None or model is None: logger.info("Generation models not initialized, creating now...") model_name = "Qwen/Qwen2.5-1.5B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="cpu", low_cpu_mem_usage=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = model.config.eos_token_id # Get embeddings for question q_embedding = embedder.encode([question]) # Find relevant chunks _, top_k_indices = index.search(q_embedding, k=3) relevant_chunks = [chunks[i] for i in top_k_indices[0]] context = " ".join(relevant_chunks) # Limit context size if len(context) > 2000: context = context[:2000] # Create prompt prompt = f"""<|im_start|>system You are a helpful assistant answering questions based on provided PDF content. Use the information below to give a clear, concise, and accurate answer. Avoid speculation and focus on the context. <|im_end|> <|im_start|>user **Context**: {context} **Question**: {question} **Instruction**: Provide a detailed and accurate answer based on the context. If the context doesn't contain enough information, say so clearly. <|im_end|>""" # Handle inputs inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) # Move inputs to CPU inputs = {k: v.to('cpu') for k, v in inputs.items()} # Generate answer output = model.generate( **inputs, max_new_tokens=300, temperature=0.7, top_p=0.9, do_sample=True, num_beams=2, no_repeat_ngram_size=2 ) # Decode and format answer answer = tokenizer.decode(output[0], skip_special_tokens=True) if "<|im_end|>" in answer: answer = answer.split("<|im_end|>")[1].strip() elif "Instruction" in answer: answer = answer.split("Instruction")[1].strip() logger.info(f"Generation answer: '{answer[:50]}...' (length: {len(answer)})") return answer.strip() except Exception as e: logger.error(f"Generation error: {str(e)}") return "I couldn't generate a good answer based on the PDF content." # Cleanup function for temporary files def cleanup_temp_files(filepath): try: if os.path.exists(filepath): os.remove(filepath) logger.info(f"Removed temporary file: {filepath}") except Exception as e: logger.warning(f"Failed to clean up file {filepath}: {str(e)}") # Improved OCR function def ocr_pdf(pdf_path): try: logger.info(f"Starting OCR for {pdf_path}") # Use a higher DPI for better quality images = convert_from_path( pdf_path, dpi=300, # Higher DPI for better quality grayscale=False, # Color might help with some PDFs thread_count=2, # Use multiple threads use_pdftocairo=True # pdftocairo often gives better results ) text = "" for i, img in enumerate(images): logger.info(f"Processing page {i+1} of {len(images)}") # Preprocess the image for better OCR results preprocessed = preprocess_image_for_ocr(img) # Use tesseract with more options page_text = pytesseract.image_to_string( preprocessed, config='--psm 1 --oem 3 -l eng' # Page segmentation mode 1 (auto), OCR Engine mode 3 (default) ) text += page_text logger.info(f"OCR completed with {len(text)} characters extracted") return text except Exception as e: logger.error(f"OCR error: {str(e)}") return "" # Image preprocessing function for better OCR def preprocess_image_for_ocr(img): # Convert to grayscale gray = img.convert('L') # Optional: You could add more preprocessing here like: # - Thresholding # - Noise removal # - Contrast enhancement return gray # Improved extract_text function with better text detection def extract_text(pdf_path): try: logger.info(f"Extracting text from {pdf_path}") doc = fitz.open(pdf_path) text = "" for page_num, page in enumerate(doc): page_text = page.get_text() text += page_text logger.info(f"Extracted {len(page_text)} characters from page {page_num+1}") # Check if the text is meaningful (more sophisticated check) words = text.split() unique_words = set(word.lower() for word in words if len(word) > 2) logger.info(f"PDF text extraction: {len(text)} chars, {len(words)} words, {len(unique_words)} unique words") # If we don't have enough meaningful text, try OCR if len(unique_words) < 20 or len(text.strip()) < 100: logger.info("Text extraction yielded insufficient results, trying OCR...") ocr_text = ocr_pdf(pdf_path) # If OCR gave us more text, use it if len(ocr_text.strip()) > len(text.strip()): logger.info(f"Using OCR result: {len(ocr_text)} chars (better than {len(text)} chars)") text = ocr_text return text except Exception as e: logger.error(f"Text extraction error: {str(e)}") return "" # Split into chunks def split_into_chunks(text, max_tokens=300, overlap=50): logger.info(f"Splitting text into chunks (max_tokens={max_tokens}, overlap={overlap})") sentences = text.split('.') chunks, current = [], '' for sentence in sentences: sentence = sentence.strip() + '.' if len(current) + len(sentence) < max_tokens: current += sentence else: chunks.append(current.strip()) words = current.split() if len(words) > overlap: current = ' '.join(words[-overlap:]) + ' ' + sentence else: current = sentence if current: chunks.append(current.strip()) logger.info(f"Split text into {len(chunks)} chunks") return chunks # Setup FAISS def setup_faiss(chunks): try: logger.info("Setting up FAISS index") global embedder if embedder is None: embedder = SentenceTransformer("all-MiniLM-L6-v2") embeddings = embedder.encode(chunks) dim = embeddings.shape[1] index = faiss.IndexFlatL2(dim) index.add(embeddings) logger.info(f"FAISS index created with {len(chunks)} chunks and dimension {dim}") return index, embeddings, chunks except Exception as e: logger.error(f"FAISS setup error: {str(e)}") raise # QA pipeline def answer_with_qa_pipeline(chunks, question): try: logger.info(f"Answering with QA pipeline: '{question}'") global qa_pipeline if qa_pipeline is None: logger.info("QA pipeline not initialized, creating now...") qa_pipeline = pipeline( "question-answering", model="distilbert-base-cased-distilled-squad", tokenizer="distilbert-base-cased", device=0 if device == "cuda" else -1 ) # Limit context size to avoid token length issues context = " ".join(chunks[:5]) if len(context) > 5000: # Approx token limit context = context[:5000] result = qa_pipeline(question=question, context=context) logger.info(f"QA pipeline answer: '{result['answer']}' (score: {result['score']})") return result["answer"] except Exception as e: logger.error(f"QA pipeline error: {str(e)}") return "" # Generation-based answering def answer_with_generation(index, embeddings, chunks, question): try: logger.info(f"Answering with generation model: '{question}'") global tokenizer, model if tokenizer is None or model is None: logger.info("Generation models not initialized, creating now...") tokenizer = AutoTokenizer.from_pretrained("distilgpt2") model = AutoModelForCausalLM.from_pretrained( "distilgpt2", device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = model.config.eos_token_id # Get embeddings for question q_embedding = embedder.encode([question]) # Find relevant chunks _, top_k_indices = index.search(q_embedding, k=3) relevant_chunks = [chunks[i] for i in top_k_indices[0]] context = " ".join(relevant_chunks) # Limit context size to avoid token length issues if len(context) > 4000: context = context[:4000] # Create prompt prompt = f"Answer the following question based on this information:\n\nInformation: {context}\n\nQuestion: {question}\n\nDetailed answer:" # Handle inputs inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) # Move inputs to the right device if needed if torch.cuda.is_available(): inputs = {k: v.to('cuda') for k, v in inputs.items()} # Generate answer output = model.generate( **inputs, max_new_tokens=300, temperature=0.7, top_p=0.9, do_sample=True, num_beams=3, no_repeat_ngram_size=2 ) # Decode and format answer answer = tokenizer.decode(output[0], skip_special_tokens=True) if "Detailed answer:" in answer: answer = answer.split("Detailed answer:")[-1].strip() logger.info(f"Generation answer: '{answer[:50]}...' (length: {len(answer)})") return answer.strip() except Exception as e: logger.error(f"Generation error: {str(e)}") return "I couldn't generate a good answer based on the PDF content." # API route @app.route('/') def home(): return jsonify({"message": "PDF QA API is running!"}) @app.route('/ask', methods=['POST']) def ask(): file = request.files.get("pdf") question = request.form.get("question", "") filepath = None if not file or not question: return jsonify({"error": "Both PDF file and question are required"}), 400 try: filename = secure_filename(file.filename) filepath = os.path.join(UPLOAD_FOLDER, filename) file.save(filepath) logger.info(f"Processing file: {filename}, Question: '{question}'") # Process PDF and generate answer text = extract_text(filepath) if not text.strip(): return jsonify({"error": "Could not extract text from the PDF"}), 400 chunks = split_into_chunks(text) if not chunks: return jsonify({"error": "PDF content couldn't be processed"}), 400 try: answer = answer_with_qa_pipeline(chunks, question) except Exception as e: logger.warning(f"QA pipeline failed: {str(e)}") answer = "" # If QA pipeline didn't give a good answer, try generation if not answer or len(answer.strip()) < 20: try: logger.info("QA pipeline answer insufficient, trying generation...") index, embeddings, chunks = setup_faiss(chunks) answer = answer_with_generation(index, embeddings, chunks, question) except Exception as e: logger.error(f"Generation fallback failed: {str(e)}") return jsonify({"error": "Failed to generate answer from PDF content"}), 500 return jsonify({"answer": answer}) except Exception as e: logger.error(f"Error processing request: {str(e)}") return jsonify({"error": f"An error occurred processing your request: {str(e)}"}), 500 finally: # Always clean up, even if errors occur if filepath: cleanup_temp_files(filepath) if __name__ == "__main__": try: # Initialize models at startup initialize_models() logger.info("Starting Flask application") app.run(host="0.0.0.0", port=7860) except Exception as e: logger.critical(f"Failed to start application: {str(e)}")