alpertml's picture
Upload 4 files
06b4325
# external libraries
import pickle
import numpy as np
import pandas as pd
import ast
from transformers import BartForConditionalGeneration, BartTokenizer
# internal libraries
from config import config
from src.nltk_utilities import NltkSegmentizer
from src.stanza_utilities import StanzaSegmentizer
from src.spacy_utilities import SpacySegmentizer
from src.preprocessing import remove_patterns
from src.summarization_utilities import SummarizationUtilities, BARTSummarizer, T5Summarizer, ProphetNetSummarizer
nltkUtilsObj = None
sentTransfModelUtilsObj = pickle.load(open(config.sent_trans_path, 'rb'))
sentTransfModelUtilsObj.model = sentTransfModelUtilsObj.model.to('cpu')
TopicModelling = ''
summUtilsObj = None
def text_to_sentences(data):
list_sentences = [*nltkUtilsObj.segment_into_sentences(data)]
return list_sentences
def preprocess(list_sentences, sentTransfModelUtilsObj):
list_sentences = [remove_patterns(x) for x in list_sentences]
list_sentences_per_doc_embeddings = [sentTransfModelUtilsObj.get_embeddings(x) for x in list_sentences if len(x) > 0]
return list_sentences_per_doc_embeddings, list_sentences
def get_emb_cluster_topic(sentTransfModelUtilsObj):
df_latVectorRep = pd.read_csv(TopicModelling)
df_latVectorRep["sentence_from_words"] = df_latVectorRep["list_topic_words"].map(lambda x: " ".join(ast.literal_eval(x)))
list_embeddings_cluster_sentences = list()
for index, row in df_latVectorRep.iterrows():
list_embeddings_cluster_sentences.append(sentTransfModelUtilsObj.get_embeddings(row["sentence_from_words"]))
return list_embeddings_cluster_sentences, df_latVectorRep
def compute_similarity_matrix(list_sentences_per_doc_embeddings, list_sentences, sentTransfModelUtilsObj):
list_embeddings_cluster_sentences, df_latVectorRep = get_emb_cluster_topic(sentTransfModelUtilsObj)
similarity_matrix = np.zeros((len(list_embeddings_cluster_sentences), len(list_sentences_per_doc_embeddings)))
for i, cluster_embedding in enumerate(list_embeddings_cluster_sentences):
for j, sentence_emebedding in enumerate(list_sentences_per_doc_embeddings):
similarity_matrix[i][j] = sentTransfModelUtilsObj.compute_cosine_similarity(cluster_embedding, sentence_emebedding)
list_index_topics_within_matrix = np.argmax(similarity_matrix, axis=0)
dict_topic_sentences = dict()
for index_sentence, index_id_topic in enumerate(list_index_topics_within_matrix):
label_class = df_latVectorRep.iloc[index_id_topic]["label_class"]
if label_class not in dict_topic_sentences.keys():
dict_topic_sentences[label_class] = list()
dict_topic_sentences[label_class].append(list_sentences[index_sentence])
return dict_topic_sentences
def summarize(dict_topic_sentences):
summaries_report = dict()
for class_label in dict_topic_sentences.keys():
summaries_report[class_label] = {}
if len(dict_topic_sentences[class_label]) >= config.MIN_NUM_SENTENCES_FOR_SUMMARY_CREATION:
summaries_report[class_label]["source"] = dict_topic_sentences[class_label]
summaries_report[class_label]["summary"] = summUtilsObj.summarize(" ".join(dict_topic_sentences[class_label]))
print(dict_topic_sentences[class_label])
else:
summaries_report[class_label]["summary"] = "X -> not possible to generate a summary due to threshold"
summaries_report[class_label]["source"] = dict_topic_sentences[class_label]
return summaries_report
def define_models(MODELS):
global TopicModelling
global summUtilsObj
global nltkUtilsObj
if MODELS['summarizer'] == 'Pegasus':
summUtilsObj = SummarizationUtilities()
elif MODELS['summarizer'] == 'Bart':
summUtilsObj = BARTSummarizer()
elif MODELS['summarizer'] == 'T5-base':
summUtilsObj = T5Summarizer()
elif MODELS['summarizer'] == 'Prophetnet':
summUtilsObj = ProphetNetSummarizer()
if MODELS['topic_modelling'] == 'BERTopic':
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS
elif MODELS['topic_modelling'] == 'LDA':
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_CTM
elif MODELS['topic_modelling'] == 'CTM':
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_LDA
elif MODELS['topic_modelling'] == 'NMF':
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_NMF
elif MODELS['topic_modelling'] == 'Top2Vec':
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_TOP2VEC
if MODELS['segmentizer'] == 'Nltk':
nltkUtilsObj = NltkSegmentizer()
if MODELS['segmentizer'] == 'Spacy':
nltkUtilsObj = SpacySegmentizer()
elif MODELS['segmentizer'] == 'Stanza':
nltkUtilsObj = StanzaSegmentizer()
def run(data, MODELS):
define_models(MODELS)
data_sentences = text_to_sentences(data)
data_embed, list_sentences = preprocess(data_sentences, sentTransfModelUtilsObj)
dict_topic_sentences = compute_similarity_matrix(data_embed, list_sentences, sentTransfModelUtilsObj)
summaries_report = summarize(dict_topic_sentences)
return summaries_report