Spaces:
Build error
Build error
| import argparse | |
| import concurrent.futures | |
| import json | |
| import os | |
| import random | |
| from functools import partial | |
| import requests | |
| from anthropic import Anthropic | |
| from openai import OpenAI | |
| from together import Together | |
| from tqdm import tqdm | |
| from model_configs import models | |
| from prompt import simple_system_prompt, system_prompt_with_2shots | |
| # Load stories | |
| with open("data/stories.json", "r", encoding="utf-8") as f: | |
| stories = json.load(f) | |
| def load_test_cases(filename): | |
| with open(filename, "r", encoding="utf-8") as f: | |
| _test_cases = [] | |
| for line in f: | |
| parts = line.strip().replace(" ", "").split("\t") | |
| if len(parts) != 3: | |
| print(f"Invalid test case: {line}") | |
| continue | |
| if parts[2] not in ["T", "F", "N"]: | |
| print(f"Skipping line with invalid ground truth: {line}") | |
| continue | |
| _test_cases.append(parts) | |
| return _test_cases | |
| def starts_with_answer(response, answer): | |
| return response.strip().lower().startswith(answer) | |
| def call_api(model, prompt, user_input): | |
| try: | |
| if model["type"] == "openai": | |
| if model["name"] == "Doubao-4k": | |
| client = OpenAI( | |
| api_key=model["config"]["apiKey"], | |
| base_url=model["config"]["baseURL"] | |
| ) | |
| messages = [ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": user_input} | |
| ] | |
| response = client.chat.completions.create( | |
| model=model["config"]["model"], | |
| messages=messages, | |
| max_tokens=model["config"]["maxTokens"], | |
| temperature=model["config"]["temperature"], | |
| top_p=model["config"]["top_p"], | |
| stream=False | |
| ) | |
| return response.choices[0].message.content | |
| else: | |
| url = model["config"]["baseURL"] + "/chat/completions" | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {model['config']['apiKey']}" | |
| } | |
| data = { | |
| "model": model["config"]["model"], | |
| "messages": [ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": user_input} | |
| ], | |
| "max_tokens": model["config"]["maxTokens"], | |
| "temperature": model["config"]["temperature"], | |
| } | |
| if "top_p" in model["config"]: | |
| data["top_p"] = model["config"]["top_p"] | |
| response = requests.post(url, headers=headers, json=data) | |
| if response.status_code != 200: | |
| raise Exception(f"API call failed with status {response.status_code}: {response.text}") | |
| result = response.json() | |
| return result["choices"][0]["message"]["content"] | |
| elif model["type"] == "together": | |
| client = Together(api_key=model["config"]["apiKey"]) | |
| messages = [ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": user_input} | |
| ] | |
| response = client.chat.completions.create( | |
| model=model["config"]["model"], | |
| messages=messages, | |
| max_tokens=model["config"]["maxTokens"], | |
| temperature=model["config"]["temperature"], | |
| top_p=model["config"]["top_p"], | |
| repetition_penalty=model["config"]["repetition_penalty"], | |
| stop=model["config"]["stop"], | |
| stream=False | |
| ) | |
| return response.choices[0].message.content | |
| elif model["type"] == "anthropic": | |
| client = Anthropic(api_key=model["config"]["apiKey"]) | |
| message = client.messages.create( | |
| model=model["config"]["model"], | |
| max_tokens=model["config"]["maxTokens"], | |
| temperature=model["config"]["temperature"], | |
| system=prompt, | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": user_input | |
| } | |
| ] | |
| } | |
| ] | |
| ) | |
| return message.content[0].text | |
| elif model["type"] == "minimax": | |
| url = f"https://api.minimax.chat/v1/text/chatcompletion_v2?GroupId={model['config']['groupId']}" | |
| headers = { | |
| "Authorization": f"Bearer {model['config']['apiKey']}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": model["config"]["model"], | |
| "messages": [ | |
| { | |
| "role": "system", | |
| "name": "MM智能助理", | |
| "content": prompt | |
| }, | |
| { | |
| "role": "user", | |
| "content": user_input | |
| } | |
| ], | |
| "tools": [], | |
| "tool_choice": "none", | |
| "stream": False, | |
| "max_tokens": model["config"]["maxTokens"], | |
| "temperature": model["config"]["temperature"], | |
| "top_p": model["config"]["top_p"] | |
| } | |
| response = requests.post(url, headers=headers, json=payload) | |
| if response.status_code != 200: | |
| raise Exception(f"API call failed with status {response.status_code}: {response.text}") | |
| result = response.json() | |
| return result["choices"][0]["message"]["content"] | |
| else: | |
| raise ValueError(f"Unsupported model type: {model['type']}") | |
| except Exception as e: | |
| print(f"Error in call_api for model {model['name']}: {str(e)}") | |
| return None | |
| def call_api_with_timeout(model, prompt, user_input, timeout=20): | |
| try: | |
| return call_api(model, prompt, user_input) | |
| except Exception as e: | |
| print(f"Error in call_api for model {model['name']}: {str(e)}") | |
| return None | |
| def evaluate_models(models, test_cases, stories, shot_type): | |
| results = {model['name']: {'correct': 0, 'total': 0} for model in models} | |
| logs = {model['name']: [] for model in models} | |
| challenging_cases = [] | |
| all_cases = [] | |
| # Determine the appropriate log folder based on shot_type | |
| log_folder = f"logs_with_{shot_type}shots" | |
| os.makedirs(log_folder, exist_ok=True) | |
| # Find the last tested sample | |
| last_tested = 0 | |
| for i in range(len(test_cases), 0, -1): | |
| if os.path.exists(f"{log_folder}/all_cases_simple_prompt_{i}.json"): | |
| with open(f"{log_folder}/all_cases_simple_prompt_{i}.json", "r", encoding="utf-8") as f: | |
| all_cases = json.load(f) | |
| last_tested = i | |
| break | |
| # Update results with previously tested samples | |
| for case in all_cases: | |
| for model_name, result in case['results'].items(): | |
| if result is not None: | |
| results[model_name]['total'] += 1 | |
| if (case['ground_truth'] == "T" and result == "T") or \ | |
| ((case['ground_truth'] == "F" or case['ground_truth'] == "N") and result != "T"): | |
| results[model_name]['correct'] += 1 | |
| # Start from the next untested sample | |
| start_index = len(all_cases) | |
| for i, (user_input, story_title, ground_truth) in enumerate(tqdm(test_cases[start_index:]), start_index + 1): | |
| try: | |
| story = next((s for s in stories if s["title"] == story_title), None) | |
| if not story: | |
| print(f"Story not found: {story_title}") | |
| continue | |
| # Use the appropriate prompt based on shot_type | |
| if shot_type == "2": | |
| prompt_template = system_prompt_with_2shots | |
| else: | |
| prompt_template = simple_system_prompt | |
| prompt = prompt_template.replace("{surface}", story["surface"]).replace("{bottom}", story["bottom"]) | |
| gt_map = {"T": "对", "F": "错", "N": "不知道"} | |
| case_results = {} | |
| all_responses_valid = True | |
| # Use ThreadPoolExecutor for concurrent API calls | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor: | |
| future_to_model = {executor.submit(partial(call_api_with_timeout, timeout=20), model, prompt, user_input): model for model in models} | |
| for future in concurrent.futures.as_completed(future_to_model): | |
| model = future_to_model[future] | |
| try: | |
| response = future.result() | |
| if response is None: | |
| all_responses_valid = False | |
| print(f"Timeout or error for model {model['name']}") | |
| else: | |
| case_results[model['name']] = response | |
| except Exception as exc: | |
| print(f'{model["name"]} generated an exception: {exc}') | |
| all_responses_valid = False | |
| # If any model timed out or had an error, skip this entire test case | |
| if not all_responses_valid: | |
| print(f"Skipping test case {i} due to timeout or error") | |
| continue | |
| # Process all responses | |
| for model in models: | |
| if model['name'] not in case_results: | |
| continue | |
| response = case_results[model['name']].strip().lower() | |
| if starts_with_answer(response, "对") or starts_with_answer(response, "错") or starts_with_answer(response, "不知道"): | |
| results[model['name']]['total'] += 1 | |
| # Save the actual model output | |
| if starts_with_answer(response, "对"): | |
| case_results[model['name']] = "T" | |
| elif starts_with_answer(response, "错"): | |
| case_results[model['name']] = "F" | |
| else: | |
| case_results[model['name']] = "N" | |
| # Calculate accuracy (merging N and F) | |
| if (ground_truth == "T" and case_results[model['name']] == "T") or \ | |
| ((ground_truth == "F" or ground_truth == "N") and case_results[model['name']] != "T"): | |
| results[model['name']]['correct'] += 1 | |
| else: | |
| # Print only wrong answers | |
| print(f"Wrong Answer - Model: {model['name']}, Input: {user_input}, Response: {response}, GT: {gt_map[ground_truth]}, Model Output: {case_results[model['name']]}") | |
| else: | |
| # Handle invalid responses | |
| case_results[model['name']] = "Invalid" | |
| print(f"Invalid Response - Model: {model['name']}, Input: {user_input}, Response: {response}, GT: {gt_map[ground_truth]}, Model Output: {case_results[model['name']]}") | |
| log_entry = { | |
| "Input": user_input, | |
| "Response": response, | |
| "GT": gt_map[ground_truth], | |
| "Model_Output": case_results[model['name']], | |
| "Accuracy": f"{results[model['name']]['correct']}/{results[model['name']]['total']} ({results[model['name']]['correct']/max(results[model['name']]['total'], 1):.2f})" | |
| } | |
| logs[model['name']].append(log_entry) | |
| case = { | |
| "input": user_input, | |
| "story_title": story_title, | |
| "ground_truth": ground_truth, | |
| "results": case_results | |
| } | |
| all_cases.append(case) | |
| if any(result != "T" for result in case_results.values()): | |
| challenging_cases.append(case) | |
| # Save log and print accuracy ranking every 10 steps | |
| if i % 10 == 0 or i == len(test_cases): | |
| print(f"\nCurrent rankings after {i} items:") | |
| current_results = [(name, res['correct'] / max(res['total'], 1), res['correct'], res['total']) | |
| for name, res in results.items()] | |
| current_results.sort(key=lambda x: x[1], reverse=True) | |
| for rank, (name, accuracy, correct, total) in enumerate(current_results, 1): | |
| print(f"{rank}. {name}: {accuracy:.2f} ({correct}/{total})") | |
| # Update challenging cases file | |
| with open(f"{log_folder}/challenging_cases_simple_prompt_{i}.json", "w", encoding="utf-8") as f: | |
| json.dump(challenging_cases, f, ensure_ascii=False, indent=2) | |
| # Update all cases file | |
| with open(f"{log_folder}/all_cases_simple_prompt_{i}.json", "w", encoding="utf-8") as f: | |
| json.dump(all_cases, f, ensure_ascii=False, indent=2) | |
| except Exception as e: | |
| print(f"Error processing test case {i}: {str(e)}") | |
| continue | |
| # Final update to challenging cases file | |
| final_index = start_index + len(test_cases[start_index:]) | |
| with open(f"{log_folder}/challenging_cases_simple_prompt_{final_index}.json", "w", encoding="utf-8") as f: | |
| json.dump(challenging_cases, f, ensure_ascii=False, indent=2) | |
| # Final update to all cases file | |
| with open(f"{log_folder}/all_cases_simple_prompt_{final_index}.json", "w", encoding="utf-8") as f: | |
| json.dump(all_cases, f, ensure_ascii=False, indent=2) | |
| return results, challenging_cases, all_cases | |
| def save_all_cases(all_cases, output_file): | |
| with open(output_file, "w", encoding="utf-8") as f: | |
| json.dump(all_cases, f, ensure_ascii=False, indent=2) | |
| print(f"All cases have been saved to {output_file}") | |
| def parse_challenging_cases(input_file, output_file): | |
| with open(input_file, 'r', encoding='utf-8') as f: | |
| challenging_cases = json.load(f) | |
| with open(output_file, 'w', encoding='utf-8') as f: | |
| for case in challenging_cases: | |
| user_input = case['input'] | |
| story_title = case['story_title'] | |
| ground_truth = case['ground_truth'] | |
| f.write(f"{user_input}\t{story_title}\t{ground_truth}\n") | |
| print(f"Parsed challenging cases have been written to {output_file}") | |
| def main(): | |
| # Parse command line arguments | |
| parser = argparse.ArgumentParser(description="Run story understanding evaluation") | |
| parser.add_argument("--shot", choices=["0", "2"], default="2", help="Number of shots (0 or 2)") | |
| args = parser.parse_args() | |
| _models = [model for model in models if model['name'] in ['DEEPSEEK', 'Kimi-Chat', 'GPT-4o-mini']] | |
| test_cases = load_test_cases("data/cases.list") | |
| _test_cases = random.sample(test_cases, k=100) | |
| results, challenging_cases, all_cases = evaluate_models(_models, _test_cases, stories, args.shot) | |
| final_results = [(name, res['correct'] / max(res['total'], 1), res['correct'], res['total']) | |
| for name, res in results.items()] | |
| final_results.sort(key=lambda x: x[1], reverse=True) | |
| print(f"\nFinal Rankings ({args.shot}-shot):") | |
| for rank, (name, accuracy, correct, total) in enumerate(final_results, 1): | |
| print(f"{rank}. {name}: {accuracy:.2f} ({correct}/{total})") | |
| log_folder = f"logs_with_{args.shot}shots" | |
| print(f"Evaluation complete. Logs have been saved in the '{log_folder}' directory.") | |
| if __name__ == "__main__": | |
| main() | |