File size: 5,080 Bytes
06b4325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# external libraries
import pickle
import numpy as np
import pandas as pd
import ast
from transformers import BartForConditionalGeneration, BartTokenizer

# internal libraries
from config import config
from src.nltk_utilities import NltkSegmentizer
from src.stanza_utilities import StanzaSegmentizer
from src.spacy_utilities import SpacySegmentizer
from src.preprocessing import remove_patterns
from src.summarization_utilities import SummarizationUtilities, BARTSummarizer, T5Summarizer, ProphetNetSummarizer


nltkUtilsObj = None

sentTransfModelUtilsObj = pickle.load(open(config.sent_trans_path, 'rb'))
sentTransfModelUtilsObj.model = sentTransfModelUtilsObj.model.to('cpu')

TopicModelling = ''

summUtilsObj = None

def text_to_sentences(data):

  list_sentences = [*nltkUtilsObj.segment_into_sentences(data)]

  return list_sentences

def preprocess(list_sentences, sentTransfModelUtilsObj):

  list_sentences = [remove_patterns(x) for x in list_sentences]
  list_sentences_per_doc_embeddings = [sentTransfModelUtilsObj.get_embeddings(x) for x in list_sentences if len(x) > 0]
  return list_sentences_per_doc_embeddings, list_sentences

def get_emb_cluster_topic(sentTransfModelUtilsObj):

  df_latVectorRep = pd.read_csv(TopicModelling)
  df_latVectorRep["sentence_from_words"] = df_latVectorRep["list_topic_words"].map(lambda x: " ".join(ast.literal_eval(x)))
  list_embeddings_cluster_sentences = list()

  for index, row in df_latVectorRep.iterrows():
      list_embeddings_cluster_sentences.append(sentTransfModelUtilsObj.get_embeddings(row["sentence_from_words"]))

  return list_embeddings_cluster_sentences, df_latVectorRep

def compute_similarity_matrix(list_sentences_per_doc_embeddings, list_sentences, sentTransfModelUtilsObj):

  list_embeddings_cluster_sentences, df_latVectorRep = get_emb_cluster_topic(sentTransfModelUtilsObj)

  similarity_matrix = np.zeros((len(list_embeddings_cluster_sentences), len(list_sentences_per_doc_embeddings)))

  for i, cluster_embedding in enumerate(list_embeddings_cluster_sentences):
    for j, sentence_emebedding in enumerate(list_sentences_per_doc_embeddings):
        similarity_matrix[i][j] = sentTransfModelUtilsObj.compute_cosine_similarity(cluster_embedding, sentence_emebedding)

  list_index_topics_within_matrix = np.argmax(similarity_matrix, axis=0)

  dict_topic_sentences = dict()

  for index_sentence, index_id_topic in enumerate(list_index_topics_within_matrix):
      label_class = df_latVectorRep.iloc[index_id_topic]["label_class"]
      
      if label_class not in dict_topic_sentences.keys():
          dict_topic_sentences[label_class] = list()
      dict_topic_sentences[label_class].append(list_sentences[index_sentence])

  return dict_topic_sentences

def summarize(dict_topic_sentences):

  summaries_report = dict()
  for class_label in dict_topic_sentences.keys():
      
      summaries_report[class_label] = {}
      
      if len(dict_topic_sentences[class_label]) >= config.MIN_NUM_SENTENCES_FOR_SUMMARY_CREATION:
          summaries_report[class_label]["source"] = dict_topic_sentences[class_label]
          summaries_report[class_label]["summary"] = summUtilsObj.summarize(" ".join(dict_topic_sentences[class_label]))
          print(dict_topic_sentences[class_label])
      else:
          summaries_report[class_label]["summary"] = "X -> not possible to generate a summary due to threshold"
          summaries_report[class_label]["source"] = dict_topic_sentences[class_label]

  return summaries_report

def define_models(MODELS):

  global TopicModelling   
  global summUtilsObj
  global nltkUtilsObj
  
  if MODELS['summarizer'] == 'Pegasus':
    summUtilsObj = SummarizationUtilities()
  elif MODELS['summarizer'] == 'Bart':
    summUtilsObj = BARTSummarizer()
  elif MODELS['summarizer'] == 'T5-base':
    summUtilsObj = T5Summarizer()
  elif MODELS['summarizer'] == 'Prophetnet':
    summUtilsObj = ProphetNetSummarizer()

  if MODELS['topic_modelling'] == 'BERTopic':
    TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS
  elif MODELS['topic_modelling'] == 'LDA':
    TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_CTM
  elif MODELS['topic_modelling'] == 'CTM':
    TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_LDA
  elif MODELS['topic_modelling'] == 'NMF':
    TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_NMF
  elif MODELS['topic_modelling'] == 'Top2Vec':
    TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_TOP2VEC

  if MODELS['segmentizer'] == 'Nltk':
    nltkUtilsObj = NltkSegmentizer()
  if MODELS['segmentizer'] == 'Spacy':
    nltkUtilsObj = SpacySegmentizer()
  elif MODELS['segmentizer'] == 'Stanza':
    nltkUtilsObj = StanzaSegmentizer()

def run(data, MODELS):

  define_models(MODELS)
  
  data_sentences = text_to_sentences(data)
  data_embed, list_sentences = preprocess(data_sentences, sentTransfModelUtilsObj)

  dict_topic_sentences = compute_similarity_matrix(data_embed, list_sentences, sentTransfModelUtilsObj)
  summaries_report = summarize(dict_topic_sentences)

  return summaries_report