terryyz commited on
Commit
f175fac
·
verified ·
1 Parent(s): 7836cdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -296
app.py CHANGED
@@ -1,296 +0,0 @@
1
- import gradio as gr
2
- import json
3
- import logging
4
- import multiprocessing
5
- import os
6
- import pickle
7
- import threading
8
- import time
9
- from collections import Counter, defaultdict
10
- from concurrent.futures import ProcessPoolExecutor, as_completed, wait, FIRST_COMPLETED
11
- from datetime import datetime
12
- from typing import Any, Dict, List, Tuple
13
- from warnings import warn
14
- import gc
15
-
16
- import numpy as np
17
- from huggingface_hub import HfApi
18
- from bigcodebench.data import get_bigcodebench, get_bigcodebench_hash, load_solutions
19
- from bigcodebench.data.utils import CACHE_DIR
20
- from bigcodebench.eval import PASS, compatible_eval_result, estimate_pass_at_k, untrusted_check
21
- from bigcodebench.gen.util import trusted_check
22
- from apscheduler.schedulers.background import BackgroundScheduler
23
-
24
- REPO_ID = "bigcode/bigcodebench-evaluator"
25
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
26
- API = HfApi(token=HF_TOKEN)
27
- Result = Tuple[str, List[bool]]
28
-
29
-
30
- def get_groundtruth(n_workers, problems, hashcode, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit):
31
- cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl")
32
- if os.path.exists(cache_file):
33
- with open(cache_file, "rb") as f:
34
- return pickle.load(f)
35
-
36
- os.makedirs(CACHE_DIR, exist_ok=True)
37
- tbegin = time.time()
38
-
39
- with ProcessPoolExecutor(max_workers=n_workers) as executor:
40
- futures = []
41
- n_samples = 0
42
- expected_time = dict()
43
-
44
- for problem in problems.values():
45
- args = (
46
- problem["complete_prompt"] + "\n" + problem["canonical_solution"],
47
- problem["test"],
48
- problem["task_id"],
49
- max_as_limit,
50
- max_data_limit,
51
- max_stack_limit,
52
- min_time_limit,
53
- )
54
-
55
- futures.append(executor.submit(trusted_check, *args))
56
- n_samples += 1
57
-
58
- for future in as_completed(futures):
59
- result = future.result()
60
- expected_time[result["task_id"]] = result["time"]
61
-
62
- if any(expected_time.values()):
63
- with open(cache_file, "wb") as f:
64
- pickle.dump(expected_time, f)
65
-
66
- return expected_time
67
-
68
-
69
- def check_correctness(
70
- completion_id: int,
71
- problem: Dict[str, Any],
72
- solution: str,
73
- max_as_limit: float,
74
- max_data_limit: float,
75
- max_stack_limit: float,
76
- identifier=None,
77
- min_time_limit: float = 0.1,
78
- gt_time_limit: float = 2.0,
79
- ) -> Dict[str, Result]:
80
- ret = {
81
- "completion_id": completion_id,
82
- "task_id": problem["task_id"],
83
- "_identifier": identifier,
84
- "solution": solution,
85
- }
86
- ret["base"] = untrusted_check(
87
- solution,
88
- problem["test"],
89
- problem["entry_point"],
90
- max_as_limit,
91
- max_data_limit,
92
- max_stack_limit,
93
- min_time_limit,
94
- gt_time_limit,
95
- )
96
- return ret
97
-
98
-
99
- def evaluate(
100
- split: str,
101
- subset: str,
102
- samples: str,
103
- pass_k: str="1,5,10",
104
- parallel: int = -1,
105
- min_time_limit: float = 1,
106
- max_as_limit: int = 30 * 1024,
107
- max_data_limit: int = 30 * 1024,
108
- max_stack_limit: int = 10,
109
- calibrated: bool = True,
110
- check_gt_only: bool = False,
111
- no_gt: bool = False,
112
- selective_evaluate: str = "",
113
- ):
114
- passk = [int(k.strip()) for k in pass_k.split(',') if k.strip().isdigit()]
115
- if parallel < 1:
116
- n_workers = max(1, multiprocessing.cpu_count() // 2)
117
- else:
118
- n_workers = parallel
119
-
120
- if check_gt_only:
121
- samples = "__dummy__.jsonl"
122
-
123
- extra = subset + "_" if subset != "full" else ""
124
-
125
- problems = get_bigcodebench(subset=subset)
126
-
127
- # Add selective evaluation logic
128
- if selective_evaluate:
129
- selected_ids = ["BigCodeBench/" + id for id in sorted(set(selective_evaluate.split(",")))]
130
- problems = {k: v for k, v in problems.items() if k in selected_ids}
131
- if not problems:
132
- raise ValueError(f"None of the provided task IDs {selected_ids} were found in the dataset")
133
-
134
- dataset_hash = get_bigcodebench_hash(subset=subset)
135
-
136
- if not no_gt:
137
- expected_time = get_groundtruth(n_workers, problems, dataset_hash, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit)
138
- else:
139
- expected_time = {task_id: None for task_id in problems}
140
-
141
- gt_pass_rate = np.mean([1 if v is not None else 0 for k, v in expected_time.items() if k in problems])
142
- failed_tasks = [k for k, v in expected_time.items() if v is None and k in problems]
143
-
144
- pass_at_k = dict()
145
- results = {
146
- "date": datetime.now().strftime("%Y-%m-%d %H:%M"),
147
- "eval": {},
148
- }
149
-
150
- if not check_gt_only:
151
-
152
- with ProcessPoolExecutor(max_workers=n_workers) as executor:
153
- futures = []
154
- completion_id = Counter()
155
- n_samples = 0
156
- eval_results = defaultdict(list) # task_id ->
157
- remainings = set()
158
-
159
- for sample in load_solutions(samples):
160
- task_id = sample["task_id"]
161
-
162
- if task_id not in problems:
163
- continue
164
- solution = (
165
- sample["solution"]
166
- if "solution" in sample
167
- else problems[task_id]["complete_prompt"] + sample["completion"]
168
- )
169
- if calibrated:
170
- solution = problems[task_id]["code_prompt"] + "\n pass\n" + solution
171
- remainings.add(sample["_identifier"])
172
- args = (
173
- completion_id[task_id],
174
- problems[task_id],
175
- solution,
176
- max_as_limit,
177
- max_data_limit,
178
- max_stack_limit,
179
- sample["_identifier"],
180
- min_time_limit,
181
- expected_time[task_id] if expected_time[task_id] else 20
182
- )
183
- futures.append(executor.submit(check_correctness, *args))
184
- completion_id[task_id] += 1
185
- n_samples += 1
186
-
187
- assert n_samples == len(remainings), "Missing problems in unfinished"
188
- assert len(completion_id) == len(problems), "Missing problems in samples"
189
-
190
- for future in as_completed(futures):
191
- result = future.result()
192
- remainings.remove(result["_identifier"])
193
- eval_results[result["task_id"]].append(result)
194
- del future, result
195
- gc.collect()
196
-
197
- # sort the results for each problem by completion_id
198
- for task_id, task_results in eval_results.items():
199
- task_results.sort(key=lambda x: x["completion_id"])
200
- results["eval"][task_id] = []
201
- for res in task_results:
202
- stat, details = res["base"]
203
- results["eval"][task_id].append(
204
- {
205
- "task_id": task_id,
206
- "solution": res["solution"],
207
- "status": stat,
208
- "details": details,
209
- }
210
- )
211
-
212
- # Calculate pass@k.
213
- total = np.array([len(r) for k, r in results["eval"].items() if k in problems])
214
- base_correct = []
215
-
216
- for key, res in results["eval"].items():
217
- if key not in problems:
218
- continue
219
- bc = sum([r["status"] == PASS for r in res])
220
- base_correct.append(bc)
221
-
222
- base_correct = np.array(base_correct)
223
-
224
- pass_at_k.update({
225
- f"pass@{k}": estimate_pass_at_k(total, base_correct, k).mean()
226
- for k in passk
227
- if total.min() >= k
228
- })
229
-
230
- del problems, futures
231
- gc.collect()
232
-
233
- pass_at_k["model"] = os.path.basename(samples).split("--bigcodebench-")[0]
234
- pass_at_k["split"] = split
235
- pass_at_k["subset"] = subset
236
- pass_at_k["calibrated"] = calibrated
237
- pass_at_k["gt_pass_rate"] = gt_pass_rate
238
- pass_at_k["failed_tasks"] = failed_tasks
239
-
240
- return results, pass_at_k
241
-
242
-
243
- # def run_gradio():
244
- interface = gr.Interface(
245
- fn=evaluate,
246
- inputs=[
247
- gr.Dropdown(["complete", "instruct"], label="BigCodeBench Split"),
248
- gr.Dropdown(["full", "hard"], label="BigCodeBench Subset"),
249
- gr.File(label="Samples Path (.jsonl)"),
250
- gr.Textbox(label="Pass k Values (comma-separated)", value="1,5,10"),
251
- gr.Slider(-1, multiprocessing.cpu_count(), step=1, label="Parallel Workers", value=-1),
252
- gr.Slider(0.1, 10, step=0.1, label="Min Time Limit", value=1),
253
- gr.Slider(1, 100 * 1024, step=1024, label="Max AS Limit", value=30 * 1024),
254
- gr.Slider(1, 100 * 1024, step=1024, label="Max Data Limit", value=30 * 1024),
255
- gr.Slider(1, 100, step=1, label="Max Stack Limit", value=10),
256
- gr.Checkbox(label="Calibrated", value=True),
257
- gr.Checkbox(label="Check GT Only"),
258
- gr.Checkbox(label="No GT"),
259
- gr.Textbox(label="Selective Evaluated Task IDs (comma-separated, e.g. '0,1,2')", value=""),
260
- ],
261
- outputs=[
262
- gr.JSON(label="Results"),
263
- gr.JSON(label="Eval Results"),
264
- ],
265
- # concurrency_limit=None
266
- )
267
- interface.queue(default_concurrency_limit=None)
268
-
269
-
270
- def preload_gt():
271
- evaluate(split="complete", subset="full", samples="", check_gt_only=True)
272
- evaluate(split="complete", subset="hard", samples="", check_gt_only=True)
273
-
274
-
275
- def restart_space():
276
- logging.info(f"Restarting space with repo ID: {REPO_ID}")
277
- try:
278
- # Now restart the space
279
- API.restart_space(repo_id=REPO_ID, token=HF_TOKEN)
280
- logging.info("Space restarted successfully.")
281
- except Exception as e:
282
- logging.error(f"Failed to restart space: {e}")
283
-
284
-
285
- # if __name__ == "__main__":
286
- while True:
287
- try:
288
- preload_gt()
289
- break
290
- except:
291
- continue
292
-
293
- scheduler = BackgroundScheduler()
294
- scheduler.add_job(restart_space, "interval", hours=1) # Restart every 2hs
295
- scheduler.start()
296
- interface.launch(show_error=True)