|  | import gradio as gr | 
					
						
						|  | import json | 
					
						
						|  | import multiprocessing | 
					
						
						|  | import os | 
					
						
						|  | import pickle | 
					
						
						|  | import threading | 
					
						
						|  | import time | 
					
						
						|  | from collections import Counter, defaultdict | 
					
						
						|  | from concurrent.futures import ProcessPoolExecutor, as_completed, wait, FIRST_COMPLETED | 
					
						
						|  | from datetime import datetime | 
					
						
						|  | from typing import Any, Dict, List, Tuple | 
					
						
						|  | from warnings import warn | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  | from termcolor import cprint | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  |  | 
					
						
						|  | from bigcodebench.data import get_bigcodebench, get_bigcodebench_hash, load_solutions | 
					
						
						|  | from bigcodebench.data.utils import CACHE_DIR | 
					
						
						|  | from bigcodebench.eval import PASS, compatible_eval_result, estimate_pass_at_k, untrusted_check | 
					
						
						|  | from bigcodebench.gen.util import trusted_check | 
					
						
						|  |  | 
					
						
						|  | Result = Tuple[str, List[bool]] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_groundtruth(n_workers, problems, hashcode, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit): | 
					
						
						|  | cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl") | 
					
						
						|  | if os.path.exists(cache_file): | 
					
						
						|  | if check_gt_only: | 
					
						
						|  | os.remove(cache_file) | 
					
						
						|  | else: | 
					
						
						|  | print(f"Load from ground-truth from {cache_file}") | 
					
						
						|  | with open(cache_file, "rb") as f: | 
					
						
						|  | return pickle.load(f) | 
					
						
						|  |  | 
					
						
						|  | os.makedirs(CACHE_DIR, exist_ok=True) | 
					
						
						|  | print("\nAsserting the groundtruth...") | 
					
						
						|  | tbegin = time.time() | 
					
						
						|  |  | 
					
						
						|  | with ProcessPoolExecutor(max_workers=n_workers) as executor: | 
					
						
						|  | futures = [] | 
					
						
						|  | n_samples = 0 | 
					
						
						|  | expected_time = dict() | 
					
						
						|  |  | 
					
						
						|  | for problem in problems.values(): | 
					
						
						|  | args = ( | 
					
						
						|  | problem["complete_prompt"] + "\n" + problem["canonical_solution"], | 
					
						
						|  | problem["test"], | 
					
						
						|  | problem["task_id"], | 
					
						
						|  | max_as_limit, | 
					
						
						|  | max_data_limit, | 
					
						
						|  | max_stack_limit, | 
					
						
						|  | min_time_limit, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | futures.append(executor.submit(trusted_check, *args)) | 
					
						
						|  | n_samples += 1 | 
					
						
						|  |  | 
					
						
						|  | for future in tqdm(as_completed(futures), total=n_samples): | 
					
						
						|  | result = future.result() | 
					
						
						|  | expected_time[result["task_id"]] = result["time"] | 
					
						
						|  |  | 
					
						
						|  | print(f"Expected outputs computed in {time.time() - tbegin:.2f}s") | 
					
						
						|  |  | 
					
						
						|  | if any(expected_time.values()): | 
					
						
						|  | with open(cache_file, "wb") as f: | 
					
						
						|  | pickle.dump(expected_time, f) | 
					
						
						|  |  | 
					
						
						|  | return expected_time | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def check_correctness( | 
					
						
						|  | completion_id: int, | 
					
						
						|  | problem: Dict[str, Any], | 
					
						
						|  | solution: str, | 
					
						
						|  | max_as_limit: float, | 
					
						
						|  | max_data_limit: float, | 
					
						
						|  | max_stack_limit: float, | 
					
						
						|  | identifier=None, | 
					
						
						|  | min_time_limit: float = 0.1, | 
					
						
						|  | gt_time_limit: float = 2.0, | 
					
						
						|  | ) -> Dict[str, Result]: | 
					
						
						|  | ret = { | 
					
						
						|  | "completion_id": completion_id, | 
					
						
						|  | "task_id": problem["task_id"], | 
					
						
						|  | "_identifier": identifier, | 
					
						
						|  | "solution": solution, | 
					
						
						|  | } | 
					
						
						|  | ret["base"] = untrusted_check( | 
					
						
						|  | solution, | 
					
						
						|  | problem["test"], | 
					
						
						|  | problem["entry_point"], | 
					
						
						|  | max_as_limit, | 
					
						
						|  | max_data_limit, | 
					
						
						|  | max_stack_limit, | 
					
						
						|  | min_time_limit, | 
					
						
						|  | gt_time_limit, | 
					
						
						|  | ) | 
					
						
						|  | return ret | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def evaluate( | 
					
						
						|  | split: str, | 
					
						
						|  | subset: str, | 
					
						
						|  | samples: str, | 
					
						
						|  | pass_k: str="1,5,10", | 
					
						
						|  | parallel: int = None, | 
					
						
						|  | min_time_limit: float = 1, | 
					
						
						|  | max_as_limit: int = 30 * 1024, | 
					
						
						|  | max_data_limit: int = 30 * 1024, | 
					
						
						|  | max_stack_limit: int = 10, | 
					
						
						|  | check_gt_only: bool = False, | 
					
						
						|  | no_gt: bool = False, | 
					
						
						|  | ): | 
					
						
						|  | pass_k = [int(k.strip()) for k in pass_k.split(',') if k.strip().isdigit()] | 
					
						
						|  | if parallel is None: | 
					
						
						|  | n_workers = max(1, multiprocessing.cpu_count() // 2) | 
					
						
						|  | else: | 
					
						
						|  | n_workers = parallel | 
					
						
						|  |  | 
					
						
						|  | if check_gt_only: | 
					
						
						|  | samples = "__dummy__.jsonl" | 
					
						
						|  |  | 
					
						
						|  | extra = subset + "_" if subset != "full" else "" | 
					
						
						|  |  | 
					
						
						|  | problems = get_bigcodebench(subset=subset) | 
					
						
						|  | dataset_hash = get_bigcodebench_hash(subset=subset) | 
					
						
						|  |  | 
					
						
						|  | if not no_gt: | 
					
						
						|  | expected_time = get_groundtruth(n_workers, problems, dataset_hash, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit) | 
					
						
						|  | else: | 
					
						
						|  | expected_time = {task_id: None for task_id in problems} | 
					
						
						|  |  | 
					
						
						|  | gt_pass_rate = np.mean([1 if v is not None else 0 for k, v in expected_time.items() if k in problems]) | 
					
						
						|  | failed_tasks = [k for k, v in expected_time.items() if v is None and k in problems] | 
					
						
						|  |  | 
					
						
						|  | pass_at_k = dict() | 
					
						
						|  |  | 
					
						
						|  | if not check_gt_only: | 
					
						
						|  |  | 
					
						
						|  | results = { | 
					
						
						|  | "date": datetime.now().strftime("%Y-%m-%d %H:%M"), | 
					
						
						|  | "eval": {}, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | with ProcessPoolExecutor(max_workers=n_workers) as executor: | 
					
						
						|  | futures = [] | 
					
						
						|  | completion_id = Counter() | 
					
						
						|  | n_samples = 0 | 
					
						
						|  | eval_results = defaultdict(list) | 
					
						
						|  | remainings = set() | 
					
						
						|  |  | 
					
						
						|  | print("Reading samples...") | 
					
						
						|  | for sample in tqdm(load_solutions(samples)): | 
					
						
						|  | task_id = sample["task_id"] | 
					
						
						|  |  | 
					
						
						|  | if task_id not in problems: | 
					
						
						|  | warn( | 
					
						
						|  | f"Task {task_id} is found in the samples but not found in the dataset" | 
					
						
						|  | ) | 
					
						
						|  | continue | 
					
						
						|  | solution = ( | 
					
						
						|  | sample["solution"] | 
					
						
						|  | if "solution" in sample | 
					
						
						|  | else problems[task_id]["complete_prompt"] + sample["completion"] | 
					
						
						|  | ) | 
					
						
						|  | if "sanitized-calibrated" in samples: | 
					
						
						|  | solution = problems[task_id]["code_prompt"] + "\n    pass\n" + solution | 
					
						
						|  | remainings.add(sample["_identifier"]) | 
					
						
						|  | args = ( | 
					
						
						|  | completion_id[task_id], | 
					
						
						|  | problems[task_id], | 
					
						
						|  | solution, | 
					
						
						|  | max_as_limit, | 
					
						
						|  | max_data_limit, | 
					
						
						|  | max_stack_limit, | 
					
						
						|  | sample["_identifier"], | 
					
						
						|  | min_time_limit, | 
					
						
						|  | expected_time[task_id] if expected_time[task_id] else 20 | 
					
						
						|  | ) | 
					
						
						|  | futures.append(executor.submit(check_correctness, *args)) | 
					
						
						|  | completion_id[task_id] += 1 | 
					
						
						|  | n_samples += 1 | 
					
						
						|  |  | 
					
						
						|  | assert n_samples == len(remainings), "Missing problems in unfinished" | 
					
						
						|  | assert len(completion_id) == len(problems), "Missing problems in samples" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for future in tqdm(as_completed(futures), total=n_samples): | 
					
						
						|  | result = future.result() | 
					
						
						|  | remainings.remove(result["_identifier"]) | 
					
						
						|  | eval_results[result["task_id"]].append(result) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for task_id, task_results in eval_results.items(): | 
					
						
						|  | task_results.sort(key=lambda x: x["completion_id"]) | 
					
						
						|  | results["eval"][task_id] = [] | 
					
						
						|  | for res in task_results: | 
					
						
						|  | stat, details = res["base"] | 
					
						
						|  | results["eval"][task_id].append( | 
					
						
						|  | { | 
					
						
						|  | "task_id": task_id, | 
					
						
						|  | "solution": res["solution"], | 
					
						
						|  | "status": stat, | 
					
						
						|  | "details": details, | 
					
						
						|  | } | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | total = np.array([len(r) for k, r in results["eval"].items() if k in problems]) | 
					
						
						|  | base_correct = [] | 
					
						
						|  |  | 
					
						
						|  | for key, res in results["eval"].items(): | 
					
						
						|  | if key not in problems: | 
					
						
						|  | continue | 
					
						
						|  | bc = sum([r["status"] == PASS for r in res]) | 
					
						
						|  | base_correct.append(bc) | 
					
						
						|  |  | 
					
						
						|  | base_correct = np.array(base_correct) | 
					
						
						|  |  | 
					
						
						|  | pass_at_k.update({ | 
					
						
						|  | f"pass@{k}": estimate_pass_at_k(total, base_correct, k).mean() | 
					
						
						|  | for k in pass_k | 
					
						
						|  | if total.min() >= k | 
					
						
						|  | }) | 
					
						
						|  |  | 
					
						
						|  | pass_at_k["model"] = os.path.basename(samples).split("--bigcodebench-")[0] | 
					
						
						|  | pass_at_k["split"] = split | 
					
						
						|  | pass_at_k["subset"] = subset | 
					
						
						|  | pass_at_k["calibrated"] = "sanitized-calibrated" in samples | 
					
						
						|  | pass_at_k["gt_pass_rate"] = gt_pass_rate | 
					
						
						|  | pass_at_k["failed_tasks"] = failed_tasks | 
					
						
						|  |  | 
					
						
						|  | return results, pass_at_k | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def run_gradio(): | 
					
						
						|  | interface = gr.Interface( | 
					
						
						|  | fn=evaluate, | 
					
						
						|  | inputs=[ | 
					
						
						|  | gr.Dropdown(["complete", "instruct"], label="BigCodeBench Split"), | 
					
						
						|  | gr.Dropdown(["full", "hard"], label="BigCodeBench Subset"), | 
					
						
						|  | gr.File(label="Samples Path (.jsonl)"), | 
					
						
						|  | gr.Textbox(label="Pass k Values (comma-separated)", value="1,5,10"), | 
					
						
						|  | gr.Slider(1, multiprocessing.cpu_count(), step=1, label="Parallel Workers"), | 
					
						
						|  | gr.Slider(0.1, 10, step=0.1, label="Min Time Limit", value=1), | 
					
						
						|  | gr.Slider(1, 100 * 1024, step=1024, label="Max AS Limit", value=30 * 1024), | 
					
						
						|  | gr.Slider(1, 100 * 1024, step=1024, label="Max Data Limit", value=30 * 1024), | 
					
						
						|  | gr.Slider(1, 100, step=1, label="Max Stack Limit", value=10), | 
					
						
						|  | gr.Checkbox(label="Check GT Only"), | 
					
						
						|  | gr.Checkbox(label="No GT"), | 
					
						
						|  | ], | 
					
						
						|  | outputs=[ | 
					
						
						|  | gr.JSON(label="Results"), | 
					
						
						|  | gr.JSON(label="Eval Results"), | 
					
						
						|  | ], | 
					
						
						|  |  | 
					
						
						|  | ) | 
					
						
						|  | interface.queue(default_concurrency_limit=None) | 
					
						
						|  | interface.launch(show_error=True) | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | run_gradio() | 
					
						
						|  |  | 
					
						
						|  |  |