Spaces:
Runtime error
Runtime error
# external libraries | |
import pickle | |
import numpy as np | |
import pandas as pd | |
import ast | |
from transformers import BartForConditionalGeneration, BartTokenizer | |
# internal libraries | |
from config import config | |
from src.nltk_utilities import NltkSegmentizer | |
from src.stanza_utilities import StanzaSegmentizer | |
from src.spacy_utilities import SpacySegmentizer | |
from src.preprocessing import remove_patterns | |
from src.summarization_utilities import SummarizationUtilities, BARTSummarizer, T5Summarizer, ProphetNetSummarizer | |
nltkUtilsObj = None | |
sentTransfModelUtilsObj = pickle.load(open(config.sent_trans_path, 'rb')) | |
sentTransfModelUtilsObj.model = sentTransfModelUtilsObj.model.to('cpu') | |
TopicModelling = '' | |
summUtilsObj = None | |
def text_to_sentences(data): | |
list_sentences = [*nltkUtilsObj.segment_into_sentences(data)] | |
return list_sentences | |
def preprocess(list_sentences, sentTransfModelUtilsObj): | |
list_sentences = [remove_patterns(x) for x in list_sentences] | |
list_sentences_per_doc_embeddings = [sentTransfModelUtilsObj.get_embeddings(x) for x in list_sentences if len(x) > 0] | |
return list_sentences_per_doc_embeddings, list_sentences | |
def get_emb_cluster_topic(sentTransfModelUtilsObj): | |
df_latVectorRep = pd.read_csv(TopicModelling) | |
df_latVectorRep["sentence_from_words"] = df_latVectorRep["list_topic_words"].map(lambda x: " ".join(ast.literal_eval(x))) | |
list_embeddings_cluster_sentences = list() | |
for index, row in df_latVectorRep.iterrows(): | |
list_embeddings_cluster_sentences.append(sentTransfModelUtilsObj.get_embeddings(row["sentence_from_words"])) | |
return list_embeddings_cluster_sentences, df_latVectorRep | |
def compute_similarity_matrix(list_sentences_per_doc_embeddings, list_sentences, sentTransfModelUtilsObj): | |
list_embeddings_cluster_sentences, df_latVectorRep = get_emb_cluster_topic(sentTransfModelUtilsObj) | |
similarity_matrix = np.zeros((len(list_embeddings_cluster_sentences), len(list_sentences_per_doc_embeddings))) | |
for i, cluster_embedding in enumerate(list_embeddings_cluster_sentences): | |
for j, sentence_emebedding in enumerate(list_sentences_per_doc_embeddings): | |
similarity_matrix[i][j] = sentTransfModelUtilsObj.compute_cosine_similarity(cluster_embedding, sentence_emebedding) | |
list_index_topics_within_matrix = np.argmax(similarity_matrix, axis=0) | |
dict_topic_sentences = dict() | |
for index_sentence, index_id_topic in enumerate(list_index_topics_within_matrix): | |
label_class = df_latVectorRep.iloc[index_id_topic]["label_class"] | |
if label_class not in dict_topic_sentences.keys(): | |
dict_topic_sentences[label_class] = list() | |
dict_topic_sentences[label_class].append(list_sentences[index_sentence]) | |
return dict_topic_sentences | |
def summarize(dict_topic_sentences): | |
summaries_report = dict() | |
for class_label in dict_topic_sentences.keys(): | |
summaries_report[class_label] = {} | |
if len(dict_topic_sentences[class_label]) >= config.MIN_NUM_SENTENCES_FOR_SUMMARY_CREATION: | |
summaries_report[class_label]["source"] = dict_topic_sentences[class_label] | |
summaries_report[class_label]["summary"] = summUtilsObj.summarize(" ".join(dict_topic_sentences[class_label])) | |
print(dict_topic_sentences[class_label]) | |
else: | |
summaries_report[class_label]["summary"] = "X -> not possible to generate a summary due to threshold" | |
summaries_report[class_label]["source"] = dict_topic_sentences[class_label] | |
return summaries_report | |
def define_models(MODELS): | |
global TopicModelling | |
global summUtilsObj | |
global nltkUtilsObj | |
if MODELS['summarizer'] == 'Pegasus': | |
summUtilsObj = SummarizationUtilities() | |
elif MODELS['summarizer'] == 'Bart': | |
summUtilsObj = BARTSummarizer() | |
elif MODELS['summarizer'] == 'T5-base': | |
summUtilsObj = T5Summarizer() | |
elif MODELS['summarizer'] == 'Prophetnet': | |
summUtilsObj = ProphetNetSummarizer() | |
if MODELS['topic_modelling'] == 'BERTopic': | |
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS | |
elif MODELS['topic_modelling'] == 'LDA': | |
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_CTM | |
elif MODELS['topic_modelling'] == 'CTM': | |
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_LDA | |
elif MODELS['topic_modelling'] == 'NMF': | |
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_NMF | |
elif MODELS['topic_modelling'] == 'Top2Vec': | |
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_TOP2VEC | |
if MODELS['segmentizer'] == 'Nltk': | |
nltkUtilsObj = NltkSegmentizer() | |
if MODELS['segmentizer'] == 'Spacy': | |
nltkUtilsObj = SpacySegmentizer() | |
elif MODELS['segmentizer'] == 'Stanza': | |
nltkUtilsObj = StanzaSegmentizer() | |
def run(data, MODELS): | |
define_models(MODELS) | |
data_sentences = text_to_sentences(data) | |
data_embed, list_sentences = preprocess(data_sentences, sentTransfModelUtilsObj) | |
dict_topic_sentences = compute_similarity_matrix(data_embed, list_sentences, sentTransfModelUtilsObj) | |
summaries_report = summarize(dict_topic_sentences) | |
return summaries_report |