Upload 11 files
Browse files- math_expert/__pycache__/config.cpython-312.pyc +0 -0
- math_expert/__pycache__/expert.cpython-312.pyc +0 -0
- math_expert/config.py +381 -0
- math_expert/data_processor.py +282 -0
- math_expert/expert.py +19 -0
- math_expert/prepare_data.py +245 -0
- math_expert/processed_data/gsm8k_processed.jsonl +0 -0
- math_expert/processed_data/proofnet_processed.jsonl +185 -0
- math_expert/requirements.txt +11 -0
- math_expert/train.py +126 -0
- math_expert/validation.py +173 -0
math_expert/__pycache__/config.cpython-312.pyc
ADDED
Binary file (9.17 kB). View file
|
|
math_expert/__pycache__/expert.cpython-312.pyc
ADDED
Binary file (1.24 kB). View file
|
|
math_expert/config.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Math Expert Configuration
|
2 |
+
|
3 |
+
# 1.1. Mathematical Domains and Specializations
|
4 |
+
MATH_DOMAINS = {
|
5 |
+
"algebra": {
|
6 |
+
"level": "expert",
|
7 |
+
"topics": [
|
8 |
+
"linear algebra",
|
9 |
+
"abstract algebra",
|
10 |
+
"polynomial equations",
|
11 |
+
"matrix operations",
|
12 |
+
"group theory",
|
13 |
+
"ring theory",
|
14 |
+
"field theory",
|
15 |
+
"representation theory",
|
16 |
+
"homological algebra",
|
17 |
+
"category theory",
|
18 |
+
"universal algebra",
|
19 |
+
"non-associative algebras",
|
20 |
+
"Lie algebras",
|
21 |
+
"quantum groups",
|
22 |
+
"Hopf algebras",
|
23 |
+
"K-theory"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
"calculus": {
|
27 |
+
"level": "expert",
|
28 |
+
"topics": [
|
29 |
+
"single variable calculus",
|
30 |
+
"multivariable calculus",
|
31 |
+
"differential equations",
|
32 |
+
"partial differential equations",
|
33 |
+
"vector calculus",
|
34 |
+
"complex analysis",
|
35 |
+
"functional analysis",
|
36 |
+
"measure theory",
|
37 |
+
"differential geometry",
|
38 |
+
"geometric measure theory",
|
39 |
+
"non-standard analysis",
|
40 |
+
"stochastic calculus",
|
41 |
+
"calculus of variations",
|
42 |
+
"symplectic geometry"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
"proof_writing": {
|
46 |
+
"level": "expert",
|
47 |
+
"topics": [
|
48 |
+
"induction",
|
49 |
+
"contradiction",
|
50 |
+
"direct proof",
|
51 |
+
"proof by cases",
|
52 |
+
"epsilon-delta proofs",
|
53 |
+
"existence proofs",
|
54 |
+
"uniqueness proofs",
|
55 |
+
"category theory proofs",
|
56 |
+
"homotopy type theory",
|
57 |
+
"model theory",
|
58 |
+
"proof theory",
|
59 |
+
"set theory",
|
60 |
+
"constructive mathematics",
|
61 |
+
"proof complexity"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
"probability": {
|
65 |
+
"level": "expert",
|
66 |
+
"topics": [
|
67 |
+
"probability theory",
|
68 |
+
"random variables",
|
69 |
+
"distributions",
|
70 |
+
"stochastic processes",
|
71 |
+
"Bayesian inference",
|
72 |
+
"Markov chains",
|
73 |
+
"measure-theoretic probability",
|
74 |
+
"stochastic calculus",
|
75 |
+
"martingales",
|
76 |
+
"large deviations",
|
77 |
+
"ergodic theory",
|
78 |
+
"random matrix theory",
|
79 |
+
"stochastic PDEs"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
"statistics": {
|
83 |
+
"level": "expert",
|
84 |
+
"topics": [
|
85 |
+
"descriptive statistics",
|
86 |
+
"inferential statistics",
|
87 |
+
"hypothesis testing",
|
88 |
+
"regression analysis",
|
89 |
+
"time series analysis",
|
90 |
+
"bayesian statistics",
|
91 |
+
"non-parametric methods",
|
92 |
+
"statistical learning theory",
|
93 |
+
"high-dimensional statistics",
|
94 |
+
"causal inference",
|
95 |
+
"spatial statistics",
|
96 |
+
"robust statistics",
|
97 |
+
"computational statistics"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
"number_theory": {
|
101 |
+
"level": "expert",
|
102 |
+
"topics": [
|
103 |
+
"prime numbers",
|
104 |
+
"modular arithmetic",
|
105 |
+
"diophantine equations",
|
106 |
+
"cryptography",
|
107 |
+
"analytic number theory",
|
108 |
+
"algebraic number theory",
|
109 |
+
"elliptic curves",
|
110 |
+
"automorphic forms",
|
111 |
+
"arithmetic geometry",
|
112 |
+
"p-adic analysis",
|
113 |
+
"analytic continuation",
|
114 |
+
"modular forms",
|
115 |
+
"zeta functions"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
"geometry": {
|
119 |
+
"level": "expert",
|
120 |
+
"topics": [
|
121 |
+
"euclidean geometry",
|
122 |
+
"non-euclidean geometry",
|
123 |
+
"differential geometry",
|
124 |
+
"topology",
|
125 |
+
"algebraic geometry",
|
126 |
+
"projective geometry",
|
127 |
+
"symplectic geometry",
|
128 |
+
"algebraic topology",
|
129 |
+
"geometric analysis",
|
130 |
+
"geometric group theory",
|
131 |
+
"Riemannian geometry",
|
132 |
+
"Kähler geometry",
|
133 |
+
"hyperbolic geometry"
|
134 |
+
]
|
135 |
+
},
|
136 |
+
"combinatorics": {
|
137 |
+
"level": "expert",
|
138 |
+
"topics": [
|
139 |
+
"graph theory",
|
140 |
+
"enumerative combinatorics",
|
141 |
+
"combinatorial optimization",
|
142 |
+
"matroid theory",
|
143 |
+
"combinatorial designs",
|
144 |
+
"extremal combinatorics",
|
145 |
+
"probabilistic combinatorics",
|
146 |
+
"algebraic combinatorics",
|
147 |
+
"topological combinatorics",
|
148 |
+
"combinatorial geometry",
|
149 |
+
"Ramsey theory"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
"logic": {
|
153 |
+
"level": "expert",
|
154 |
+
"topics": [
|
155 |
+
"first-order logic",
|
156 |
+
"model theory",
|
157 |
+
"proof theory",
|
158 |
+
"set theory",
|
159 |
+
"computability theory",
|
160 |
+
"type theory",
|
161 |
+
"category theory",
|
162 |
+
"modal logic",
|
163 |
+
"temporal logic",
|
164 |
+
"constructive logic",
|
165 |
+
"intuitionistic logic",
|
166 |
+
"proof complexity"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
"theoretical_cs": {
|
170 |
+
"level": "expert",
|
171 |
+
"topics": [
|
172 |
+
"computational complexity",
|
173 |
+
"algorithms",
|
174 |
+
"cryptography",
|
175 |
+
"quantum computing",
|
176 |
+
"machine learning theory",
|
177 |
+
"formal verification",
|
178 |
+
"type systems",
|
179 |
+
"programming language theory",
|
180 |
+
"distributed computing",
|
181 |
+
"parallel algorithms",
|
182 |
+
"computational geometry",
|
183 |
+
"randomized algorithms"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
"applied_math": {
|
187 |
+
"level": "expert",
|
188 |
+
"topics": [
|
189 |
+
"numerical analysis",
|
190 |
+
"optimization",
|
191 |
+
"control theory",
|
192 |
+
"mathematical physics",
|
193 |
+
"fluid dynamics",
|
194 |
+
"quantum mechanics",
|
195 |
+
"relativity",
|
196 |
+
"mathematical biology",
|
197 |
+
"financial mathematics",
|
198 |
+
"signal processing",
|
199 |
+
"data assimilation",
|
200 |
+
"inverse problems"
|
201 |
+
]
|
202 |
+
}
|
203 |
+
}
|
204 |
+
|
205 |
+
# 1.2. Core Tasks
|
206 |
+
CORE_TASKS = [
|
207 |
+
{
|
208 |
+
"task_type": "problem_solving",
|
209 |
+
"description": "Solve complex mathematical problems",
|
210 |
+
"example": "Prove the Riemann Hypothesis",
|
211 |
+
"difficulty_levels": ["basic", "intermediate", "advanced", "research_level", "open_problem"]
|
212 |
+
},
|
213 |
+
{
|
214 |
+
"task_type": "proof_writing",
|
215 |
+
"description": "Prove mathematical statements with advanced techniques",
|
216 |
+
"example": "Prove Fermat's Last Theorem using elliptic curves",
|
217 |
+
"proof_types": ["induction", "contradiction", "direct", "cases", "category_theory", "homotopy_type", "model_theory", "proof_complexity", "constructive"]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"task_type": "calculus_computation",
|
221 |
+
"description": "Perform advanced calculus operations",
|
222 |
+
"example": "Solve Navier-Stokes equations for turbulence",
|
223 |
+
"operation_types": ["differentiation", "integration", "limits", "functional_analysis", "measure_theory", "stochastic_calculus", "geometric_measure_theory"]
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"task_type": "symbolic_computation",
|
227 |
+
"description": "Manipulate complex mathematical expressions",
|
228 |
+
"example": "Simplify tensor equations in general relativity",
|
229 |
+
"expression_types": ["polynomial", "rational", "trigonometric", "exponential", "tensor", "operator", "Lie_algebra", "Hopf_algebra"]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"task_type": "concept_explanation",
|
233 |
+
"description": "Explain advanced mathematical concepts",
|
234 |
+
"example": "Explain the Langlands program",
|
235 |
+
"explanation_types": ["definition", "intuition", "application", "example", "formal", "geometric", "historical", "pedagogical"]
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"task_type": "statistical_analysis",
|
239 |
+
"description": "Perform advanced statistical analysis",
|
240 |
+
"example": "Analyze high-dimensional genomic data",
|
241 |
+
"statistical_methods": ["regression", "hypothesis_testing", "confidence_intervals", "bayesian_methods", "non_parametric", "causal_inference", "computational_methods"]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"task_type": "probability_calculation",
|
245 |
+
"description": "Calculate complex probabilities",
|
246 |
+
"example": "Calculate phase transitions in random matrix theory",
|
247 |
+
"distributions": ["binomial", "normal", "poisson", "exponential", "multivariate", "stochastic_processes", "random_matrix", "levy_processes"]
|
248 |
+
},
|
249 |
+
{
|
250 |
+
"task_type": "number_theory_problem",
|
251 |
+
"description": "Solve advanced number theory problems",
|
252 |
+
"example": "Prove the Birch and Swinnerton-Dyer conjecture",
|
253 |
+
"problem_types": ["prime", "modular", "diophantine", "analytic", "algebraic", "elliptic_curve", "modular_form"]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"task_type": "geometric_construction",
|
257 |
+
"description": "Construct and analyze complex geometric objects",
|
258 |
+
"example": "Construct a Calabi-Yau manifold",
|
259 |
+
"construction_types": ["euclidean", "non_euclidean", "projective", "differential", "algebraic", "symplectic", "topological"]
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"task_type": "mathematical_modeling",
|
263 |
+
"description": "Create advanced mathematical models",
|
264 |
+
"example": "Model quantum field theory",
|
265 |
+
"model_types": ["continuous", "discrete", "stochastic", "partial_differential", "non_linear", "quantum", "statistical"]
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"task_type": "proof_verification",
|
269 |
+
"description": "Verify complex mathematical proofs",
|
270 |
+
"example": "Verify the proof of the Four Color Theorem",
|
271 |
+
"verification_methods": ["formal_verification", "model_checking", "proof_assistant", "automated_reasoning", "interactive_theorem_proving"]
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"task_type": "algorithm_design",
|
275 |
+
"description": "Design and analyze mathematical algorithms",
|
276 |
+
"example": "Design a quantum algorithm for factorization",
|
277 |
+
"algorithm_types": ["numerical", "combinatorial", "geometric", "algebraic", "probabilistic", "quantum", "parallel"]
|
278 |
+
},
|
279 |
+
{
|
280 |
+
"task_type": "research_paper_analysis",
|
281 |
+
"description": "Analyze and explain mathematical research papers",
|
282 |
+
"example": "Explain Wiles' proof of Fermat's Last Theorem",
|
283 |
+
"analysis_types": ["technical", "historical", "pedagogical", "critical", "extensional"]
|
284 |
+
},
|
285 |
+
{
|
286 |
+
"task_type": "open_problem_analysis",
|
287 |
+
"description": "Analyze and make progress on open mathematical problems",
|
288 |
+
"example": "Analyze the Collatz conjecture",
|
289 |
+
"problem_classes": ["number_theory", "combinatorics", "analysis", "algebra", "geometry", "probability"]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"task_type": "mathematical_philosophy",
|
293 |
+
"description": "Analyze philosophical aspects of mathematics",
|
294 |
+
"example": "Explain the foundations of mathematics",
|
295 |
+
"philosophical_topics": ["foundations", "philosophy_of_math", "logic", "set_theory", "constructivism", "intuitionism"]
|
296 |
+
},
|
297 |
+
{
|
298 |
+
"task_type": "mathematical_software_development",
|
299 |
+
"description": "Develop mathematical software and algorithms",
|
300 |
+
"example": "Implement a new numerical method",
|
301 |
+
"software_types": ["numerical", "symbolic", "proof_assistant", "visualization", "simulation", "optimization"]
|
302 |
+
}
|
303 |
+
]
|
304 |
+
|
305 |
+
# Dataset Configuration
|
306 |
+
DATASETS = {
|
307 |
+
"proofnet": {
|
308 |
+
"source": "huggingface",
|
309 |
+
"dataset_name": "proofnet",
|
310 |
+
"split": "train",
|
311 |
+
"use_fields": ["problem", "solution", "proof_steps"]
|
312 |
+
},
|
313 |
+
"math_dataset": {
|
314 |
+
"source": "huggingface",
|
315 |
+
"dataset_name": "deepmind/mathematics_dataset",
|
316 |
+
"split": "train-hard",
|
317 |
+
"use_fields": ["question", "answer", "steps"]
|
318 |
+
},
|
319 |
+
"gsm8k": {
|
320 |
+
"source": "huggingface",
|
321 |
+
"dataset_name": "gsm8k",
|
322 |
+
"split": "train",
|
323 |
+
"use_fields": ["question", "answer"]
|
324 |
+
},
|
325 |
+
"mathlib": {
|
326 |
+
"source": "huggingface",
|
327 |
+
"dataset_name": "mathlib",
|
328 |
+
"split": "train",
|
329 |
+
"use_fields": ["theorem", "proof", "dependencies"]
|
330 |
+
},
|
331 |
+
"arxiv_math": {
|
332 |
+
"source": "huggingface",
|
333 |
+
"dataset_name": "arxiv_math",
|
334 |
+
"split": "train",
|
335 |
+
"use_fields": ["paper", "equations", "proofs"]
|
336 |
+
},
|
337 |
+
"clay_institute": {
|
338 |
+
"source": "huggingface",
|
339 |
+
"dataset_name": "clay_institute_problems",
|
340 |
+
"split": "train",
|
341 |
+
"use_fields": ["problem", "background", "current_status", "approaches"]
|
342 |
+
},
|
343 |
+
"open_problems": {
|
344 |
+
"source": "huggingface",
|
345 |
+
"dataset_name": "open_math_problems",
|
346 |
+
"split": "train",
|
347 |
+
"use_fields": ["problem", "category", "history", "attempts"]
|
348 |
+
},
|
349 |
+
"research_papers": {
|
350 |
+
"source": "huggingface",
|
351 |
+
"dataset_name": "math_research_papers",
|
352 |
+
"split": "train",
|
353 |
+
"use_fields": ["title", "abstract", "content", "proofs", "theorems"]
|
354 |
+
}
|
355 |
+
}
|
356 |
+
|
357 |
+
# Data Processing Configuration
|
358 |
+
DATA_PROCESSING = {
|
359 |
+
"format": "jsonl",
|
360 |
+
"normalization": {
|
361 |
+
"equations": "sympy",
|
362 |
+
"latex": "plaintext",
|
363 |
+
"proof_steps": "yaml",
|
364 |
+
"tensor_operations": "torch",
|
365 |
+
"quantum_operations": "qiskit",
|
366 |
+
"geometric_objects": "geometric_algebra",
|
367 |
+
"category_theory": "category_theory"
|
368 |
+
},
|
369 |
+
"validation": {
|
370 |
+
"min_steps": 2,
|
371 |
+
"max_steps": 200,
|
372 |
+
"min_length": 10,
|
373 |
+
"max_length": 100000
|
374 |
+
}
|
375 |
+
}
|
376 |
+
|
377 |
+
if __name__ == "__main__":
|
378 |
+
print("Math Expert Configuration Loaded")
|
379 |
+
print(f"Number of domains: {len(MATH_DOMAINS)}")
|
380 |
+
print(f"Number of tasks: {len(CORE_TASKS)}")
|
381 |
+
print(f"Number of datasets: {len(DATASETS)}")
|
math_expert/data_processor.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import yaml
|
3 |
+
import sympy
|
4 |
+
from sympy.parsing.latex import parse_latex
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
from pathlib import Path
|
7 |
+
import jsonlines
|
8 |
+
from typing import Dict, List, Any
|
9 |
+
|
10 |
+
from config import DATASETS, DATA_PROCESSING
|
11 |
+
|
12 |
+
class MathDataProcessor:
|
13 |
+
def __init__(self):
|
14 |
+
self.processed_data = []
|
15 |
+
self.dataset_paths = {}
|
16 |
+
self.math_operations = {
|
17 |
+
"differentiation": self._process_differentiation,
|
18 |
+
"integration": self._process_integration,
|
19 |
+
"limits": self._process_limits,
|
20 |
+
"simplification": self._process_simplification,
|
21 |
+
"matrix": self._process_matrix,
|
22 |
+
"probability": self._process_probability,
|
23 |
+
"statistics": self._process_statistics
|
24 |
+
}
|
25 |
+
|
26 |
+
def download_dataset(self, dataset_name: str) -> Path:
|
27 |
+
"""Download dataset from Hugging Face"""
|
28 |
+
if dataset_name not in DATASETS:
|
29 |
+
raise ValueError(f"Dataset {dataset_name} not defined in configuration")
|
30 |
+
|
31 |
+
dataset_config = DATASETS[dataset_name]
|
32 |
+
dataset_path = Path(f"data/{dataset_name}")
|
33 |
+
|
34 |
+
# Download from Hugging Face
|
35 |
+
hf_hub_download(
|
36 |
+
repo_id=dataset_config["dataset_name"],
|
37 |
+
filename=f"{dataset_config['split']}.jsonl",
|
38 |
+
local_dir=dataset_path
|
39 |
+
)
|
40 |
+
|
41 |
+
self.dataset_paths[dataset_name] = dataset_path
|
42 |
+
return dataset_path
|
43 |
+
|
44 |
+
def normalize_equation(self, equation: str) -> str:
|
45 |
+
"""Normalize mathematical equations using sympy"""
|
46 |
+
try:
|
47 |
+
# Try to parse LaTeX first
|
48 |
+
if "\\" in equation:
|
49 |
+
eq = parse_latex(equation)
|
50 |
+
else:
|
51 |
+
eq = sympy.sympify(equation)
|
52 |
+
return str(eq)
|
53 |
+
except:
|
54 |
+
return equation
|
55 |
+
|
56 |
+
def process_proof_steps(self, steps: List[str]) -> List[Dict[str, str]]:
|
57 |
+
"""Process proof steps into structured format"""
|
58 |
+
processed_steps = []
|
59 |
+
|
60 |
+
for step in steps:
|
61 |
+
try:
|
62 |
+
# Try to parse as YAML if it contains structured data
|
63 |
+
structured_step = yaml.safe_load(step)
|
64 |
+
if isinstance(structured_step, dict):
|
65 |
+
processed_steps.append(structured_step)
|
66 |
+
else:
|
67 |
+
processed_steps.append({"step": step})
|
68 |
+
except:
|
69 |
+
processed_steps.append({"step": step})
|
70 |
+
|
71 |
+
return processed_steps
|
72 |
+
|
73 |
+
def _process_differentiation(self, expression: str) -> str:
|
74 |
+
"""Process and validate differentiation operations"""
|
75 |
+
x = sympy.Symbol('x')
|
76 |
+
try:
|
77 |
+
expr = sympy.sympify(expression)
|
78 |
+
derivative = sympy.diff(expr, x)
|
79 |
+
return str(derivative)
|
80 |
+
except:
|
81 |
+
return expression
|
82 |
+
|
83 |
+
def _process_integration(self, expression: str) -> str:
|
84 |
+
"""Process and validate integration operations"""
|
85 |
+
x = sympy.Symbol('x')
|
86 |
+
try:
|
87 |
+
expr = sympy.sympify(expression)
|
88 |
+
integral = sympy.integrate(expr, x)
|
89 |
+
return str(integral)
|
90 |
+
except:
|
91 |
+
return expression
|
92 |
+
|
93 |
+
def _process_limits(self, expression: str) -> str:
|
94 |
+
"""Process and validate limit operations"""
|
95 |
+
x = sympy.Symbol('x')
|
96 |
+
try:
|
97 |
+
expr = sympy.sympify(expression)
|
98 |
+
limit = sympy.limit(expr, x, sympy.oo)
|
99 |
+
return str(limit)
|
100 |
+
except:
|
101 |
+
return expression
|
102 |
+
|
103 |
+
def _process_simplification(self, expression: str) -> str:
|
104 |
+
"""Process and validate expression simplification"""
|
105 |
+
try:
|
106 |
+
expr = sympy.sympify(expression)
|
107 |
+
simplified = sympy.simplify(expr)
|
108 |
+
return str(simplified)
|
109 |
+
except:
|
110 |
+
return expression
|
111 |
+
|
112 |
+
def _process_matrix(self, matrix_str: str) -> str:
|
113 |
+
"""Process and validate matrix operations"""
|
114 |
+
try:
|
115 |
+
matrix = sympy.Matrix([[float(n) for n in row.split()]
|
116 |
+
for row in matrix_str.split(';')])
|
117 |
+
return str(matrix)
|
118 |
+
except:
|
119 |
+
return matrix_str
|
120 |
+
|
121 |
+
def _process_probability(self, problem: str) -> Dict:
|
122 |
+
"""Process probability problems and extract key parameters"""
|
123 |
+
try:
|
124 |
+
# Basic parsing for probability problems
|
125 |
+
if "probability" in problem.lower():
|
126 |
+
return {
|
127 |
+
"type": "probability",
|
128 |
+
"parameters": self._extract_parameters(problem),
|
129 |
+
"distribution": self._identify_distribution(problem)
|
130 |
+
}
|
131 |
+
return {"type": "unknown"}
|
132 |
+
except:
|
133 |
+
return {"type": "unknown"}
|
134 |
+
|
135 |
+
def _process_statistics(self, data: str) -> Dict:
|
136 |
+
"""Process statistical data and extract key metrics"""
|
137 |
+
try:
|
138 |
+
# Basic statistical processing
|
139 |
+
if "," in data:
|
140 |
+
numbers = [float(n) for n in data.split(',')]
|
141 |
+
return {
|
142 |
+
"mean": sum(numbers) / len(numbers),
|
143 |
+
"median": sorted(numbers)[len(numbers)//2],
|
144 |
+
"std_dev": self._calculate_std_dev(numbers)
|
145 |
+
}
|
146 |
+
return {"error": "Invalid data format"}
|
147 |
+
except:
|
148 |
+
return {"error": "Processing failed"}
|
149 |
+
|
150 |
+
def _extract_parameters(self, text: str) -> Dict:
|
151 |
+
"""Extract parameters from mathematical text"""
|
152 |
+
parameters = {}
|
153 |
+
# Basic parameter extraction logic
|
154 |
+
if "=" in text:
|
155 |
+
parts = text.split("=")
|
156 |
+
parameters["equation"] = parts[0].strip()
|
157 |
+
parameters["value"] = parts[1].strip()
|
158 |
+
return parameters
|
159 |
+
|
160 |
+
def _identify_distribution(self, text: str) -> str:
|
161 |
+
"""Identify probability distribution from text"""
|
162 |
+
distributions = {
|
163 |
+
"binomial": ["binomial", "bernoulli"],
|
164 |
+
"normal": ["normal", "gaussian"],
|
165 |
+
"poisson": ["poisson"],
|
166 |
+
"exponential": ["exponential"]
|
167 |
+
}
|
168 |
+
|
169 |
+
text_lower = text.lower()
|
170 |
+
for dist, keywords in distributions.items():
|
171 |
+
if any(keyword in text_lower for keyword in keywords):
|
172 |
+
return dist
|
173 |
+
return "unknown"
|
174 |
+
|
175 |
+
def _calculate_std_dev(self, numbers: List[float]) -> float:
|
176 |
+
"""Calculate standard deviation"""
|
177 |
+
mean = sum(numbers) / len(numbers)
|
178 |
+
variance = sum((x - mean) ** 2 for x in numbers) / len(numbers)
|
179 |
+
return variance ** 0.5
|
180 |
+
|
181 |
+
def process_math_operation(self, operation_type: str, content: str) -> Any:
|
182 |
+
"""Process a specific mathematical operation"""
|
183 |
+
if operation_type in self.math_operations:
|
184 |
+
return self.math_operations[operation_type](content)
|
185 |
+
return content
|
186 |
+
|
187 |
+
def validate_entry(self, entry: Dict[str, Any]) -> bool:
|
188 |
+
"""Enhanced validation with mathematical checks"""
|
189 |
+
steps = entry.get("steps", [])
|
190 |
+
text = entry.get("question", "") + entry.get("answer", "")
|
191 |
+
|
192 |
+
# Basic validation
|
193 |
+
if len(steps) < DATA_PROCESSING["validation"]["min_steps"]:
|
194 |
+
return False
|
195 |
+
|
196 |
+
if len(text) < DATA_PROCESSING["validation"]["min_length"]:
|
197 |
+
return False
|
198 |
+
|
199 |
+
# Mathematical validation
|
200 |
+
try:
|
201 |
+
# Check if equations are valid
|
202 |
+
if "equation" in entry:
|
203 |
+
sympy.sympify(entry["equation"])
|
204 |
+
|
205 |
+
# Check if steps follow logical progression
|
206 |
+
if len(steps) > 1:
|
207 |
+
for i in range(len(steps) - 1):
|
208 |
+
if not self._check_step_continuity(steps[i], steps[i+1]):
|
209 |
+
return False
|
210 |
+
|
211 |
+
# Check for circular logic in proofs
|
212 |
+
if "proof" in entry:
|
213 |
+
if not self._check_proof_validity(entry["proof"]):
|
214 |
+
return False
|
215 |
+
|
216 |
+
return True
|
217 |
+
|
218 |
+
except:
|
219 |
+
return False
|
220 |
+
|
221 |
+
def _check_step_continuity(self, step1: str, step2: str) -> bool:
|
222 |
+
"""Check if mathematical steps are logically connected"""
|
223 |
+
try:
|
224 |
+
# Basic check for logical progression
|
225 |
+
if "=" in step1 and "=" in step2:
|
226 |
+
s1 = step1.split("=")[1].strip()
|
227 |
+
s2 = step2.split("=")[0].strip()
|
228 |
+
return s1 == s2
|
229 |
+
return True
|
230 |
+
except:
|
231 |
+
return False
|
232 |
+
|
233 |
+
def _check_proof_validity(self, proof: str) -> bool:
|
234 |
+
"""Check if a proof is logically valid"""
|
235 |
+
# Basic proof validation
|
236 |
+
if "assume" in proof.lower() and "therefore" not in proof.lower():
|
237 |
+
return False
|
238 |
+
|
239 |
+
if "contradiction" in proof.lower() and "false" not in proof.lower():
|
240 |
+
return False
|
241 |
+
|
242 |
+
return True
|
243 |
+
|
244 |
+
def process_dataset(self, dataset_name: str):
|
245 |
+
"""Process a specific dataset according to its configuration"""
|
246 |
+
dataset_path = self.download_dataset(dataset_name)
|
247 |
+
dataset_config = DATASETS[dataset_name]
|
248 |
+
|
249 |
+
with jsonlines.open(dataset_path / f"{dataset_config['split']}.jsonl") as reader:
|
250 |
+
for entry in reader:
|
251 |
+
processed_entry = {}
|
252 |
+
|
253 |
+
# Process each field
|
254 |
+
for field in dataset_config["use_fields"]:
|
255 |
+
value = entry.get(field)
|
256 |
+
if value:
|
257 |
+
if field == "equation":
|
258 |
+
processed_entry[field] = self.normalize_equation(value)
|
259 |
+
elif field == "proof_steps":
|
260 |
+
processed_entry[field] = self.process_proof_steps(value)
|
261 |
+
else:
|
262 |
+
processed_entry[field] = value
|
263 |
+
|
264 |
+
# Validate and add if valid
|
265 |
+
if self.validate_entry(processed_entry):
|
266 |
+
self.processed_data.append(processed_entry)
|
267 |
+
|
268 |
+
def save_processed_data(self, output_path: str):
|
269 |
+
"""Save processed data to JSONL format"""
|
270 |
+
with jsonlines.open(output_path, mode='w') as writer:
|
271 |
+
writer.write_all(self.processed_data)
|
272 |
+
|
273 |
+
if __name__ == "__main__":
|
274 |
+
processor = MathDataProcessor()
|
275 |
+
|
276 |
+
# Process all defined datasets
|
277 |
+
for dataset in DATASETS.keys():
|
278 |
+
processor.process_dataset(dataset)
|
279 |
+
|
280 |
+
# Save processed data
|
281 |
+
output_path = "processed_data/math_expert_data.jsonl"
|
282 |
+
processor.save_processed_data(output_path)
|
math_expert/expert.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Math Expert Module
|
3 |
+
"""
|
4 |
+
from typing import Dict, Any, List
|
5 |
+
|
6 |
+
class MathExpert:
|
7 |
+
def __init__(self):
|
8 |
+
self.name = "math"
|
9 |
+
self.domains = ["mathematics", "calculus", "algebra"]
|
10 |
+
|
11 |
+
def handle_query(self, query: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
12 |
+
return {
|
13 |
+
'response': f"Math expert response to: {query}",
|
14 |
+
'confidence': 0.9,
|
15 |
+
'metadata': {'domains': self.domains}
|
16 |
+
}
|
17 |
+
|
18 |
+
def get_domains(self) -> List[str]:
|
19 |
+
return self.domains
|
math_expert/prepare_data.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
import sympy
|
5 |
+
from sympy.parsing.latex import parse_latex
|
6 |
+
from sympy.parsing.sympy_parser import parse_expr
|
7 |
+
from datasets import load_dataset
|
8 |
+
import jsonlines
|
9 |
+
from typing import Dict, List, Any
|
10 |
+
import sys
|
11 |
+
import psutil
|
12 |
+
|
13 |
+
class MathDataPreparer:
|
14 |
+
def __init__(self, output_dir: str = "processed_data"):
|
15 |
+
self.output_dir = Path(output_dir)
|
16 |
+
self.output_dir.mkdir(exist_ok=True)
|
17 |
+
self.datasets = {
|
18 |
+
"gsm8k": {
|
19 |
+
"source": "gsm8k",
|
20 |
+
"config": "main",
|
21 |
+
"split": "train",
|
22 |
+
"fields": ["question", "answer"]
|
23 |
+
},
|
24 |
+
"proofnet": {
|
25 |
+
"source": "hoskinson-center/proofnet",
|
26 |
+
"split": "validation",
|
27 |
+
"fields": ["problem", "solution", "proof_steps"]
|
28 |
+
}
|
29 |
+
}
|
30 |
+
|
31 |
+
def normalize_equation(self, equation: str) -> str:
|
32 |
+
"""Normalize mathematical equations using sympy"""
|
33 |
+
try:
|
34 |
+
# Try LaTeX first
|
35 |
+
if "\\" in equation:
|
36 |
+
eq = parse_latex(equation)
|
37 |
+
# Then try markdown math
|
38 |
+
elif equation.startswith('$') and equation.endswith('$'):
|
39 |
+
eq = parse_expr(equation[1:-1])
|
40 |
+
# Then try regular expression
|
41 |
+
else:
|
42 |
+
eq = parse_expr(equation)
|
43 |
+
return str(eq)
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Error normalizing equation: {equation}", file=sys.stderr)
|
46 |
+
return equation
|
47 |
+
|
48 |
+
def process_proof_steps(self, steps: List[str]) -> List[Dict[str, Any]]:
|
49 |
+
"""Process and validate proof steps"""
|
50 |
+
processed_steps = []
|
51 |
+
for step in steps:
|
52 |
+
try:
|
53 |
+
# Basic validation
|
54 |
+
if not step.strip():
|
55 |
+
continue
|
56 |
+
|
57 |
+
# Try to parse as structured data
|
58 |
+
try:
|
59 |
+
structured_step = json.loads(step)
|
60 |
+
if isinstance(structured_step, dict):
|
61 |
+
processed_steps.append(structured_step)
|
62 |
+
continue
|
63 |
+
except json.JSONDecodeError:
|
64 |
+
pass
|
65 |
+
|
66 |
+
# Process as plain text
|
67 |
+
processed_steps.append({
|
68 |
+
"text": step.strip(),
|
69 |
+
"valid": True
|
70 |
+
})
|
71 |
+
except Exception as e:
|
72 |
+
print(f"Error processing proof step: {step}", file=sys.stderr)
|
73 |
+
processed_steps.append({
|
74 |
+
"text": step,
|
75 |
+
"valid": False,
|
76 |
+
"error": str(e)
|
77 |
+
})
|
78 |
+
return processed_steps
|
79 |
+
|
80 |
+
def process_gsm8k(self, dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
81 |
+
"""Process GSM8K dataset"""
|
82 |
+
processed = []
|
83 |
+
for example in dataset:
|
84 |
+
try:
|
85 |
+
processed_example = {
|
86 |
+
"question": example["question"].strip(),
|
87 |
+
"answer": example["answer"].strip()
|
88 |
+
}
|
89 |
+
|
90 |
+
# Normalize equations in question
|
91 |
+
if "=" in processed_example["question"]:
|
92 |
+
processed_example["question"] = self.normalize_equation(processed_example["question"])
|
93 |
+
|
94 |
+
# Normalize equations in answer
|
95 |
+
if "=" in processed_example["answer"]:
|
96 |
+
processed_example["answer"] = self.normalize_equation(processed_example["answer"])
|
97 |
+
|
98 |
+
processed.append(processed_example)
|
99 |
+
except Exception as e:
|
100 |
+
print(f"Error processing GSM8K example: {str(e)}", file=sys.stderr)
|
101 |
+
return processed
|
102 |
+
|
103 |
+
def process_proofnet(self, dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
104 |
+
"""Process ProofNet dataset"""
|
105 |
+
processed = []
|
106 |
+
error_count = 0
|
107 |
+
|
108 |
+
# First, let's print some dataset info
|
109 |
+
print("\nProofNet dataset info:")
|
110 |
+
print(f"Dataset type: {type(dataset)}")
|
111 |
+
if hasattr(dataset, 'features'):
|
112 |
+
print("\nDataset features:")
|
113 |
+
for feature, dtype in dataset.features.items():
|
114 |
+
print(f"{feature}: {dtype}")
|
115 |
+
|
116 |
+
# Print first example structure
|
117 |
+
if len(dataset) > 0:
|
118 |
+
first_example = dataset[0]
|
119 |
+
print("\nFirst example keys:", list(first_example.keys()))
|
120 |
+
print("\nFirst example preview:")
|
121 |
+
for key, value in first_example.items():
|
122 |
+
print(f"\n{key}:")
|
123 |
+
print(f"Type: {type(value)}")
|
124 |
+
if isinstance(value, str):
|
125 |
+
print(f"Length: {len(value)}")
|
126 |
+
elif isinstance(value, list):
|
127 |
+
print(f"List length: {len(value)}")
|
128 |
+
if len(value) > 0:
|
129 |
+
print(f"First item type: {type(value[0])}")
|
130 |
+
print("\n")
|
131 |
+
|
132 |
+
for idx, example in enumerate(dataset):
|
133 |
+
try:
|
134 |
+
processed_example = {
|
135 |
+
"problem": example.get("problem", "").strip(),
|
136 |
+
"solution": example.get("solution", "").strip(),
|
137 |
+
"proof_steps": []
|
138 |
+
}
|
139 |
+
|
140 |
+
# Handle proof steps
|
141 |
+
if "proof_steps" in example:
|
142 |
+
steps = example["proof_steps"]
|
143 |
+
print(f"\nExample {idx} proof steps info:")
|
144 |
+
print(f"Type: {type(steps)}")
|
145 |
+
if isinstance(steps, str):
|
146 |
+
print(f"Length: {len(steps)}")
|
147 |
+
# Try to split string into steps
|
148 |
+
steps = steps.split('\n')
|
149 |
+
print(f"Split into {len(steps)} steps")
|
150 |
+
elif isinstance(steps, list):
|
151 |
+
print(f"List length: {len(steps)}")
|
152 |
+
if len(steps) > 0:
|
153 |
+
print(f"First item type: {type(steps[0])}")
|
154 |
+
else:
|
155 |
+
print(f"Warning: Unexpected proof steps type: {type(steps)}")
|
156 |
+
steps = []
|
157 |
+
|
158 |
+
processed_example["proof_steps"] = self.process_proof_steps(steps)
|
159 |
+
|
160 |
+
# Normalize equations
|
161 |
+
for field in ["problem", "solution"]:
|
162 |
+
if "=" in processed_example[field]:
|
163 |
+
try:
|
164 |
+
processed_example[field] = self.normalize_equation(processed_example[field])
|
165 |
+
except Exception as e:
|
166 |
+
print(f"Error normalizing {field} in ProofNet example {idx}: {str(e)}")
|
167 |
+
|
168 |
+
processed.append(processed_example)
|
169 |
+
except Exception as e:
|
170 |
+
print(f"Error processing ProofNet example {idx}: {str(e)}")
|
171 |
+
error_count += 1
|
172 |
+
|
173 |
+
print(f"\nProcessed {len(processed)} examples from ProofNet")
|
174 |
+
print(f"Encountered {error_count} errors during processing")
|
175 |
+
return processed
|
176 |
+
|
177 |
+
def save_to_jsonl(self, data: List[Dict[str, Any]], filename: str):
|
178 |
+
"""Save processed data to JSONL file"""
|
179 |
+
filepath = self.output_dir / filename
|
180 |
+
with jsonlines.open(filepath, mode='w') as writer:
|
181 |
+
writer.write_all(data)
|
182 |
+
return filepath
|
183 |
+
|
184 |
+
def print_memory_usage(self):
|
185 |
+
"""Print current memory usage"""
|
186 |
+
process = psutil.Process()
|
187 |
+
memory_info = process.memory_info()
|
188 |
+
print(f"Current memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
|
189 |
+
|
190 |
+
def print_sample(self, data: List[Dict[str, Any]], count: int = 3):
|
191 |
+
"""Print sample of processed data"""
|
192 |
+
print("\nSample data:")
|
193 |
+
for i, example in enumerate(data[:count]):
|
194 |
+
print(f"\nSample {i+1}:")
|
195 |
+
if "proof_steps" in example:
|
196 |
+
# For ProofNet samples, show proof steps
|
197 |
+
print(f"Problem: {example['problem']}")
|
198 |
+
print(f"Solution: {example['solution']}")
|
199 |
+
print("\nProof Steps:")
|
200 |
+
for step in example["proof_steps"]:
|
201 |
+
print(f"- {step['text']}")
|
202 |
+
else:
|
203 |
+
# For GSM8K samples
|
204 |
+
print(json.dumps(example, indent=2))
|
205 |
+
|
206 |
+
def main():
|
207 |
+
preparer = MathDataPreparer()
|
208 |
+
|
209 |
+
# Load and process GSM8K
|
210 |
+
print("\nProcessing GSM8K dataset...")
|
211 |
+
gsm8k_dataset = load_dataset("gsm8k", "main", split="train")
|
212 |
+
print(f"Loaded {len(gsm8k_dataset)} samples from GSM8K")
|
213 |
+
|
214 |
+
processed_gsm8k = preparer.process_gsm8k(gsm8k_dataset)
|
215 |
+
print(f"Processed {len(processed_gsm8k)} samples")
|
216 |
+
|
217 |
+
preparer.print_sample(processed_gsm8k)
|
218 |
+
|
219 |
+
# Save GSM8K
|
220 |
+
gsm8k_path = preparer.save_to_jsonl(processed_gsm8k, "gsm8k_processed.jsonl")
|
221 |
+
print(f"\nSaved GSM8K processed data to: {gsm8k_path}")
|
222 |
+
|
223 |
+
# Load and process ProofNet
|
224 |
+
print("\nProcessing ProofNet dataset...")
|
225 |
+
try:
|
226 |
+
proofnet_dataset = load_dataset("hoskinson-center/proofnet", split="validation")
|
227 |
+
print(f"Loaded {len(proofnet_dataset)} samples from ProofNet")
|
228 |
+
|
229 |
+
processed_proofnet = preparer.process_proofnet(proofnet_dataset)
|
230 |
+
print(f"Processed {len(processed_proofnet)} samples")
|
231 |
+
|
232 |
+
preparer.print_sample(processed_proofnet)
|
233 |
+
|
234 |
+
# Save ProofNet
|
235 |
+
proofnet_path = preparer.save_to_jsonl(processed_proofnet, "proofnet_processed.jsonl")
|
236 |
+
print(f"\nSaved ProofNet processed data to: {proofnet_path}")
|
237 |
+
except Exception as e:
|
238 |
+
print(f"Error processing ProofNet dataset: {str(e)}")
|
239 |
+
print("Continuing with GSM8K data only")
|
240 |
+
|
241 |
+
# Print memory usage
|
242 |
+
preparer.print_memory_usage()
|
243 |
+
|
244 |
+
if __name__ == "__main__":
|
245 |
+
main()
|
math_expert/processed_data/gsm8k_processed.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
math_expert/processed_data/proofnet_processed.jsonl
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
2 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
3 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
4 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
5 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
6 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
7 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
8 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
9 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
10 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
11 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
12 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
13 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
14 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
15 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
16 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
17 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
18 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
19 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
20 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
21 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
22 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
23 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
24 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
25 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
26 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
27 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
28 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
29 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
30 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
31 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
32 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
33 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
34 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
35 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
36 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
37 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
38 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
39 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
40 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
41 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
42 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
43 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
44 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
45 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
46 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
47 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
48 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
49 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
50 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
51 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
52 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
53 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
54 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
55 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
56 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
57 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
58 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
59 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
60 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
61 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
62 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
63 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
64 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
65 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
66 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
67 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
68 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
69 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
70 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
71 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
72 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
73 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
74 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
75 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
76 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
77 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
78 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
79 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
80 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
81 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
82 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
83 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
84 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
85 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
86 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
87 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
88 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
89 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
90 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
91 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
92 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
93 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
94 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
95 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
96 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
97 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
98 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
99 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
100 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
101 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
102 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
103 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
104 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
105 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
106 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
107 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
108 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
109 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
110 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
111 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
112 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
113 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
114 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
115 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
116 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
117 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
118 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
119 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
120 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
121 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
122 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
123 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
124 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
125 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
126 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
127 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
128 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
129 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
130 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
131 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
132 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
133 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
134 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
135 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
136 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
137 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
138 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
139 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
140 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
141 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
142 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
143 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
144 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
145 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
146 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
147 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
148 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
149 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
150 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
151 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
152 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
153 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
154 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
155 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
156 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
157 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
158 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
159 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
160 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
161 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
162 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
163 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
164 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
165 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
166 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
167 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
168 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
169 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
170 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
171 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
172 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
173 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
174 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
175 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
176 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
177 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
178 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
179 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
180 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
181 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
182 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
183 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
184 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
185 |
+
{"problem": "", "solution": "", "proof_steps": []}
|
math_expert/requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers>=4.30.0
|
2 |
+
sympy>=1.11.1
|
3 |
+
torch>=2.0.0
|
4 |
+
numpy>=1.24.0
|
5 |
+
scipy>=1.10.0
|
6 |
+
pandas>=2.0.0
|
7 |
+
huggingface_hub>=0.16.0
|
8 |
+
jsonlines>=3.0.0
|
9 |
+
pyyaml>=5.4.1
|
10 |
+
datasets>=2.14.0
|
11 |
+
psutil>=5.9.0
|
math_expert/train.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import TrainingArguments, Trainer
|
2 |
+
from datasets import load_dataset
|
3 |
+
import jsonlines
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from model import Transformer, ModelArgs
|
7 |
+
from tokenizer import Tokenizer
|
8 |
+
|
9 |
+
class MathDataset(torch.utils.data.Dataset):
|
10 |
+
def __init__(self, tokenizer, data_paths, max_length=512):
|
11 |
+
self.tokenizer = tokenizer
|
12 |
+
self.max_length = max_length
|
13 |
+
self.data = []
|
14 |
+
|
15 |
+
# Load and combine data from all files
|
16 |
+
for path in data_paths:
|
17 |
+
with jsonlines.open(path) as reader:
|
18 |
+
self.data.extend(list(reader))
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.data)
|
22 |
+
|
23 |
+
def __getitem__(self, idx):
|
24 |
+
example = self.data[idx]
|
25 |
+
|
26 |
+
# Format the input text
|
27 |
+
if "proof_steps" in example:
|
28 |
+
# For ProofNet-style data
|
29 |
+
text = f"Problem: {example['problem']}\nSolution: {example['solution']}\nProof Steps:\n"
|
30 |
+
for step in example["proof_steps"]:
|
31 |
+
text += f"- {step['text']}\n"
|
32 |
+
else:
|
33 |
+
# For GSM8K-style data
|
34 |
+
text = f"Question: {example['question']}\nAnswer: {example['answer']}"
|
35 |
+
|
36 |
+
# Tokenize
|
37 |
+
inputs = self.tokenizer(
|
38 |
+
text,
|
39 |
+
padding="max_length",
|
40 |
+
truncation=True,
|
41 |
+
max_length=self.max_length,
|
42 |
+
return_tensors="pt"
|
43 |
+
)
|
44 |
+
|
45 |
+
# Remove batch dimension
|
46 |
+
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
|
47 |
+
|
48 |
+
return {
|
49 |
+
"input_ids": inputs["input_ids"],
|
50 |
+
"attention_mask": inputs["attention_mask"],
|
51 |
+
"labels": inputs["input_ids"] # For causal LM training
|
52 |
+
}
|
53 |
+
|
54 |
+
def main():
|
55 |
+
# Initialize your custom model
|
56 |
+
model_args = ModelArgs(
|
57 |
+
dim=512,
|
58 |
+
n_layers=8,
|
59 |
+
n_heads=8,
|
60 |
+
vocab_size=50000, # Adjust based on your tokenizer
|
61 |
+
max_seq_len=1024
|
62 |
+
)
|
63 |
+
model = Transformer(model_args)
|
64 |
+
|
65 |
+
# Initialize your custom tokenizer
|
66 |
+
tokenizer = Tokenizer()
|
67 |
+
|
68 |
+
# Configure tokenizer
|
69 |
+
if tokenizer.pad_token is None:
|
70 |
+
tokenizer.pad_token = tokenizer.eos_token
|
71 |
+
|
72 |
+
# Set up training data paths
|
73 |
+
data_dir = os.path.join(os.path.dirname(__file__), "processed_data")
|
74 |
+
data_paths = [
|
75 |
+
os.path.join(data_dir, "gsm8k_processed.jsonl"),
|
76 |
+
os.path.join(data_dir, "proofnet_processed.jsonl")
|
77 |
+
]
|
78 |
+
|
79 |
+
# Create dataset
|
80 |
+
dataset = MathDataset(
|
81 |
+
tokenizer=tokenizer,
|
82 |
+
data_paths=data_paths,
|
83 |
+
max_length=1024 # Increased max_length for longer proofs
|
84 |
+
)
|
85 |
+
|
86 |
+
# Define training arguments
|
87 |
+
training_args = TrainingArguments(
|
88 |
+
output_dir="./math_expert_output",
|
89 |
+
overwrite_output_dir=True,
|
90 |
+
num_train_epochs=3,
|
91 |
+
per_device_train_batch_size=2,
|
92 |
+
gradient_accumulation_steps=4,
|
93 |
+
save_steps=1000,
|
94 |
+
save_total_limit=2,
|
95 |
+
logging_dir="./math_expert_logs",
|
96 |
+
logging_steps=100,
|
97 |
+
evaluation_strategy="steps",
|
98 |
+
eval_steps=1000,
|
99 |
+
load_best_model_at_end=True,
|
100 |
+
learning_rate=5e-5,
|
101 |
+
warmup_steps=500,
|
102 |
+
weight_decay=0.01,
|
103 |
+
fp16=True if torch.cuda.is_available() else False
|
104 |
+
)
|
105 |
+
|
106 |
+
# Create trainer
|
107 |
+
trainer = Trainer(
|
108 |
+
model=model,
|
109 |
+
args=training_args,
|
110 |
+
train_dataset=dataset,
|
111 |
+
tokenizer=tokenizer,
|
112 |
+
)
|
113 |
+
|
114 |
+
# Start training
|
115 |
+
print("Starting training with your custom model...")
|
116 |
+
trainer.train()
|
117 |
+
|
118 |
+
# Save the model
|
119 |
+
output_dir = "./math_expert_model"
|
120 |
+
os.makedirs(output_dir, exist_ok=True)
|
121 |
+
torch.save(model.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
|
122 |
+
model_args.save(os.path.join(output_dir, "config.json"))
|
123 |
+
print(f"Model saved to {output_dir}")
|
124 |
+
|
125 |
+
if __name__ == "__main__":
|
126 |
+
main()
|
math_expert/validation.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Validation module for the Math Expert model
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
from pathlib import Path
|
7 |
+
import hashlib
|
8 |
+
import datetime
|
9 |
+
from typing import Dict, Any, List, Optional
|
10 |
+
import numpy as np
|
11 |
+
from sympy import simplify, Eq
|
12 |
+
|
13 |
+
class MathValidator:
|
14 |
+
def __init__(self, checkpoint_dir: str = "checkpoints"):
|
15 |
+
self.checkpoint_dir = Path(checkpoint_dir)
|
16 |
+
self.checkpoint_dir.mkdir(exist_ok=True)
|
17 |
+
self.validation_dir = self.checkpoint_dir / "validation"
|
18 |
+
self.validation_dir.mkdir(exist_ok=True)
|
19 |
+
|
20 |
+
# Initialize validation metrics
|
21 |
+
self.metrics = {
|
22 |
+
"accuracy": [],
|
23 |
+
"equation_simplification": [],
|
24 |
+
"proof_validation": [],
|
25 |
+
"memory_usage": []
|
26 |
+
}
|
27 |
+
|
28 |
+
def validate_equation(self, equation: str, expected_result: str) -> Dict[str, Any]:
|
29 |
+
"""Validate mathematical equation correctness"""
|
30 |
+
try:
|
31 |
+
# Try to simplify both sides
|
32 |
+
lhs = simplify(equation)
|
33 |
+
rhs = simplify(expected_result)
|
34 |
+
|
35 |
+
# Check if simplified forms are equal
|
36 |
+
is_correct = lhs == rhs
|
37 |
+
|
38 |
+
return {
|
39 |
+
"is_correct": is_correct,
|
40 |
+
"simplified_lhs": str(lhs),
|
41 |
+
"simplified_rhs": str(rhs),
|
42 |
+
"validation_score": float(is_correct)
|
43 |
+
}
|
44 |
+
except Exception as e:
|
45 |
+
return {
|
46 |
+
"is_correct": False,
|
47 |
+
"error": str(e),
|
48 |
+
"validation_score": 0.0
|
49 |
+
}
|
50 |
+
|
51 |
+
def validate_proof(self, proof_steps: List[str], expected_theorem: str) -> Dict[str, Any]:
|
52 |
+
"""Validate mathematical proof steps"""
|
53 |
+
try:
|
54 |
+
# Check if each step logically follows from previous steps
|
55 |
+
current_context = set()
|
56 |
+
validation_score = 1.0
|
57 |
+
|
58 |
+
for step in proof_steps:
|
59 |
+
# Try to parse the step as an equation
|
60 |
+
try:
|
61 |
+
lhs, rhs = step.split('=')
|
62 |
+
if not Eq(simplify(lhs), simplify(rhs)):
|
63 |
+
validation_score *= 0.9 # Penalize incorrect steps
|
64 |
+
except:
|
65 |
+
pass # Not all steps are equations
|
66 |
+
|
67 |
+
# Update context
|
68 |
+
current_context.add(step)
|
69 |
+
|
70 |
+
# Check if final step matches expected theorem
|
71 |
+
final_step = proof_steps[-1]
|
72 |
+
matches_theorem = expected_theorem in final_step
|
73 |
+
|
74 |
+
return {
|
75 |
+
"is_valid": validation_score > 0.5,
|
76 |
+
"validation_score": validation_score,
|
77 |
+
"matches_theorem": matches_theorem,
|
78 |
+
"context_size": len(current_context)
|
79 |
+
}
|
80 |
+
except Exception as e:
|
81 |
+
return {
|
82 |
+
"is_valid": False,
|
83 |
+
"error": str(e),
|
84 |
+
"validation_score": 0.0
|
85 |
+
}
|
86 |
+
|
87 |
+
def create_checkpoint(self, data: Dict[str, Any], name: str = None) -> str:
|
88 |
+
"""Create a checkpoint of validation data"""
|
89 |
+
if name is None:
|
90 |
+
name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
91 |
+
|
92 |
+
checkpoint_path = self.validation_dir / f"checkpoint_{name}.json"
|
93 |
+
|
94 |
+
# Add timestamp and hash
|
95 |
+
data["timestamp"] = str(datetime.datetime.now())
|
96 |
+
data["hash"] = hashlib.sha256(str(data).encode()).hexdigest()
|
97 |
+
|
98 |
+
with open(checkpoint_path, 'w') as f:
|
99 |
+
json.dump(data, f, indent=2)
|
100 |
+
|
101 |
+
return str(checkpoint_path)
|
102 |
+
|
103 |
+
def load_checkpoint(self, name: str) -> Optional[Dict[str, Any]]:
|
104 |
+
"""Load a validation checkpoint"""
|
105 |
+
checkpoint_path = self.validation_dir / f"checkpoint_{name}.json"
|
106 |
+
if not checkpoint_path.exists():
|
107 |
+
return None
|
108 |
+
|
109 |
+
with open(checkpoint_path, 'r') as f:
|
110 |
+
return json.load(f)
|
111 |
+
|
112 |
+
def validate_dataset(self, dataset: List[Dict[str, Any]]) -> Dict[str, Any]:
|
113 |
+
"""Validate a complete dataset"""
|
114 |
+
results = []
|
115 |
+
error_count = 0
|
116 |
+
|
117 |
+
for idx, example in enumerate(dataset):
|
118 |
+
try:
|
119 |
+
# Validate equations
|
120 |
+
if "equation" in example:
|
121 |
+
eq_result = self.validate_equation(
|
122 |
+
example["equation"],
|
123 |
+
example.get("expected_result", "")
|
124 |
+
)
|
125 |
+
results.append(eq_result)
|
126 |
+
|
127 |
+
# Validate proofs
|
128 |
+
if "proof_steps" in example:
|
129 |
+
proof_result = self.validate_proof(
|
130 |
+
example["proof_steps"],
|
131 |
+
example.get("theorem", "")
|
132 |
+
)
|
133 |
+
results.append(proof_result)
|
134 |
+
except Exception as e:
|
135 |
+
error_count += 1
|
136 |
+
results.append({
|
137 |
+
"error": str(e),
|
138 |
+
"validation_score": 0.0
|
139 |
+
})
|
140 |
+
|
141 |
+
# Calculate overall metrics
|
142 |
+
scores = [r["validation_score"] for r in results if "validation_score" in r]
|
143 |
+
if scores:
|
144 |
+
avg_score = np.mean(scores)
|
145 |
+
else:
|
146 |
+
avg_score = 0.0
|
147 |
+
|
148 |
+
return {
|
149 |
+
"total_examples": len(dataset),
|
150 |
+
"processed_examples": len(results),
|
151 |
+
"error_count": error_count,
|
152 |
+
"average_score": float(avg_score),
|
153 |
+
"detailed_results": results
|
154 |
+
}
|
155 |
+
|
156 |
+
def save_validation_report(self, report: Dict[str, Any], name: str = None) -> str:
|
157 |
+
"""Save a validation report"""
|
158 |
+
if name is None:
|
159 |
+
name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
160 |
+
|
161 |
+
report_path = self.validation_dir / f"report_{name}.json"
|
162 |
+
|
163 |
+
# Add timestamp and summary metrics
|
164 |
+
report["timestamp"] = str(datetime.datetime.now())
|
165 |
+
report["summary"] = {
|
166 |
+
"accuracy": report.get("average_score", 0.0),
|
167 |
+
"error_rate": report.get("error_count", 0) / report.get("total_examples", 1)
|
168 |
+
}
|
169 |
+
|
170 |
+
with open(report_path, 'w') as f:
|
171 |
+
json.dump(report, f, indent=2)
|
172 |
+
|
173 |
+
return str(report_path)
|