Spaces:
Sleeping
Sleeping
File size: 2,016 Bytes
d5e322f f2df712 d5e322f 500905f d5e322f 500905f d5e322f 37ec07c dcb5fc8 37ec07c 0072cc0 c9bc296 37ec07c d5e322f 37ec07c 03e9f5b c9bc296 d5e322f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch
import csv, os, sys
import argparse
from keybert import KeyBERT
from sentence_transformers import SentenceTransformer
class KeyWordExtractor():
def __init__(self):
KWE_PRETRAINED = 'medmediani/Arabic-KW-Mdel'
self.SEQ_LENGTH = 512
self.MAX_KW_NGS=3
self.NKW=3
#self.device = torch.device('cpu')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sentence_model = SentenceTransformer(KWE_PRETRAINED)
sentence_model.to(self.device)
self.kw_model = KeyBERT(model=sentence_model)
#self.kw_model.to(self.device)
def _extract_by_paragraph(self, ctxt, nkws=None, max_kw_ngs=None):
paragraphs=map(str.strip,ctxt.split("\n"))
kws=[]
for paragraph in paragraphs:
if paragraph:
kws.extend(self.kw_model.extract_keywords(paragraph, keyphrase_ngram_range=(1, max_kw_ngs),
top_n=nkws,
#use_maxsum=True,nr_candidates=20, top_n=5,
#use_mmr=True,
diversity=0.8,
stop_words=None)
)
print("KWS=",kws,file=sys.stderr)
kws.sort(key=lambda x: x[1],reverse=True)
ukws=set()
for kw,_ in kws:
if len(ukws)>=nkws:
return ukws
ukws.add(kw)
return ukws
def extract(self, ctxt, nkws=None, max_kw_ngs=None):
nkws= nkws if nkws is not None else self.NKW
max_kw_ngs=max_kw_ngs if max_kw_ngs is not None else self.MAX_KW_NGS
#Since we are taking only 512 tokens, let's do by paragraph
kw=self._extract_by_paragraph(ctxt,nkws,max_kw_ngs)
return ", ".join(kw)
return ", ".join(w for w,_ in kw)
|