Spaces:
Runtime error
Runtime error
| class HebEMO: | |
| def __init__(self, device=-1, emotions = ['anticipation', 'joy', 'trust', 'fear', 'surprise', 'anger', | |
| 'sadness', 'disgust']): | |
| from transformers import pipeline | |
| from tqdm import tqdm | |
| self.device = device | |
| self.emotions = emotions | |
| self.hebemo_models = {} | |
| for emo in tqdm(emotions): | |
| self.hebemo_models[emo] = pipeline( | |
| "sentiment-analysis", | |
| model="avichr/hebEMO_"+emo, | |
| tokenizer="avichr/heBERT", | |
| device = self.device #-1 run on CPU, else - device ID | |
| ) | |
| def hebemo(self, text = None, input_path=False, save_results=False, read_lines=False, plot=False): | |
| ''' | |
| text (str): a text or list of text to analyze | |
| input_path(str): the path to the text file (txt file, each row for different instance) | |
| returns pandas DataFrame of the analyzed texts and save it to the same dir of the input file | |
| ''' | |
| from pyplutchik import plutchik | |
| from spider_plot import spider_plot | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import time | |
| import torch | |
| from tqdm import tqdm | |
| if text is None and type(input_path) is str: | |
| # read the file | |
| with open(input_path, encoding='utf8') as p: | |
| txt = p.readlines() | |
| elif text is not None and (input_path is None or input_path is False): | |
| if type(text) is str: | |
| if read_lines: | |
| txt = text.split('\n') | |
| else: | |
| txt = [text] | |
| elif type(text) is list: | |
| txt = text | |
| else: | |
| raise ValueError('text should be text or list of text.') | |
| else: | |
| raise ValueError('you should provide a text string, list of strings or text path.') | |
| # run hebEMO | |
| hebEMO_df = pd.DataFrame(txt) | |
| for emo in tqdm(self.emotions): | |
| x = self.hebemo_models[emo](txt) | |
| hebEMO_df = hebEMO_df.join(pd.DataFrame(x).rename(columns = {'label': emo, 'score':'confidence_'+emo})) | |
| del x | |
| torch.cuda.empty_cache() | |
| hebEMO_df = hebEMO_df.applymap(lambda x: 0 if x=='LABEL_0' else 1 if x=='LABEL_1' else x) | |
| if save_results is not False: | |
| gen_name = str(int(time.time()*1e7)) | |
| if type(save_results) is str: | |
| hebEMO_df.to_csv(save_results+'/'+gen_name+'_heEMOed.csv', encoding='utf8') | |
| else: | |
| hebEMO_df.to_csv(gen_name+'_heEMOed.csv', encoding='utf8') | |
| if plot: | |
| hebEMO = pd.DataFrame() | |
| for emo in hebEMO_df.columns[1::2]: | |
| hebEMO[emo] = abs(hebEMO_df[emo]-(1-hebEMO_df['confidence_'+emo])) | |
| for i in range(0,1): | |
| try: ax = plutchik(hebEMO.to_dict(orient='records')[i]) | |
| except: ax = spider_plot(hebEMO) | |
| print(hebEMO_df[0][i]) | |
| plt.show() | |
| return (hebEMO_df[0][i], ax) | |
| else: | |
| return (hebEMO_df) | |
| # HebEMO_model = HebEMO() |