Spaces:
Sleeping
Sleeping
| import requests | |
| import nltk | |
| import random | |
| import json | |
| import os | |
| import pickle | |
| import re | |
| nltk.download('punkt') | |
| hf_tokens = [] | |
| filepath = __file__.replace("\\", "/").replace("utils.py", "") | |
| with open(filepath + "data/hf_tokens.pkl", "rb") as f: | |
| hf_tokens = pickle.load(f) | |
| MAX_TOKEN_LENGTH = 4096 | |
| MAX_CHUNK_SIZE = 16000 | |
| API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct" | |
| def prompt_template(prompt, sys_prompt = ""): | |
| return_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<system_prompt><|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<user_prompt><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'.replace('<user_prompt>', prompt).replace('<system_prompt>', sys_prompt) | |
| return return_prompt | |
| def query(payload: dict, hf_token: str): | |
| headers = {"Authorization": f"Bearer {hf_token}"} | |
| response = requests.post(API_URL, headers=headers, json=payload) | |
| return response.json() | |
| def gen_prompt(prompt: str, sys_prompt:str = ""): | |
| input_prompt = prompt_template(prompt, sys_prompt) | |
| selected_token = '' | |
| for token in hf_tokens: | |
| test_output = query({ | |
| "inputs": prompt_template("Who are you?"), | |
| "parameters": {"max_new_tokens": 100} | |
| }, token) | |
| if 'error' not in test_output: | |
| selected_token = token | |
| break | |
| output = query({ | |
| "inputs": input_prompt, | |
| "parameters": {"max_new_tokens": 512}, | |
| }, selected_token) | |
| return output[0]['generated_text'][len(input_prompt):] | |
| class Node: | |
| def __init__(self, summary=None): | |
| self.summary = summary | |
| self.children = [] | |
| self.parent = None | |
| def add_child(self, child_node): | |
| child_node.parent = self | |
| self.children.append(child_node) | |
| class MemWalker: | |
| def __init__(self, segments): | |
| self.segments = segments | |
| self.root = 0 | |
| def build_memory_tree(self): | |
| # Step 1: Create leaf nodes for each segment | |
| leaves = [Node(summarize(seg, 0)) for seg in self.segments] | |
| # Step 2: Build tree recursively | |
| while len(leaves) > 1: | |
| new_leaves = [] | |
| for i in range(0, len(leaves), 2): | |
| if i + 1 < len(leaves): | |
| combined_summary = summarize(leaves[i].summary + ", " + leaves[i + 1].summary, 1) | |
| parent_node = Node(combined_summary) | |
| parent_node.add_child(leaves[i]) | |
| parent_node.add_child(leaves[i + 1]) | |
| else: | |
| parent_node = leaves[i] | |
| new_leaves.append(parent_node) | |
| leaves = new_leaves | |
| self.root = leaves[0] | |
| # Placeholder functions for LLM operations | |
| def summarize(text, sum_type: int = 1): | |
| assert sum_type in [0, 1], "Lmao sum type should be either 0 or 1" | |
| if sum_type == 0: | |
| USER_PROMPT = "Write a concise summary of the meeting transcript in maximum 5 sentences:" + "\n\n" + text | |
| else: | |
| USER_PROMPT = "Compress the following summaries into a much shorter summary: " + "\n\n" + text | |
| SYS_PROMPT = "Act as a professional technical meeting minutes writer." | |
| tmp = gen_prompt(USER_PROMPT, SYS_PROMPT) | |
| if len(tmp.split("\n\n")) == 1: | |
| return tmp | |
| else: | |
| return tmp.split("\n\n")[1] | |
| #return output[0]['generated_text'][len(input_prompt):] | |
| def split_chunk(transcript: str): | |
| sentences = nltk.sent_tokenize(transcript) | |
| idx = 0 | |
| chunk = [] | |
| current_chunk = "" | |
| while idx < len(sentences): | |
| if len(current_chunk + sentences[idx]) < MAX_CHUNK_SIZE: | |
| current_chunk += sentences[idx] + " " | |
| else: | |
| chunk.append(current_chunk) | |
| current_chunk = '' | |
| for i in range(10, -1, -1): | |
| current_chunk += sentences[idx - i] + " " | |
| idx += 1 | |
| chunk.append(current_chunk) | |
| return chunk | |
| def summarize_three_ways(chunks: list[str]): | |
| SYS_PROMPT = "Act as a professional technical meeting minutes writer." | |
| PROMPT_TEMPLATE = "Write a concise summary of the meeting transcript in maximum 5 sentences:" + "\n\n" + "{text}" | |
| REFINE_TEMPLATE = ( | |
| "Your job is to produce a final summary\n" | |
| "We have provided an existing summary up to a certain point: {existing_answer}\n" | |
| "We have the opportunity to refine the existing summary" | |
| "(only if needed) with some more context below.\n" | |
| "------------\n" | |
| "{text}\n" | |
| "------------\n" | |
| f"Given the new context, refine the original summary in English within 5 sentences. If the context isn't useful, return the original summary." | |
| ) | |
| step = 0 | |
| prev_sum = "" | |
| partial_sum = [] | |
| return_dict = {} | |
| for chunk in chunks: | |
| if step == 0: | |
| CUR_PROMPT = PROMPT_TEMPLATE.replace("{text}", chunk) | |
| cur_sum = gen_prompt(CUR_PROMPT , SYS_PROMPT) | |
| else: | |
| CUR_PROMPT = REFINE_TEMPLATE.replace("{existing_answer}", partial_sum[-1]) | |
| CUR_PROMPT = CUR_PROMPT.replace("{text}", chunk) | |
| cur_sum = gen_prompt(CUR_PROMPT, SYS_PROMPT) | |
| if len(cur_sum.split("\n\n")) > 1: | |
| cur_sum = cur_sum.split("\n\n")[1] | |
| #print(cur_sum) | |
| partial_sum.append(cur_sum) | |
| step += 1 | |
| #print(partial_sum) | |
| CUR_PROMPT = "Rewrite the following text by maintaining coherency: " + "\n\n" | |
| CUR_PROMPT += ' '.join(partial_sum) | |
| tmp = gen_prompt(CUR_PROMPT, SYS_PROMPT) | |
| final_sum = '' | |
| if len(tmp.split("\n\n")) == 1: | |
| final_sum = tmp | |
| else: | |
| final_sum = tmp.split("\n\n")[1] | |
| return_dict['truncated'] = partial_sum[0] | |
| return_dict['accumulate'] = partial_sum[-1] | |
| return_dict['rewrite'] = final_sum | |
| return return_dict | |
| def get_example()->list[str]: | |
| data = [] | |
| with open(filepath + "data/test.json", "r") as f: | |
| for line in f: | |
| data.append(json.loads(line)) | |
| #random_idx = random.sample(list(range(len(data))), 6) | |
| random_idx = [2, 89, 94, 97] | |
| #random_idx = [1, 2, 9, 13] | |
| return ['\n'.join(nltk.sent_tokenize(data[i]['transcript'])) for i in random_idx] | |
| if __name__ == "__main__": | |
| data = [] | |
| with open(filepath + "data/test.json", "r") as f: | |
| for line in f: | |
| data.append(json.loads(line)) | |
| tmp = data[:100] | |
| for j, i in enumerate(tmp): | |
| print(j, len(i['transcript'])) |