kbmlcoding commited on
Commit
d1b8990
·
1 Parent(s): f68d1c6

create metrics test

Browse files
README.md CHANGED
@@ -1,12 +1,50 @@
1
  ---
2
- title: Apps Metric
3
- emoji: 📈
4
  colorFrom: blue
5
- colorTo: indigo
 
 
 
 
6
  sdk: gradio
7
- sdk_version: 4.38.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: APPS Metric
3
+ emoji: 📊
4
  colorFrom: blue
5
+ colorTo: pink
6
+ tags:
7
+ - evaluate
8
+ - metric
9
+ description: "Evaluation metric for the APPS benchmark"
10
  sdk: gradio
11
+ sdk_version: 3.0.2
12
  app_file: app.py
13
  pinned: false
14
  ---
15
 
16
+ # Metric Card for apps_metric
17
+
18
+ ## Metric Description
19
+ This metric is used to evaluate code generation on the [APPS benchmark](https://huggingface.co/datasets/codeparrot/apps).
20
+
21
+ ## How to Use
22
+ You can load the metric and use it with the following commands:
23
+
24
+ ```python
25
+ from evaluate import load
26
+ apps_metric = load('codeparrot/apps_metric')
27
+ # to evaluate generations made for all levels for example
28
+ results = apps_metric.compute(predictions=generations, level="all")
29
+ ```
30
+
31
+ ### Inputs
32
+ **generations** list(list(str)): List of code generations, each sub-list corresponds to the generations for a problem in APPS dataset, **the order of the samples in the dataset must be kept (with respect to the difficulty level)**.
33
+
34
+ ### Output Values
35
+
36
+ **average accuracy**: when a single solution is generated, average accuracy computes the average of test cases that are passed.
37
+
38
+ **strict accuracy**: when a single solution is generated, strict accuracy computes the average number of problems that pass all their test cases.
39
+
40
+ **pass@k**: when multiple solutions are generated per problem, pass@k is the metric originally used for the [HumanEval](https://huggingface.co/datasets/openai_humaneval) benchmark. For more details please refer to the [metric space](https://huggingface.co/spaces/evaluate-metric/code_eval) and [Codex paper](https://arxiv.org/pdf/2107.03374v2.pdf).
41
+
42
+ ## Citation
43
+ ```
44
+ @article{hendrycksapps2021,
45
+ title={Measuring Coding Challenge Competence With APPS},
46
+ author={Dan Hendrycks and Steven Basart and Saurav Kadavath and Mantas Mazeika and Akul Arora and Ethan Guo and Collin Burns and Samir Puranik and Horace He and Dawn Song and Jacob Steinhardt},
47
+ journal={NeurIPS},
48
+ year={2021}
49
+ }
50
+ ```
app.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import evaluate
2
+ from evaluate.utils import launch_gradio_widget
3
+
4
+
5
+ module = evaluate.load("loubnabnl/apps_metric")
6
+ launch_gradio_widget(module)
apps_metric.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Evaluation of code generation on the APPS benchmark"""
15
+
16
+ import evaluate
17
+ import datasets
18
+ from .utils import compute_metrics
19
+ from .testing_util import run_test
20
+
21
+
22
+ _CITATION = """\
23
+ @article{hendrycksapps2021,
24
+ title={Measuring Coding Challenge Competence With APPS},
25
+ author={Dan Hendrycks and Steven Basart and Saurav Kadavath and Mantas Mazeika and Akul Arora and Ethan Guo and Collin Burns and Samir Puranik and Horace He and Dawn Song and Jacob Steinhardt},
26
+ journal={NeurIPS},
27
+ year={2021}
28
+ }
29
+ """
30
+
31
+
32
+ _DESCRIPTION = """\
33
+ This is a metric to evaluate code generation using the APPS benchmark "Measuring Coding Challenge Competence With
34
+ APPS" (https://arxiv.org/pdf/2105.09938.pdf).
35
+ """
36
+
37
+
38
+ # TODO: Add description of the arguments of the module here
39
+ _KWARGS_DESCRIPTION = """
40
+ Computes Average accuracy and strict accuracy for single generations, and pass@k for multiple generations.
41
+ Args:
42
+ predictions: list of code generations to score. It's a list of list(s), each corresponding to a problem from APPS dataset.
43
+
44
+ Returns:
45
+ metrics: dict of three metrics: average accuracy, stric accuracy, and pass@k.
46
+ Examples:
47
+ >>> my_new_module = evaluate.load("loubnabnl/apps_metric")
48
+ >>> results = my_new_module.compute(predictions=[["s=input()\nprint(s)"]])
49
+ >>> print(results)
50
+ {'avg_accuracy': 0, 'strict_accuracy': 0, 'pass_at_k': None}
51
+ """
52
+
53
+
54
+
55
+
56
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
57
+ class apps_metric(evaluate.EvaluationModule):
58
+ """Evaluate code generation on APPS benchmark.
59
+ The generations are compiled and their corresponding unit tests are run"""
60
+
61
+ def _info(self):
62
+
63
+ return evaluate.EvaluationModuleInfo(
64
+
65
+ module_type="metric",
66
+ description=_DESCRIPTION,
67
+ citation=_CITATION,
68
+ inputs_description=_KWARGS_DESCRIPTION,
69
+
70
+ features=datasets.Features({
71
+ 'predictions': datasets.Sequence(datasets.Value("string")),
72
+ }),
73
+ homepage="https://github.com/hendrycks/apps",
74
+ reference_urls=["https://huggingface.co/datasets/codeparrot/apps"]
75
+ )
76
+
77
+
78
+
79
+ def _compute(self, predictions, k_list=[1, 10, 100], count_errors=True, level="all", debug=False):
80
+ """Returns the scores"""
81
+ metrics = compute_metrics(predictions, k_list=k_list, count_errors=count_errors, level=level, debug=debug)
82
+ return metrics
example_script.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is an example script to evaluate a code generation model on APPS, you can also use the APPS solutions as code generations
2
+ > python example_script.py --model_ckpt MODEL_NAME --num_tasks 10 --difficulty introductory --n_samples 1
3
+ > python example_script.py --use_solutions True --num_tasks 10 --difficulty introductory --n_samples 1"""
4
+
5
+ import json
6
+ import pprint
7
+ from tqdm import tqdm
8
+ from datasets import load_dataset
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, set_seed
10
+ from evaluate import load
11
+
12
+ def generate_prompt(sample):
13
+ starter_code = None if len(sample["starter_code"]) == 0 else sample["starter_code"]
14
+ try:
15
+ input_outpout = json.loads(sample["input_output"])
16
+ fn_name = None if not input_outpout.get("fn_name") else input_outpout["fn_name"]
17
+ except ValueError:
18
+ fn_name = None
19
+ _input = "\nQUESTION:\n"
20
+ _input += sample["question"]
21
+ if starter_code:
22
+ _input += starter_code
23
+ if fn_name:
24
+ _input += "\nUse Standard Input format"
25
+ else:
26
+ _input += "\nUse Call-Based format"
27
+
28
+ _input += "\nANSWER:\n"
29
+ return _input
30
+
31
+
32
+ def complete_code(pipe, prompt, num_completions=1, max_length=256, **gen_kwargs):
33
+ """Complete prompt with text generation pipeline and return num_completions."""
34
+ prompt = pipe.tokenizer.eos_token + prompt
35
+ try:
36
+ code_gens = pipe(prompt, num_return_sequences=num_completions, max_length=max_length, **gen_kwargs)
37
+ return [code_gen["generated_text"][len(prompt):] for code_gen in code_gens]
38
+ except IndexError:
39
+ print("prompt is longer than the context size of the model, generation skipped")
40
+ code_gens = ""
41
+ return [""]
42
+
43
+
44
+ def make_generations(dataset, args, model, tokenizer):
45
+ set_seed(args.seed)
46
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int)
47
+
48
+ # Generation settings
49
+ gen_kwargs = {
50
+ "do_sample": args.do_sample,
51
+ "temperature": args.temperature,
52
+ "top_p": args.top_p,
53
+ "top_k": args.top_k
54
+ }
55
+
56
+ # Generate completions for evaluation set
57
+ n_tasks = args.num_tasks if args.num_tasks is not None else len(dataset)
58
+ print(f"ntasks is {n_tasks}")
59
+ generations = []
60
+ for task in tqdm(range(n_tasks)):
61
+ task_generations = []
62
+ prompt = generate_prompt(dataset[task]).strip()
63
+ task_generations.extend(complete_code(pipe, prompt, num_completions=args.n_samples, max_length=args.max_length, **gen_kwargs))
64
+ generations.append([gen.replace(args.eos, "") for gen in task_generations])
65
+ return generations
66
+
67
+
68
+ def main(args):
69
+ DATA_PATH = "codeparrot/apps"
70
+ argsdict = vars(args)
71
+ print(pprint.pformat(argsdict))
72
+
73
+ # setup
74
+ print("Loading evaluation dataset...")
75
+ dataset = load_dataset(DATA_PATH, split="test", difficulties=[args.difficulty])
76
+ if args.use_solutions:
77
+ print("Using data solutions as code generations")
78
+ model = None
79
+ tokenizer = None
80
+ generations = []
81
+ for index in range(args.num_tasks+1):
82
+ try:
83
+ sol = json.loads(dataset[index]["solutions"])
84
+ generations.append(sol[:args.n_solutions])
85
+ except ValueError:
86
+ print(f"No solutions for task {index} or not enough to have {args.n_solutions} solutions")
87
+ break
88
+
89
+ else:
90
+ print("Loading tokenizer and model...")
91
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
92
+ model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
93
+ generations = make_generations(dataset, args, model, tokenizer)
94
+
95
+ metric = load("loubnabnl/apps_metric")
96
+ results = metric.compute(predictions=generations, level=args.difficulty, k_list=args.k_list, count_errors=args.count_errors, debug=args.debug)
97
+ print(results)
98
+ with open(args.output_file, "w") as fp:
99
+ json.dump(results, fp)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ import argparse
104
+
105
+ parser = argparse.ArgumentParser(description="Testing a Language Model on APPS Python Code dataset")
106
+ #model and tokenizer arguments
107
+ parser.add_argument("--model_ckpt", default="loubnabnl/apps-1.5B-model", type=str, help="path to model checkpoint.")
108
+ parser.add_argument("--tokenizer", default="gpt2", type=str, help="tokenizer to use.")
109
+ parser.add_argument("--eos", default="<|endoftext|>", type=str, help="end of sentence token.")
110
+ # generation arguments
111
+ parser.add_argument("--do_sample", default=True, type=bool, help="do sampling in generation")
112
+ parser.add_argument("--temperature", default=0.2, type=float, help="temperature for sampling")
113
+ parser.add_argument("--top_p", default=0.95, type=float, help="top p for sampling")
114
+ parser.add_argument("--top_k", default=0, type=float, help="top k for sampling")
115
+ parser.add_argument("--max_length", default=1024, type=int, help="max length of generated code")
116
+ # evaluation arguments
117
+ parser.add_argument("--difficulty", default="all", type=str, help="difficulty level to select in the dataset from:\
118
+ 'all', 'introductory', 'interview' and 'competition' ")
119
+ parser.add_argument("--num_tasks", default=6, type=int, help="number of tasks to evaluate")
120
+ parser.add_argument("--use_solutions", default=False, type=bool, help="use solutions instead of generating new code")
121
+ parser.add_argument("--n_samples", default=1, type=int, help="number of samples to generate")
122
+ parser.add_argument("--n_solutions", default=1, type=int, help="number of solutions to use")
123
+ parser.add_argument("--k_list", default=[1, 2, 3], type=list, help="list of k values to evaluate pass@k")
124
+ parser.add_argument("--count_errors", default=False, type=bool, help="count compilation and runtime errors for single generations")
125
+ # configuration
126
+ parser.add_argument("--seed", default=0, type=int, help="generation seed")
127
+ parser.add_argument("--device_int", default=-1, type=int, help="device on which code generation is run, if positive use GPU")
128
+ parser.add_argument("--debug", default=False, type=bool, help="debug mode")
129
+ # save
130
+ parser.add_argument("--output_file", default="apps_metrics.json", type=str, help="output file to save the results")
131
+
132
+ args = parser.parse_args()
133
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ evaluate==0.1.0
2
+ datasets~=2.0
3
+ pyext==0.7
test_examples/solutions_problem_1.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["s = input()\nn = len(s)\nind = -1\nf = False\nfor i in range(n):\n if s[i] == '[':\n f = True\n elif s[i] == ':':\n if f:\n ind = i\n break\nbind = -1\nf = False\nfor i in range(n-1,-1,-1):\n if s[i] == ']':\n f = True\n elif s[i] == ':':\n if f:\n bind = i\n break\n# print(ind,bind)\nif ind == -1 or bind == -1:\n print(-1)\nelif ind >= bind:\n print(-1)\nelse:\n ans = 4\n for i in range(ind+1,bind):\n if s[i] == '|':\n ans += 1\n print(ans)\n", "def main():\n s = input()\n \n if s.count('[') == 0 or s.count(']') == 0:\n print(-1)\n return\n \n t = s[s.find('['):s.rfind(']')+1]\n \n if t.count(':') < 2:\n print(-1)\n return\n \n t = t[t.find(':'):t.rfind(':')+1]\n print(4 + t.count('|'))\n\nmain()", "s = input()\nif '[' in s:\n s = s[s.find('[') + 1:]\n if ']' in s:\n s = s[:s.rfind(']')]\n if s.count(':') >= 2:\n s = s[s.find(':') + 1 : s.rfind(':')]\n print(s.count('|') + 4)\n\n else:\n print(-1)\n else:\n print(-1)\nelse:\n print(-1)", "import sys\ns = input()\nst = s.find('[')\nif st==-1: print((-1)); return\ns = s[st+1:]\n#print(s)\nst = s.find(':')\nif st==-1: print((-1)); return\ns = s[st+1:]\n#print(s)\ns = s[::-1]\nst = s.find(']')\nif st==-1: print((-1)); return\ns = s[st+1:]\n#print(s)\nst = s.find(':')\nif st==-1: print((-1)); return\ns = s[st+1:]\n#print(s)\nx = s.count('|')\nprint(x+4 if x>=0 else -1)\n", "s = input()\n\nsb,eb,sc,ec = -1, -1, -1, -1\n\nfor i in range(len(s)):\n\tif s[i] == '[' and sb == -1:\n\t\tsb = i\n\telif s[i] == ']':\n\t\teb = i\n\telif s[i] == ':' and sc == -1 and sb!=-1:\n\t\tsc = i\n\nif eb <= sb or sc>eb:\n\tprint(-1)\nelif sb ==-1 or eb==-1 or sc==-1:\n\tprint(-1)\nelse:\n\tfor i in range(sc+1, eb):\n\t\tif s[i] == ':':\n\t\t\tec = i\n\tif ec == -1:\n\t\tprint(-1)\n\telse:\n\t\tcnt = 0\n\t\tfor i in range(sc,ec):\n\t\t\tif (s[i] == '|'):\n\t\t\t\tcnt += 1\n\t\tprint(cnt+4)", "s = input()\nt_d = 0\ntry:\n left = -1\n was_b = False\n for i in range(len(s)):\n if s[i] == '[' and not was_b:\n was_b = True\n continue\n if s[i] == ':' and was_b:\n left = i\n break\n t_d += 1\n if left == -1:\n raise ArithmeticError()\n right = -1\n was_b = False\n for i in range(len(s) - 1, -1, -1):\n if s[i] == ']' and not was_b:\n was_b = True\n continue\n if s[i] == ':' and was_b:\n right = i\n break\n t_d += 1\n if right == -1 or right <= left:\n raise ArithmeticError()\n for i in range(left + 1, right):\n if s[i] != '|':\n t_d += 1\n print(len(s) - t_d)\nexcept:\n print(-1)\n \n", "s = input()\n\nmode = 0\nl = len(s)\nr = -1\nfor i in range(len(s)):\n if mode == 0:\n if s[i] == \"[\":\n mode = 1\n if mode == 1:\n if s[i] == \":\":\n l = i\n break\n\nmode = 0\nfor i in range(len(s)-1, -1, -1):\n if mode == 0:\n if s[i] == \"]\":\n mode = 1\n if mode == 1:\n if s[i] == \":\":\n r = i\n break\n \nif l >= r:\n print(-1)\nelse:\n c = 0\n for i in range(l+1, r):\n if s[i] == \"|\":\n c += 1\n print(c+4)\n", "s = input()\n\nf1 = False\nf2 = False\nl1 = -1\nfor l in range(len(s)):\n if f1 == False and s[l] == '[':\n f1 = True\n elif f1 == True and s[l] == ':':\n f2 = True\n l1 = l\n break\ng1 = False\ng2 = False\nr1 = -1\nfor r in range(len(s) - 1, -1, -1):\n if g1 == False and s[r] == ']':\n g1 = True\n elif g1 == True and s[r] == ':':\n g2 = True\n r1 = r\n break\nif (l1 == -1 or r1 == -1) or (r1 <= l1):\n print(-1)\n \nelse:\n ans = 4\n for i in range(l1 + 1, r1):\n if s[i] == '|': ans += 1\n print(ans)", "s=input()\npos1=-1\npos2=-1\npos3=-1\npos4=-1\nfor i in range(0,len(s)):\n if(s[i]=='['):\n pos1=i\n break\nfor i in range(len(s)-1,pos1,-1):\n if(s[i]==']'):\n pos2=i\n break\nfor i in range(pos1,pos2+1):\n if(s[i]==':'):\n pos3=i\n break\nfor i in range(pos2,pos3,-1):\n if(s[i]==':'):\n pos4=i\n break\n \nif(pos1==-1 or pos2==-1 or pos3==-1 or pos4==-1 or len(s)<4):\n print('-1')\nelse:\n c=0\n for j in range(pos3,pos4):\n if(s[j]=='|'):\n c=c+1\n print(c+4)\n", "def ii():\n return int(input())\ndef mi():\n return list(map(int, input().split()))\ndef li():\n return list(mi())\n\ns = input().strip()\nn = len(s)\nans = -1\nfb = s.find('[')\nif fb >= 0:\n fc = s.find(':', fb)\n if fc >= 0:\n lb = s.rfind(']')\n if lb > fc:\n lc = s.rfind(':', 0, lb)\n if lc > fc:\n ans = 4 + s[fc:lc].count('|')\nprint(ans)\n", "s = input()\n\ndef sovle(s):\n\n i1 = s.find('[')\n if i1 == -1:\n return -1\n s = s[i1+1:]\n i2 = s.find(':')\n if i2 == -1:\n return -1\n\n s = s[i2+1 :]\n i1 = s.rfind(']')\n if i1 == -1:\n return -1\n s = s[:i1]\n i2 = s.rfind(':')\n if i2 == -1:\n return -1\n s = s[:i2]\n x = s.count('|')\n return x+4\n\nprint(sovle(s))", "def solve(s):\n if s.find('[') == -1:\n return -1\n s = s[s.find('['):]\n #print(s)\n if s.find(':') == -1:\n return -1\n s = s[s.find(':') + 1:]\n #print(s)\n if s.find(']') == -1:\n return -1\n s = s[:s.rfind(']')]\n #print(s)\n if s.find(':') == -1:\n return -1\n s = s[:s.rfind(':')]\n #print(s)\n return s.count('|') + 4\n\ns = input()\nprint(solve(s))", "s=input()\ni=s.find('[')\nif i==-1:\n print(-1)\n return\ns=s[i:]\ni=s.rfind(']')\n\nif i==-1:\n print(-1)\n return\ns=s[:i+1]\nl,h=0,0\nfor i,d in enumerate(s):\n if d==':':\n l=i\n break\nfor i,d in enumerate(s):\n if d==':':\n h=i\nif l==h:\n print(-1)\n return\nc=0\nfor i in range(l+1,h):\n if s[i]=='|':\n c+=1\nprint(c+4)\n", "from sys import stdin\ns=stdin.readline().strip()\nx=-1\nfor i in range(len(s)):\n if s[i]==\"[\":\n x=i\n break\ny=-1\nfor i in range(len(s)-1,-1,-1):\n if s[i]==\"]\":\n y=i\n break\nif x==-1 or y==-1 or y<x:\n print(-1)\n return\nx1=-1\nfor i in range(x,y):\n if s[i]==\":\":\n x1=i\n break\ny1=-1\nfor i in range(y-1,x,-1):\n if s[i]==\":\":\n y1=i\n break\nif x1==-1 or y1==-1 or y1<=x1:\n print(-1)\n return\nans=4\nfor i in range(x1,y1):\n if s[i]==\"|\":\n ans+=1\nprint(ans)\n", "s = str(input().strip())\ni = 0\nn = len(s)\nwhile i < n and s[i] != '[':\n i+=1\nif(i == n):\n print(-1)\n return\nj = n-1\nwhile j > i and s[j] != ']':\n j-=1\nif(j <= i):\n print(-1)\n return\nwhile i < j and s[i] != ':':\n i+=1\nif(i == j):\n print(-1)\n return\nwhile j > i and s[j] != ':':\n j-=1\nif(j == i):\n print(-1)\n return\nk = i+1\nc = 0\nwhile k < j:\n if(s[k] == '|'):\n c+=1\n k+=1\nprint(c+4)\n", "import sys\ns = input()\nl = len(s)\ns_list = [x for x in s]\n\ncounter = 0\ntry:\n\ta = s_list.index('[')\n\tcounter += a\n\ts_list = s_list[a + 1:]\nexcept:\n\tprint(-1)\n\treturn\n\ntry:\n\ta = s_list.index(':')\n\tcounter += a\n\ts_list = s_list[a + 1:]\nexcept:\n\tprint(-1)\n\treturn\n\ns_list_rev = s_list.copy()\ns_list_rev.reverse()\n\ntry:\n\tb = s_list_rev.index(']')\n\tcounter += b\n\ts_list_rev = s_list_rev[b+1:]\nexcept:\n\tprint(-1)\n\treturn\n\ntry:\n\tb = s_list_rev.index(':')\n\tcounter += b\n\ts_list_rev = s_list_rev[b+1:]\nexcept:\n\tprint(-1)\n\treturn\ns_list_rev = [x for x in s_list_rev if x != '|']\ncounter += len(s_list_rev)\nprint(l - counter)", "MOD = 10**9 + 7\nI = lambda:list(map(int,input().split()))\n\ns = input()\nres = 0\nn = len(s)\nst = -1\ne = -1\nfor i in range(n):\n if s[i] == '[':\n st = i\n break\nfor i in range(n-1, -1, -1):\n if s[i] == ']':\n e = i\n break\n# print(st , e)\nif st > e or st == -1 or e == -1:\n print(-1)\n return\na = -1\nb = -1\nfor i in range(st, e):\n if s[i] == ':':\n a = i\n break\nfor i in range(e, st, -1):\n if s[i] == ':':\n b = i\n break\nif a == b or a == -1 or b == -1:\n print(-1)\n return\ncount = 0\nfor i in range(a, b):\n if s[i] == '|':\n count += 1\nprint(4 + count)", "s=input()\nst=\"\"\nidx=-1\nfor i in range(len(s)):\n if s[i]=='[':\n idx=i\n break\nif idx==-1:\n print(-1)\n return\nidxl=-1\nfor i in range(len(s)-1,-1,-1):\n if s[i]==']' and i>idx:\n idxl=i\n break\nif idxl==-1:\n print(-1)\n return\ncol=col2=-1\nfor i in range(len(s)):\n if s[i]==':' and i>idx and i<idxl:\n col=i\n break\nif col==-1:\n print(-1)\n return\nfor i in range(len(s)-1,-1,-1):\n if s[i]==':' and i>col and i<idxl:\n col2=i\n break\nif col2==-1:\n print(-1)\n return\nans=0\nfor i in range(col+1,col2):\n if s[i]=='|':\n ans+=1\nprint(4+ans)\n \n\n\n", "s = input()\nrev = s[::-1]\n\nleft = s.find(\"[\")\nif left != -1:\n left = s.find(\":\", left)\n\nright = rev.find(\"]\")\nif right != -1:\n right = rev.find(\":\", right)\n\nif left == -1 or right == -1:\n print(-1)\n return\nright = len(s)-right-1\nif left >= right:\n print(-1)\n return\n\nprint(4 + s[left:right].count(\"|\"))\n", "def ba(s):\n c1 = s.find('[')\n c2 = s.find(':', c1+1)\n c3 = s.rfind(']', c2+1)\n c4 = s.rfind(':', c2+1, c3)\n if -1 in [c1, c2, c3, c4]:\n return -1\n return s.count('|', c2, c4)+4\n\n\nprint(ba(input()))\n\n", "s = input()\nif '[' in s and ']' in s:\n a = s.index('[') + 1\n b = len(s)-s[::-1].index(']') - 1\nelse:\n print(-1)\n return\ns = s[a:b]\nif s.count(':') >= 2:\n a = s.index(':')+1\n b = len(s)-s[::-1].index(':')-1\nelse:\n print(-1)\n return\nc = 0\nfor el in s[a:b]:\n if el =='|':\n c += 1\nprint(4 + c)", "s = input()\n\nb = [0]*len(s)\n\nob = 0\ncc = 0\np = -1\nq = -1\n\ncount = 0\n\nfor ind,c in enumerate(s):\n if c == '[':\n ob = 1\n elif c == ':' and p >= 0:\n q = ind\n elif c == ':' and ob == 1 and p < 0:\n p = ind\n elif c == ']' and q >= 0:\n cc = q\n elif c == '|':\n count += 1\n b[ind] = count\n\nif cc > 0:\n print( 4 + b[cc]-b[p])\nelse:\n print(-1)\n", "s = input()\nif '[' in s and ']' in s and ':' in s:\n e = s.count(':')\n if e<2:\n print(-1)\n else:\n a = s.index('[')\n b = len(s)-1-s[::-1].index(']')\n if b<a:\n print(-1)\n else:\n if s[a+1:b].count(':')<2:\n print(-1)\n else:\n st1 = True\n count = 0\n for i in range(a+1, b):\n if st1 and s[i]==':':\n pos1 = i\n st1 = False\n if s[i]==':':\n pos2 = i\n \n for i in range(pos1+1, pos2):\n if s[i]=='|':\n count+=1\n \n print(count+4)\nelse:\n print(-1) ", "s=input()\ni1=-1\ni2=-1\nk1=-1\nk2=-1\nc=0\nfor i in range(len(s)):\n if(s[i]=='['):\n i1=i\n break\nfor i in range(len(s)-1,-1,-1):\n if(s[i]==']'):\n i2=i\n break\nfor i in range(i1,i2+1):\n if(s[i]==':'):\n k1=i\n break\nfor i in range(i2,i1-1,-1):\n if(s[i]==':'):\n k2=i\n break\nfor i in range(k1,k2+1):\n if(s[i]=='|'):\n c+=1\n\nif(i1==-1 or i2==-1 or i1>=i2 or k1==-1 or k2==-1 or k1==k2):\n print(-1)\nelse:\n print(4+c)", "s = input()\nl = 0\nend = 0\ni = 1\n\nwhile i <= len(s):\n if l == 0 and s[-i] == ']':\n l += 1\n elif l == 1 and s[-i] == ':':\n l += 1\n end = len(s) - i\n break\n i += 1\n\nif l < 2:\n print(-1)\n return\n\nfor i in range(0, end):\n if l >= 4 and s[i] == '|':\n l += 1\n elif l == 2 and s[i] == '[':\n l += 1\n elif l == 3 and s[i] == ':':\n l += 1\n\nif l >= 4:\n print(l)\nelse:\n print(-1)"]
test_examples/solutions_problem_2.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["num = list(map(int, input()))\nbest = num[:]\nfor i in range(-1, -len(num) - 1, -1):\n if num[i] == 0:\n continue\n num[i] -= 1\n for j in range(i + 1, 0):\n num[j] = 9\n if sum(num) > sum(best):\n best = num[:]\ns = ''.join(map(str, best)).lstrip('0')\nprint(s)\n", "s_num = input()\nnum = int(s_num)\ndigs = [int(s_num[i]) for i in range(len(s_num))]\n\nmax_sum = sum(digs)\nres = num\nfor i in range(len(s_num)):\n if (digs[i] != 0):\n digs[i] -= 1\n n_sum = sum(digs[:i + 1]) + 9 * (len(s_num) - i - 1)\n if n_sum >= max_sum:\n n_res = int(''.join([str(digs[i]) for i in range(i + 1)]) + '9' * (len(s_num) - i - 1))\n if (n_sum == max_sum):\n res = max(n_res, res)\n else:\n res = n_res\n max_sum = n_sum\n\n digs[i] += 1\nprint(res)\n", "a=int(input())\nif(a//10==0):\n print(a)\n return\nk=9\nwhile(k<a):\n k=k*10+9\nif(k==a):\n print(k)\nelse:\n k//=10\n k=int(str(a)[0]+str(k))\n i=len(str(k))-1\n z=k\n while(z>a):\n z=int(str(k)[0:i]+str(int(str(k)[i])-1)+str(k)[i+1:len(str(k))])\n i-=1\n print(z) ", "x = int(input())\nif x < 10:\n print(x)\nelif x == int(str(x)[0] + '9'*(len(str(x))-1)):\n print(x)\nelse:\n a = str(x)[0] + '9' * (len(str(x)) - 1)\n a = list(a)\n for i in range(len(a) - 1, -1, -1):\n k = a[i]\n a[i] = str(int(a[i]) - 1)\n if x >= int(''.join(a)):\n print(int(''.join(a)))\n break\n a[i] = k\n", "def sum_str(y):\n return sum(map(int, str(y)))\n\n\nx = input()\nlength = len(x)\nbad_answer = str(int(x[0]) - 1) + '9' * (length - 1) \ntotal = sum_str(bad_answer)\n\n\nif length == 1 or sum_str(x) >= total:\n print(x)\nelse:\n for i in range(length - 1, 0, -1):\n new_total = 9 * (length - i)\n new_answer = str(int(x[:i]) - 1)\n new_total += sum_str(new_answer)\n\n if new_total >= total:\n new_answer = new_answer if new_answer != '0' else ''\n print(new_answer + '9' * (length - i))\n break\n else:\n print(bad_answer)\n", "import sys\n\ndef calc(s):\n res =0\n for c in s:\n res+= int(c)\n return res\n\n\ns = list(sys.stdin.readline().rstrip())\nbest = \"\".join(s) \ncount = calc(s)\n\ni = len(s)-1\nwhile i!=0:\n i-=1\n if s[i+1]!= '9':\n s[i+1] = '9'\n while s[i]=='0':\n s[i]='9'\n i-=1\n s[i] = chr(ord(s[i])-1)\n c = calc(s)\n if count < c:\n count = c\n best = \"\".join(s)\n\nif best[0] == '0':\n best = best[1:]\n\nprint(best)", "x = input()\nn = len(x)\nif n == 1:\n print(x)\n return\nans = \"\"\ns = 0\nps = 0\npn = \"\"\nfor i in range(n):\n ts = ps + int(x[i]) - 1 + 9 * (n - i - 1)\n if ts >= s:\n ans = pn + str(int(x[i]) - 1) + \"9\" * (n - i - 1)\n s = ts\n ps += int(x[i])\n pn += x[i]\nif ps >= s:\n ans = pn\nprint(int(ans))", "n = int(input())\n\ndef f(numb):\n lst = [numb]\n cap = 10\n\n while numb // cap > 0:\n lst.append((numb // cap - 1) * cap + cap - 1)\n cap *= 10\n\n return lst\n\ndef g(numb):\n lst = []\n while numb != 0:\n lst.append(numb % 10)\n numb //= 10\n\n return lst\n\n\nmaximum = max([sum(g(i)) for i in f(n)])\n\nmaximum = [i for i in f(n) if maximum == sum(g(i))]\n\nprint(max(maximum))", "\"\"\" Created by Shahen Kosyan on 3/11/17 \"\"\"\n\ndef __starting_point():\n x = input()\n\n if int(x) < 10:\n print(x)\n return\n\n arr = [int(a) for a in list(x)]\n x_sum = sum(arr)\n\n i = len(arr) - 1\n answer = ''\n while i > 0:\n if arr[i] != 9 and arr[i] != 8:\n arr[i - 1] -= 1\n answer = '9' + answer\n else:\n change = False\n for j in range(i - 1, 0, -1):\n if arr[j] < 9:\n change = True\n break\n\n if arr[i] == 8 and change:\n answer = '9' + answer\n arr[i - 1] -= 1\n else:\n if not change:\n answer = str(arr[i]) + answer\n else:\n answer = '9' + answer\n\n if i == 1 and arr[0] != 0:\n answer = str(arr[0]) + answer\n i -= 1\n\n answer = [int(a) for a in list(answer)]\n if x_sum == sum(answer):\n print(x)\n else:\n answer = [str(a) for a in answer]\n print(''.join(answer))\n\n__starting_point()", "x=input()\nl=len(x)\nx=int(x)\ns='9'*l\nsx=str(x)\nm=int(s)\nc=0\nwhile c!=1:\n if m>x:\n m=m-10**(l-1)\n else:\n c=1\nsm=str(m)\nmm=[] \nfor i in range(len(sm)):\n mm.append(int(sm[i]))\nxx=[] \nfor i in range(l):\n xx.append(int(sx[i]))\nif m==x:\n print(m)\nelif sum(xx)==sum(mm):\n print(x)\nelse:\n k=len(xx)-1\n while k>=0:\n if sum(xx)<sum(mm):\n if xx[k]==9:\n k-=1\n else:\n xx[k]=9\n xx[k-1]-=1\n k-=1\n else:\n if xx[0]==0:\n xx.remove(0)\n for b in range(len(xx)):\n xx[b]=str(xx[b])\n ww=''.join(xx)\n print(ww)\n break", "x = input()\nvariants = [x] + [str(int(x[:i]) - 1) +\n '9' * (len(x) - i) for i in range(1, len(x))]\nprint(int(max(variants, key=lambda x: (sum(map(int, x)), int(x)))))\n", "def sum_div(n):\n summa = 0\n while n > 0:\n summa = summa + n % 10\n n = n // 10\n return summa\n\n\ndef run(n):\n l_n = len(n)\n left = ''\n if l_n > 2 and '9' * l_n != n and n[1] == '9' and '9' * (l_n - 1) != n[1:]:\n left = n[0]\n n = n[1:]\n while l_n > 1 and n[1] == '9':\n left += n[1]\n n = n[1:]\n l_n = len(n)\n l_n = len(n)\n if len(n) == 1:\n return n\n elif '9' * (l_n - 1) == n[1:]:\n return left + n\n elif n[0] != '1':\n min_number = int(str(int(n[0]) - 1) + '9' * (l_n - 1))\n if sum_div(min_number) > sum_div(int(n)):\n return left + str(min_number)\n else:\n return left + n\n else:\n min_number = int('9' * (l_n - 1)) if l_n > 1 else 0\n if sum_div(min_number) > sum_div(int(n)):\n return left + str(min_number)\n else:\n return left + n\n\n\nn = input()\nprint(run(n))\n", "#This code is dedicated to Olya S.\n\ndef e(x):\n s=0\n while x>0:\n s+=x%10\n x//=10\n return s\n\ndef down(x):\n l=len(x)-1\n return str(int(x[0])-1)+'9'*l\n\nn=input()\nif len(n)>1 and n[1]=='9':\n print(n[0],end='')\n n=n[1:]\n while len(n)>1 and n[0]=='9' and n[1]=='9':\n print('9',end='')\n n=n[1:]\n\nif e(int(n))>=e(int(down(n))):\n print(n)\nelse:\n print(int(down(n)))\n\n \n \n\n\n\n \n\n", "def sum_n(n):\n l = len(n)\n\n summ = 0\n for i in range(l):\n summ += int(n[i])\n\n return summ\n\ndef transfer(x, i):\n x = list(x)\n \n x[i+1] = '9'\n if x[i] != '0':\n x[i] = str(int(x[i])-1)\n else:\n j = i\n while (j > 0) and (int(x[j]) == 0):\n x[j] = '9'\n j -= 1\n x[j] = str(int(x[j])-1)\n if (x[0] == '0'):\n del x[0]\n\n return x\n\nx = list(input())\nmax_cifr = sum_n(x)\nmaxnum = x\nres = ''\n\nfor i in range(len(x)-2, -1, -1):\n x = transfer(x, i)\n if(max_cifr < sum_n(x)):\n max_cifr = sum_n(x)\n maxnum = x\n\nfor i in range(len(maxnum)):\n res = res+maxnum[i]\n \nprint(res)\n", "x = input()\nsum = 0\nfor i in x:\n temp = int(i)\n sum += temp\n\nxlen = len(x)\none = int(x[0])\ntry:\n two = int(x[1])\nexcept:\n two = 0\n\nif (two == 9):\n count = 1\n for i in range(1, xlen):\n z = int(x[i])\n if (z == 9):\n count = i\n else:\n break\n answ = x[0:count] + \"8\" + (\"9\" * (xlen - count - 1))\nelif (one == 1):\n answ = '9' * (xlen - 1)\nelse:\n answ = str((one - 1)) + (\"9\" * (xlen-1))\n\nansw = str(answ)\nsumansw = 0\nfor i in answ:\n temp = int(i)\n sumansw += temp\n\nif (sum >= sumansw):\n print(x)\nelse:\n print(answ)", "def sum1(x): # \u043f\u043e\u0434\u0441\u0447\u0451\u0442 \u0441\u0443\u043c\u043c\u044b \u0446\u0438\u0444\u0440 \u0447\u0438\u0441\u043b\u0430 x\n summa = 0\n for i in x:\n summa += int(i)\n return summa\n\n\nx = input()\nc = sum1(x)\nresult = int(x)\nn = len(x) - 1\nj = n\nfor i in range(0, n):\n if x[i] != '0':\n ni = int(x[i]) - 1 # \u0443\u043c\u0435\u043d\u044c\u0448\u0430\u044e i-\u044b\u0439 \u0440\u0430\u0437\u0440\u044f\u0434 \u043d\u0430 1\n xi = x[0:i] + str(ni) + '9' * j # \u0441\u0442\u0440\u043e\u044e \u043d\u043e\u0432\u043e\u0435 \u0447\u0438\u0441\u043b\u043e\n j -= 1\n ci = sum1(xi)\n if c < ci:\n c = ci\n result = int(xi)\n elif c == ci and result < int(xi):\n result = int(xi)\n else:\n j -= 1\n continue\nprint(result)\n", "def f(n, k):\n n = str(n)\n if n[k] == \"0\":\n return f(n, k - 1)\n a = []\n for i in n:\n a.append(int(i))\n n = a\n n[k] = int(n[k]) - 1\n n[k + 1::] = [9] * (len(n) - k - 1)\n return n\na = input()\nn = len(a)\nans = [int(x) for x in a]\nms = sum(ans)\nfor i in range(0, n):\n ca = f(a, i)\n cs = sum(ca)\n if cs> ms:\n ans = ca\n ms = cs\n elif cs == ms:\n if int(''.join([str(_) for _ in ca])) > int(''.join([str(_) for _ in ans])):\n ans = ca\nprint(int(''.join([str(_) for _ in ans])))", "n = int(input().strip())\n\ns = []\nwhile n > 0:\n s.append(n % 10)\n n //= 10\ns = s[::-1]\n\nn = len(s)\nans = 0\nbest = -1\nfor i in range(n):\n res = sum(s[:i + 1]) - 1 + 9 * (n - i - 1)\n if res >= ans:\n ans = res\n best = i\n\ndef get(s, pos):\n ans = 0\n for i in range(len(s)):\n if i > pos:\n ans = ans * 10 + 9\n else:\n ans = ans * 10 + s[i]\n if i == pos:\n ans -= 1\n return ans\n\nif sum(s) >= ans:\n print(get(s, n))\nelse:\n print(get(s, best))\n\n", "def main():\n\n\tdef sum(x):\n\t\tres = 0\n\n\t\twhile x > 0:\n\t\t\tres += x % 10\n\t\t\tx //= 10\n\n\t\treturn res\n\n\tn = input()\n\tfirst = n[0]\n\tp = [1]\n\n\tfor i in range(1, 20):\n\t\tp.append(p[-1] * 10)\n\n\tdata = []\t\n\tfor i in range(len(n)):\n\t\tif i > 0 and n[i] == '0':\n\t\t\tcontinue\n\t\ttemp = n[:i] + str(max(0, int(n[i]) - 1)) + \"9\"* (len(n) - i - 1)\n\t\tdata.append((sum(int(temp)), int(temp)))\n\n\tdata.append((sum(int(n)), int(n)))\n\t\n\tdata.sort(reverse=True)\n\n\tprint(data[0][1])\n\n\treturn\n\ndef __starting_point():\n\tmain()\n__starting_point()", "def cnt_sum(str_num):\n\tsum = 0\n\tfor a in str_num:\n\t\tsum += ord(a) - ord('0')\n\treturn sum\n\nstr_a = input().strip()\nmax_sum = cnt_sum(str_a)\nans = str_a\ncnt_digit = len(str_a)\n\nfor i in range(cnt_digit - 1, -1, -1):\n\tif str_a[i] != '0':\n\t\tnew_str = str_a[:i] + chr(ord(str_a[i]) - 1) + '9'*(cnt_digit - i - 1)\n\t\tcur_sum = cnt_sum(new_str)\n\t\tif cur_sum > max_sum:\n\t\t\tmax_sum = cur_sum\n\t\t\tans = new_str\n\nprint(int(ans))\n", "def summaX(x):\n k=0\n for el in x:\n k+=int(el)\n return k\nn=input();N=[];Z=[]\nfor el in n:\n N.append(el)\nz=summaX(N)\nZ=N.copy()\nfor i in range(1,len(N)):\n if int(N[i])!=9:\n N[i-1]=int(N[i-1])-1\n for j in range(i,len(n)):\n N[j]=9\nif z>=summaX(N):\n for el in Z:\n print(el,end='')\nelse:\n if N[0]==0:\n N.pop(0)\n for el in N:\n print(el,end='')\n", "n = int(input())\n\ndef sumd(n):\n\tj = n\n\tsumn = 0\n\twhile j:\n\t\tsumn += j % 10\n\t\tj //= 10\n\treturn sumn\n\nj = n\nstrn = str(n)\nl = len(strn)\nsumn = sumd(n)\n\nstra = [i for i in str(n)]\ni = 1\nwhile i < l and stra[i] == '9':\n\ti += 1\nif (i != l):\n\tstra[i - 1] = str(int(stra[i - 1]) - 1)\n\twhile i < l:\n\t\tstra[i] = '9'\n\t\ti += 1\n\nss = ''\nfor i in range(l):\n\tss += stra[i]\nif ss[0] == '0':\n\tss = ss[1:]\nsn = int(ss)\n\nif sn < n and sumd(sn) <= sumn:\n\tss = strn\n\tsn = n\n\nprint(ss)\n", "from random import randint\n\ndef f(s):\n a = 0\n for i in s:\n a += int(i)\n return a\n\ndef solve(n):\n n1 = list(str(n))\n ans = 0\n maxx = 0\n for i in range(len(n1)):\n n2 = n1[:i] + [str(int(n1[i]) - 1)] + ['9' for j in range(len(n1) - i - 1)]\n if f(n2) >= maxx:\n maxx = f(n2)\n ans = n2\n if f(n1) >= maxx:\n maxx = f(n1)\n ans = n1\n return [int(''.join(ans)), maxx]\n\ndef tl(n):\n ans = 0\n maxx = 0\n for i in range(1, n + 1):\n if f(list(str(i))) >= maxx:\n maxx = f(list(str(i)))\n ans = i\n return [ans, maxx]\n\n'''for kkk in range(100):\n n = randint(1, 10 ** 5)\n c1 = solve(n)\n c2 = tl(n)\n if c1 != c2:\n print(n)\n print(c1)\n print(c2)\nprint('ok')'''\nn = int(input())\nprint(solve(n)[0])\n", "a = [1, 2, 3, 4, 5, 6, 7, 8, 9]\nfor length in range(2, 30):\n for first in range(1, 10):\n for pos in range(1, length):\n a.append(int(str(first) + '9' * (pos - 1) + '8' + '9' * (length - pos - 1)))\n a.append(int(str(first) + '9' * (length - 1)))\n \nn = int(input())\nl = 0\nr = len(a)\nwhile l < r - 1:\n middle = (l + r) // 2\n if (a[middle] <= n):\n l = middle\n else:\n r = middle\n \nprint(a[l])", "def get(s):\n ans = 0\n for i in s:\n ans += (ord(i) - ord('0'))\n return ans\n\n\ndef solve1():\n x = input()\n n = len(x)\n best_ans = x\n best_val = get(x)\n ans = str('' if int(x[0]) - 1 == 0 else int(x[0]) - 1) + '9' * (n - 1)\n if get(ans) > best_val or (get(ans) >= best_val and int(ans) > int(best_ans)):\n best_ans = ans\n best_val = get(ans)\n for i in range(1, n):\n #print(ans)\n ans = x[:i] + str(int(x[i]) - 1) + '9' * (n - i - 1)\n if get(ans) > best_val or (get(ans) >= best_val and int(ans) > int(best_ans)):\n best_ans = ans\n best_val = get(ans)\n return best_ans\n \nbest = [0] * 10000\ndef solve2():\n nonlocal best\n was = 0\n for i in range(1, 10000):\n if get(str(i)) >= was:\n best[i] = i\n was = get(str(i))\n else:\n best[i] = best[i - 1]\n \ndef stress():\n solve2()\n for i in range(1, 10000):\n if int(solve1(str(i))) != best[i]:\n print(i, best[i], solve1(str(i)))\n\n#stress()\nprint(solve1())"]
testing_util.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ import faulthandler
4
+ import platform
5
+
6
+ # used for debugging to time steps
7
+ from datetime import datetime
8
+
9
+ # to run the solution files we're using a timing based approach
10
+ import signal
11
+
12
+ import numpy as np
13
+ # for capturing the stdout
14
+ from io import StringIO
15
+ # used for testing the code that reads from input
16
+ from unittest.mock import patch, mock_open
17
+
18
+ from pyext import RuntimeModule
19
+
20
+ from enum import Enum
21
+ class CODE_TYPE(Enum):
22
+ call_based = 0
23
+ standard_input = 1
24
+
25
+ # stuff for setting up signal timer
26
+ class TimeoutException(Exception):
27
+ pass
28
+ def timeout_handler(signum, frame):
29
+ print("alarm went off")
30
+ #return
31
+ raise TimeoutException
32
+ signal.signal(signal.SIGALRM, timeout_handler)
33
+ timeout = 4 # seconds
34
+
35
+ # used to capture stdout as a list
36
+ # from https://stackoverflow.com/a/16571630/6416660
37
+ # alternative use redirect_stdout() from contextlib
38
+ class Capturing(list):
39
+ def __enter__(self):
40
+ self._stdout = sys.stdout
41
+ sys.stdout = self._stringio = StringIO()
42
+ # Make closing the StringIO a no-op
43
+ self._stringio.close = lambda x: 1
44
+ return self
45
+ def __exit__(self, *args):
46
+ self.extend(self._stringio.getvalue().splitlines())
47
+ del self._stringio # free up some memory
48
+ sys.stdout = self._stdout
49
+
50
+
51
+ def run_test(sample, test=None, debug=False):
52
+ """
53
+ if test(generated_code) is not None it'll try to run the code.
54
+ otherwise it'll just return an input and output pair.
55
+ """
56
+ # Disable functionalities that can make destructive changes to the test.
57
+ reliability_guard()
58
+
59
+ if debug:
60
+ print(f"start = {datetime.now().time()}")
61
+
62
+ try:
63
+ in_outs = json.loads(sample["input_output"])
64
+ except ValueError:
65
+ in_outs = None
66
+ if in_outs:
67
+ if in_outs.get("fn_name") is None:
68
+ which_type = CODE_TYPE.standard_input # Standard input
69
+ method_name = None
70
+ else:
71
+ which_type = CODE_TYPE.call_based # Call-based
72
+ method_name = in_outs["fn_name"]
73
+
74
+ if debug:
75
+ print(f"loaded input_output = {datetime.now().time()}")
76
+
77
+ if test is None:
78
+ return in_outs
79
+ elif test is not None:
80
+ results = []
81
+ sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"
82
+ if debug:
83
+ print(f"loading test code = {datetime.now().time()}")
84
+
85
+ if which_type == CODE_TYPE.call_based:
86
+ sol += test
87
+ if debug:
88
+ print(f"sol = {sol}")
89
+ signal.alarm(timeout)
90
+ try:
91
+ tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
92
+ if "class Solution" not in test:
93
+ tmp = tmp_sol
94
+ else:
95
+ tmp = tmp_sol.Solution()
96
+ signal.alarm(0)
97
+ except Exception as e:
98
+ signal.alarm(0)
99
+ if debug:
100
+ print(f"type 0 compilation error = {e}")
101
+ results.append(-2)
102
+ return results
103
+ signal.alarm(0)
104
+
105
+ elif which_type == CODE_TYPE.standard_input:
106
+ # sol
107
+ tmp_test = test.split("\n")
108
+
109
+ new_test = []
110
+ for x in tmp_test:
111
+ if (not x.startswith("from ")) and (not x.startswith("import ")):
112
+ new_test.append("\t" + x + "\n")
113
+ else:
114
+ new_test.append(x + "\n")
115
+ tmp_test = new_test
116
+
117
+ new_test = ""
118
+ started = False
119
+ for i in tmp_test:
120
+ if i.startswith("\t") and not started:
121
+ new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
122
+ new_test += "def code():\n"
123
+ new_test += i
124
+ started = True
125
+ elif started and ((i.startswith("from ")) or (i.startswith("import "))):
126
+ new_test += "\t" + i
127
+ else:
128
+ new_test += i
129
+ tmp_test = new_test
130
+
131
+ sol += tmp_test
132
+ if debug:
133
+ print(f"sol = {sol}")
134
+ method_name = "code"
135
+ signal.alarm(timeout)
136
+ try:
137
+ tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
138
+ tmp = tmp_sol
139
+ signal.alarm(0)
140
+ except Exception as e:
141
+ signal.alarm(0)
142
+ if debug:
143
+ print(f"type 1 compilation error = {e}")
144
+ results.append(-2)
145
+ return results
146
+ signal.alarm(0)
147
+ if debug:
148
+ print(f"get method = {datetime.now().time()}")
149
+
150
+ try:
151
+ method = getattr(tmp, method_name) # get_attr second arg must be str
152
+ except:
153
+ signal.alarm(0)
154
+ e = sys.exc_info()
155
+ print(f"unable to get function error = {e}")
156
+ results.append(-2)
157
+ return results
158
+
159
+ for index, inputs in enumerate(in_outs["inputs"]):
160
+ # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
161
+ try:
162
+ if isinstance(inputs[0], dict):
163
+ inputs = [{int(k): v for k,v in inputs[0].items()}]
164
+ except:
165
+ True
166
+ try:
167
+ if isinstance(in_outs["outputs"][index], dict):
168
+ in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index].items()}]
169
+ except:
170
+ True
171
+ try:
172
+ if isinstance(in_outs["outputs"][index][0], dict):
173
+ in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index][0].items()}]
174
+ except:
175
+ True
176
+
177
+ if debug:
178
+ print(f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}")
179
+ if which_type == CODE_TYPE.call_based: # Call-based
180
+ signal.alarm(timeout)
181
+ faulthandler.enable()
182
+ try:
183
+ output = method(*inputs)
184
+
185
+ # ground truth sequences are not tuples
186
+ if isinstance(output, tuple):
187
+ output = list(output)
188
+
189
+ tmp_result = output == in_outs["outputs"][index]
190
+ if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]:
191
+ tmp_result = tmp_result or (output == in_outs["outputs"][index][0])
192
+
193
+ # ground truth sequences are not tuples
194
+ try:
195
+ if isinstance(output[0], tuple):
196
+ tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0])
197
+ except:
198
+ True
199
+ results.append(tmp_result)
200
+
201
+ # reset the alarm
202
+ signal.alarm(0)
203
+ except Exception as e:
204
+ signal.alarm(0)
205
+ faulthandler.disable()
206
+ if debug:
207
+ print(f"Standard input runtime error or time limit exceeded error = {e}")
208
+ results.append(-1)
209
+ continue
210
+ faulthandler.disable()
211
+ signal.alarm(0)
212
+ if debug:
213
+ print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
214
+ elif which_type == CODE_TYPE.standard_input: # Standard input
215
+ faulthandler.enable()
216
+ signal.alarm(timeout)
217
+ passed = False
218
+
219
+ if isinstance(inputs, list):
220
+ inputs = "\n".join(inputs)
221
+ if isinstance(in_outs['outputs'][index], list):
222
+ in_outs['outputs'][index] = "\n".join(in_outs['outputs'][index])
223
+
224
+ with Capturing() as output:
225
+ try:
226
+ call_method(method, inputs)
227
+ # reset the alarm
228
+ signal.alarm(0)
229
+ passed = True
230
+ except Exception as e:
231
+ # runtime error or took too long
232
+ signal.alarm(0)
233
+ print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
234
+ results.append(-1)
235
+ signal.alarm(0)
236
+
237
+ if not passed:
238
+ if debug:
239
+ nl = "\n"
240
+ if not isinstance(inputs, list):
241
+ print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
242
+ else:
243
+ print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
244
+ continue
245
+
246
+ if passed and debug:
247
+ print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}")
248
+
249
+ if custom_compare_(output, in_outs['outputs'][index]):
250
+ tmp_result = True
251
+ results.append(tmp_result)
252
+ continue
253
+
254
+ # ground truth sequences are expressed as lists not tuples
255
+ if isinstance(output, tuple):
256
+ output = list(output)
257
+
258
+ tmp_result = False
259
+ try:
260
+ tmp_result = (output == [in_outs["outputs"][index]])
261
+ if isinstance(in_outs["outputs"][index], list):
262
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
263
+ if isinstance(output[0], str):
264
+ tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index])
265
+ except Exception as e:
266
+ if debug:
267
+ print(f"Failed check1 exception = {e}")
268
+ pass
269
+
270
+ if tmp_result == True:
271
+ results.append(tmp_result)
272
+ continue
273
+
274
+ # try one more time without \n
275
+ if isinstance(in_outs["outputs"][index], list):
276
+ for tmp_index, i in enumerate(in_outs["outputs"][index]):
277
+ in_outs["outputs"][index][tmp_index] = i.split("\n")
278
+ in_outs["outputs"][index][tmp_index] = [x.strip() for x in in_outs["outputs"][index][tmp_index] if x]
279
+ else:
280
+ in_outs["outputs"][index] = in_outs["outputs"][index].split("\n")
281
+ in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index]))
282
+ in_outs["outputs"][index] = list(map(lambda x:x.strip(), in_outs["outputs"][index]))
283
+
284
+ try:
285
+ tmp_result = (output == [in_outs["outputs"][index]])
286
+ if isinstance(in_outs["outputs"][index], list):
287
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
288
+ except Exception as e:
289
+ if debug:
290
+ print(f"Failed check2 exception = {e}")
291
+ pass
292
+
293
+ if tmp_result == True:
294
+ results.append(tmp_result)
295
+ continue
296
+
297
+ # try by converting the output into a split up list too
298
+ if isinstance(output, list):
299
+ output = list(filter(len, output))
300
+
301
+ if debug:
302
+ nl = "\n"
303
+ if not isinstance(inputs, list):
304
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
305
+ else:
306
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
307
+
308
+ if tmp_result == True:
309
+ results.append(tmp_result)
310
+ continue
311
+
312
+ try:
313
+ tmp_result = (output == [in_outs["outputs"][index]])
314
+ if isinstance(in_outs["outputs"][index], list):
315
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
316
+ except Exception as e:
317
+ if debug:
318
+ print(f"Failed check3 exception = {e}")
319
+ pass
320
+
321
+ try:
322
+ output_float = [float(e) for e in output]
323
+ gt_float = [float(e) for e in in_outs['outputs'][index]]
324
+ tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
325
+ except Exception as e:
326
+ pass
327
+ try:
328
+ if isinstance(output[0], list):
329
+ output_float = [float(e) for e in output[0]]
330
+ gt_float = [float(e) for e in in_outs['outputs'][index][0]]
331
+ tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
332
+ except Exception as e:
333
+ pass
334
+
335
+ if tmp_result == True:
336
+ results.append(tmp_result)
337
+ continue
338
+
339
+ # try by converting the stuff into split up list
340
+ if isinstance(in_outs["outputs"][index], list):
341
+ for tmp_index, i in enumerate(in_outs["outputs"][index]):
342
+ in_outs["outputs"][index][tmp_index] = set(i.split())
343
+ else:
344
+ in_outs["outputs"][index] = set(in_outs["outputs"][index].split())
345
+
346
+ try:
347
+ tmp_result = (output == in_outs["outputs"][index])
348
+ except Exception as e:
349
+ if debug:
350
+ print(f"Failed check4 exception = {e}")
351
+ continue
352
+
353
+ if tmp_result == True:
354
+ results.append(tmp_result)
355
+ continue
356
+
357
+ # try by converting the output into a split up list too
358
+ if isinstance(output, list):
359
+ for tmp_index, i in enumerate(output):
360
+ output[tmp_index] = i.split()
361
+ output = list(filter(len, output))
362
+ for tmp_index, i in enumerate(output):
363
+ output[tmp_index] = set(i)
364
+ else:
365
+ output = output.split()
366
+ output = list(filter(len, output))
367
+ output = set(output)
368
+
369
+ try:
370
+ tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index]))
371
+ except Exception as e:
372
+ if debug:
373
+ print(f"Failed check5 exception = {e}")
374
+
375
+
376
+ # if they are all numbers, round so that similar numbers are treated as identical
377
+ try:
378
+ tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\
379
+ set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index]))
380
+ except Exception as e:
381
+ if debug:
382
+ print(f"Failed check6 exception = {e}")
383
+
384
+ if tmp_result == True and debug:
385
+ print("PASSED")
386
+
387
+ results.append(tmp_result)
388
+
389
+ if debug:
390
+ nl = "\n"
391
+ if not isinstance(inputs, list):
392
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
393
+ else:
394
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
395
+
396
+
397
+ return results
398
+
399
+
400
+ def custom_compare_(output, ground_truth):
401
+
402
+ if isinstance(output, list):
403
+ output_1 = "\n".join(output)
404
+ if stripped_string_compare(output_1, ground_truth):
405
+ return True
406
+
407
+ if isinstance(output, list):
408
+ output_2 = [o.lstrip().rstrip() for o in output]
409
+ output_2 = "\n".join(output_2)
410
+ if stripped_string_compare(output_2, ground_truth):
411
+ return True
412
+
413
+ return False
414
+
415
+ def stripped_string_compare(s1, s2):
416
+ s1 = s1.lstrip().rstrip()
417
+ s2 = s2.lstrip().rstrip()
418
+ return s1 == s2
419
+
420
+ def call_method(method, inputs):
421
+
422
+ if isinstance(inputs, list):
423
+ inputs = "\n".join(inputs)
424
+
425
+ inputs_line_iterator = iter(inputs.split("\n"))
426
+
427
+ # sys.setrecursionlimit(10000)
428
+
429
+ # @patch('builtins.input', side_effect=inputs.split("\n"))
430
+ @patch('builtins.open', mock_open(read_data=inputs))
431
+ @patch('sys.stdin', StringIO(inputs))
432
+ @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator))
433
+ @patch('sys.stdin.readlines', lambda *args: inputs.split("\n"))
434
+ @patch('sys.stdin.read', lambda *args: inputs)
435
+ # @patch('sys.stdout.write', print)
436
+ def _inner_call_method(_method):
437
+ try:
438
+ return _method()
439
+ except SystemExit as e:
440
+ pass
441
+ finally:
442
+ pass
443
+ return _inner_call_method(method)
444
+
445
+
446
+
447
+
448
+ def reliability_guard(maximum_memory_bytes=None):
449
+ """
450
+ This disables various destructive functions and prevents the generated code
451
+ from interfering with the test (e.g. fork bomb, killing other processes,
452
+ removing filesystem files, etc.)
453
+ WARNING
454
+ This function is NOT a security sandbox. Untrusted code, including, model-
455
+ generated code, should not be blindly executed outside of one. See the
456
+ Codex paper for more information about OpenAI's code sandbox, and proceed
457
+ with caution.
458
+ """
459
+
460
+ if maximum_memory_bytes is not None:
461
+ import resource
462
+
463
+ resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
464
+ resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
465
+ if not platform.uname().system == "Darwin":
466
+ resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
467
+
468
+ faulthandler.disable()
469
+
470
+ import builtins
471
+
472
+ builtins.exit = None
473
+ builtins.quit = None
474
+
475
+ import os
476
+
477
+ os.environ["OMP_NUM_THREADS"] = "1"
478
+
479
+ os.kill = None
480
+ os.system = None
481
+ os.putenv = None
482
+ os.remove = None
483
+ os.removedirs = None
484
+ os.rmdir = None
485
+ os.fchdir = None
486
+ os.setuid = None
487
+ os.fork = None
488
+ os.forkpty = None
489
+ os.killpg = None
490
+ os.rename = None
491
+ os.renames = None
492
+ os.truncate = None
493
+ os.replace = None
494
+ os.unlink = None
495
+ os.fchmod = None
496
+ os.fchown = None
497
+ os.chmod = None
498
+ os.chown = None
499
+ os.chroot = None
500
+ os.fchdir = None
501
+ os.lchflags = None
502
+ os.lchmod = None
503
+ os.lchown = None
504
+ os.getcwd = None
505
+ os.chdir = None
506
+
507
+ import shutil
508
+
509
+ shutil.rmtree = None
510
+ shutil.move = None
511
+ shutil.chown = None
512
+
513
+ import subprocess
514
+
515
+ subprocess.Popen = None # type: ignore
516
+
517
+ __builtins__["help"] = None
518
+
519
+ import sys
520
+
521
+ sys.modules["ipdb"] = None
522
+ sys.modules["joblib"] = None
523
+ sys.modules["resource"] = None
524
+ sys.modules["psutil"] = None
525
+ sys.modules["tkinter"] = None
tests.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from evaluate import load
3
+
4
+ solution_sample1 = json.load(open("test_examples/solutions_problem_1.json", "r"))
5
+ solution_sample2 = json.load(open("test_examples/solutions_problem_2.json", "r"))
6
+ single_solutions = [solution_sample1[:1], solution_sample2[:1]]
7
+ multiple_solutions = [solution_sample1[:3], solution_sample2[:3]]
8
+ from evaluate import evaluator
9
+
10
+ metric = load("kmlcoding/apps_metric")
11
+ result_1 = metric.compute(predictions=single_solutions, level="all")
12
+ result_2 = metric.compute(predictions=multiple_solutions, level="all", k_list=[1, 2, 3])
13
+
14
+ assert result_1 == {'avg_accuracy': 1.0, 'strict_accuracy': 1.0, 'pass_at_k': None}
15
+ assert result_2 == {'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}
utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import json
3
+ import multiprocessing
4
+ import numpy as np
5
+ from typing import Dict
6
+ from datasets import load_dataset
7
+ from .testing_util import run_test
8
+
9
+ DATASET = "codeparrot/apps"
10
+ TIMEOUT = 10
11
+
12
+ def check_correctness(sample, generation, timeout, debug=True):
13
+ """Check correctness of code generation with a global timeout.
14
+ The global timeout is to catch some extreme/rare cases not handled by the timeouts
15
+ inside `run_test`"""
16
+ def _temp_run(sample, generation, debug, result):
17
+ result.append(run_test(sample, test=generation, debug=debug))
18
+
19
+ manager = multiprocessing.Manager()
20
+ result = manager.list()
21
+ p = multiprocessing.Process(target=_temp_run, args=(sample, generation, debug, result))
22
+ p.start()
23
+ p.join(timeout=timeout + 1)
24
+ if p.is_alive():
25
+ p.kill()
26
+ if not result:
27
+ in_outs = json.loads(sample["input_output"])
28
+ # consider that all tests failed
29
+ result = [[-1 for i in range(len(in_outs["inputs"]))]]
30
+ if debug:
31
+ print(f"global timeout")
32
+ return result[0]
33
+
34
+
35
+ def evaluate_generations(generations: list, level: str = "all", debug: bool = False):
36
+ """We take the list of code generations and try to compile them
37
+ and the run their corresponding unit tests which are retrieved from the APPS dataset.
38
+
39
+ Args:
40
+ generations: list of code generations (same order as samples in APPS dataset)
41
+ level: difficulty level used in the generation, can be "all", "introductory", "interview" or "competition"
42
+
43
+ Returns:
44
+ results: dictionary of results, key is the problem index, value is a list of results for each generation
45
+ [-2] = compile error, [-1] = runtime error [False] = failed test case [True] = passed test case
46
+ """
47
+
48
+ # generations are code generations in the same order of the dataset
49
+ apps_eval = load_dataset(DATASET, split="test", difficulties=[level])
50
+ results = {}
51
+ for index in range(len(generations)):
52
+ # code generations for problem (index)
53
+ problem_generations = generations[index]
54
+ # get corresponding samples from APPS dataset
55
+ sample = apps_eval[index]
56
+ res = []
57
+ # loop over the generations
58
+ for o_idx, o in enumerate(problem_generations):
59
+ curr_res = [-2]
60
+ try:
61
+ curr_res = check_correctness(sample, o, timeout=TIMEOUT, debug=debug)
62
+ if debug:
63
+ print(f"\nSuccessful compilation of task {index}!")
64
+ fixed = []
65
+ for e in curr_res:
66
+ if isinstance(e, np.ndarray):
67
+ e = e.item(0)
68
+ if isinstance(e, np.bool_):
69
+ e = bool(e)
70
+ fixed.append(e)
71
+ curr_res = fixed
72
+ if not np.all(curr_res):
73
+ if debug:
74
+ print(f"Results were not True for all test cases")
75
+ except Exception as e:
76
+ if debug:
77
+ print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
78
+ break
79
+ finally:
80
+ assert isinstance(curr_res, list)
81
+ res.append(curr_res)
82
+ results[index] = res
83
+ return results
84
+
85
+
86
+ def estimate_pass_at_k(num_samples, num_correct, k):
87
+ """Estimates pass@k of each problem and returns them in an array."""
88
+
89
+ def estimator(n: int, c: int, k: int) -> float:
90
+ """Calculates 1 - comb(n - c, k) / comb(n, k)."""
91
+ if n - c < k:
92
+ return 1.0
93
+ return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
94
+
95
+ if isinstance(num_samples, int):
96
+ num_samples_it = itertools.repeat(num_samples, len(num_correct))
97
+ else:
98
+ assert len(num_samples) == len(num_correct)
99
+ num_samples_it = iter(num_samples)
100
+
101
+ return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
102
+
103
+
104
+ def get_results(results: Dict[int, list], count_errors: bool = False, k_list: list = [1, 10, 100]):
105
+ """
106
+ Given the results evaluated against the testcases we output some statistics.
107
+ For single generations:
108
+ >>> example_results = {0: [[-2]], 1: [[False,False]], 2: [[True,True]], 3: [[False,True,False,True]], 4: [[-1,-1]]}
109
+ >>> get_results(example_results, count_errors=True)
110
+ Computing accuracy metrics...
111
+ number of compile errors = 1 avg = 0.2
112
+ number of runtime errors = 1 avg = 0.2
113
+ number of problems evaluated = 5
114
+ Average Accuracy : 0.3
115
+ Strict Accuracy : 0.2
116
+ {'avg_accuracy': 0.3, 'strict_accuracy': 0.2, 'pass_at_k': None}
117
+
118
+ For multiple generations:
119
+ >>> example_results = {0: [[-2], [True, True, True]], 1: [[-1,-1, -1], [True, False, True]]}
120
+ >>> get_results(example_results, k_list=[1, 2])
121
+ Computing pass@k metric for multiple generations...
122
+ {'pass@1': 0.25, 'pass@2': 0.5}
123
+ {'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 0.25, 'pass@2': 0.5}}
124
+ """
125
+
126
+ metrics = {"avg_accuracy": None, "strict_accuracy": None, "pass_at_k": None}
127
+
128
+ if len(results[0]) == 1:
129
+ # for single generations we compute average accuracy and stric accuracy: original APPS metrics
130
+ print("Computing accuracy metrics...")
131
+ res = []
132
+ per_prob_res = []
133
+ all_correct = []
134
+ for index in results:
135
+ problem_results = np.asarray(results[index])
136
+ res.extend(problem_results)
137
+ per_prob_res.append(np.mean(problem_results > 0))
138
+ all_correct.append(np.all(problem_results > 0))
139
+ # we count campilation and runtime errors once per pronlem
140
+ compile_errors = len([e for e in res if -2 in e])
141
+ runtime_errors = len([e for e in res if -1 in e])
142
+ total_testcases = len(res)
143
+ if count_errors:
144
+ print(f"number of compile errors = {compile_errors} avg = {compile_errors / total_testcases}")
145
+ print(f"number of runtime errors = {runtime_errors} avg = {runtime_errors / total_testcases}")
146
+ print(f"number of problems evaluated = {total_testcases}")
147
+
148
+ print(f"Average Accuracy : {np.mean(per_prob_res)}")
149
+ print(f"Strict Accuracy : {np.mean(all_correct)}")
150
+ metrics["avg_accuracy"] = np.mean(per_prob_res)
151
+ metrics["strict_accuracy"] = np.mean(all_correct)
152
+
153
+ else:
154
+ # for multiple generations we use pass@k metric used in the HumanEval benchmark
155
+ # we use strict accuracy, a generation is valid if it has to pass all the tests
156
+ print("Computing pass@k metric for multiple generations...")
157
+ # total is list with nb generations per task (task=index)
158
+ # correct is number of generations that passed all tests per task
159
+ total = []
160
+ correct = []
161
+ for index in results:
162
+ all_correct = []
163
+ for generation in results[index]:
164
+ gen = np.array(generation)
165
+ all_correct.append(np.all(gen>0))
166
+ total.append(len(all_correct))
167
+ correct.append(sum(all_correct))
168
+ total = np.array(total)
169
+ correct = np.array(correct)
170
+ ks = k_list
171
+ pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()}
172
+ print(pass_at_k)
173
+ metrics["pass_at_k"] = pass_at_k
174
+ return metrics
175
+
176
+ def compute_metrics(generations, level="all", k_list=[1, 10, 100], count_errors=True, debug=False):
177
+ """Return metrics for the given generations.
178
+ Args:
179
+ generations: list of code generations for each problem (each generation is a list of generations)
180
+ k_list: list of k values to compute pass@k when using multiple generations
181
+ count_errors: whether to count compilation and runtime errors when using single generations
182
+ level: difficulty level in APPS dataset that was used for the given generations (from: "all", "introductory", "interview", "competition")
183
+ Returns:
184
+ metrics: dict of metrics
185
+
186
+ Examples:
187
+
188
+ >>> import json
189
+ >>> # lists of solutions to the two first APPS problems (note not all solutions pass all tests)
190
+ >>> solution_sample1 = json.load(open("test_examples/solutions_problem_1.json", "r"))
191
+ >>> solution_sample2 = json.load(open("test_examples/solutions_problem_2.json", "r"))
192
+ >>> single_solutions = [solution_sample1[:1], solution_sample2[:1]]
193
+ >>> compute_metrics(single_solutions, level="all")
194
+ Computing accuracy metrics...
195
+ number of compile errors = 0 avg = 0.0
196
+ number of runtime errors = 0 avg = 0.0
197
+ number of problems evaluated = 2
198
+ Average Accuracy : 1.0
199
+ Strict Accuracy : 1.0
200
+ {'avg_accuracy': 1.0, 'strict_accuracy': 1.0, 'pass_at_k': None}
201
+ >>> multiple_solutions = [solution_sample1[:3], solution_sample2[:3]]
202
+ >>> compute_metrics(multiple_solutions, level="all", k_list=[1, 2, 3])
203
+ Computing pass@k metric for multiple generations...
204
+ {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}
205
+ {'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}
206
+ """
207
+ results = evaluate_generations(generations, level=level, debug=debug)
208
+ metrics = get_results(results, count_errors=count_errors, k_list=k_list)
209
+ return metrics
210
+
211
+ # import doctest
212
+ # doctest.testmod()