Spaces:
Runtime error
Runtime error
import nltk | |
from nltk.corpus import stopwords | |
from nltk.tokenize import word_tokenize, sent_tokenize | |
import traceback | |
import sys | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
nltk.download('stopwords') | |
nltk.download('punkt') | |
def summary_nlp(text): | |
stopWords = set(stopwords.words("english")) | |
words = word_tokenize(text) | |
freqTable = dict() | |
for word in words: | |
word = word.lower() | |
if word in stopWords: | |
continue | |
if word in freqTable: | |
freqTable[word] += 1 | |
else: | |
freqTable[word] = 1 | |
sentences = sent_tokenize(text) | |
sentenceValue = dict() | |
for sentence in sentences: | |
for word, freq in freqTable.items(): | |
if word in sentence.lower(): | |
if sentence in sentenceValue: | |
sentenceValue[sentence] += freq | |
else: | |
sentenceValue[sentence] = freq | |
sumValues = 0 | |
for sentence in sentenceValue: | |
sumValues += sentenceValue[sentence] | |
average = int(sumValues / len(sentenceValue)) | |
summary = '' | |
for sentence in sentences: | |
if (sentence in sentenceValue) and (sentenceValue[sentence] > (1.2 * average)): | |
summary += " " + sentence | |
return summary | |
def Summary_BART(text): | |
checkpoint = "sshleifer/distilbart-cnn-12-6" | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) | |
inputs = tokenizer(text, | |
max_length=1024, | |
truncation=True, | |
return_tensors="pt") | |
summary_ids = model.generate(inputs["input_ids"]) | |
summary = tokenizer.batch_decode(summary_ids, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False) | |
return summary[0] |