import gradio as gr
from pathlib import Path
import os

os.system('pip install tensorflow')
os.system('pip install nltk')

from transformers import pipeline
from transformers import MarianMTModel, MarianTokenizer
from nltk.tokenize import sent_tokenize
from nltk.tokenize import LineTokenizer
import math
import torch
import nltk
nltk.download('punkt')

docs = None

if torch.cuda.is_available():  
  dev = "cuda"
else:  
  dev = "cpu" 
device = torch.device(dev)

# Definimos los modelos:
mname = "Helsinki-NLP/opus-mt-es-en"
tokenizer_es_en = MarianTokenizer.from_pretrained(mname)
model_es_en = MarianMTModel.from_pretrained(mname)
model_es_en.to(device)

mname = "Helsinki-NLP/opus-mt-en-es"
tokenizer_en_es = MarianTokenizer.from_pretrained(mname)
model_en_es = MarianMTModel.from_pretrained(mname)
model_en_es.to(device)

lt = LineTokenizer()

question_answerer = pipeline("question-answering", model='distilbert-base-cased-distilled-squad')

def request_pathname(files):
    if files is None:
        return [[]]
    return [[file.name, file.name.split('/')[-1]] for file in files]

def traducir_parrafos(parrafos, tokenizer, model, tam_bloque=8, ):
  parrafos_traducidos = []
  for parrafo in parrafos:
    frases = sent_tokenize(parrafo)
    batches = math.ceil(len(frases) / tam_bloque)     
    traducido = []
    for i in range(batches):

        bloque_enviado = frases[i*tam_bloque:(i+1)*tam_bloque]
        model_inputs = tokenizer(bloque_enviado, return_tensors="pt", 
                                 padding=True, truncation=True, 
                                 max_length=500).to(device)
        with torch.no_grad():
            bloque_traducido = model.generate(**model_inputs)
        traducido += bloque_traducido
    traducido = [tokenizer.decode(t, skip_special_tokens=True) for t in traducido]
    parrafos_traducidos += [" ".join(traducido)]
  return parrafos_traducidos

def traducir_es_en(texto):
    parrafos = lt.tokenize(texto)
    par_tra = traducir_parrafos(parrafos, tokenizer_es_en, model_es_en) 
    return "\n".join(par_tra)    

def traducir_en_es(texto):
    parrafos = lt.tokenize(texto)
    par_tra = traducir_parrafos(parrafos, tokenizer_en_es, model_en_es) 
    return "\n".join(par_tra)

def validate_dataset(dataset):
    global docs
    docs = None  # clear it out if dataset is modified
    docs_ready = dataset.iloc[-1, 0] != ""
    if docs_ready:
        return "✨Listo✨"
    else:
        return "⚠️Esperando documentos..."

def do_ask(question, button, dataset):
    global docs
    docs_ready = dataset.iloc[-1, 0] != ""
    if button == "✨Listo✨" and docs_ready:
        for _, row in dataset.iterrows():
            path = row['filepath']
            text = Path(f'{path}').read_text()
            text_en = traducir_es_en(text)
            QA_input = {
                'question': traducir_es_en(question),
                'context': text_en
            }
            return traducir_en_es(question_answerer(QA_input)['answer'])
    else:        
        return ""

# def do_ask(question, button, dataset, progress=gr.Progress()):
#     global docs
#     docs_ready = dataset.iloc[-1, 0] != ""
#     if button == "✨Listo✨" and docs_ready:
#         if docs is None:  # don't want to rebuild index if it's already built
#             import paperqa
#             docs = paperqa.Docs()
#             # dataset is pandas dataframe
#             for _, row in dataset.iterrows():
#                 key = None
#                 if ',' not in row['citation string']:
#                     key = row['citation string']
#                 docs.add(row['filepath'], row['citation string'], key=key)
#     else:
#         return ""
#     progress(0, "Construyendo índices...")
#     docs._build_faiss_index()
#     progress(0.25, "Encolando...")
#     result = docs.query(question)
#     progress(1.0, "¡Hecho!")
#     return result.formatted_answer, result.context


with gr.Blocks() as demo:
    gr.Markdown("""
    # Document Question and Answer adaptado al castellano por Pablo Ascorbe.

    Este espacio ha sido clonado y adaptado de: https://huggingface.co/spaces/whitead/paper-qa

    La idea es utilizar un modelo preentrenado de HuggingFace como "distilbert-base-cased-distilled-squad"
    y responder las preguntas en inglés, para ello, será necesario hacer primero una traducción de los textos en castellano
    a inglés y luego volver a traducir en sentido contrario.

    ## Instrucciones:

    Adjunte su documento, ya sea en formato .txt o .pdf, y pregunte lo que desee.
    
    """)
    uploaded_files = gr.File(
        label="Sus documentos subidos (PDF o txt)", file_count="multiple", )
    dataset = gr.Dataframe(
        headers=["filepath", "citation string"],
        datatype=["str", "str"],
        col_count=(2, "fixed"),
        interactive=True,
        label="Documentos y citas"
    )
    buildb = gr.Textbox("⚠️Esperando documentos...",
                        label="Estado", interactive=False, show_label=True)
    dataset.change(validate_dataset, inputs=[
                   dataset], outputs=[buildb])
    uploaded_files.change(request_pathname, inputs=[
                          uploaded_files], outputs=[dataset])
    query = gr.Textbox(
        placeholder="Introduzca su pregunta aquí...", label="Pregunta")
    ask = gr.Button("Preguntar")
    gr.Markdown("## Respuesta")
    answer = gr.Markdown(label="Respuesta")
    with gr.Accordion("Contexto", open=False):
        gr.Markdown(
            "### Contexto\n\nEl siguiente contexto ha sido utilizado para generar la respuesta:")
        context = gr.Markdown(label="Contexto")
    # ask.click(fn=do_ask, inputs=[query, buildb,
    #                              dataset], outputs=[answer, context])
    ask.click(fn=do_ask, inputs=[query, buildb,
                                 dataset], outputs=[answer])

demo.queue(concurrency_count=20)
demo.launch(show_error=True)