import os
import numpy as np
import spacy
import nltk
import tensorflow as tf
import streamlit as st
from streamlit_extras.add_vertical_space import add_vertical_space
from bs4 import BeautifulSoup
from gensim.models import Word2Vec
from nltk.tokenize import word_tokenize
from warnings import filterwarnings
filterwarnings('ignore')



def streamlit_config():

    # page configuration
    st.set_page_config(page_title='Classification', layout='centered')

    # page header transparent color
    page_background_color = """
    <style>

    [data-testid="stHeader"] 
    {
    background: rgba(0,0,0,0);
    }

    </style>
    """
    st.markdown(page_background_color, unsafe_allow_html=True)

    # title and position
    st.markdown(f'<h1 style="text-align: center;">Financial Document Classification</h1>',
                unsafe_allow_html=True)
    add_vertical_space(4)


def text_extract_from_html(html_file):

    # Read the uploaded HTML file
    html_content = html_file.read().decode('utf-8')

    # Parse the HTML Content
    soup = BeautifulSoup(html_content, 'html.parser')

    # Extract the Text
    text = soup.get_text()

    # Split the Text and Remove Unwanted Space
    result = [i.strip() for i in text.split()]

    return result


def text_processing(text):

    # spaCy Engine
    nlp = spacy.load('en_core_web_lg')

    # Process the Text with spaCy
    doc = nlp(' '.join(text))

    # Tokenization, Lemmatization, and Remove Stopwords, punctuation, digits
    token_list = [
                  token.lemma_.lower().strip()
                  for token in doc
                  if token.text.lower() not in nlp.Defaults.stop_words and token.text.isalpha()
                 ]

    if len(token_list) > 0:
        return ' '.join(token_list)
    else:
        return 'empty'
    

def sentence_embeddings(sentence):

    # split the sentence into separate words
    words = word_tokenize(sentence)                         

    # load the trained model
    model = Word2Vec.load(os.path.join('model', 'word2vec_model.bin'))

    # get the vectors of each words
    vectors = [model.wv[word] for word in words if word in model.wv]
    
    if vectors:
        # return the average of vectors
        return np.mean(vectors, axis=0)

    else:
        # we set the model parameter in training ---> vector_size = 300
        return np.zeros(model.vector_size)


def prediction(html_file):

    # Extract the Text from HTML Document
    extracted_text = text_extract_from_html(html_file)

    # Preprocess the Text
    preprocessed_text = text_processing(extracted_text)

    # Text Convert into Embeddings
    features = sentence_embeddings(preprocessed_text)

    # Reshape the features into match the expected input shape of Model
    features = np.expand_dims(features, axis=0)
    features = np.expand_dims(features, axis=2)

    # Convert into Tensors
    features_tensors = tf.convert_to_tensor(features, dtype=tf.float32)

    # Load the Model and Prediction
    model = tf.keras.models.load_model(os.path.join('model', 'model.h5'))
    prediction = model.predict(features_tensors)

    # Find the Maximum Probability Value
    target_label = np.argmax(prediction)

    # Find the Target_Label Name
    target = {0:'Balance Sheets', 1:'Cash Flow', 2:'Income Statement', 3:'Notes', 4:'Others'}
    predicted_class = target[target_label]

    # Find the Confidence
    confidence = round(np.max(prediction)*100, 2)

    add_vertical_space(1)
    st.markdown(f'<h4 style="text-align: center; color: orange;">{confidence}% Match Found</h4>', 
                    unsafe_allow_html=True)
    
    # Display the HTML content in Streamlit
    st.components.v1.html(html_file, height=600, scrolling=True)
    
    add_vertical_space(1)
    st.markdown(f'<h3 style="text-align: center; color: green;">{predicted_class}</h3>', 
                    unsafe_allow_html=True)



# Streamlit Configuration Setup
streamlit_config()


# Check 'punkt' Already Downloaded or Not
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')
    

# File uploader to upload the HTML file
input_file = st.file_uploader('Upload an HTML file', type='html')

if input_file is not None:
    prediction(input_file)