from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
from flask_cors import CORS
import os
import torch
import fitz  # PyMuPDF
import pytesseract
from pdf2image import convert_from_path
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import tempfile
from PIL import Image

import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Fix caching issue on Hugging Face Spaces
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
os.environ["HF_HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp"

app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

UPLOAD_FOLDER = "/tmp/uploads"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")

# Global model variables
embedder = None
qa_pipeline = None
tokenizer = None
model = None

# Initialize models once on startup
def initialize_models():
    global embedder, qa_pipeline, tokenizer, model
    try:
        logger.info("Loading SentenceTransformer model...")
        embedder = SentenceTransformer("all-MiniLM-L6-v2")
        
        logger.info("Loading QA pipeline...")
        qa_pipeline = pipeline(
            "question-answering",
            model="distilbert-base-cased-distilled-squad",
            tokenizer="distilbert-base-cased",
            device=-1  # Force CPU
        )
        
        logger.info("Loading language model...")
        model_name = "Qwen/Qwen2.5-1.5B-Instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,  # Use float16 for lower memory on CPU
            device_map="cpu",  # Explicitly set to CPU
            low_cpu_mem_usage=True  # Optimize memory loading
        )
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            model.config.pad_token_id = model.config.eos_token_id
            
        logger.info("Models initialized successfully")
    except Exception as e:
        logger.error(f"Error initializing models: {str(e)}")
        raise

# Generation-based answering
def answer_with_generation(index, embeddings, chunks, question):
    try:
        logger.info(f"Answering with generation model: '{question}'")
        global tokenizer, model
        
        if tokenizer is None or model is None:
            logger.info("Generation models not initialized, creating now...")
            model_name = "Qwen/Qwen2.5-1.5B-Instruct"
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16,
                device_map="cpu",
                low_cpu_mem_usage=True
            )
            
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
                model.config.pad_token_id = model.config.eos_token_id
        
        # Get embeddings for question
        q_embedding = embedder.encode([question])
        
        # Find relevant chunks
        _, top_k_indices = index.search(q_embedding, k=3)
        relevant_chunks = [chunks[i] for i in top_k_indices[0]]
        context = " ".join(relevant_chunks)
        
        # Limit context size
        if len(context) > 2000:
            context = context[:2000]

        # Create prompt
        prompt = f"""<|im_start|>system
You are a helpful assistant answering questions based on provided PDF content. Use the information below to give a clear, concise, and accurate answer. Avoid speculation and focus on the context.
<|im_end|>
<|im_start|>user
**Context**: {context}

**Question**: {question}

**Instruction**: Provide a detailed and accurate answer based on the context. If the context doesn't contain enough information, say so clearly. <|im_end|>"""
        
        # Handle inputs
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
        
        # Move inputs to CPU
        inputs = {k: v.to('cpu') for k, v in inputs.items()}
        
        # Generate answer
        output = model.generate(
            **inputs,
            max_new_tokens=300,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            num_beams=2,
            no_repeat_ngram_size=2
        )
        
        # Decode and format answer
        answer = tokenizer.decode(output[0], skip_special_tokens=True)
        if "<|im_end|>" in answer:
            answer = answer.split("<|im_end|>")[1].strip()
        elif "Instruction" in answer:
            answer = answer.split("Instruction")[1].strip()
        
        logger.info(f"Generation answer: '{answer[:50]}...' (length: {len(answer)})")
        return answer.strip()
    except Exception as e:
        logger.error(f"Generation error: {str(e)}")
        return "I couldn't generate a good answer based on the PDF content."

        
        

# Cleanup function for temporary files
def cleanup_temp_files(filepath):
    try:
        if os.path.exists(filepath):
            os.remove(filepath)
            logger.info(f"Removed temporary file: {filepath}")
    except Exception as e:
        logger.warning(f"Failed to clean up file {filepath}: {str(e)}")

# Improved OCR function
def ocr_pdf(pdf_path):
    try:
        logger.info(f"Starting OCR for {pdf_path}")
        # Use a higher DPI for better quality
        images = convert_from_path(
            pdf_path, 
            dpi=300,  # Higher DPI for better quality
            grayscale=False,  # Color might help with some PDFs
            thread_count=2,  # Use multiple threads
            use_pdftocairo=True  # pdftocairo often gives better results
        )
        
        text = ""
        for i, img in enumerate(images):
            logger.info(f"Processing page {i+1} of {len(images)}")
            # Preprocess the image for better OCR results
            preprocessed = preprocess_image_for_ocr(img)
            # Use tesseract with more options
            page_text = pytesseract.image_to_string(
                preprocessed,
                config='--psm 1 --oem 3 -l eng'  # Page segmentation mode 1 (auto), OCR Engine mode 3 (default)
            )
            text += page_text
        logger.info(f"OCR completed with {len(text)} characters extracted")
        return text
    except Exception as e:
        logger.error(f"OCR error: {str(e)}")
        return ""

# Image preprocessing function for better OCR
def preprocess_image_for_ocr(img):
    # Convert to grayscale
    gray = img.convert('L')
    
    # Optional: You could add more preprocessing here like:
    # - Thresholding
    # - Noise removal
    # - Contrast enhancement
    
    return gray

# Improved extract_text function with better text detection
def extract_text(pdf_path):
    try:
        logger.info(f"Extracting text from {pdf_path}")
        doc = fitz.open(pdf_path)
        text = ""
        for page_num, page in enumerate(doc):
            page_text = page.get_text()
            text += page_text
            logger.info(f"Extracted {len(page_text)} characters from page {page_num+1}")
        
        # Check if the text is meaningful (more sophisticated check)
        words = text.split()
        unique_words = set(word.lower() for word in words if len(word) > 2)
        
        logger.info(f"PDF text extraction: {len(text)} chars, {len(words)} words, {len(unique_words)} unique words")
        
        # If we don't have enough meaningful text, try OCR
        if len(unique_words) < 20 or len(text.strip()) < 100:
            logger.info("Text extraction yielded insufficient results, trying OCR...")
            ocr_text = ocr_pdf(pdf_path)
            # If OCR gave us more text, use it
            if len(ocr_text.strip()) > len(text.strip()):
                logger.info(f"Using OCR result: {len(ocr_text)} chars (better than {len(text)} chars)")
                text = ocr_text
        
        return text
    except Exception as e:
        logger.error(f"Text extraction error: {str(e)}")
        return ""

# Split into chunks
def split_into_chunks(text, max_tokens=300, overlap=50):
    logger.info(f"Splitting text into chunks (max_tokens={max_tokens}, overlap={overlap})")
    sentences = text.split('.')
    chunks, current = [], ''
    for sentence in sentences:
        sentence = sentence.strip() + '.'
        if len(current) + len(sentence) < max_tokens:
            current += sentence
        else:
            chunks.append(current.strip())
            words = current.split()
            if len(words) > overlap:
                current = ' '.join(words[-overlap:]) + ' ' + sentence
            else:
                current = sentence
    if current:
        chunks.append(current.strip())
    logger.info(f"Split text into {len(chunks)} chunks")
    return chunks

# Setup FAISS
def setup_faiss(chunks):
    try:
        logger.info("Setting up FAISS index")
        global embedder
        if embedder is None:
            embedder = SentenceTransformer("all-MiniLM-L6-v2")
            
        embeddings = embedder.encode(chunks)
        dim = embeddings.shape[1]
        index = faiss.IndexFlatL2(dim)
        index.add(embeddings)
        logger.info(f"FAISS index created with {len(chunks)} chunks and dimension {dim}")
        return index, embeddings, chunks
    except Exception as e:
        logger.error(f"FAISS setup error: {str(e)}")
        raise

# QA pipeline
def answer_with_qa_pipeline(chunks, question):
    try:
        logger.info(f"Answering with QA pipeline: '{question}'")
        global qa_pipeline
        if qa_pipeline is None:
            logger.info("QA pipeline not initialized, creating now...")
            qa_pipeline = pipeline(
                "question-answering",
                model="distilbert-base-cased-distilled-squad",
                tokenizer="distilbert-base-cased",
                device=0 if device == "cuda" else -1
            )
        
        # Limit context size to avoid token length issues
        context = " ".join(chunks[:5])
        if len(context) > 5000:  # Approx token limit
            context = context[:5000]
            
        result = qa_pipeline(question=question, context=context)
        logger.info(f"QA pipeline answer: '{result['answer']}' (score: {result['score']})")
        return result["answer"]
    except Exception as e:
        logger.error(f"QA pipeline error: {str(e)}")
        return ""

# Generation-based answering
def answer_with_generation(index, embeddings, chunks, question):
    try:
        logger.info(f"Answering with generation model: '{question}'")
        global tokenizer, model
        
        if tokenizer is None or model is None:
            logger.info("Generation models not initialized, creating now...")
            tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
            model = AutoModelForCausalLM.from_pretrained(
                "distilgpt2",
                device_map="auto",
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            )
            
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
                model.config.pad_token_id = model.config.eos_token_id
        
        # Get embeddings for question
        q_embedding = embedder.encode([question])
        
        # Find relevant chunks
        _, top_k_indices = index.search(q_embedding, k=3)
        relevant_chunks = [chunks[i] for i in top_k_indices[0]]
        context = " ".join(relevant_chunks)
        
        # Limit context size to avoid token length issues
        if len(context) > 4000:
            context = context[:4000]

        # Create prompt
        prompt = f"Answer the following question based on this information:\n\nInformation: {context}\n\nQuestion: {question}\n\nDetailed answer:"
        
        # Handle inputs
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
        
        # Move inputs to the right device if needed
        if torch.cuda.is_available():
            inputs = {k: v.to('cuda') for k, v in inputs.items()}
        
        # Generate answer
        output = model.generate(
            **inputs,
            max_new_tokens=300,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            num_beams=3,
            no_repeat_ngram_size=2
        )
        
        # Decode and format answer
        answer = tokenizer.decode(output[0], skip_special_tokens=True)
        if "Detailed answer:" in answer:
            answer = answer.split("Detailed answer:")[-1].strip()
        
        logger.info(f"Generation answer: '{answer[:50]}...' (length: {len(answer)})")
        return answer.strip()
    except Exception as e:
        logger.error(f"Generation error: {str(e)}")
        return "I couldn't generate a good answer based on the PDF content."

# API route
@app.route('/')
def home():
    return jsonify({"message": "PDF QA API is running!"})

@app.route('/ask', methods=['POST'])
def ask():
    file = request.files.get("pdf")
    question = request.form.get("question", "")
    filepath = None

    if not file or not question:
        return jsonify({"error": "Both PDF file and question are required"}), 400

    try:
        filename = secure_filename(file.filename)
        filepath = os.path.join(UPLOAD_FOLDER, filename)
        file.save(filepath)
        
        logger.info(f"Processing file: {filename}, Question: '{question}'")
        
        # Process PDF and generate answer
        text = extract_text(filepath)
        if not text.strip():
            return jsonify({"error": "Could not extract text from the PDF"}), 400
            
        chunks = split_into_chunks(text)
        if not chunks:
            return jsonify({"error": "PDF content couldn't be processed"}), 400
            
        try:
            answer = answer_with_qa_pipeline(chunks, question)
        except Exception as e:
            logger.warning(f"QA pipeline failed: {str(e)}")
            answer = ""
            
        # If QA pipeline didn't give a good answer, try generation
        if not answer or len(answer.strip()) < 20:
            try:
                logger.info("QA pipeline answer insufficient, trying generation...")
                index, embeddings, chunks = setup_faiss(chunks)
                answer = answer_with_generation(index, embeddings, chunks, question)
            except Exception as e:
                logger.error(f"Generation fallback failed: {str(e)}")
                return jsonify({"error": "Failed to generate answer from PDF content"}), 500

        return jsonify({"answer": answer})

    except Exception as e:
        logger.error(f"Error processing request: {str(e)}")
        return jsonify({"error": f"An error occurred processing your request: {str(e)}"}), 500
    finally:
        # Always clean up, even if errors occur
        if filepath:
            cleanup_temp_files(filepath)

if __name__ == "__main__":
    try:
        # Initialize models at startup
        initialize_models()
        logger.info("Starting Flask application")
        app.run(host="0.0.0.0", port=7860)
    except Exception as e:
        logger.critical(f"Failed to start application: {str(e)}")