hollywoodfrancis commited on
Commit
b8ab4a2
·
verified ·
1 Parent(s): 6061012

Upload 11 files

Browse files
math_expert/__pycache__/config.cpython-312.pyc ADDED
Binary file (9.17 kB). View file
 
math_expert/__pycache__/expert.cpython-312.pyc ADDED
Binary file (1.24 kB). View file
 
math_expert/config.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Math Expert Configuration
2
+
3
+ # 1.1. Mathematical Domains and Specializations
4
+ MATH_DOMAINS = {
5
+ "algebra": {
6
+ "level": "expert",
7
+ "topics": [
8
+ "linear algebra",
9
+ "abstract algebra",
10
+ "polynomial equations",
11
+ "matrix operations",
12
+ "group theory",
13
+ "ring theory",
14
+ "field theory",
15
+ "representation theory",
16
+ "homological algebra",
17
+ "category theory",
18
+ "universal algebra",
19
+ "non-associative algebras",
20
+ "Lie algebras",
21
+ "quantum groups",
22
+ "Hopf algebras",
23
+ "K-theory"
24
+ ]
25
+ },
26
+ "calculus": {
27
+ "level": "expert",
28
+ "topics": [
29
+ "single variable calculus",
30
+ "multivariable calculus",
31
+ "differential equations",
32
+ "partial differential equations",
33
+ "vector calculus",
34
+ "complex analysis",
35
+ "functional analysis",
36
+ "measure theory",
37
+ "differential geometry",
38
+ "geometric measure theory",
39
+ "non-standard analysis",
40
+ "stochastic calculus",
41
+ "calculus of variations",
42
+ "symplectic geometry"
43
+ ]
44
+ },
45
+ "proof_writing": {
46
+ "level": "expert",
47
+ "topics": [
48
+ "induction",
49
+ "contradiction",
50
+ "direct proof",
51
+ "proof by cases",
52
+ "epsilon-delta proofs",
53
+ "existence proofs",
54
+ "uniqueness proofs",
55
+ "category theory proofs",
56
+ "homotopy type theory",
57
+ "model theory",
58
+ "proof theory",
59
+ "set theory",
60
+ "constructive mathematics",
61
+ "proof complexity"
62
+ ]
63
+ },
64
+ "probability": {
65
+ "level": "expert",
66
+ "topics": [
67
+ "probability theory",
68
+ "random variables",
69
+ "distributions",
70
+ "stochastic processes",
71
+ "Bayesian inference",
72
+ "Markov chains",
73
+ "measure-theoretic probability",
74
+ "stochastic calculus",
75
+ "martingales",
76
+ "large deviations",
77
+ "ergodic theory",
78
+ "random matrix theory",
79
+ "stochastic PDEs"
80
+ ]
81
+ },
82
+ "statistics": {
83
+ "level": "expert",
84
+ "topics": [
85
+ "descriptive statistics",
86
+ "inferential statistics",
87
+ "hypothesis testing",
88
+ "regression analysis",
89
+ "time series analysis",
90
+ "bayesian statistics",
91
+ "non-parametric methods",
92
+ "statistical learning theory",
93
+ "high-dimensional statistics",
94
+ "causal inference",
95
+ "spatial statistics",
96
+ "robust statistics",
97
+ "computational statistics"
98
+ ]
99
+ },
100
+ "number_theory": {
101
+ "level": "expert",
102
+ "topics": [
103
+ "prime numbers",
104
+ "modular arithmetic",
105
+ "diophantine equations",
106
+ "cryptography",
107
+ "analytic number theory",
108
+ "algebraic number theory",
109
+ "elliptic curves",
110
+ "automorphic forms",
111
+ "arithmetic geometry",
112
+ "p-adic analysis",
113
+ "analytic continuation",
114
+ "modular forms",
115
+ "zeta functions"
116
+ ]
117
+ },
118
+ "geometry": {
119
+ "level": "expert",
120
+ "topics": [
121
+ "euclidean geometry",
122
+ "non-euclidean geometry",
123
+ "differential geometry",
124
+ "topology",
125
+ "algebraic geometry",
126
+ "projective geometry",
127
+ "symplectic geometry",
128
+ "algebraic topology",
129
+ "geometric analysis",
130
+ "geometric group theory",
131
+ "Riemannian geometry",
132
+ "Kähler geometry",
133
+ "hyperbolic geometry"
134
+ ]
135
+ },
136
+ "combinatorics": {
137
+ "level": "expert",
138
+ "topics": [
139
+ "graph theory",
140
+ "enumerative combinatorics",
141
+ "combinatorial optimization",
142
+ "matroid theory",
143
+ "combinatorial designs",
144
+ "extremal combinatorics",
145
+ "probabilistic combinatorics",
146
+ "algebraic combinatorics",
147
+ "topological combinatorics",
148
+ "combinatorial geometry",
149
+ "Ramsey theory"
150
+ ]
151
+ },
152
+ "logic": {
153
+ "level": "expert",
154
+ "topics": [
155
+ "first-order logic",
156
+ "model theory",
157
+ "proof theory",
158
+ "set theory",
159
+ "computability theory",
160
+ "type theory",
161
+ "category theory",
162
+ "modal logic",
163
+ "temporal logic",
164
+ "constructive logic",
165
+ "intuitionistic logic",
166
+ "proof complexity"
167
+ ]
168
+ },
169
+ "theoretical_cs": {
170
+ "level": "expert",
171
+ "topics": [
172
+ "computational complexity",
173
+ "algorithms",
174
+ "cryptography",
175
+ "quantum computing",
176
+ "machine learning theory",
177
+ "formal verification",
178
+ "type systems",
179
+ "programming language theory",
180
+ "distributed computing",
181
+ "parallel algorithms",
182
+ "computational geometry",
183
+ "randomized algorithms"
184
+ ]
185
+ },
186
+ "applied_math": {
187
+ "level": "expert",
188
+ "topics": [
189
+ "numerical analysis",
190
+ "optimization",
191
+ "control theory",
192
+ "mathematical physics",
193
+ "fluid dynamics",
194
+ "quantum mechanics",
195
+ "relativity",
196
+ "mathematical biology",
197
+ "financial mathematics",
198
+ "signal processing",
199
+ "data assimilation",
200
+ "inverse problems"
201
+ ]
202
+ }
203
+ }
204
+
205
+ # 1.2. Core Tasks
206
+ CORE_TASKS = [
207
+ {
208
+ "task_type": "problem_solving",
209
+ "description": "Solve complex mathematical problems",
210
+ "example": "Prove the Riemann Hypothesis",
211
+ "difficulty_levels": ["basic", "intermediate", "advanced", "research_level", "open_problem"]
212
+ },
213
+ {
214
+ "task_type": "proof_writing",
215
+ "description": "Prove mathematical statements with advanced techniques",
216
+ "example": "Prove Fermat's Last Theorem using elliptic curves",
217
+ "proof_types": ["induction", "contradiction", "direct", "cases", "category_theory", "homotopy_type", "model_theory", "proof_complexity", "constructive"]
218
+ },
219
+ {
220
+ "task_type": "calculus_computation",
221
+ "description": "Perform advanced calculus operations",
222
+ "example": "Solve Navier-Stokes equations for turbulence",
223
+ "operation_types": ["differentiation", "integration", "limits", "functional_analysis", "measure_theory", "stochastic_calculus", "geometric_measure_theory"]
224
+ },
225
+ {
226
+ "task_type": "symbolic_computation",
227
+ "description": "Manipulate complex mathematical expressions",
228
+ "example": "Simplify tensor equations in general relativity",
229
+ "expression_types": ["polynomial", "rational", "trigonometric", "exponential", "tensor", "operator", "Lie_algebra", "Hopf_algebra"]
230
+ },
231
+ {
232
+ "task_type": "concept_explanation",
233
+ "description": "Explain advanced mathematical concepts",
234
+ "example": "Explain the Langlands program",
235
+ "explanation_types": ["definition", "intuition", "application", "example", "formal", "geometric", "historical", "pedagogical"]
236
+ },
237
+ {
238
+ "task_type": "statistical_analysis",
239
+ "description": "Perform advanced statistical analysis",
240
+ "example": "Analyze high-dimensional genomic data",
241
+ "statistical_methods": ["regression", "hypothesis_testing", "confidence_intervals", "bayesian_methods", "non_parametric", "causal_inference", "computational_methods"]
242
+ },
243
+ {
244
+ "task_type": "probability_calculation",
245
+ "description": "Calculate complex probabilities",
246
+ "example": "Calculate phase transitions in random matrix theory",
247
+ "distributions": ["binomial", "normal", "poisson", "exponential", "multivariate", "stochastic_processes", "random_matrix", "levy_processes"]
248
+ },
249
+ {
250
+ "task_type": "number_theory_problem",
251
+ "description": "Solve advanced number theory problems",
252
+ "example": "Prove the Birch and Swinnerton-Dyer conjecture",
253
+ "problem_types": ["prime", "modular", "diophantine", "analytic", "algebraic", "elliptic_curve", "modular_form"]
254
+ },
255
+ {
256
+ "task_type": "geometric_construction",
257
+ "description": "Construct and analyze complex geometric objects",
258
+ "example": "Construct a Calabi-Yau manifold",
259
+ "construction_types": ["euclidean", "non_euclidean", "projective", "differential", "algebraic", "symplectic", "topological"]
260
+ },
261
+ {
262
+ "task_type": "mathematical_modeling",
263
+ "description": "Create advanced mathematical models",
264
+ "example": "Model quantum field theory",
265
+ "model_types": ["continuous", "discrete", "stochastic", "partial_differential", "non_linear", "quantum", "statistical"]
266
+ },
267
+ {
268
+ "task_type": "proof_verification",
269
+ "description": "Verify complex mathematical proofs",
270
+ "example": "Verify the proof of the Four Color Theorem",
271
+ "verification_methods": ["formal_verification", "model_checking", "proof_assistant", "automated_reasoning", "interactive_theorem_proving"]
272
+ },
273
+ {
274
+ "task_type": "algorithm_design",
275
+ "description": "Design and analyze mathematical algorithms",
276
+ "example": "Design a quantum algorithm for factorization",
277
+ "algorithm_types": ["numerical", "combinatorial", "geometric", "algebraic", "probabilistic", "quantum", "parallel"]
278
+ },
279
+ {
280
+ "task_type": "research_paper_analysis",
281
+ "description": "Analyze and explain mathematical research papers",
282
+ "example": "Explain Wiles' proof of Fermat's Last Theorem",
283
+ "analysis_types": ["technical", "historical", "pedagogical", "critical", "extensional"]
284
+ },
285
+ {
286
+ "task_type": "open_problem_analysis",
287
+ "description": "Analyze and make progress on open mathematical problems",
288
+ "example": "Analyze the Collatz conjecture",
289
+ "problem_classes": ["number_theory", "combinatorics", "analysis", "algebra", "geometry", "probability"]
290
+ },
291
+ {
292
+ "task_type": "mathematical_philosophy",
293
+ "description": "Analyze philosophical aspects of mathematics",
294
+ "example": "Explain the foundations of mathematics",
295
+ "philosophical_topics": ["foundations", "philosophy_of_math", "logic", "set_theory", "constructivism", "intuitionism"]
296
+ },
297
+ {
298
+ "task_type": "mathematical_software_development",
299
+ "description": "Develop mathematical software and algorithms",
300
+ "example": "Implement a new numerical method",
301
+ "software_types": ["numerical", "symbolic", "proof_assistant", "visualization", "simulation", "optimization"]
302
+ }
303
+ ]
304
+
305
+ # Dataset Configuration
306
+ DATASETS = {
307
+ "proofnet": {
308
+ "source": "huggingface",
309
+ "dataset_name": "proofnet",
310
+ "split": "train",
311
+ "use_fields": ["problem", "solution", "proof_steps"]
312
+ },
313
+ "math_dataset": {
314
+ "source": "huggingface",
315
+ "dataset_name": "deepmind/mathematics_dataset",
316
+ "split": "train-hard",
317
+ "use_fields": ["question", "answer", "steps"]
318
+ },
319
+ "gsm8k": {
320
+ "source": "huggingface",
321
+ "dataset_name": "gsm8k",
322
+ "split": "train",
323
+ "use_fields": ["question", "answer"]
324
+ },
325
+ "mathlib": {
326
+ "source": "huggingface",
327
+ "dataset_name": "mathlib",
328
+ "split": "train",
329
+ "use_fields": ["theorem", "proof", "dependencies"]
330
+ },
331
+ "arxiv_math": {
332
+ "source": "huggingface",
333
+ "dataset_name": "arxiv_math",
334
+ "split": "train",
335
+ "use_fields": ["paper", "equations", "proofs"]
336
+ },
337
+ "clay_institute": {
338
+ "source": "huggingface",
339
+ "dataset_name": "clay_institute_problems",
340
+ "split": "train",
341
+ "use_fields": ["problem", "background", "current_status", "approaches"]
342
+ },
343
+ "open_problems": {
344
+ "source": "huggingface",
345
+ "dataset_name": "open_math_problems",
346
+ "split": "train",
347
+ "use_fields": ["problem", "category", "history", "attempts"]
348
+ },
349
+ "research_papers": {
350
+ "source": "huggingface",
351
+ "dataset_name": "math_research_papers",
352
+ "split": "train",
353
+ "use_fields": ["title", "abstract", "content", "proofs", "theorems"]
354
+ }
355
+ }
356
+
357
+ # Data Processing Configuration
358
+ DATA_PROCESSING = {
359
+ "format": "jsonl",
360
+ "normalization": {
361
+ "equations": "sympy",
362
+ "latex": "plaintext",
363
+ "proof_steps": "yaml",
364
+ "tensor_operations": "torch",
365
+ "quantum_operations": "qiskit",
366
+ "geometric_objects": "geometric_algebra",
367
+ "category_theory": "category_theory"
368
+ },
369
+ "validation": {
370
+ "min_steps": 2,
371
+ "max_steps": 200,
372
+ "min_length": 10,
373
+ "max_length": 100000
374
+ }
375
+ }
376
+
377
+ if __name__ == "__main__":
378
+ print("Math Expert Configuration Loaded")
379
+ print(f"Number of domains: {len(MATH_DOMAINS)}")
380
+ print(f"Number of tasks: {len(CORE_TASKS)}")
381
+ print(f"Number of datasets: {len(DATASETS)}")
math_expert/data_processor.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import yaml
3
+ import sympy
4
+ from sympy.parsing.latex import parse_latex
5
+ from huggingface_hub import hf_hub_download
6
+ from pathlib import Path
7
+ import jsonlines
8
+ from typing import Dict, List, Any
9
+
10
+ from config import DATASETS, DATA_PROCESSING
11
+
12
+ class MathDataProcessor:
13
+ def __init__(self):
14
+ self.processed_data = []
15
+ self.dataset_paths = {}
16
+ self.math_operations = {
17
+ "differentiation": self._process_differentiation,
18
+ "integration": self._process_integration,
19
+ "limits": self._process_limits,
20
+ "simplification": self._process_simplification,
21
+ "matrix": self._process_matrix,
22
+ "probability": self._process_probability,
23
+ "statistics": self._process_statistics
24
+ }
25
+
26
+ def download_dataset(self, dataset_name: str) -> Path:
27
+ """Download dataset from Hugging Face"""
28
+ if dataset_name not in DATASETS:
29
+ raise ValueError(f"Dataset {dataset_name} not defined in configuration")
30
+
31
+ dataset_config = DATASETS[dataset_name]
32
+ dataset_path = Path(f"data/{dataset_name}")
33
+
34
+ # Download from Hugging Face
35
+ hf_hub_download(
36
+ repo_id=dataset_config["dataset_name"],
37
+ filename=f"{dataset_config['split']}.jsonl",
38
+ local_dir=dataset_path
39
+ )
40
+
41
+ self.dataset_paths[dataset_name] = dataset_path
42
+ return dataset_path
43
+
44
+ def normalize_equation(self, equation: str) -> str:
45
+ """Normalize mathematical equations using sympy"""
46
+ try:
47
+ # Try to parse LaTeX first
48
+ if "\\" in equation:
49
+ eq = parse_latex(equation)
50
+ else:
51
+ eq = sympy.sympify(equation)
52
+ return str(eq)
53
+ except:
54
+ return equation
55
+
56
+ def process_proof_steps(self, steps: List[str]) -> List[Dict[str, str]]:
57
+ """Process proof steps into structured format"""
58
+ processed_steps = []
59
+
60
+ for step in steps:
61
+ try:
62
+ # Try to parse as YAML if it contains structured data
63
+ structured_step = yaml.safe_load(step)
64
+ if isinstance(structured_step, dict):
65
+ processed_steps.append(structured_step)
66
+ else:
67
+ processed_steps.append({"step": step})
68
+ except:
69
+ processed_steps.append({"step": step})
70
+
71
+ return processed_steps
72
+
73
+ def _process_differentiation(self, expression: str) -> str:
74
+ """Process and validate differentiation operations"""
75
+ x = sympy.Symbol('x')
76
+ try:
77
+ expr = sympy.sympify(expression)
78
+ derivative = sympy.diff(expr, x)
79
+ return str(derivative)
80
+ except:
81
+ return expression
82
+
83
+ def _process_integration(self, expression: str) -> str:
84
+ """Process and validate integration operations"""
85
+ x = sympy.Symbol('x')
86
+ try:
87
+ expr = sympy.sympify(expression)
88
+ integral = sympy.integrate(expr, x)
89
+ return str(integral)
90
+ except:
91
+ return expression
92
+
93
+ def _process_limits(self, expression: str) -> str:
94
+ """Process and validate limit operations"""
95
+ x = sympy.Symbol('x')
96
+ try:
97
+ expr = sympy.sympify(expression)
98
+ limit = sympy.limit(expr, x, sympy.oo)
99
+ return str(limit)
100
+ except:
101
+ return expression
102
+
103
+ def _process_simplification(self, expression: str) -> str:
104
+ """Process and validate expression simplification"""
105
+ try:
106
+ expr = sympy.sympify(expression)
107
+ simplified = sympy.simplify(expr)
108
+ return str(simplified)
109
+ except:
110
+ return expression
111
+
112
+ def _process_matrix(self, matrix_str: str) -> str:
113
+ """Process and validate matrix operations"""
114
+ try:
115
+ matrix = sympy.Matrix([[float(n) for n in row.split()]
116
+ for row in matrix_str.split(';')])
117
+ return str(matrix)
118
+ except:
119
+ return matrix_str
120
+
121
+ def _process_probability(self, problem: str) -> Dict:
122
+ """Process probability problems and extract key parameters"""
123
+ try:
124
+ # Basic parsing for probability problems
125
+ if "probability" in problem.lower():
126
+ return {
127
+ "type": "probability",
128
+ "parameters": self._extract_parameters(problem),
129
+ "distribution": self._identify_distribution(problem)
130
+ }
131
+ return {"type": "unknown"}
132
+ except:
133
+ return {"type": "unknown"}
134
+
135
+ def _process_statistics(self, data: str) -> Dict:
136
+ """Process statistical data and extract key metrics"""
137
+ try:
138
+ # Basic statistical processing
139
+ if "," in data:
140
+ numbers = [float(n) for n in data.split(',')]
141
+ return {
142
+ "mean": sum(numbers) / len(numbers),
143
+ "median": sorted(numbers)[len(numbers)//2],
144
+ "std_dev": self._calculate_std_dev(numbers)
145
+ }
146
+ return {"error": "Invalid data format"}
147
+ except:
148
+ return {"error": "Processing failed"}
149
+
150
+ def _extract_parameters(self, text: str) -> Dict:
151
+ """Extract parameters from mathematical text"""
152
+ parameters = {}
153
+ # Basic parameter extraction logic
154
+ if "=" in text:
155
+ parts = text.split("=")
156
+ parameters["equation"] = parts[0].strip()
157
+ parameters["value"] = parts[1].strip()
158
+ return parameters
159
+
160
+ def _identify_distribution(self, text: str) -> str:
161
+ """Identify probability distribution from text"""
162
+ distributions = {
163
+ "binomial": ["binomial", "bernoulli"],
164
+ "normal": ["normal", "gaussian"],
165
+ "poisson": ["poisson"],
166
+ "exponential": ["exponential"]
167
+ }
168
+
169
+ text_lower = text.lower()
170
+ for dist, keywords in distributions.items():
171
+ if any(keyword in text_lower for keyword in keywords):
172
+ return dist
173
+ return "unknown"
174
+
175
+ def _calculate_std_dev(self, numbers: List[float]) -> float:
176
+ """Calculate standard deviation"""
177
+ mean = sum(numbers) / len(numbers)
178
+ variance = sum((x - mean) ** 2 for x in numbers) / len(numbers)
179
+ return variance ** 0.5
180
+
181
+ def process_math_operation(self, operation_type: str, content: str) -> Any:
182
+ """Process a specific mathematical operation"""
183
+ if operation_type in self.math_operations:
184
+ return self.math_operations[operation_type](content)
185
+ return content
186
+
187
+ def validate_entry(self, entry: Dict[str, Any]) -> bool:
188
+ """Enhanced validation with mathematical checks"""
189
+ steps = entry.get("steps", [])
190
+ text = entry.get("question", "") + entry.get("answer", "")
191
+
192
+ # Basic validation
193
+ if len(steps) < DATA_PROCESSING["validation"]["min_steps"]:
194
+ return False
195
+
196
+ if len(text) < DATA_PROCESSING["validation"]["min_length"]:
197
+ return False
198
+
199
+ # Mathematical validation
200
+ try:
201
+ # Check if equations are valid
202
+ if "equation" in entry:
203
+ sympy.sympify(entry["equation"])
204
+
205
+ # Check if steps follow logical progression
206
+ if len(steps) > 1:
207
+ for i in range(len(steps) - 1):
208
+ if not self._check_step_continuity(steps[i], steps[i+1]):
209
+ return False
210
+
211
+ # Check for circular logic in proofs
212
+ if "proof" in entry:
213
+ if not self._check_proof_validity(entry["proof"]):
214
+ return False
215
+
216
+ return True
217
+
218
+ except:
219
+ return False
220
+
221
+ def _check_step_continuity(self, step1: str, step2: str) -> bool:
222
+ """Check if mathematical steps are logically connected"""
223
+ try:
224
+ # Basic check for logical progression
225
+ if "=" in step1 and "=" in step2:
226
+ s1 = step1.split("=")[1].strip()
227
+ s2 = step2.split("=")[0].strip()
228
+ return s1 == s2
229
+ return True
230
+ except:
231
+ return False
232
+
233
+ def _check_proof_validity(self, proof: str) -> bool:
234
+ """Check if a proof is logically valid"""
235
+ # Basic proof validation
236
+ if "assume" in proof.lower() and "therefore" not in proof.lower():
237
+ return False
238
+
239
+ if "contradiction" in proof.lower() and "false" not in proof.lower():
240
+ return False
241
+
242
+ return True
243
+
244
+ def process_dataset(self, dataset_name: str):
245
+ """Process a specific dataset according to its configuration"""
246
+ dataset_path = self.download_dataset(dataset_name)
247
+ dataset_config = DATASETS[dataset_name]
248
+
249
+ with jsonlines.open(dataset_path / f"{dataset_config['split']}.jsonl") as reader:
250
+ for entry in reader:
251
+ processed_entry = {}
252
+
253
+ # Process each field
254
+ for field in dataset_config["use_fields"]:
255
+ value = entry.get(field)
256
+ if value:
257
+ if field == "equation":
258
+ processed_entry[field] = self.normalize_equation(value)
259
+ elif field == "proof_steps":
260
+ processed_entry[field] = self.process_proof_steps(value)
261
+ else:
262
+ processed_entry[field] = value
263
+
264
+ # Validate and add if valid
265
+ if self.validate_entry(processed_entry):
266
+ self.processed_data.append(processed_entry)
267
+
268
+ def save_processed_data(self, output_path: str):
269
+ """Save processed data to JSONL format"""
270
+ with jsonlines.open(output_path, mode='w') as writer:
271
+ writer.write_all(self.processed_data)
272
+
273
+ if __name__ == "__main__":
274
+ processor = MathDataProcessor()
275
+
276
+ # Process all defined datasets
277
+ for dataset in DATASETS.keys():
278
+ processor.process_dataset(dataset)
279
+
280
+ # Save processed data
281
+ output_path = "processed_data/math_expert_data.jsonl"
282
+ processor.save_processed_data(output_path)
math_expert/expert.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Math Expert Module
3
+ """
4
+ from typing import Dict, Any, List
5
+
6
+ class MathExpert:
7
+ def __init__(self):
8
+ self.name = "math"
9
+ self.domains = ["mathematics", "calculus", "algebra"]
10
+
11
+ def handle_query(self, query: str, context: Dict[str, Any]) -> Dict[str, Any]:
12
+ return {
13
+ 'response': f"Math expert response to: {query}",
14
+ 'confidence': 0.9,
15
+ 'metadata': {'domains': self.domains}
16
+ }
17
+
18
+ def get_domains(self) -> List[str]:
19
+ return self.domains
math_expert/prepare_data.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from pathlib import Path
4
+ import sympy
5
+ from sympy.parsing.latex import parse_latex
6
+ from sympy.parsing.sympy_parser import parse_expr
7
+ from datasets import load_dataset
8
+ import jsonlines
9
+ from typing import Dict, List, Any
10
+ import sys
11
+ import psutil
12
+
13
+ class MathDataPreparer:
14
+ def __init__(self, output_dir: str = "processed_data"):
15
+ self.output_dir = Path(output_dir)
16
+ self.output_dir.mkdir(exist_ok=True)
17
+ self.datasets = {
18
+ "gsm8k": {
19
+ "source": "gsm8k",
20
+ "config": "main",
21
+ "split": "train",
22
+ "fields": ["question", "answer"]
23
+ },
24
+ "proofnet": {
25
+ "source": "hoskinson-center/proofnet",
26
+ "split": "validation",
27
+ "fields": ["problem", "solution", "proof_steps"]
28
+ }
29
+ }
30
+
31
+ def normalize_equation(self, equation: str) -> str:
32
+ """Normalize mathematical equations using sympy"""
33
+ try:
34
+ # Try LaTeX first
35
+ if "\\" in equation:
36
+ eq = parse_latex(equation)
37
+ # Then try markdown math
38
+ elif equation.startswith('$') and equation.endswith('$'):
39
+ eq = parse_expr(equation[1:-1])
40
+ # Then try regular expression
41
+ else:
42
+ eq = parse_expr(equation)
43
+ return str(eq)
44
+ except Exception as e:
45
+ print(f"Error normalizing equation: {equation}", file=sys.stderr)
46
+ return equation
47
+
48
+ def process_proof_steps(self, steps: List[str]) -> List[Dict[str, Any]]:
49
+ """Process and validate proof steps"""
50
+ processed_steps = []
51
+ for step in steps:
52
+ try:
53
+ # Basic validation
54
+ if not step.strip():
55
+ continue
56
+
57
+ # Try to parse as structured data
58
+ try:
59
+ structured_step = json.loads(step)
60
+ if isinstance(structured_step, dict):
61
+ processed_steps.append(structured_step)
62
+ continue
63
+ except json.JSONDecodeError:
64
+ pass
65
+
66
+ # Process as plain text
67
+ processed_steps.append({
68
+ "text": step.strip(),
69
+ "valid": True
70
+ })
71
+ except Exception as e:
72
+ print(f"Error processing proof step: {step}", file=sys.stderr)
73
+ processed_steps.append({
74
+ "text": step,
75
+ "valid": False,
76
+ "error": str(e)
77
+ })
78
+ return processed_steps
79
+
80
+ def process_gsm8k(self, dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
81
+ """Process GSM8K dataset"""
82
+ processed = []
83
+ for example in dataset:
84
+ try:
85
+ processed_example = {
86
+ "question": example["question"].strip(),
87
+ "answer": example["answer"].strip()
88
+ }
89
+
90
+ # Normalize equations in question
91
+ if "=" in processed_example["question"]:
92
+ processed_example["question"] = self.normalize_equation(processed_example["question"])
93
+
94
+ # Normalize equations in answer
95
+ if "=" in processed_example["answer"]:
96
+ processed_example["answer"] = self.normalize_equation(processed_example["answer"])
97
+
98
+ processed.append(processed_example)
99
+ except Exception as e:
100
+ print(f"Error processing GSM8K example: {str(e)}", file=sys.stderr)
101
+ return processed
102
+
103
+ def process_proofnet(self, dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
104
+ """Process ProofNet dataset"""
105
+ processed = []
106
+ error_count = 0
107
+
108
+ # First, let's print some dataset info
109
+ print("\nProofNet dataset info:")
110
+ print(f"Dataset type: {type(dataset)}")
111
+ if hasattr(dataset, 'features'):
112
+ print("\nDataset features:")
113
+ for feature, dtype in dataset.features.items():
114
+ print(f"{feature}: {dtype}")
115
+
116
+ # Print first example structure
117
+ if len(dataset) > 0:
118
+ first_example = dataset[0]
119
+ print("\nFirst example keys:", list(first_example.keys()))
120
+ print("\nFirst example preview:")
121
+ for key, value in first_example.items():
122
+ print(f"\n{key}:")
123
+ print(f"Type: {type(value)}")
124
+ if isinstance(value, str):
125
+ print(f"Length: {len(value)}")
126
+ elif isinstance(value, list):
127
+ print(f"List length: {len(value)}")
128
+ if len(value) > 0:
129
+ print(f"First item type: {type(value[0])}")
130
+ print("\n")
131
+
132
+ for idx, example in enumerate(dataset):
133
+ try:
134
+ processed_example = {
135
+ "problem": example.get("problem", "").strip(),
136
+ "solution": example.get("solution", "").strip(),
137
+ "proof_steps": []
138
+ }
139
+
140
+ # Handle proof steps
141
+ if "proof_steps" in example:
142
+ steps = example["proof_steps"]
143
+ print(f"\nExample {idx} proof steps info:")
144
+ print(f"Type: {type(steps)}")
145
+ if isinstance(steps, str):
146
+ print(f"Length: {len(steps)}")
147
+ # Try to split string into steps
148
+ steps = steps.split('\n')
149
+ print(f"Split into {len(steps)} steps")
150
+ elif isinstance(steps, list):
151
+ print(f"List length: {len(steps)}")
152
+ if len(steps) > 0:
153
+ print(f"First item type: {type(steps[0])}")
154
+ else:
155
+ print(f"Warning: Unexpected proof steps type: {type(steps)}")
156
+ steps = []
157
+
158
+ processed_example["proof_steps"] = self.process_proof_steps(steps)
159
+
160
+ # Normalize equations
161
+ for field in ["problem", "solution"]:
162
+ if "=" in processed_example[field]:
163
+ try:
164
+ processed_example[field] = self.normalize_equation(processed_example[field])
165
+ except Exception as e:
166
+ print(f"Error normalizing {field} in ProofNet example {idx}: {str(e)}")
167
+
168
+ processed.append(processed_example)
169
+ except Exception as e:
170
+ print(f"Error processing ProofNet example {idx}: {str(e)}")
171
+ error_count += 1
172
+
173
+ print(f"\nProcessed {len(processed)} examples from ProofNet")
174
+ print(f"Encountered {error_count} errors during processing")
175
+ return processed
176
+
177
+ def save_to_jsonl(self, data: List[Dict[str, Any]], filename: str):
178
+ """Save processed data to JSONL file"""
179
+ filepath = self.output_dir / filename
180
+ with jsonlines.open(filepath, mode='w') as writer:
181
+ writer.write_all(data)
182
+ return filepath
183
+
184
+ def print_memory_usage(self):
185
+ """Print current memory usage"""
186
+ process = psutil.Process()
187
+ memory_info = process.memory_info()
188
+ print(f"Current memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
189
+
190
+ def print_sample(self, data: List[Dict[str, Any]], count: int = 3):
191
+ """Print sample of processed data"""
192
+ print("\nSample data:")
193
+ for i, example in enumerate(data[:count]):
194
+ print(f"\nSample {i+1}:")
195
+ if "proof_steps" in example:
196
+ # For ProofNet samples, show proof steps
197
+ print(f"Problem: {example['problem']}")
198
+ print(f"Solution: {example['solution']}")
199
+ print("\nProof Steps:")
200
+ for step in example["proof_steps"]:
201
+ print(f"- {step['text']}")
202
+ else:
203
+ # For GSM8K samples
204
+ print(json.dumps(example, indent=2))
205
+
206
+ def main():
207
+ preparer = MathDataPreparer()
208
+
209
+ # Load and process GSM8K
210
+ print("\nProcessing GSM8K dataset...")
211
+ gsm8k_dataset = load_dataset("gsm8k", "main", split="train")
212
+ print(f"Loaded {len(gsm8k_dataset)} samples from GSM8K")
213
+
214
+ processed_gsm8k = preparer.process_gsm8k(gsm8k_dataset)
215
+ print(f"Processed {len(processed_gsm8k)} samples")
216
+
217
+ preparer.print_sample(processed_gsm8k)
218
+
219
+ # Save GSM8K
220
+ gsm8k_path = preparer.save_to_jsonl(processed_gsm8k, "gsm8k_processed.jsonl")
221
+ print(f"\nSaved GSM8K processed data to: {gsm8k_path}")
222
+
223
+ # Load and process ProofNet
224
+ print("\nProcessing ProofNet dataset...")
225
+ try:
226
+ proofnet_dataset = load_dataset("hoskinson-center/proofnet", split="validation")
227
+ print(f"Loaded {len(proofnet_dataset)} samples from ProofNet")
228
+
229
+ processed_proofnet = preparer.process_proofnet(proofnet_dataset)
230
+ print(f"Processed {len(processed_proofnet)} samples")
231
+
232
+ preparer.print_sample(processed_proofnet)
233
+
234
+ # Save ProofNet
235
+ proofnet_path = preparer.save_to_jsonl(processed_proofnet, "proofnet_processed.jsonl")
236
+ print(f"\nSaved ProofNet processed data to: {proofnet_path}")
237
+ except Exception as e:
238
+ print(f"Error processing ProofNet dataset: {str(e)}")
239
+ print("Continuing with GSM8K data only")
240
+
241
+ # Print memory usage
242
+ preparer.print_memory_usage()
243
+
244
+ if __name__ == "__main__":
245
+ main()
math_expert/processed_data/gsm8k_processed.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
math_expert/processed_data/proofnet_processed.jsonl ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"problem": "", "solution": "", "proof_steps": []}
2
+ {"problem": "", "solution": "", "proof_steps": []}
3
+ {"problem": "", "solution": "", "proof_steps": []}
4
+ {"problem": "", "solution": "", "proof_steps": []}
5
+ {"problem": "", "solution": "", "proof_steps": []}
6
+ {"problem": "", "solution": "", "proof_steps": []}
7
+ {"problem": "", "solution": "", "proof_steps": []}
8
+ {"problem": "", "solution": "", "proof_steps": []}
9
+ {"problem": "", "solution": "", "proof_steps": []}
10
+ {"problem": "", "solution": "", "proof_steps": []}
11
+ {"problem": "", "solution": "", "proof_steps": []}
12
+ {"problem": "", "solution": "", "proof_steps": []}
13
+ {"problem": "", "solution": "", "proof_steps": []}
14
+ {"problem": "", "solution": "", "proof_steps": []}
15
+ {"problem": "", "solution": "", "proof_steps": []}
16
+ {"problem": "", "solution": "", "proof_steps": []}
17
+ {"problem": "", "solution": "", "proof_steps": []}
18
+ {"problem": "", "solution": "", "proof_steps": []}
19
+ {"problem": "", "solution": "", "proof_steps": []}
20
+ {"problem": "", "solution": "", "proof_steps": []}
21
+ {"problem": "", "solution": "", "proof_steps": []}
22
+ {"problem": "", "solution": "", "proof_steps": []}
23
+ {"problem": "", "solution": "", "proof_steps": []}
24
+ {"problem": "", "solution": "", "proof_steps": []}
25
+ {"problem": "", "solution": "", "proof_steps": []}
26
+ {"problem": "", "solution": "", "proof_steps": []}
27
+ {"problem": "", "solution": "", "proof_steps": []}
28
+ {"problem": "", "solution": "", "proof_steps": []}
29
+ {"problem": "", "solution": "", "proof_steps": []}
30
+ {"problem": "", "solution": "", "proof_steps": []}
31
+ {"problem": "", "solution": "", "proof_steps": []}
32
+ {"problem": "", "solution": "", "proof_steps": []}
33
+ {"problem": "", "solution": "", "proof_steps": []}
34
+ {"problem": "", "solution": "", "proof_steps": []}
35
+ {"problem": "", "solution": "", "proof_steps": []}
36
+ {"problem": "", "solution": "", "proof_steps": []}
37
+ {"problem": "", "solution": "", "proof_steps": []}
38
+ {"problem": "", "solution": "", "proof_steps": []}
39
+ {"problem": "", "solution": "", "proof_steps": []}
40
+ {"problem": "", "solution": "", "proof_steps": []}
41
+ {"problem": "", "solution": "", "proof_steps": []}
42
+ {"problem": "", "solution": "", "proof_steps": []}
43
+ {"problem": "", "solution": "", "proof_steps": []}
44
+ {"problem": "", "solution": "", "proof_steps": []}
45
+ {"problem": "", "solution": "", "proof_steps": []}
46
+ {"problem": "", "solution": "", "proof_steps": []}
47
+ {"problem": "", "solution": "", "proof_steps": []}
48
+ {"problem": "", "solution": "", "proof_steps": []}
49
+ {"problem": "", "solution": "", "proof_steps": []}
50
+ {"problem": "", "solution": "", "proof_steps": []}
51
+ {"problem": "", "solution": "", "proof_steps": []}
52
+ {"problem": "", "solution": "", "proof_steps": []}
53
+ {"problem": "", "solution": "", "proof_steps": []}
54
+ {"problem": "", "solution": "", "proof_steps": []}
55
+ {"problem": "", "solution": "", "proof_steps": []}
56
+ {"problem": "", "solution": "", "proof_steps": []}
57
+ {"problem": "", "solution": "", "proof_steps": []}
58
+ {"problem": "", "solution": "", "proof_steps": []}
59
+ {"problem": "", "solution": "", "proof_steps": []}
60
+ {"problem": "", "solution": "", "proof_steps": []}
61
+ {"problem": "", "solution": "", "proof_steps": []}
62
+ {"problem": "", "solution": "", "proof_steps": []}
63
+ {"problem": "", "solution": "", "proof_steps": []}
64
+ {"problem": "", "solution": "", "proof_steps": []}
65
+ {"problem": "", "solution": "", "proof_steps": []}
66
+ {"problem": "", "solution": "", "proof_steps": []}
67
+ {"problem": "", "solution": "", "proof_steps": []}
68
+ {"problem": "", "solution": "", "proof_steps": []}
69
+ {"problem": "", "solution": "", "proof_steps": []}
70
+ {"problem": "", "solution": "", "proof_steps": []}
71
+ {"problem": "", "solution": "", "proof_steps": []}
72
+ {"problem": "", "solution": "", "proof_steps": []}
73
+ {"problem": "", "solution": "", "proof_steps": []}
74
+ {"problem": "", "solution": "", "proof_steps": []}
75
+ {"problem": "", "solution": "", "proof_steps": []}
76
+ {"problem": "", "solution": "", "proof_steps": []}
77
+ {"problem": "", "solution": "", "proof_steps": []}
78
+ {"problem": "", "solution": "", "proof_steps": []}
79
+ {"problem": "", "solution": "", "proof_steps": []}
80
+ {"problem": "", "solution": "", "proof_steps": []}
81
+ {"problem": "", "solution": "", "proof_steps": []}
82
+ {"problem": "", "solution": "", "proof_steps": []}
83
+ {"problem": "", "solution": "", "proof_steps": []}
84
+ {"problem": "", "solution": "", "proof_steps": []}
85
+ {"problem": "", "solution": "", "proof_steps": []}
86
+ {"problem": "", "solution": "", "proof_steps": []}
87
+ {"problem": "", "solution": "", "proof_steps": []}
88
+ {"problem": "", "solution": "", "proof_steps": []}
89
+ {"problem": "", "solution": "", "proof_steps": []}
90
+ {"problem": "", "solution": "", "proof_steps": []}
91
+ {"problem": "", "solution": "", "proof_steps": []}
92
+ {"problem": "", "solution": "", "proof_steps": []}
93
+ {"problem": "", "solution": "", "proof_steps": []}
94
+ {"problem": "", "solution": "", "proof_steps": []}
95
+ {"problem": "", "solution": "", "proof_steps": []}
96
+ {"problem": "", "solution": "", "proof_steps": []}
97
+ {"problem": "", "solution": "", "proof_steps": []}
98
+ {"problem": "", "solution": "", "proof_steps": []}
99
+ {"problem": "", "solution": "", "proof_steps": []}
100
+ {"problem": "", "solution": "", "proof_steps": []}
101
+ {"problem": "", "solution": "", "proof_steps": []}
102
+ {"problem": "", "solution": "", "proof_steps": []}
103
+ {"problem": "", "solution": "", "proof_steps": []}
104
+ {"problem": "", "solution": "", "proof_steps": []}
105
+ {"problem": "", "solution": "", "proof_steps": []}
106
+ {"problem": "", "solution": "", "proof_steps": []}
107
+ {"problem": "", "solution": "", "proof_steps": []}
108
+ {"problem": "", "solution": "", "proof_steps": []}
109
+ {"problem": "", "solution": "", "proof_steps": []}
110
+ {"problem": "", "solution": "", "proof_steps": []}
111
+ {"problem": "", "solution": "", "proof_steps": []}
112
+ {"problem": "", "solution": "", "proof_steps": []}
113
+ {"problem": "", "solution": "", "proof_steps": []}
114
+ {"problem": "", "solution": "", "proof_steps": []}
115
+ {"problem": "", "solution": "", "proof_steps": []}
116
+ {"problem": "", "solution": "", "proof_steps": []}
117
+ {"problem": "", "solution": "", "proof_steps": []}
118
+ {"problem": "", "solution": "", "proof_steps": []}
119
+ {"problem": "", "solution": "", "proof_steps": []}
120
+ {"problem": "", "solution": "", "proof_steps": []}
121
+ {"problem": "", "solution": "", "proof_steps": []}
122
+ {"problem": "", "solution": "", "proof_steps": []}
123
+ {"problem": "", "solution": "", "proof_steps": []}
124
+ {"problem": "", "solution": "", "proof_steps": []}
125
+ {"problem": "", "solution": "", "proof_steps": []}
126
+ {"problem": "", "solution": "", "proof_steps": []}
127
+ {"problem": "", "solution": "", "proof_steps": []}
128
+ {"problem": "", "solution": "", "proof_steps": []}
129
+ {"problem": "", "solution": "", "proof_steps": []}
130
+ {"problem": "", "solution": "", "proof_steps": []}
131
+ {"problem": "", "solution": "", "proof_steps": []}
132
+ {"problem": "", "solution": "", "proof_steps": []}
133
+ {"problem": "", "solution": "", "proof_steps": []}
134
+ {"problem": "", "solution": "", "proof_steps": []}
135
+ {"problem": "", "solution": "", "proof_steps": []}
136
+ {"problem": "", "solution": "", "proof_steps": []}
137
+ {"problem": "", "solution": "", "proof_steps": []}
138
+ {"problem": "", "solution": "", "proof_steps": []}
139
+ {"problem": "", "solution": "", "proof_steps": []}
140
+ {"problem": "", "solution": "", "proof_steps": []}
141
+ {"problem": "", "solution": "", "proof_steps": []}
142
+ {"problem": "", "solution": "", "proof_steps": []}
143
+ {"problem": "", "solution": "", "proof_steps": []}
144
+ {"problem": "", "solution": "", "proof_steps": []}
145
+ {"problem": "", "solution": "", "proof_steps": []}
146
+ {"problem": "", "solution": "", "proof_steps": []}
147
+ {"problem": "", "solution": "", "proof_steps": []}
148
+ {"problem": "", "solution": "", "proof_steps": []}
149
+ {"problem": "", "solution": "", "proof_steps": []}
150
+ {"problem": "", "solution": "", "proof_steps": []}
151
+ {"problem": "", "solution": "", "proof_steps": []}
152
+ {"problem": "", "solution": "", "proof_steps": []}
153
+ {"problem": "", "solution": "", "proof_steps": []}
154
+ {"problem": "", "solution": "", "proof_steps": []}
155
+ {"problem": "", "solution": "", "proof_steps": []}
156
+ {"problem": "", "solution": "", "proof_steps": []}
157
+ {"problem": "", "solution": "", "proof_steps": []}
158
+ {"problem": "", "solution": "", "proof_steps": []}
159
+ {"problem": "", "solution": "", "proof_steps": []}
160
+ {"problem": "", "solution": "", "proof_steps": []}
161
+ {"problem": "", "solution": "", "proof_steps": []}
162
+ {"problem": "", "solution": "", "proof_steps": []}
163
+ {"problem": "", "solution": "", "proof_steps": []}
164
+ {"problem": "", "solution": "", "proof_steps": []}
165
+ {"problem": "", "solution": "", "proof_steps": []}
166
+ {"problem": "", "solution": "", "proof_steps": []}
167
+ {"problem": "", "solution": "", "proof_steps": []}
168
+ {"problem": "", "solution": "", "proof_steps": []}
169
+ {"problem": "", "solution": "", "proof_steps": []}
170
+ {"problem": "", "solution": "", "proof_steps": []}
171
+ {"problem": "", "solution": "", "proof_steps": []}
172
+ {"problem": "", "solution": "", "proof_steps": []}
173
+ {"problem": "", "solution": "", "proof_steps": []}
174
+ {"problem": "", "solution": "", "proof_steps": []}
175
+ {"problem": "", "solution": "", "proof_steps": []}
176
+ {"problem": "", "solution": "", "proof_steps": []}
177
+ {"problem": "", "solution": "", "proof_steps": []}
178
+ {"problem": "", "solution": "", "proof_steps": []}
179
+ {"problem": "", "solution": "", "proof_steps": []}
180
+ {"problem": "", "solution": "", "proof_steps": []}
181
+ {"problem": "", "solution": "", "proof_steps": []}
182
+ {"problem": "", "solution": "", "proof_steps": []}
183
+ {"problem": "", "solution": "", "proof_steps": []}
184
+ {"problem": "", "solution": "", "proof_steps": []}
185
+ {"problem": "", "solution": "", "proof_steps": []}
math_expert/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.30.0
2
+ sympy>=1.11.1
3
+ torch>=2.0.0
4
+ numpy>=1.24.0
5
+ scipy>=1.10.0
6
+ pandas>=2.0.0
7
+ huggingface_hub>=0.16.0
8
+ jsonlines>=3.0.0
9
+ pyyaml>=5.4.1
10
+ datasets>=2.14.0
11
+ psutil>=5.9.0
math_expert/train.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TrainingArguments, Trainer
2
+ from datasets import load_dataset
3
+ import jsonlines
4
+ import os
5
+ import torch
6
+ from model import Transformer, ModelArgs
7
+ from tokenizer import Tokenizer
8
+
9
+ class MathDataset(torch.utils.data.Dataset):
10
+ def __init__(self, tokenizer, data_paths, max_length=512):
11
+ self.tokenizer = tokenizer
12
+ self.max_length = max_length
13
+ self.data = []
14
+
15
+ # Load and combine data from all files
16
+ for path in data_paths:
17
+ with jsonlines.open(path) as reader:
18
+ self.data.extend(list(reader))
19
+
20
+ def __len__(self):
21
+ return len(self.data)
22
+
23
+ def __getitem__(self, idx):
24
+ example = self.data[idx]
25
+
26
+ # Format the input text
27
+ if "proof_steps" in example:
28
+ # For ProofNet-style data
29
+ text = f"Problem: {example['problem']}\nSolution: {example['solution']}\nProof Steps:\n"
30
+ for step in example["proof_steps"]:
31
+ text += f"- {step['text']}\n"
32
+ else:
33
+ # For GSM8K-style data
34
+ text = f"Question: {example['question']}\nAnswer: {example['answer']}"
35
+
36
+ # Tokenize
37
+ inputs = self.tokenizer(
38
+ text,
39
+ padding="max_length",
40
+ truncation=True,
41
+ max_length=self.max_length,
42
+ return_tensors="pt"
43
+ )
44
+
45
+ # Remove batch dimension
46
+ inputs = {k: v.squeeze(0) for k, v in inputs.items()}
47
+
48
+ return {
49
+ "input_ids": inputs["input_ids"],
50
+ "attention_mask": inputs["attention_mask"],
51
+ "labels": inputs["input_ids"] # For causal LM training
52
+ }
53
+
54
+ def main():
55
+ # Initialize your custom model
56
+ model_args = ModelArgs(
57
+ dim=512,
58
+ n_layers=8,
59
+ n_heads=8,
60
+ vocab_size=50000, # Adjust based on your tokenizer
61
+ max_seq_len=1024
62
+ )
63
+ model = Transformer(model_args)
64
+
65
+ # Initialize your custom tokenizer
66
+ tokenizer = Tokenizer()
67
+
68
+ # Configure tokenizer
69
+ if tokenizer.pad_token is None:
70
+ tokenizer.pad_token = tokenizer.eos_token
71
+
72
+ # Set up training data paths
73
+ data_dir = os.path.join(os.path.dirname(__file__), "processed_data")
74
+ data_paths = [
75
+ os.path.join(data_dir, "gsm8k_processed.jsonl"),
76
+ os.path.join(data_dir, "proofnet_processed.jsonl")
77
+ ]
78
+
79
+ # Create dataset
80
+ dataset = MathDataset(
81
+ tokenizer=tokenizer,
82
+ data_paths=data_paths,
83
+ max_length=1024 # Increased max_length for longer proofs
84
+ )
85
+
86
+ # Define training arguments
87
+ training_args = TrainingArguments(
88
+ output_dir="./math_expert_output",
89
+ overwrite_output_dir=True,
90
+ num_train_epochs=3,
91
+ per_device_train_batch_size=2,
92
+ gradient_accumulation_steps=4,
93
+ save_steps=1000,
94
+ save_total_limit=2,
95
+ logging_dir="./math_expert_logs",
96
+ logging_steps=100,
97
+ evaluation_strategy="steps",
98
+ eval_steps=1000,
99
+ load_best_model_at_end=True,
100
+ learning_rate=5e-5,
101
+ warmup_steps=500,
102
+ weight_decay=0.01,
103
+ fp16=True if torch.cuda.is_available() else False
104
+ )
105
+
106
+ # Create trainer
107
+ trainer = Trainer(
108
+ model=model,
109
+ args=training_args,
110
+ train_dataset=dataset,
111
+ tokenizer=tokenizer,
112
+ )
113
+
114
+ # Start training
115
+ print("Starting training with your custom model...")
116
+ trainer.train()
117
+
118
+ # Save the model
119
+ output_dir = "./math_expert_model"
120
+ os.makedirs(output_dir, exist_ok=True)
121
+ torch.save(model.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
122
+ model_args.save(os.path.join(output_dir, "config.json"))
123
+ print(f"Model saved to {output_dir}")
124
+
125
+ if __name__ == "__main__":
126
+ main()
math_expert/validation.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Validation module for the Math Expert model
3
+ """
4
+ import os
5
+ import json
6
+ from pathlib import Path
7
+ import hashlib
8
+ import datetime
9
+ from typing import Dict, Any, List, Optional
10
+ import numpy as np
11
+ from sympy import simplify, Eq
12
+
13
+ class MathValidator:
14
+ def __init__(self, checkpoint_dir: str = "checkpoints"):
15
+ self.checkpoint_dir = Path(checkpoint_dir)
16
+ self.checkpoint_dir.mkdir(exist_ok=True)
17
+ self.validation_dir = self.checkpoint_dir / "validation"
18
+ self.validation_dir.mkdir(exist_ok=True)
19
+
20
+ # Initialize validation metrics
21
+ self.metrics = {
22
+ "accuracy": [],
23
+ "equation_simplification": [],
24
+ "proof_validation": [],
25
+ "memory_usage": []
26
+ }
27
+
28
+ def validate_equation(self, equation: str, expected_result: str) -> Dict[str, Any]:
29
+ """Validate mathematical equation correctness"""
30
+ try:
31
+ # Try to simplify both sides
32
+ lhs = simplify(equation)
33
+ rhs = simplify(expected_result)
34
+
35
+ # Check if simplified forms are equal
36
+ is_correct = lhs == rhs
37
+
38
+ return {
39
+ "is_correct": is_correct,
40
+ "simplified_lhs": str(lhs),
41
+ "simplified_rhs": str(rhs),
42
+ "validation_score": float(is_correct)
43
+ }
44
+ except Exception as e:
45
+ return {
46
+ "is_correct": False,
47
+ "error": str(e),
48
+ "validation_score": 0.0
49
+ }
50
+
51
+ def validate_proof(self, proof_steps: List[str], expected_theorem: str) -> Dict[str, Any]:
52
+ """Validate mathematical proof steps"""
53
+ try:
54
+ # Check if each step logically follows from previous steps
55
+ current_context = set()
56
+ validation_score = 1.0
57
+
58
+ for step in proof_steps:
59
+ # Try to parse the step as an equation
60
+ try:
61
+ lhs, rhs = step.split('=')
62
+ if not Eq(simplify(lhs), simplify(rhs)):
63
+ validation_score *= 0.9 # Penalize incorrect steps
64
+ except:
65
+ pass # Not all steps are equations
66
+
67
+ # Update context
68
+ current_context.add(step)
69
+
70
+ # Check if final step matches expected theorem
71
+ final_step = proof_steps[-1]
72
+ matches_theorem = expected_theorem in final_step
73
+
74
+ return {
75
+ "is_valid": validation_score > 0.5,
76
+ "validation_score": validation_score,
77
+ "matches_theorem": matches_theorem,
78
+ "context_size": len(current_context)
79
+ }
80
+ except Exception as e:
81
+ return {
82
+ "is_valid": False,
83
+ "error": str(e),
84
+ "validation_score": 0.0
85
+ }
86
+
87
+ def create_checkpoint(self, data: Dict[str, Any], name: str = None) -> str:
88
+ """Create a checkpoint of validation data"""
89
+ if name is None:
90
+ name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
91
+
92
+ checkpoint_path = self.validation_dir / f"checkpoint_{name}.json"
93
+
94
+ # Add timestamp and hash
95
+ data["timestamp"] = str(datetime.datetime.now())
96
+ data["hash"] = hashlib.sha256(str(data).encode()).hexdigest()
97
+
98
+ with open(checkpoint_path, 'w') as f:
99
+ json.dump(data, f, indent=2)
100
+
101
+ return str(checkpoint_path)
102
+
103
+ def load_checkpoint(self, name: str) -> Optional[Dict[str, Any]]:
104
+ """Load a validation checkpoint"""
105
+ checkpoint_path = self.validation_dir / f"checkpoint_{name}.json"
106
+ if not checkpoint_path.exists():
107
+ return None
108
+
109
+ with open(checkpoint_path, 'r') as f:
110
+ return json.load(f)
111
+
112
+ def validate_dataset(self, dataset: List[Dict[str, Any]]) -> Dict[str, Any]:
113
+ """Validate a complete dataset"""
114
+ results = []
115
+ error_count = 0
116
+
117
+ for idx, example in enumerate(dataset):
118
+ try:
119
+ # Validate equations
120
+ if "equation" in example:
121
+ eq_result = self.validate_equation(
122
+ example["equation"],
123
+ example.get("expected_result", "")
124
+ )
125
+ results.append(eq_result)
126
+
127
+ # Validate proofs
128
+ if "proof_steps" in example:
129
+ proof_result = self.validate_proof(
130
+ example["proof_steps"],
131
+ example.get("theorem", "")
132
+ )
133
+ results.append(proof_result)
134
+ except Exception as e:
135
+ error_count += 1
136
+ results.append({
137
+ "error": str(e),
138
+ "validation_score": 0.0
139
+ })
140
+
141
+ # Calculate overall metrics
142
+ scores = [r["validation_score"] for r in results if "validation_score" in r]
143
+ if scores:
144
+ avg_score = np.mean(scores)
145
+ else:
146
+ avg_score = 0.0
147
+
148
+ return {
149
+ "total_examples": len(dataset),
150
+ "processed_examples": len(results),
151
+ "error_count": error_count,
152
+ "average_score": float(avg_score),
153
+ "detailed_results": results
154
+ }
155
+
156
+ def save_validation_report(self, report: Dict[str, Any], name: str = None) -> str:
157
+ """Save a validation report"""
158
+ if name is None:
159
+ name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
160
+
161
+ report_path = self.validation_dir / f"report_{name}.json"
162
+
163
+ # Add timestamp and summary metrics
164
+ report["timestamp"] = str(datetime.datetime.now())
165
+ report["summary"] = {
166
+ "accuracy": report.get("average_score", 0.0),
167
+ "error_rate": report.get("error_count", 0) / report.get("total_examples", 1)
168
+ }
169
+
170
+ with open(report_path, 'w') as f:
171
+ json.dump(report, f, indent=2)
172
+
173
+ return str(report_path)