Spaces:
Paused
Paused
| from statistics import mean | |
| import sys | |
| import os | |
| import json | |
| from datetime import datetime | |
| import warnings | |
| from pprint import pprint | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| warnings.filterwarnings("ignore") | |
| sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) | |
| # sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', 'financial_dataset'))) | |
| dataset_dir = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'financial_dataset')) | |
| sys.path.append(dataset_dir) | |
| from load_test_data import get_labels_df, get_texts | |
| from app import ( | |
| summarize, | |
| read_and_split_file, | |
| get_label_prediction | |
| ) | |
| from config import ( | |
| labels, headers_inference_api, headers_inference_endpoint, | |
| # summarization_prompt_template, | |
| prompt_template, | |
| # task_explain_for_predictor_model, | |
| summarizers, predictors, summary_scores_template, | |
| summarization_system_msg, summarization_user_prompt, prediction_user_prompt, prediction_system_msg, | |
| # prediction_prompt, | |
| chat_prompt, instruction_prompt | |
| ) | |
| def split_text(text, chunk_size=1200, chunk_overlap=200): | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, chunk_overlap=chunk_overlap, | |
| length_function = len, separators=[" ", ",", "\n"] | |
| ) | |
| text_chunks = text_splitter.create_documents([text]) | |
| return text_chunks | |
| predictions = { | |
| # method: {name: {'actual': []}} | |
| 'summarization+classification': { | |
| 'bart-pegasus+gpt': [], # list of pred_labels | |
| 'gpt+gpt': [], | |
| }, | |
| 'chunk_classification': {}, | |
| 'embedding_classification': {}, | |
| 'zero-shot_classification': {}, | |
| 'full_text_classification': {}, | |
| 'QA_classification': {} | |
| } | |
| # if __name__ == '__main__': | |
| labels_dir = dataset_dir + '/csvs/' | |
| df = get_labels_df(labels_dir) | |
| texts_dir = dataset_dir + '/txts/' | |
| texts = get_texts(texts_dir) | |
| # print(len(df), len(texts)) | |
| # print(mean(list(map(len, texts)))) | |
| # summarization+classification | |
| # for selected_summarizer in summarizers: | |
| # print(selected_summarizer) | |
| # # for selected_predictor in predictors: | |
| # # predictions['summarization+classification'][selected_summarizer + '+' + selected_predictor] = [] | |
| # for text, (idx, (year, label, company)) in zip(texts, df.iterrows()): | |
| # print(year, label, company) | |
| # # summary_filename = f'./texts/{year}_{company}_{selected_summarizer}_summary.txt' | |
| # summary_filename = f'./texts/{company}_{year}_{selected_summarizer}_summary.txt' | |
| # if os.path.isfile(summary_filename): | |
| # print('Loading summary from the cache') | |
| # with open(summary_filename, 'r') as f: | |
| # summary = f.read() | |
| # else: | |
| # print(f'Making request to {selected_summarizer} to summarize {company}, {year}') | |
| # text_chunks = split_text(text, | |
| # chunk_size=summarizers[selected_summarizer]['chunk_size'], | |
| # chunk_overlap=100) | |
| # summary, summary_score = summarize(selected_summarizer, text_chunks) | |
| # with open(summary_filename, 'w') as f: | |
| # f.write(summary) | |
| # print('-' * 50) | |
| # # break | |
| # # summary_chunks = split_text(summary, chunk_size=3_600) | |
| # # predicted_label = get_label_prediction(selected_predictor, summary_chunks) | |
| # # if predicted_label in labels: | |
| # # predictions['summarization+classification'][selected_summarizer + '+' + selected_predictor].append(predicted_label) | |
| # print() | |
| # break | |
| # # chunk_classification | |
| # for selected_predictor in predictors: | |
| # predictions['chunk_classification'][selected_predictor] = [] | |
| # for text, (idx, (year, label, company)) in zip(texts, df.iterrows()): | |
| # print(year, label, company) | |
| # text_chunks = split_text(text, chunk_size=3600) | |
| # predicted_label = get_label_prediction(selected_predictor, text_chunks) | |
| # if predicted_label in labels: | |
| # predictions['summarization+chunk_classification'][selected_predictor].append(predicted_label) | |
| # print('-' * 50) | |
| # with open(f'predictions/predictions_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.json', 'w') as json_file: | |
| # json.dump(predictions, json_file, indent=4) |