Arabic-KW / kwextractor.py
medmediani
first app
d5e322f
raw
history blame
1.22 kB
import torch
import csv, os, sys
import argparse
from keybert import KeyBERT
from sentence_transformers import SentenceTransformer
class KeyWordExtractor():
def __init__(self):
KWE_PRETRAINED = 'medmediani/Arabic-KW-Mdel/model'
self.SEQ_LENGTH = 512
self.MAX_KW_NGS=3
self.NKW=3
#self.device = torch.device('cpu')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sentence_model = SentenceTransformer(KWE_PRETRAINED)
self.kw_model = KeyBERT(model=sentence_model)
self.kw_model.to(self.device)
def extract(self, ctxt, nkws=None, max_kw_ngs=None):
nkws= nkws if nkws is not None else self.NKW
max_kw_ngs=max_kw_ngs if max_kw_ngs is not None else self.MAX_KW_NGS
kw=self.kw_model.extract_keywords(ctxt, keyphrase_ngram_range=(1, max_kw_ngs),
top_n=nkws,
#use_maxsum=True,nr_candidates=20, top_n=5,
#use_mmr=True, diversity=0.1,
stop_words=None)
return ", ".join(w for w,_ in kw)