numBery's picture
Update app.py
44f0002
raw
history blame
No virus
3.25 kB
import torch
from keybert import KeyBERT
from sentence_transformers import SentenceTransformer
from keyphrase_vectorizers import KeyphraseCountVectorizer
from transformers import T5ForConditionalGeneration,T5Tokenizer
import nltk
from nltk.tokenize import sent_tokenize
nltk.download('stopwords')
nltk.download('punkt')
from huggingface_hub import snapshot_download, HfFolder
import streamlit as st
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HfFolder.save_token(st.secrets["hf-auth-token"])
# Load KeyBert Model
tmp_model = SentenceTransformer('valurank/MiniLM-L6-Keyword-Extraction', use_auth_token=True)
kw_extractor = KeyBERT(tmp_model)
# Load T5 for Paraphrasing
t5_model = T5ForConditionalGeneration.from_pretrained('valurank/t5-paraphraser', use_auth_token=True)
t5_tokenizer = T5Tokenizer.from_pretrained('t5-base')
t5_model = t5_model.to(device)
def get_keybert_results_with_vectorizer(text, number_of_results=20):
keywords = kw_extractor.extract_keywords(text, vectorizer=KeyphraseCountVectorizer(), stop_words=None, top_n=number_of_results)
return keywords
def t5_paraphraser(text, number_of_results=10):
text = "paraphrase: " + text + " </s>"
max_len = 2048
encoding = t5_tokenizer.encode_plus(text, pad_to_max_length=True, return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
beam_outputs = t5_model.generate(
input_ids=input_ids, attention_mask=attention_masks,
do_sample=True,
max_length=2048,
top_k=50,
top_p=0.95,
early_stopping=True,
num_return_sequences=number_of_results
)
final_outputs =[]
for beam_output in beam_outputs:
sent = t5_tokenizer.decode(beam_output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
final_outputs.append(sent)
return final_outputs
#### Extract Sentences with Keywords -> Paraphrase multiple versions -> Extract Keywords again
def extract_paraphrased_sentences(article):
original_keywords = [i[0] for i in get_keybert_results_with_vectorizer(article)]
article_sentences = sent_tokenize(article)
target_sentences = [sent for sent in article_sentences if any(kw[0] in sent for kw in original_keywords)]
start1 = time.time()
t5_paraphrasing_keywords = []
for sent in target_sentences:
### T5
t5_paraphrased = t5_paraphraser(sent)
t5_keywords = [get_keybert_results_with_vectorizer(i) for i in t5_paraphrased]
t5_keywords = [word[0] for s in t5_keywords for word in s]
t5_paraphrasing_keywords.extend(t5_keywords)
print(f'T5 Approach2 PARAPHRASING RUNTIME: {time.time()-start1}\n')
print('T5 Keywords Extracted: \n{}\n\n'.format(t5_paraphrasing_keywords))
print('----------------------------')
print('T5 Unique New Keywords Extracted: \n{}\n\n'.format([i for i in set(t5_paraphrasing_keywords)
if i not in original_keywords]))
return t5_paraphrasing_keywords
doc = st.text_area("Enter a custom document")
if doc:
keywords = extract_paraphrased_sentences(doc)
st.write(keywords)