import pandas as pd
import os
from datetime import datetime, timedelta, timezone
import json
from Bio import Entrez, Medline
from huggingface_hub import HfApi, hf_hub_download, DatasetCard, DatasetCardData
from datasets import Dataset, load_dataset
from hf_api import (
    evaluate_relevance,
    summarize_abstract,
    compose_newsletter
)
import logging
import argparse
from huggingface_hub import HfFileSystem
import pdfkit
from jinja2 import Environment, FileSystemLoader
import markdown2

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("app.log"),
        logging.StreamHandler()
    ]
)

# Retrieve environment variables
HF_TOKEN = os.environ.get("HF_TOKEN")
DATASET_NAME = os.environ.get("DATASET_NAME", "cmcmaster/this_week_in_rheumatology")

if not HF_TOKEN:
    logging.error("Hugging Face token not found. Set the HF_TOKEN environment variable.")
    exit(1)

# Initialize Hugging Face Hub API
api = HfApi(token=HF_TOKEN)

def ensure_repo_exists(api, repo_id, repo_type, token):
    try:
        api.repo_info(repo_id=repo_id, repo_type=repo_type)
        logging.info(f"Repository {repo_id} already exists.")
    except Exception as e:
        logging.info(f"Repository {repo_id} not found. Creating a new one.")
        try:
            api.create_repo(
                repo_id=repo_id,
                repo_type=repo_type,
                token=token,
                private=False,
                exist_ok=True
            )
            # Create a dataset card
            card_data = DatasetCardData(
                language="en",
                license="cc-by-sa-4.0",
                task_categories=["text-classification"],
                tags=["rheumatology", "medical-research"]
            )
            card = DatasetCard("---\n" + card_data.to_yaml() + "\n---\n# This Week in Rheumatology\n\nA weekly collection of relevant rheumatology papers.")
            api.upload_file(
                path_or_fileobj=str(card).encode(),
                path_in_repo="README.md",
                repo_id=repo_id,
                repo_type=repo_type,
                commit_message="Add dataset card",
                token=token
            )
            logging.info(f"Repository {repo_id} created successfully with a dataset card.")
        except Exception as create_error:
            logging.error(f"Failed to create repository {repo_id}: {create_error}")
            exit(1)

# Ensure the repository exists before proceeding
ensure_repo_exists(api, DATASET_NAME, repo_type="dataset", token=HF_TOKEN)

# Load search terms from JSON
with open('search_terms.json', 'r') as f:
    search_terms = json.load(f)

def build_query():
    # Constructing MeSH terms
    mesh_terms = ' OR '.join(f'"{term}"[MeSH Terms]' for term in search_terms['search_strategy']['mesh_terms'])

    # Constructing keywords
    keywords = ' OR '.join(f'"{term}"[Title/Abstract]' for term in search_terms['search_strategy']['keywords'])

    # Constructing specific conditions
    specific_conditions = ' OR '.join(f'"{term}"[Title/Abstract]' for term in search_terms['search_strategy']['specific_conditions'])

    # Constructing research-related terms
    research_terms = ' OR '.join(f'"{term}"[Title/Abstract]' for term in search_terms['search_strategy']['research_related_terms'])

    # Constructing journal names
    journals = ' OR '.join(f'"{journal}"[Journal]' for journal in search_terms['journals'])

    # Correctly grouping exclusion terms with parentheses and using OR
    exclusion_terms = 'NOT (' + ' OR '.join(f'"{term}"[Title/Abstract]' for term in search_terms['search_strategy']['exclusion_terms']) + ')'

    # Grouping all inclusion terms within parentheses and combining with OR
    inclusion_terms = f"({mesh_terms} OR {keywords} OR {specific_conditions} OR {journals})"

    # Enclosing research terms within parentheses
    research_terms_grouped = f"({research_terms})"

    # Constructing the final query with proper grouping and operator precedence
    query = f"{inclusion_terms} AND {research_terms_grouped} {exclusion_terms}"

    # Adding filters for human studies, English language, and publication types
    human_filter = 'AND "humans"[MeSH Terms]'
    language_filter = 'AND "english"[Language]'
    pub_types = ' OR '.join(f'"{pt}"[Publication Type]' for pt in search_terms['publication_types'])
    pub_type_filter = f'AND ({pub_types})'

    # Exclude case reports
    exclude_case_reports = 'NOT "Case Reports"[Publication Type]'

    query = f"{query} {human_filter} {language_filter} {pub_type_filter} {exclude_case_reports}"

    logging.info(f"Built PubMed query: {query}")
    return query

def search_pubmed(query, start_date: datetime, end_date: datetime):
    Entrez.email = "mcmastc1@gmail.com"  # Replace with your actual email
    try:
        handle = Entrez.esearch(
            db="pubmed", 
            term=query, 
            mindate=start_date.strftime('%Y/%m/%d'),
            maxdate=end_date.strftime('%Y/%m/%d'),
            usehistory="y",
            retmax=1000
        )
        results = Entrez.read(handle)
        logging.info(f"PubMed search completed. Found {results['Count']} papers.")
        return results
    except Exception as e:
        logging.error(f"Error searching PubMed: {e}")
        logging.error(f"Query: {query}")
        logging.error(f"Date range: {start_date.strftime('%Y/%m/%d')} to {end_date.strftime('%Y/%m/%d')}")
        raise

def fetch_details(id_list):
    ids = ",".join(id_list)
    handle = Entrez.efetch(db="pubmed", id=ids, rettype="medline", retmode="text")
    records = list(Medline.parse(handle))
    logging.info(f"Fetched details for {len(records)} papers.")
    return records

def process_papers(records):
    data = []
    relevant_count = 0
    for record in records:
        article = {
            "PMID": record.get("PMID", ""),
            "Title": record.get("TI", ""),
            "Authors": ", ".join(record.get("AU", [])),
            "Journal": record.get("JT", ""),
            "Abstract": record.get("AB", ""),
            "Publication Type": ", ".join(record.get("PT", [])),
        }
        try:
            relevance = evaluate_relevance(article["Title"], article["Abstract"])
            # If relevant and confidence is > 7, add to data
            if relevance.get("relevance_score", 0) > 8:
                summary = summarize_abstract(article["Abstract"])
                article["Summary"] = summary.get("summary", "")
                article["Topic"] = summary.get("topic", "")
                # Drop Abstract and Publication Type from article
                article.pop("Abstract", None)
                article.pop("Publication Type", None)
                data.append(article)
                relevant_count += 1
            logging.info(f"Paper PMID {article['PMID']} processed successfully. Relevance Score: {relevance.get('relevance_score', 0)}")
        except json.JSONDecodeError as json_err:
            logging.error(f"JSON decode error for paper PMID {article['PMID']}: {json_err}")
        except Exception as e:
            logging.error(f"Error processing paper PMID {article['PMID']}: {e}")
    
    logging.info(f"Processed {len(records)} papers. {relevant_count} were deemed relevant.")
    return pd.DataFrame(data)

def get_rheumatology_papers(start_date: datetime, end_date: datetime, test: bool = False):
    query = build_query()
    logging.info(f"Searching PubMed for papers between {start_date.strftime('%Y-%m-%d')} and {end_date.strftime('%Y-%m-%d')}")
    logging.debug(f"PubMed query: {query}")  # Add this line to log the query
    search_results = search_pubmed(query, start_date, end_date)
    id_list = search_results.get("IdList", [])
    if not id_list:
        logging.info("No new papers found.")
        return pd.DataFrame()
    
    logging.info(f"Fetching details for {len(id_list)} papers.")
    records = fetch_details(id_list)
    if test:
        logging.info("Running in test mode. Processing only 50 papers.")
        return process_papers(records[:50])
    else:
        return process_papers(records)

def cache_dataset(papers_df: pd.DataFrame, start_date: datetime, end_date: datetime):
    try:
        # Convert Dataframe to a dict so it can be uploaded to the Hub
        papers_dict = papers_df.to_dict(orient="records")
        repo_path = f"{end_date.strftime('%Y%m%d')}/papers.jsonl"
        # Upload to the Hub
        api.upload_file(
            path_or_fileobj=json.dumps(papers_dict).encode('utf-8'),
            path_in_repo=repo_path,
            repo_id=DATASET_NAME,
            repo_type="dataset",
            commit_message=f"Add papers from {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}",
            token=HF_TOKEN
        )
        logging.info(f"Papers cached successfully to repository {DATASET_NAME}.")
    except Exception as e:
        logging.error(f"Failed to cache papers: {e}")

def load_cached_papers(start_date: datetime, end_date: datetime, test: bool = False) -> pd.DataFrame:
    try:
        fs = HfFileSystem()
        # Updated dataset_path to point to the specific parquet file within the subdirectory
        dataset_path = f"datasets/cmcmaster/this_week_in_rheumatology/{end_date.strftime('%Y%m%d')}/papers.jsonl"
        if fs.exists(dataset_path):
            dataset = load_dataset("jsonl", data_files={"train": dataset_path}, split="train")
            papers_df = dataset.to_pandas()
            return papers_df
        else:
            logging.info(f"No cache found for {end_date.strftime('%Y-%m-%d')}. Processing new papers.")
            return get_rheumatology_papers(start_date, end_date, test)
    except Exception as e:
        logging.info(f"Error loading cache: {e}. Processing new papers.")
        return get_rheumatology_papers(start_date, end_date, test)

def generate_pdf_newsletter(content: dict, end_date: datetime):
    """Generate a PDF version of the newsletter using pdfkit"""
    try:
        # Convert markdown to HTML
        html_content = markdown2.markdown(content['content'])
        
        # Setup Jinja2 template environment
        env = Environment(loader=FileSystemLoader('templates'))
        template = env.get_template('newsletter_pdf.html')
        
        # Render the template
        html = template.render(
            title=f"This Week in Rheumatology - {content['date']}",
            content=html_content
        )
        
        # Configure PDF options
        options = {
            'page-size': 'A4',
            'margin-top': '2cm',
            'margin-right': '2cm',
            'margin-bottom': '2cm',
            'margin-left': '2cm',
            'encoding': 'UTF-8',
            'enable-local-file-access': None,
            'quiet': ''
        }
        
        # Generate PDF
        pdf_path = f"{end_date.strftime('%Y%m%d')}/newsletter.pdf"
        os.makedirs(os.path.dirname(pdf_path), exist_ok=True)
        
        # Add CSS to HTML string
        html_with_style = f"""
        <html>
        <head>
            <style>
                body {{ 
                    font-family: Arial, sans-serif; 
                    line-height: 1.6; 
                    margin: 0 auto;
                    max-width: 21cm;  /* A4 width */
                    color: #333;
                }}
                h1, h2 {{ color: #2c3e50; }}
                h1 {{ font-size: 24px; margin-top: 2em; }}
                h2 {{ font-size: 20px; margin-top: 1.5em; }}
                a {{ color: #3498db; text-decoration: none; }}
                p {{ margin-bottom: 1em; }}
            </style>
        </head>
        <body>
            {html}
        </body>
        </html>
        """
        
        pdfkit.from_string(html_with_style, pdf_path, options=options)
        
        # Upload PDF to Hub
        with open(pdf_path, 'rb') as f:
            api.upload_file(
                path_or_fileobj=f,
                path_in_repo=pdf_path,
                repo_id=DATASET_NAME,
                repo_type="dataset",
                commit_message=f"Add PDF newsletter for {end_date.strftime('%Y-%m-%d')}",
                token=HF_TOKEN
            )
        logging.info("PDF newsletter generated and uploaded successfully")
        
    except Exception as e:
        logging.error(f"Failed to generate PDF newsletter: {e}")

def generate_and_store_newsletter(papers_df: pd.DataFrame, end_date: datetime):
    if papers_df.empty:
        logging.info("No papers to include in the newsletter.")
        return

    try:
        logging.info(f"Generating newsletter with {len(papers_df)} papers.")
        newsletter_content = compose_newsletter(papers_df)
        newsletter_data = {
            "date": end_date.strftime('%Y-%m-%d'),
            "content": newsletter_content
        }
        
        # Store JSON version
        newsletter_json = json.dumps(newsletter_data, indent=4)
        repo_path = f'{end_date.strftime("%Y%m%d")}/newsletter.json'
        api.upload_file(
            path_or_fileobj=newsletter_json.encode('utf-8'),
            path_in_repo=repo_path,
            repo_id=DATASET_NAME,
            repo_type="dataset",
            commit_message=f"Add newsletter for {end_date.strftime('%Y-%m-%d')}",
            token=HF_TOKEN
        )
        
        # Generate and store PDF version
        generate_pdf_newsletter(newsletter_data, end_date)
        
        logging.info(f"Newsletter (JSON and PDF) successfully pushed to repository {DATASET_NAME}.")
    except Exception as e:
        logging.error(f"Failed to generate or store newsletter: {e}")

def process_new_papers(end_date: datetime = None, test: bool = False):
    end_date = end_date or datetime.now(timezone.utc)
    start_date = end_date - timedelta(days=7)

    print(f"End date: {end_date.strftime('%Y-%m-%d')}")

    logging.info(f"Processing papers for the week: {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}")
    
    # Check if newsletter already exists for this date
    fs = HfFileSystem()
    newsletter_path = f"datasets/{DATASET_NAME}/{end_date.strftime('%Y%m%d')}/newsletter.json"
    if fs.exists(newsletter_path) and not test:
        logging.info(f"Newsletter already exists for {end_date.strftime('%Y-%m-%d')}. Skipping generation.")
        return
    
    papers_df = load_cached_papers(start_date, end_date, test)
    
    if papers_df.empty and not test:
        logging.info("No relevant papers found in cache or recent search.")
        return
    
    logging.info(f"Found {len(papers_df)} relevant papers for the newsletter.")
    
    # Cache the papers_df as a Hugging Face dataset
    cache_dataset(papers_df, start_date, end_date)
    
    # Generate and store the newsletter
    generate_and_store_newsletter(papers_df, end_date)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate a weekly Rheumatology newsletter.")
    parser.add_argument('--end_date', type=str, help='End date for the newsletter in YYYY-MM-DD format. Defaults to today.')
    parser.add_argument('--test', action='store_true', help='Run the script in test mode.')
    args = parser.parse_args()
    
    end_date = None
    if args.end_date:
        try:
            end_date = datetime.strptime(args.end_date, '%Y-%m-%d').replace(tzinfo=timezone.utc)
        except ValueError:
            logging.error("Invalid date format for --end_date. Use YYYY-MM-DD.")
            exit(1)
    
    process_new_papers(end_date, args.test)