Sarthak commited on
Commit
1bc7e54
·
1 Parent(s): 454e47c

chore: remove unused scripts and update dependencies

Browse files

This commit removes the MTEB evaluation script, distillation script, and evaluation script as they are no longer needed. Additionally, updates the pyproject.toml file to remove dependencies related to the removed scripts and adds typing-extensions to the dependencies.

Files changed (34) hide show
  1. .codemap.yml +294 -0
  2. MTEB_evaluate.py +0 -343
  3. REPORT.md +299 -0
  4. Taskfile.yml +23 -0
  5. analysis_charts/batch_size_scaling.png +3 -0
  6. analysis_charts/benchmark_performance.png +3 -0
  7. analysis_charts/code_performance_radar.png +3 -0
  8. analysis_charts/comparative_radar.png +3 -0
  9. analysis_charts/efficiency_analysis.png +3 -0
  10. analysis_charts/language_heatmap.png +3 -0
  11. analysis_charts/memory_scaling.png +3 -0
  12. analysis_charts/model_comparison.png +3 -0
  13. analysis_charts/model_specifications.png +3 -0
  14. analysis_charts/peer_comparison.png +3 -0
  15. analysis_charts/radar_code_model2vec_Linq_Embed_Mistral.png +3 -0
  16. analysis_charts/radar_code_model2vec_Qodo_Embed_1_15B.png +3 -0
  17. analysis_charts/radar_code_model2vec_Reason_ModernColBERT.png +3 -0
  18. analysis_charts/radar_code_model2vec_all_MiniLM_L6_v2.png +3 -0
  19. analysis_charts/radar_code_model2vec_all_mpnet_base_v2.png +3 -0
  20. analysis_charts/radar_code_model2vec_bge_m3.png +3 -0
  21. analysis_charts/radar_code_model2vec_codebert_base.png +3 -0
  22. analysis_charts/radar_code_model2vec_graphcodebert_base.png +3 -0
  23. analysis_charts/radar_code_model2vec_gte_Qwen2_15B_instruct.png +3 -0
  24. analysis_charts/radar_code_model2vec_gte_Qwen2_7B_instruct.png +3 -0
  25. analysis_charts/radar_code_model2vec_jina_embeddings_v2_base_code.png +3 -0
  26. analysis_charts/radar_code_model2vec_jina_embeddings_v3.png +3 -0
  27. analysis_charts/radar_code_model2vec_nomic_embed_text_v2_moe.png +3 -0
  28. analysis_charts/radar_code_model2vec_paraphrase_MiniLM_L6_v2.png +3 -0
  29. distill.py +0 -116
  30. evaluate.py +0 -422
  31. pyproject.toml +37 -5
  32. src/distiller/distill.py +419 -159
  33. src/distiller/evaluate.py +371 -43
  34. train_code_classification.py +0 -365
.codemap.yml ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CodeMap Configuration File
2
+ # -------------------------
3
+ # This file configures CodeMap's behavior. Uncomment and modify settings as needed.
4
+
5
+ # LLM Configuration - Controls which model is used for AI operations
6
+ llm:
7
+ # Format: "provider:model-name", e.g., "openai:gpt-4o", "anthropic:claude-3-opus"
8
+ model: "google-gla:gemini-2.0-flash"
9
+ temperature: 0.5 # Lower for more deterministic outputs, higher for creativity
10
+ max_input_tokens: 1000000 # Maximum tokens in input
11
+ max_output_tokens: 10000 # Maximum tokens in responses
12
+ max_requests: 25 # Maximum number of requests
13
+
14
+ # Embedding Configuration - Controls vector embedding behavior
15
+ embedding:
16
+ # Recommended models: "minishlab/potion-base-8M3", Only Model2Vec static models are supported
17
+ model_name: "minishlab/potion-base-8M"
18
+ dimension: 256
19
+ # dimension_metric: "cosine" # Metric for dimension calculation (e.g., "cosine", "euclidean")
20
+ # max_retries: 3 # Maximum retries for embedding requests
21
+ # retry_delay: 5 # Delay in seconds between retries
22
+ # max_content_length: 5000 # Maximum characters per file chunk
23
+ # Qdrant (Vector DB) settings
24
+ # qdrant_batch_size: 100 # Batch size for Qdrant uploads
25
+ # url: "http://localhost:6333" # Qdrant server URL
26
+ # timeout: 30 # Qdrant client timeout in seconds
27
+ # prefer_grpc: true # Prefer gRPC for Qdrant communication
28
+
29
+ # Advanced chunking settings - controls how code is split
30
+ # chunking:
31
+ # max_hierarchy_depth: 2 # Maximum depth of code hierarchy to consider
32
+ # max_file_lines: 1000 # Maximum lines per file before splitting
33
+
34
+ # Clustering settings for embeddings
35
+ # clustering:
36
+ # method: "agglomerative" # Clustering method: "agglomerative", "dbscan"
37
+ # agglomerative: # Settings for Agglomerative Clustering
38
+ # metric: "precomputed" # Metric: "cosine", "euclidean", "manhattan", "l1", "l2", "precomputed"
39
+ # distance_threshold: 0.3 # Distance threshold for forming clusters
40
+ # linkage: "complete" # Linkage criterion: "ward", "complete", "average", "single"
41
+ # dbscan: # Settings for DBSCAN Clustering
42
+ # eps: 0.3 # The maximum distance between two samples for one to be considered as in the neighborhood of the other
43
+ # min_samples: 2 # The number of samples in a neighborhood for a point to be considered as a core point
44
+ # algorithm: "auto" # Algorithm to compute pointwise distances: "auto", "ball_tree", "kd_tree", "brute"
45
+ # metric: "precomputed" # Metric for distance computation: "cityblock", "cosine", "euclidean", "l1", "l2", "manhattan", "precomputed"
46
+
47
+ # RAG (Retrieval Augmented Generation) Configuration
48
+ rag:
49
+ max_context_length: 8000 # Maximum context length for the LLM
50
+ max_context_results: 100 # Maximum number of context results to return
51
+ similarity_threshold: 0.75 # Minimum similarity score (0-1) for relevance
52
+ # system_prompt: null # Optional system prompt to guide the RAG model (leave commented or set if needed)
53
+ include_file_content: true # Include file content in context
54
+ include_metadata: true # Include file metadata in context
55
+
56
+ # Sync Configuration - Controls which files are excluded from processing
57
+ sync:
58
+ exclude_patterns:
59
+ - "^node_modules/"
60
+ - "^\\.venv/"
61
+ - "^venv/"
62
+ - "^env/"
63
+ - "^__pycache__/"
64
+ - "^\\.mypy_cache/"
65
+ - "^\\.pytest_cache/"
66
+ - "^\\.ruff_cache/"
67
+ - "^dist/"
68
+ - "^build/"
69
+ - "^\\.git/"
70
+ - "^typings/"
71
+ - "^\\.pyc$"
72
+ - "^\\.pyo$"
73
+ - "^\\.so$"
74
+ - "^\\.dll$"
75
+ - "^\\.lib$"
76
+ - "^\\.a$"
77
+ - "^\\.o$"
78
+ - "^\\.class$"
79
+ - "^\\.jar$"
80
+
81
+ # Generation Configuration - Controls documentation generation
82
+ gen:
83
+ max_content_length: 5000 # Maximum content length per file for generation
84
+ use_gitignore: true # Use .gitignore patterns to exclude files
85
+ output_dir: "documentation" # Directory to store generated documentation
86
+ include_tree: true # Include directory tree in output
87
+ include_entity_graph: true # Include entity relationship graph
88
+ semantic_analysis: true # Enable semantic analysis
89
+ lod_level: "skeleton" # Level of detail: "signatures", "structure", "docs", "skeleton", "full"
90
+
91
+ # Mermaid diagram configuration for entity graphs
92
+ # mermaid_entities:
93
+ # - "module"
94
+ # - "class"
95
+ # - "function"
96
+ # - "method"
97
+ # - "constant"
98
+ # - "variable"
99
+ # - "import"
100
+ # mermaid_relationships:
101
+ # - "declares"
102
+ # - "imports"
103
+ # - "calls"
104
+ mermaid_show_legend: false
105
+ mermaid_remove_unconnected: true # Show isolated nodes
106
+ mermaid_styled: false # Style the mermaid diagram
107
+
108
+ # Processor Configuration - Controls code processing behavior
109
+ processor:
110
+ enabled: true # Enable the processor
111
+ max_workers: 4 # Maximum number of parallel workers
112
+ ignored_patterns: # Patterns to ignore during processing
113
+ - "**/.git/**"
114
+ - "**/__pycache__/**"
115
+ - "**/.venv/**"
116
+ - "**/node_modules/**"
117
+ - "**/*.pyc"
118
+ - "**/dist/**"
119
+ - "**/build/**"
120
+ default_lod_level: "signatures" # Default level of detail: "signatures", "structure", "docs", "full"
121
+
122
+ # File watcher configuration
123
+ # watcher:
124
+ # enabled: true # Enable file watching
125
+ # debounce_delay: 1.0 # Delay in seconds before processing changes
126
+
127
+ # Commit Command Configuration
128
+ commit:
129
+ strategy: "semantic" # Strategy for splitting diffs: "file", "hunk", "semantic"
130
+ bypass_hooks: false # Whether to bypass git hooks
131
+ use_lod_context: true # Use level of detail context
132
+ is_non_interactive: false # Run in non-interactive mode
133
+
134
+ # Diff splitter configuration
135
+ # diff_splitter:
136
+ # similarity_threshold: 0.6 # Similarity threshold for grouping related changes
137
+ # directory_similarity_threshold: 0.3 # Threshold for considering directories similar (e.g., for renames)
138
+ # file_move_similarity_threshold: 0.85 # Threshold for detecting file moves/renames based on content
139
+ # min_chunks_for_consolidation: 2 # Minimum number of small chunks to consider for consolidation
140
+ # max_chunks_before_consolidation: 20 # Maximum number of chunks before forcing consolidation
141
+ # max_file_size_for_llm: 50000 # Maximum file size (bytes) for LLM processing of individual files
142
+ # max_log_diff_size: 1000 # Maximum size (lines) of diff log to pass to LLM for context
143
+ # default_code_extensions: # File extensions considered as code for semantic splitting
144
+ # - "js"
145
+ # - "jsx"
146
+ # - "ts"
147
+ # - "tsx"
148
+ # - "py"
149
+ # - "java"
150
+ # - "c"
151
+ # - "cpp"
152
+ # - "h"
153
+ # - "hpp"
154
+ # - "cc"
155
+ # - "cs"
156
+ # - "go"
157
+ # - "rb"
158
+ # - "php"
159
+ # - "rs"
160
+ # - "swift"
161
+ # - "scala"
162
+ # - "kt"
163
+ # - "sh"
164
+ # - "pl"
165
+ # - "pm"
166
+
167
+ # Commit convention configuration (Conventional Commits)
168
+ convention:
169
+ types: # Allowed commit types
170
+ - "feat"
171
+ - "fix"
172
+ - "docs"
173
+ - "style"
174
+ - "refactor"
175
+ - "perf"
176
+ - "test"
177
+ - "build"
178
+ - "ci"
179
+ - "chore"
180
+ scopes: [] # Add project-specific scopes here, e.g., ["api", "ui", "db"]
181
+ max_length: 72 # Maximum length of commit message header
182
+
183
+ # Commit linting configuration (based on conventional-changelog-lint rules)
184
+ # lint:
185
+ # # Rules are defined as: {level: "ERROR"|"WARNING"|"DISABLED", rule: "always"|"never", value: <specific_value_if_any>}
186
+ # header_max_length:
187
+ # level: "ERROR"
188
+ # rule: "always"
189
+ # value: 100
190
+ # header_case: # e.g., 'lower-case', 'upper-case', 'camel-case', etc.
191
+ # level: "DISABLED"
192
+ # rule: "always"
193
+ # value: "lower-case"
194
+ # header_full_stop:
195
+ # level: "ERROR"
196
+ # rule: "never"
197
+ # value: "."
198
+ # type_enum: # Types must be from the 'convention.types' list
199
+ # level: "ERROR"
200
+ # rule: "always"
201
+ # type_case:
202
+ # level: "ERROR"
203
+ # rule: "always"
204
+ # value: "lower-case"
205
+ # type_empty:
206
+ # level: "ERROR"
207
+ # rule: "never"
208
+ # scope_case:
209
+ # level: "ERROR"
210
+ # rule: "always"
211
+ # value: "lower-case"
212
+ # scope_empty: # Set to "ERROR" if scopes are mandatory
213
+ # level: "DISABLED"
214
+ # rule: "never"
215
+ # scope_enum: # Scopes must be from the 'convention.scopes' list if enabled
216
+ # level: "DISABLED"
217
+ # rule: "always"
218
+ # # value: [] # Add allowed scopes here if rule is "always" and level is not DISABLED
219
+ # subject_case: # Forbids specific cases in the subject
220
+ # level: "ERROR"
221
+ # rule: "never"
222
+ # value: ["sentence-case", "start-case", "pascal-case", "upper-case"]
223
+ # subject_empty:
224
+ # level: "ERROR"
225
+ # rule: "never"
226
+ # subject_full_stop:
227
+ # level: "ERROR"
228
+ # rule: "never"
229
+ # value: "."
230
+ # subject_exclamation_mark:
231
+ # level: "DISABLED"
232
+ # rule: "never"
233
+ # body_leading_blank: # Body must start with a blank line after subject
234
+ # level: "WARNING"
235
+ # rule: "always"
236
+ # body_empty:
237
+ # level: "DISABLED"
238
+ # rule: "never"
239
+ # body_max_line_length:
240
+ # level: "ERROR"
241
+ # rule: "always"
242
+ # value: 100
243
+ # footer_leading_blank: # Footer must start with a blank line after body
244
+ # level: "WARNING"
245
+ # rule: "always"
246
+ # footer_empty:
247
+ # level: "DISABLED"
248
+ # rule: "never"
249
+ # footer_max_line_length:
250
+ # level: "ERROR"
251
+ # rule: "always"
252
+ # value: 100
253
+
254
+ # Pull Request Configuration
255
+ pr:
256
+ defaults:
257
+ base_branch: null # Default base branch (null = auto-detect, e.g., main, master, develop)
258
+ feature_prefix: "feature/" # Default feature branch prefix
259
+
260
+ strategy: "github-flow" # Git workflow: "github-flow", "gitflow", "trunk-based"
261
+
262
+ # Branch mapping for different PR types (primarily used in gitflow strategy)
263
+ # branch_mapping:
264
+ # feature:
265
+ # base: "develop"
266
+ # prefix: "feature/"
267
+ # release:
268
+ # base: "main"
269
+ # prefix: "release/"
270
+ # hotfix:
271
+ # base: "main"
272
+ # prefix: "hotfix/"
273
+ # bugfix:
274
+ # base: "develop"
275
+ # prefix: "bugfix/"
276
+
277
+ # PR generation configuration
278
+ generate:
279
+ title_strategy: "llm" # Strategy for generating PR titles: "commits" (from commit messages), "llm" (AI generated)
280
+ description_strategy: "llm" # Strategy for descriptions: "commits", "llm"
281
+ # description_template: | # Template for PR description when using 'llm' strategy. Placeholders: {changes}, {testing_instructions}, {screenshots}
282
+ # ## Changes
283
+ # {changes}
284
+ #
285
+ # ## Testing
286
+ # {testing_instructions}
287
+ #
288
+ # ## Screenshots
289
+ # {screenshots}
290
+ use_workflow_templates: true # Use workflow-specific templates if available (e.g., for GitHub PR templates)
291
+
292
+ # Ask Command Configuration
293
+ ask:
294
+ interactive_chat: false # Enable interactive chat mode for the 'ask' command
MTEB_evaluate.py DELETED
@@ -1,343 +0,0 @@
1
- #!/usr/bin/env python
2
- """
3
- MTEB Evaluation Script with Subprocess Isolation (Code Information Retrieval Tasks).
4
-
5
- This script evaluates models using MTEB with subprocess isolation to prevent
6
- memory issues and process killing.
7
-
8
- Features:
9
- - Each task runs in a separate subprocess to isolate memory
10
- - 1-minute timeout per task
11
- - No retries - if task fails or times out, move to next one
12
- - Memory monitoring and cleanup
13
-
14
- Note: Multi-threading is NOT used here because:
15
- 1. Memory is the main bottleneck, not CPU
16
- 2. Running multiple tasks simultaneously would increase memory pressure
17
- 3. Many tasks are being killed (return code -9) due to OOM conditions
18
- 4. Sequential processing with subprocess isolation is more stable
19
- """
20
-
21
- import contextlib
22
- import json
23
- import logging
24
- import subprocess
25
- import sys
26
- import tempfile
27
- import time
28
- from pathlib import Path
29
-
30
- import psutil
31
-
32
- # =============================================================================
33
- # CONFIGURATION
34
- # =============================================================================
35
-
36
- MODEL_PATH = "."
37
- MODEL_NAME = "gte-Qwen2-7B-instruct-M2V-Distilled"
38
- OUTPUT_DIR = "mteb_results"
39
- TASK_TIMEOUT = 30 # 30 seconds timeout per task
40
- MAX_RETRIES = 0 # No retries - move to next task if failed/timeout
41
-
42
- # Constants
43
- SIGKILL_RETURN_CODE = -9 # Process killed by SIGKILL (usually OOM)
44
-
45
- # Configure logging
46
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
47
- logger = logging.getLogger(__name__)
48
-
49
- # =============================================================================
50
- # SINGLE TASK RUNNER SCRIPT
51
- # =============================================================================
52
-
53
- TASK_RUNNER_SCRIPT = """
54
- import sys
55
- import os
56
- import json
57
- import tempfile
58
- import traceback
59
- from pathlib import Path
60
-
61
- # Add current directory to path
62
- sys.path.insert(0, ".")
63
-
64
- try:
65
- import mteb
66
- from model2vec import StaticModel
67
- from mteb import ModelMeta
68
- from evaluation import CustomMTEB
69
-
70
- def run_single_task():
71
- # Get arguments
72
- model_path = sys.argv[1]
73
- task_name = sys.argv[2]
74
- output_dir = sys.argv[3]
75
- model_name = sys.argv[4]
76
-
77
- # Load model
78
- model = StaticModel.from_pretrained(model_path)
79
- model.mteb_model_meta = ModelMeta(
80
- name=model_name, revision="distilled", release_date=None, languages=["eng"]
81
- )
82
-
83
- # Get and run task
84
- task = mteb.get_task(task_name, languages=["eng"])
85
- evaluation = CustomMTEB(tasks=[task])
86
-
87
- results = evaluation.run(
88
- model,
89
- eval_splits=["test"],
90
- output_folder=output_dir,
91
- verbosity=0
92
- )
93
-
94
- # Save results to temp file for parent process
95
- with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as f:
96
- json.dump({
97
- "success": True,
98
- "task_name": task_name,
99
- "results": results
100
- }, f)
101
- temp_file = f.name
102
-
103
- print(f"RESULT_FILE:{temp_file}")
104
- return 0
105
-
106
- if __name__ == "__main__":
107
- exit(run_single_task())
108
-
109
- except Exception as e:
110
- print(f"ERROR: {str(e)}")
111
- print(f"TRACEBACK: {traceback.format_exc()}")
112
- exit(1)
113
- """
114
-
115
-
116
- def get_available_tasks() -> list[str]:
117
- """Get list of available tasks."""
118
- try:
119
- import mteb
120
- import mteb.benchmarks
121
-
122
- # Use main MTEB benchmark for comprehensive evaluation
123
- benchmark = mteb.benchmarks.CoIR
124
- return [str(task) for task in benchmark.tasks] # All tasks
125
- except Exception:
126
- logger.exception("Failed to get tasks")
127
- return []
128
-
129
-
130
- def check_existing_results(output_path: Path, task_names: list[str]) -> list[str]:
131
- """Check for existing results and return remaining tasks."""
132
- remaining_tasks = []
133
-
134
- for task_name in task_names:
135
- result_file = output_path / MODEL_NAME / "distilled" / f"{task_name}.json"
136
- if result_file.exists():
137
- logger.info(f"Skipping {task_name} - results already exist")
138
- else:
139
- remaining_tasks.append(task_name)
140
-
141
- return remaining_tasks
142
-
143
-
144
- def run_task_subprocess(task_name: str, output_dir: str) -> tuple[bool, str, float]:
145
- """Run a single task in a subprocess with memory and time limits."""
146
- # Create temporary script file
147
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
148
- f.write(TASK_RUNNER_SCRIPT)
149
- script_path = f.name
150
-
151
- try:
152
- logger.info(f"Running task: {task_name}")
153
- start_time = time.time()
154
-
155
- # Run subprocess with timeout
156
- # subprocess security: We control all inputs (script path and known arguments)
157
- cmd = [sys.executable, script_path, MODEL_PATH, task_name, output_dir, MODEL_NAME]
158
-
159
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) # noqa: S603
160
-
161
- try:
162
- stdout, stderr = process.communicate(timeout=TASK_TIMEOUT)
163
- duration = time.time() - start_time
164
-
165
- if process.returncode == 0:
166
- # Check for result file
167
- for line in stdout.split("\n"):
168
- if line.startswith("RESULT_FILE:"):
169
- result_file = line.split(":", 1)[1]
170
- try:
171
- with Path(result_file).open() as f:
172
- json.load(f)
173
- Path(result_file).unlink() # Clean up temp file
174
- logger.info(f"✓ Completed {task_name} in {duration:.2f}s")
175
- return True, task_name, duration
176
- except (json.JSONDecodeError, OSError):
177
- logger.exception("Failed to read result file")
178
-
179
- logger.info(f"✓ Completed {task_name} in {duration:.2f}s")
180
- return True, task_name, duration
181
- if process.returncode == SIGKILL_RETURN_CODE:
182
- logger.error(f"✗ Task {task_name} killed (OOM) - return code {process.returncode}")
183
- else:
184
- logger.error(f"✗ Task {task_name} failed with return code {process.returncode}")
185
- if stderr:
186
- logger.error(f"Error output: {stderr}")
187
- return False, task_name, duration
188
-
189
- except subprocess.TimeoutExpired:
190
- logger.warning(f"⏱ Task {task_name} timed out after {TASK_TIMEOUT}s")
191
- process.kill()
192
- process.wait()
193
- return False, task_name, TASK_TIMEOUT
194
-
195
- except Exception:
196
- logger.exception(f"✗ Failed to run task {task_name}")
197
- return False, task_name, 0.0
198
-
199
- finally:
200
- # Clean up script file
201
- with contextlib.suppress(Exception):
202
- Path(script_path).unlink()
203
-
204
-
205
- def collect_results(output_path: Path) -> dict:
206
- """Collect all results from completed tasks."""
207
- results_dir = output_path / MODEL_NAME / "distilled"
208
- if not results_dir.exists():
209
- return {}
210
-
211
- task_results = {}
212
- for result_file in results_dir.glob("*.json"):
213
- if result_file.name == "model_meta.json":
214
- continue
215
-
216
- try:
217
- with result_file.open() as f:
218
- data = json.load(f)
219
- task_name = result_file.stem
220
- task_results[task_name] = data
221
- except (json.JSONDecodeError, OSError) as e:
222
- logger.warning(f"Could not load {result_file}: {e}")
223
-
224
- return task_results
225
-
226
-
227
- def save_summary(output_path: Path, results: dict, stats: dict) -> None:
228
- """Save evaluation summary."""
229
- summary = {
230
- "model_name": MODEL_NAME,
231
- "timestamp": time.time(),
232
- "task_timeout": TASK_TIMEOUT,
233
- "stats": stats,
234
- "task_results": results,
235
- }
236
-
237
- summary_file = output_path / "mteb_summary.json"
238
- with summary_file.open("w") as f:
239
- json.dump(summary, f, indent=2, default=str)
240
-
241
- logger.info(f"Summary saved to {summary_file}")
242
-
243
-
244
- def main() -> None:
245
- """Main evaluation function."""
246
- logger.info(f"Starting MTEB evaluation for {MODEL_NAME}")
247
- logger.info(f"Task timeout: {TASK_TIMEOUT}s (no retries)")
248
- logger.info("Memory isolation: Each task runs in separate subprocess")
249
-
250
- # Log system info
251
- memory_info = psutil.virtual_memory()
252
- logger.info(f"System memory: {memory_info.total / (1024**3):.1f} GB total")
253
-
254
- output_path = Path(OUTPUT_DIR)
255
- output_path.mkdir(parents=True, exist_ok=True)
256
-
257
- # Get tasks
258
- all_tasks = get_available_tasks()
259
- if not all_tasks:
260
- logger.error("No tasks found!")
261
- return
262
-
263
- logger.info(f"Found {len(all_tasks)} tasks")
264
-
265
- # Check existing results
266
- remaining_tasks = check_existing_results(output_path, all_tasks)
267
- logger.info(f"Will evaluate {len(remaining_tasks)} remaining tasks")
268
-
269
- if not remaining_tasks:
270
- logger.info("All tasks already completed!")
271
- return
272
-
273
- # Process tasks sequentially (no retries)
274
- start_time = time.time()
275
- successful_tasks = []
276
- failed_tasks = []
277
- timed_out_tasks = []
278
-
279
- for i, task_name in enumerate(remaining_tasks):
280
- logger.info(f"[{i + 1}/{len(remaining_tasks)}] Processing: {task_name}")
281
-
282
- # Run task once (no retries)
283
- success, name, duration = run_task_subprocess(task_name, str(output_path))
284
-
285
- if success:
286
- successful_tasks.append((name, duration))
287
- elif duration == TASK_TIMEOUT:
288
- timed_out_tasks.append(name)
289
- else:
290
- failed_tasks.append(name)
291
- # Check if it was OOM killed (this is logged in run_task_subprocess)
292
-
293
- # Progress update
294
- progress = ((i + 1) / len(remaining_tasks)) * 100
295
- logger.info(f"Progress: {i + 1}/{len(remaining_tasks)} ({progress:.1f}%)")
296
-
297
- # Brief pause between tasks
298
- time.sleep(1)
299
-
300
- total_time = time.time() - start_time
301
-
302
- # Log final summary
303
- logger.info("=" * 80)
304
- logger.info("EVALUATION SUMMARY")
305
- logger.info("=" * 80)
306
- logger.info(f"Total tasks: {len(remaining_tasks)}")
307
- logger.info(f"Successful: {len(successful_tasks)}")
308
- logger.info(f"Failed: {len(failed_tasks)}")
309
- logger.info(f"Timed out: {len(timed_out_tasks)}")
310
- logger.info(f"Total time: {total_time:.2f}s")
311
-
312
- if successful_tasks:
313
- avg_time = sum(duration for _, duration in successful_tasks) / len(successful_tasks)
314
- logger.info(f"Average successful task time: {avg_time:.2f}s")
315
-
316
- if failed_tasks:
317
- logger.warning(f"Failed tasks: {failed_tasks}")
318
-
319
- if timed_out_tasks:
320
- logger.warning(f"Timed out tasks: {timed_out_tasks}")
321
-
322
- logger.info("=" * 80)
323
-
324
- # Collect and save results
325
- all_results = collect_results(output_path)
326
- stats = {
327
- "total_tasks": len(remaining_tasks),
328
- "successful": len(successful_tasks),
329
- "failed": len(failed_tasks),
330
- "timed_out": len(timed_out_tasks),
331
- "total_time": total_time,
332
- "avg_time": avg_time if successful_tasks else 0,
333
- "successful_task_details": successful_tasks,
334
- "failed_tasks": failed_tasks,
335
- "timed_out_tasks": timed_out_tasks,
336
- }
337
-
338
- save_summary(output_path, all_results, stats)
339
- logger.info("Evaluation completed!")
340
-
341
-
342
- if __name__ == "__main__":
343
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
REPORT.md ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code-Specialized Model2Vec Distillation Analysis
2
+
3
+ ## 🎯 Executive Summary
4
+
5
+ This report presents a comprehensive analysis of Model2Vec distillation experiments using different teacher models for code-specialized embedding generation.
6
+
7
+ ### Evaluated Models Overview
8
+
9
+ **Simplified Distillation Models:** 13
10
+ **Peer Comparison Models:** 19
11
+ **Total Models Analyzed:** 32
12
+
13
+ ### Best Performing Simplified Model: code_model2vec_all_mpnet_base_v2
14
+
15
+ **Overall CodeSearchNet Performance:**
16
+ - **NDCG@10**: 0.7387
17
+ - **Mean Reciprocal Rank (MRR)**: 0.7010
18
+ - **Recall@5**: 0.8017
19
+ - **Mean Rank**: 6.4
20
+
21
+ ## 📊 Comprehensive Model Comparison
22
+
23
+ ### All Simplified Distillation Models Performance
24
+
25
+ | Model | Teacher | NDCG@10 | MRR | Recall@5 | Status |
26
+ |-------|---------|---------|-----|----------|--------|
27
+ | code_model2vec_all_mpnet_base_v2 | [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) | 0.7387 | 0.7010 | 0.8017 | 🥇 Best |
28
+ | code_model2vec_all_MiniLM_L6_v2 | [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) | 0.7385 | 0.7049 | 0.7910 | 🥈 2nd |
29
+ | code_model2vec_jina_embeddings_v2_base_code | [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code) | 0.7381 | 0.6996 | 0.8130 | 🥉 3rd |
30
+ | code_model2vec_paraphrase_MiniLM_L6_v2 | [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) | 0.7013 | 0.6638 | 0.7665 | #4 |
31
+ | code_model2vec_Reason_ModernColBERT | [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT) | 0.6598 | 0.6228 | 0.7260 | #5 |
32
+ | code_model2vec_bge_m3 | [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | 0.4863 | 0.4439 | 0.5514 | #6 |
33
+ | code_model2vec_jina_embeddings_v3 | [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) | 0.4755 | 0.4416 | 0.5456 | #7 |
34
+ | code_model2vec_nomic_embed_text_v2_moe | [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) | 0.4532 | 0.4275 | 0.5094 | #8 |
35
+ | code_model2vec_gte_Qwen2_1.5B_instruct | [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) | 0.4238 | 0.3879 | 0.4719 | #9 |
36
+ | code_model2vec_Qodo_Embed_1_1.5B | [Qodo/Qodo-Embed-1-1.5B](https://huggingface.co/Qodo/Qodo-Embed-1-1.5B) | 0.4101 | 0.3810 | 0.4532 | #10 |
37
+ | code_model2vec_graphcodebert_base | [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base) | 0.3420 | 0.3140 | 0.3704 | #11 |
38
+ | code_model2vec_Linq_Embed_Mistral | [Linq-AI-Research/Linq-Embed-Mistral](https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral) | 0.2868 | 0.2581 | 0.3412 | #12 |
39
+ | code_model2vec_codebert_base | [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base) | 0.2779 | 0.2534 | 0.3136 | #13 |
40
+
41
+
42
+ ### 📊 Model Specifications Analysis
43
+
44
+ Our distilled models exhibit consistent architectural characteristics across different teacher models:
45
+
46
+ | Model | Vocabulary Size | Parameters | Embedding Dim | Disk Size |
47
+ |-------|----------------|------------|---------------|-----------|
48
+ | all_mpnet_base_v2 | 29,528 | 7.6M | 256 | 14.4MB |
49
+ | all_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
50
+ | jina_embeddings_v2_base_code | 61,053 | 15.6M | 256 | 29.8MB |
51
+ | paraphrase_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
52
+ | Reason_ModernColBERT | 50,254 | 12.9M | 256 | 24.5MB |
53
+ | bge_m3 | 249,999 | 64.0M | 256 | 122.1MB |
54
+ | jina_embeddings_v3 | 249,999 | 64.0M | 256 | 122.1MB |
55
+ | nomic_embed_text_v2_moe | 249,999 | 64.0M | 256 | 122.1MB |
56
+ | gte_Qwen2_1.5B_instruct | 151,644 | 38.8M | 256 | 74.0MB |
57
+ | Qodo_Embed_1_1.5B | 151,644 | 38.8M | 256 | 74.0MB |
58
+ | graphcodebert_base | 50,262 | 12.9M | 256 | 24.5MB |
59
+ | Linq_Embed_Mistral | 31,999 | 8.2M | 256 | 15.6MB |
60
+ | codebert_base | 50,262 | 12.9M | 256 | 24.5MB |
61
+
62
+
63
+ ![Model Specifications](analysis_charts/model_specifications.png)
64
+
65
+ *Comprehensive analysis of our distilled models showing vocabulary size, parameter count, embedding dimensions, and storage requirements.*
66
+
67
+ #### Key Insights from Model Specifications:
68
+
69
+
70
+ - **Vocabulary Consistency**: All models use vocabulary sizes ranging from 29,525 to 249,999 tokens (avg: 106,592)
71
+ - **Parameter Efficiency**: Models range from 7.6M to 64.0M parameters (avg: 27.3M)
72
+ - **Storage Efficiency**: Disk usage ranges from 14.4MB to 122.1MB (avg: 52.0MB)
73
+ - **Embedding Dimensions**: Consistent 256 dimensions across all models (optimized for efficiency)
74
+
75
+
76
+ ### Key Findings
77
+
78
+
79
+ - **Best Teacher Model**: code_model2vec_all_mpnet_base_v2 (NDCG@10: 0.7387)
80
+ - **Least Effective Teacher**: code_model2vec_codebert_base (NDCG@10: 0.2779)
81
+ - **Performance Range**: 62.4% difference between best and worst
82
+ - **Average Performance**: 0.5178 NDCG@10
83
+
84
+
85
+ ## 🎯 Language Performance Radar Charts
86
+
87
+ ### Best Model vs Peer Models Comparison
88
+
89
+ ![Comparative Radar Chart](analysis_charts/comparative_radar.png)
90
+
91
+ *Comparative view showing how the best simplified distillation model performs against top peer models across programming languages.*
92
+
93
+ ### Individual Model Performance by Language
94
+
95
+ #### code_model2vec_all_mpnet_base_v2 (Teacher: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)) - NDCG@10: 0.7387
96
+
97
+ ![code_model2vec_all_mpnet_base_v2 Radar Chart](analysis_charts/radar_code_model2vec_all_mpnet_base_v2.png)
98
+
99
+ #### code_model2vec_all_MiniLM_L6_v2 (Teacher: [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)) - NDCG@10: 0.7385
100
+
101
+ ![code_model2vec_all_MiniLM_L6_v2 Radar Chart](analysis_charts/radar_code_model2vec_all_MiniLM_L6_v2.png)
102
+
103
+ #### code_model2vec_jina_embeddings_v2_base_code (Teacher: [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code)) - NDCG@10: 0.7381
104
+
105
+ ![code_model2vec_jina_embeddings_v2_base_code Radar Chart](analysis_charts/radar_code_model2vec_jina_embeddings_v2_base_code.png)
106
+
107
+ #### code_model2vec_paraphrase_MiniLM_L6_v2 (Teacher: [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2)) - NDCG@10: 0.7013
108
+
109
+ ![code_model2vec_paraphrase_MiniLM_L6_v2 Radar Chart](analysis_charts/radar_code_model2vec_paraphrase_MiniLM_L6_v2.png)
110
+
111
+ #### code_model2vec_Reason_ModernColBERT (Teacher: [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT)) - NDCG@10: 0.6598
112
+
113
+ ![code_model2vec_Reason_ModernColBERT Radar Chart](analysis_charts/radar_code_model2vec_Reason_ModernColBERT.png)
114
+
115
+ #### code_model2vec_bge_m3 (Teacher: [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3)) - NDCG@10: 0.4863
116
+
117
+ ![code_model2vec_bge_m3 Radar Chart](analysis_charts/radar_code_model2vec_bge_m3.png)
118
+
119
+ #### code_model2vec_jina_embeddings_v3 (Teacher: [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3)) - NDCG@10: 0.4755
120
+
121
+ ![code_model2vec_jina_embeddings_v3 Radar Chart](analysis_charts/radar_code_model2vec_jina_embeddings_v3.png)
122
+
123
+ #### code_model2vec_nomic_embed_text_v2_moe (Teacher: [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe)) - NDCG@10: 0.4532
124
+
125
+ ![code_model2vec_nomic_embed_text_v2_moe Radar Chart](analysis_charts/radar_code_model2vec_nomic_embed_text_v2_moe.png)
126
+
127
+ #### code_model2vec_gte_Qwen2_1.5B_instruct (Teacher: [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct)) - NDCG@10: 0.4238
128
+
129
+ ![code_model2vec_gte_Qwen2_1.5B_instruct Radar Chart](analysis_charts/radar_code_model2vec_gte_Qwen2_15B_instruct.png)
130
+
131
+ #### code_model2vec_Qodo_Embed_1_1.5B (Teacher: [Qodo/Qodo-Embed-1-1.5B](https://huggingface.co/Qodo/Qodo-Embed-1-1.5B)) - NDCG@10: 0.4101
132
+
133
+ ![code_model2vec_Qodo_Embed_1_1.5B Radar Chart](analysis_charts/radar_code_model2vec_Qodo_Embed_1_15B.png)
134
+
135
+ #### code_model2vec_graphcodebert_base (Teacher: [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base)) - NDCG@10: 0.3420
136
+
137
+ ![code_model2vec_graphcodebert_base Radar Chart](analysis_charts/radar_code_model2vec_graphcodebert_base.png)
138
+
139
+ #### code_model2vec_Linq_Embed_Mistral (Teacher: [Linq-AI-Research/Linq-Embed-Mistral](https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral)) - NDCG@10: 0.2868
140
+
141
+ ![code_model2vec_Linq_Embed_Mistral Radar Chart](analysis_charts/radar_code_model2vec_Linq_Embed_Mistral.png)
142
+
143
+ #### code_model2vec_codebert_base (Teacher: [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base)) - NDCG@10: 0.2779
144
+
145
+ ![code_model2vec_codebert_base Radar Chart](analysis_charts/radar_code_model2vec_codebert_base.png)
146
+
147
+
148
+
149
+ ## 🏆 Peer Model Comparison
150
+
151
+ ![Peer Comparison](analysis_charts/peer_comparison.png)
152
+
153
+ *Comparison with established code-specialized embedding models using actual evaluation results.*
154
+
155
+ ### Complete Model Ranking
156
+
157
+ | Rank | Model | Type | NDCG@10 | MRR | Recall@5 |
158
+ |------|-------|------|---------|-----|----------|
159
+ | 1 | Alibaba-NLP/gte-Qwen2-1.5B-instruct | General | 0.9729 | 0.9676 | 0.9825 |
160
+ | 2 | Qodo/Qodo-Embed-1-1.5B | General | 0.9715 | 0.9659 | 0.9875 |
161
+ | 3 | jina-embeddings-v2-base-code | General | 0.9677 | 0.9618 | 0.9849 |
162
+ | 4 | jinaai/jina-embeddings-v3 | General | 0.9640 | 0.9573 | 0.9839 |
163
+ | 5 | sentence-transformers/all-mpnet-base-v2 | General | 0.9477 | 0.9358 | 0.9732 |
164
+ | 6 | nomic-ai/nomic-embed-text-v2-moe | General | 0.9448 | 0.9357 | 0.9659 |
165
+ | 7 | sentence-transformers/all-MiniLM-L12-v2 | General | 0.9398 | 0.9265 | 0.9732 |
166
+ | 8 | BAAI/bge-m3 | General | 0.9383 | 0.9295 | 0.9643 |
167
+ | 9 | sentence-transformers/all-MiniLM-L6-v2 | General | 0.9255 | 0.9099 | 0.9642 |
168
+ | 10 | lightonai/Reason-ModernColBERT | General | 0.9188 | 0.9036 | 0.9486 |
169
+ | 11 | Linq-AI-Research/Linq-Embed-Mistral | General | 0.9080 | 0.8845 | 0.9650 |
170
+ | 12 | sentence-transformers/paraphrase-MiniLM-L6-v2 | General | 0.8297 | 0.8016 | 0.8828 |
171
+ | 13 | minishlab/potion-base-8M | Model2Vec | 0.8162 | 0.7817 | 0.8931 |
172
+ | 14 | minishlab/potion-retrieval-32M | Model2Vec | 0.8137 | 0.7810 | 0.8792 |
173
+ | 15 | code_model2vec_all_mpnet_base_v2 | **🔥 Simplified Distillation** | 0.7387 | 0.7010 | 0.8017 |
174
+ | 16 | code_model2vec_all_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7385 | 0.7049 | 0.7910 |
175
+ | 17 | code_model2vec_jina_embeddings_v2_base_code | **🔥 Simplified Distillation** | 0.7381 | 0.6996 | 0.8130 |
176
+ | 18 | code_model2vec_paraphrase_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7013 | 0.6638 | 0.7665 |
177
+ | 19 | code_model2vec_Reason_ModernColBERT | **🔥 Simplified Distillation** | 0.6598 | 0.6228 | 0.7260 |
178
+ | 20 | potion-multilingual-128M | Model2Vec | 0.6124 | 0.5683 | 0.7017 |
179
+ | 21 | huggingface/CodeBERTa-small-v1 | Code-Specific | 0.5903 | 0.5350 | 0.6779 |
180
+ | 22 | Salesforce/codet5-base | Code-Specific | 0.4872 | 0.4500 | 0.5742 |
181
+ | 23 | code_model2vec_bge_m3 | **🔥 Simplified Distillation** | 0.4863 | 0.4439 | 0.5514 |
182
+ | 24 | code_model2vec_jina_embeddings_v3 | **🔥 Simplified Distillation** | 0.4755 | 0.4416 | 0.5456 |
183
+ | 25 | code_model2vec_nomic_embed_text_v2_moe | **🔥 Simplified Distillation** | 0.4532 | 0.4275 | 0.5094 |
184
+ | 26 | code_model2vec_gte_Qwen2_1.5B_instruct | **🔥 Simplified Distillation** | 0.4238 | 0.3879 | 0.4719 |
185
+ | 27 | code_model2vec_Qodo_Embed_1_1.5B | **🔥 Simplified Distillation** | 0.4101 | 0.3810 | 0.4532 |
186
+ | 28 | microsoft/graphcodebert-base | Code-Specific | 0.4039 | 0.3677 | 0.4650 |
187
+ | 29 | code_model2vec_graphcodebert_base | **🔥 Simplified Distillation** | 0.3420 | 0.3140 | 0.3704 |
188
+ | 30 | code_model2vec_Linq_Embed_Mistral | **🔥 Simplified Distillation** | 0.2868 | 0.2581 | 0.3412 |
189
+ | 31 | code_model2vec_codebert_base | **🔥 Simplified Distillation** | 0.2779 | 0.2534 | 0.3136 |
190
+ | 32 | microsoft/codebert-base | Code-Specific | 0.1051 | 0.1058 | 0.1105 |
191
+
192
+
193
+ ## 📈 Performance Analysis
194
+
195
+ ### Multi-Model Comparison Charts
196
+
197
+ ![Model Comparison](analysis_charts/model_comparison.png)
198
+
199
+ *Comprehensive comparison across all evaluation metrics.*
200
+
201
+ ### Language Performance Analysis
202
+
203
+ ![Language Heatmap](analysis_charts/language_heatmap.png)
204
+
205
+ *Performance heatmap showing how different models perform across programming languages.*
206
+
207
+ ### Efficiency Analysis
208
+
209
+ ![Efficiency Analysis](analysis_charts/efficiency_analysis.png)
210
+
211
+ *Performance vs model size analysis showing the efficiency benefits of distillation.*
212
+
213
+
214
+
215
+ ## ⚡ Operational Performance Analysis
216
+
217
+ ![Benchmark Performance](analysis_charts/benchmark_performance.png)
218
+
219
+ *Comprehensive performance benchmarking across multiple operational metrics.*
220
+
221
+ ### Performance Scaling Analysis
222
+
223
+ ![Batch Size Scaling](analysis_charts/batch_size_scaling.png)
224
+
225
+ *How performance scales with different batch sizes for optimal throughput.*
226
+
227
+ ![Memory Scaling](analysis_charts/memory_scaling.png)
228
+
229
+ *Memory usage patterns across different batch sizes.*
230
+
231
+
232
+
233
+ ## 🔍 Language-Specific Analysis
234
+
235
+ ### Performance by Programming Language
236
+
237
+ | Language | Best Model Performance | Average Performance | Language Difficulty |
238
+ |----------|------------------------|--------------------|--------------------|
239
+ | Go | 0.9780 | 0.6950 | Easy |
240
+ | Java | 0.9921 | 0.6670 | Easy |
241
+ | Javascript | 0.9550 | 0.5847 | Easy |
242
+ | Php | 1.0000 | 0.6379 | Easy |
243
+ | Python | 1.0000 | 0.8604 | Easy |
244
+ | Ruby | 0.9493 | 0.6372 | Easy |
245
+
246
+
247
+ ## 🎯 Conclusions and Recommendations
248
+
249
+ ### Teacher Model Analysis
250
+
251
+ Based on the evaluation results across all simplified distillation models:
252
+
253
+
254
+ 1. **Best Teacher Model**: sentence-transformers/all-mpnet-base-v2 (NDCG@10: 0.7387)
255
+ 2. **Least Effective Teacher**: microsoft/codebert-base (NDCG@10: 0.2779)
256
+ 3. **Teacher Model Impact**: Choice of teacher model affects performance by 62.4%
257
+
258
+ ### Recommendations
259
+
260
+ - **For Production**: Use sentence-transformers/all-mpnet-base-v2 as teacher model for best performance
261
+ - **For Efficiency**: Model2Vec distillation provides significant size reduction with competitive performance
262
+ - **For Code Tasks**: Specialized models consistently outperform general-purpose models
263
+
264
+
265
+ ## 📄 Methodology
266
+
267
+ ### Evaluation Protocol
268
+ - **Dataset**: CodeSearchNet test sets for 6 programming languages
269
+ - **Metrics**: NDCG@k, MRR, Recall@k following CodeSearchNet methodology
270
+ - **Query Format**: Natural language documentation strings
271
+ - **Corpus Format**: Function code strings
272
+ - **Evaluation**: Retrieval of correct code for each documentation query
273
+
274
+ ### Teacher Models Tested
275
+ - [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) (proven baseline)
276
+ - [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) (general purpose)
277
+ - [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) (paraphrase model)
278
+ - [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base) (code-specialized)
279
+ - [microsoft/graphcodebert-base](https://huggingface.co/microsoft/graphcodebert-base) (graph-aware code model)
280
+ - [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) (instruction model)
281
+ - [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) (multilingual model)
282
+ - [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) (modern embedding model)
283
+ - [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) (mixture of experts)
284
+ - [Qodo/Qodo-Embed-1-1.5B](https://huggingface.co/Qodo/Qodo-Embed-1-1.5B) (code-specialized)
285
+ - [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT) (ColBERT architecture)
286
+ - [Linq-AI-Research/Linq-Embed-Mistral](https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral) (Mistral-based)
287
+ - [BAAI/bge-code-v1](https://huggingface.co/BAAI/bge-code-v1) (code-specialized BGE)
288
+ - [Salesforce/SFR-Embedding-Code-2B_R](https://huggingface.co/Salesforce/SFR-Embedding-Code-2B_R) (large code model)
289
+
290
+ ### Distillation Method
291
+ - **Technique**: Model2Vec static embedding generation
292
+ - **Parameters**: PCA dims=256, SIF coefficient=1e-3, Zipf weighting=True
293
+ - **Training Data**: CodeSearchNet comment-code pairs
294
+ - **Languages**: Python, JavaScript, Java, PHP, Ruby, Go
295
+
296
+ ---
297
+
298
+ *Report generated on 2025-05-31 11:39:39 using automated analysis pipeline.*
299
+ *For questions about methodology or results, please refer to the CodeSearchNet documentation.*
Taskfile.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3"
2
+
3
+ tasks:
4
+ default:
5
+ desc: List all available tasks
6
+ cmds:
7
+ - task -l
8
+
9
+ lint:
10
+ desc: Run all linting checks
11
+ cmds:
12
+ - uv run ruff check src --fix --unsafe-fixes
13
+
14
+ type:
15
+ desc: Run type checker
16
+ cmds:
17
+ - find src/distiller -name "*.py" | xargs uv run mypy
18
+
19
+ format:
20
+ desc: Run all formatters
21
+ cmds:
22
+ - uv run ruff format src
23
+
analysis_charts/batch_size_scaling.png ADDED

Git LFS Details

  • SHA256: ef5210a43a6c09db766650ba45a577916557f05384d127541eceebe1a8f42615
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
analysis_charts/benchmark_performance.png ADDED

Git LFS Details

  • SHA256: 04f5572ddb072b8085e3d22d621c0497de35623703eb5e3698a0bd0b32e2ed9a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.99 MB
analysis_charts/code_performance_radar.png ADDED

Git LFS Details

  • SHA256: 3acee93605e989a6cd8fa31cf3f4e3708df977fcf56ce08da40dd687096d9101
  • Pointer size: 131 Bytes
  • Size of remote file: 190 kB
analysis_charts/comparative_radar.png ADDED

Git LFS Details

  • SHA256: 902f53602bc6780da9a136b38fc8ea6b90d13c0fd2005e02cbd6f58e6f8f9b05
  • Pointer size: 131 Bytes
  • Size of remote file: 285 kB
analysis_charts/efficiency_analysis.png ADDED

Git LFS Details

  • SHA256: a8c5628d39cdd19ee9e77b38545647b80e1092d8cad9a0c7da1cfdb506651e35
  • Pointer size: 131 Bytes
  • Size of remote file: 221 kB
analysis_charts/language_heatmap.png ADDED

Git LFS Details

  • SHA256: a408507a764878d0959da3ce4a972380cdcdba952515aef46292addb4052e182
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
analysis_charts/memory_scaling.png ADDED

Git LFS Details

  • SHA256: d4506ecf38a454ffd5cac28b96a1fb1b287370bb576a74cd34aea8e8eb0b36e5
  • Pointer size: 131 Bytes
  • Size of remote file: 620 kB
analysis_charts/model_comparison.png ADDED

Git LFS Details

  • SHA256: 2e5d458ff85a706db56ff7774adbc26fa6645e94e9154352e3b45e0db890051c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
analysis_charts/model_specifications.png ADDED

Git LFS Details

  • SHA256: 685fe86b51d7175d783948ae1928a03959e9ed67073da720cd4cc881869140d7
  • Pointer size: 131 Bytes
  • Size of remote file: 617 kB
analysis_charts/peer_comparison.png ADDED

Git LFS Details

  • SHA256: a95bc3dc495b1e6ef637b5719611581ca9890b689d72bdf3399a29fca23619ce
  • Pointer size: 131 Bytes
  • Size of remote file: 679 kB
analysis_charts/radar_code_model2vec_Linq_Embed_Mistral.png ADDED

Git LFS Details

  • SHA256: 2b229a777e69b37402b2a0ac77b878436d3a6321912767c20da84bfb6ab893f1
  • Pointer size: 131 Bytes
  • Size of remote file: 190 kB
analysis_charts/radar_code_model2vec_Qodo_Embed_1_15B.png ADDED

Git LFS Details

  • SHA256: 5d124f47dc851222959b8160313eeec01d0fb9071cb5c8b987536e4478d190e7
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB
analysis_charts/radar_code_model2vec_Reason_ModernColBERT.png ADDED

Git LFS Details

  • SHA256: ae23be629335dd36c443cd60ef96ced638c411b6c348433a605d16e3a55212ec
  • Pointer size: 131 Bytes
  • Size of remote file: 213 kB
analysis_charts/radar_code_model2vec_all_MiniLM_L6_v2.png ADDED

Git LFS Details

  • SHA256: 3a63eefc044329f66b19e3ec6f289d41863949c769642f195ec340456ad0ced2
  • Pointer size: 131 Bytes
  • Size of remote file: 193 kB
analysis_charts/radar_code_model2vec_all_mpnet_base_v2.png ADDED

Git LFS Details

  • SHA256: 3acee93605e989a6cd8fa31cf3f4e3708df977fcf56ce08da40dd687096d9101
  • Pointer size: 131 Bytes
  • Size of remote file: 190 kB
analysis_charts/radar_code_model2vec_bge_m3.png ADDED

Git LFS Details

  • SHA256: 530395605197f1c3e69454eafd61a944ffddb16a203478dbc672fb013a4a4685
  • Pointer size: 131 Bytes
  • Size of remote file: 215 kB
analysis_charts/radar_code_model2vec_codebert_base.png ADDED

Git LFS Details

  • SHA256: b016f7e05ceec769707b54ba62d523d731996a2ecd66d4d119474b34bbd63e41
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB
analysis_charts/radar_code_model2vec_graphcodebert_base.png ADDED

Git LFS Details

  • SHA256: 41bbe7c6aee8133e6caa0195b124474767dfa678663925fd19d671db58503a39
  • Pointer size: 131 Bytes
  • Size of remote file: 198 kB
analysis_charts/radar_code_model2vec_gte_Qwen2_15B_instruct.png ADDED

Git LFS Details

  • SHA256: 5d1ce04352fad4653d0854c7045f4d02ae720c6f504313b7ffc9ab6b570fe3f6
  • Pointer size: 131 Bytes
  • Size of remote file: 206 kB
analysis_charts/radar_code_model2vec_gte_Qwen2_7B_instruct.png ADDED

Git LFS Details

  • SHA256: 4b8399fe5ad1247ac4a69d730510d1fd668ae774739b81a437f3ca27fadc1749
  • Pointer size: 131 Bytes
  • Size of remote file: 180 kB
analysis_charts/radar_code_model2vec_jina_embeddings_v2_base_code.png ADDED

Git LFS Details

  • SHA256: 3ea35308256248d44dd9b3a00fbee3bfcfb24bfd282ce2826d55d4dc46c62661
  • Pointer size: 131 Bytes
  • Size of remote file: 182 kB
analysis_charts/radar_code_model2vec_jina_embeddings_v3.png ADDED

Git LFS Details

  • SHA256: f8434a8cebda7e3fc1455b7b3225ca4af945508a72b42af0a54b3810dacd5c3a
  • Pointer size: 131 Bytes
  • Size of remote file: 215 kB
analysis_charts/radar_code_model2vec_nomic_embed_text_v2_moe.png ADDED

Git LFS Details

  • SHA256: 7a16b12430f75580d8003d1e888f6f2b9c8ee84e5f4aba492c2cce5f9594652c
  • Pointer size: 131 Bytes
  • Size of remote file: 205 kB
analysis_charts/radar_code_model2vec_paraphrase_MiniLM_L6_v2.png ADDED

Git LFS Details

  • SHA256: b0ea6532f6e426d474a5f7826c66b350928c230e41855fc3f2fcbdbcfd3b34c4
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
distill.py DELETED
@@ -1,116 +0,0 @@
1
- #!/usr/bin/env python
2
- """
3
- Script to distill Alibaba-NLP/gte-Qwen2-7B-instruct using Model2Vec.
4
-
5
- This script performs the following operations:
6
- 1. Downloads the Alibaba-NLP/gte-Qwen2-7B-instruct model
7
- 2. Distills it using Model2Vec to create a smaller, faster static model
8
- 3. Saves the distilled model for further use
9
- """
10
-
11
- import logging
12
- import shutil
13
- import time
14
- from pathlib import Path
15
-
16
- from model2vec.distill import distill
17
-
18
- # =============================================================================
19
- # CONFIGURATION CONSTANTS
20
- # =============================================================================
21
-
22
- # Model Configuration
23
- MODEL_NAME = "Alibaba-NLP/gte-Qwen2-7B-instruct" # Model name or path for the source model
24
- OUTPUT_DIR = "." # Directory to save the distilled model (current directory)
25
- PCA_DIMS = 256 # Dimensions for PCA reduction (smaller = faster but less accurate)
26
-
27
- # Hub Configuration
28
- SAVE_TO_HUB = False # Whether to push the model to HuggingFace Hub
29
- HUB_MODEL_ID = None # Model ID for HuggingFace Hub (if saving to hub)
30
-
31
- # Generation Configuration
32
- SKIP_README = True # Skip generating the README file
33
-
34
- # =============================================================================
35
-
36
- # Configure logging
37
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
38
- logger = logging.getLogger(__name__)
39
-
40
-
41
- def main() -> None:
42
- """Run the distillation process for Alibaba-NLP/gte-Qwen2-7B-instruct."""
43
- # Create output directory if it doesn't exist
44
- output_dir = Path(OUTPUT_DIR)
45
- output_dir.mkdir(parents=True, exist_ok=True)
46
-
47
- logger.info(f"Starting distillation of {MODEL_NAME}")
48
- logger.info(f"Distilled model will be saved to {output_dir}")
49
- logger.info(f"Using PCA dimensions: {PCA_DIMS}")
50
- logger.info(f"Skipping README generation: {SKIP_README}")
51
-
52
- # Record start time for benchmarking
53
- start_time = time.time()
54
-
55
- # Run the distillation
56
- try:
57
- logger.info("Starting Model2Vec distillation...")
58
- m2v_model = distill(
59
- model_name=MODEL_NAME,
60
- pca_dims=PCA_DIMS,
61
- )
62
-
63
- distill_time = time.time() - start_time
64
- logger.info(f"Distillation completed in {distill_time:.2f} seconds")
65
-
66
- # Save the distilled model
67
- m2v_model.save_pretrained(OUTPUT_DIR)
68
- logger.info(f"Model saved to {OUTPUT_DIR}")
69
-
70
- # Remove README.md if it was created and we want to skip it
71
- if SKIP_README and (output_dir / "README.md").exists():
72
- (output_dir / "README.md").unlink()
73
- logger.info("Removed auto-generated README.md")
74
-
75
- # Get model size information
76
- model_size_mb = sum(
77
- f.stat().st_size for f in output_dir.glob("**/*") if f.is_file() and f.name != "README.md"
78
- ) / (1024 * 1024)
79
- logger.info(f"Distilled model size: {model_size_mb:.2f} MB")
80
-
81
- # Push to hub if requested
82
- if SAVE_TO_HUB:
83
- if HUB_MODEL_ID:
84
- logger.info(f"Pushing model to HuggingFace Hub as {HUB_MODEL_ID}")
85
-
86
- # Create a temporary README for Hub upload if needed
87
- readme_path = output_dir / "README.md"
88
- had_readme = readme_path.exists()
89
-
90
- if SKIP_README and had_readme:
91
- # Backup the README
92
- shutil.move(readme_path, output_dir / "README.md.bak")
93
-
94
- # Push to Hub
95
- m2v_model.push_to_hub(HUB_MODEL_ID)
96
-
97
- # Restore state
98
- if SKIP_README:
99
- if had_readme:
100
- # Restore the backup
101
- shutil.move(output_dir / "README.md.bak", readme_path)
102
- elif (output_dir / "README.md").exists():
103
- # Remove README created during push_to_hub
104
- (output_dir / "README.md").unlink()
105
- else:
106
- logger.error("HUB_MODEL_ID must be specified when SAVE_TO_HUB is True")
107
-
108
- logger.info("Distillation process completed successfully!")
109
-
110
- except Exception:
111
- logger.exception("Error during distillation")
112
- raise
113
-
114
-
115
- if __name__ == "__main__":
116
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluate.py DELETED
@@ -1,422 +0,0 @@
1
- #!/usr/bin/env python
2
- """
3
- Script to evaluate the performance of the distilled Qodo-Embed model.
4
-
5
- This script performs the following:
6
- 1. Loads both the original Qodo-Embed-1-1.5B model and the distilled version
7
- 2. Compares them on:
8
- - Embedding similarity
9
- - Inference speed
10
- - Memory usage
11
- 3. Outputs a comprehensive evaluation report
12
- """
13
-
14
- import argparse
15
- import gc
16
- import logging
17
- import os
18
- import time
19
- from pathlib import Path
20
- from typing import Any, cast
21
-
22
- import matplotlib.pyplot as plt
23
- import numpy as np
24
- import psutil # type: ignore [import]
25
- import torch
26
- from model2vec import StaticModel
27
- from sentence_transformers import SentenceTransformer
28
- from sklearn.metrics.pairwise import cosine_similarity # type: ignore [import]
29
-
30
- # For transformer models
31
- from transformers import AutoModel, AutoTokenizer
32
- from transformers.modeling_utils import PreTrainedModel
33
-
34
- # =============================================================================
35
- # CONFIGURATION CONSTANTS
36
- # =============================================================================
37
-
38
- # Model Configuration
39
- ORIGINAL_MODEL = "Alibaba-NLP/gte-Qwen2-7B-instruct" # Original model name or path
40
- DISTILLED_MODEL = "." # Path to the distilled model (current directory)
41
- OUTPUT_DIR = "evaluation" # Directory to save evaluation results
42
-
43
- # =============================================================================
44
-
45
- # Constants
46
- BYTES_PER_KB = 1024.0
47
- TEXT_TRUNCATE_LENGTH = 20
48
-
49
- # Configure logging
50
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
51
- logger = logging.getLogger(__name__)
52
-
53
- # Sample texts for evaluation
54
- SAMPLE_TEXTS = [
55
- "def process_data_stream(source_iterator):",
56
- "implement binary search tree",
57
- "how to handle memory efficient data streaming",
58
- """class LazyLoader:
59
- def __init__(self, source):
60
- self.generator = iter(source)
61
- self._cache = []""",
62
- """def dfs_traversal(root):
63
- if not root:
64
- return []
65
- visited = []
66
- stack = [root]
67
- while stack:
68
- node = stack.pop()
69
- visited.append(node.val)
70
- if node.right:
71
- stack.append(node.right)
72
- if node.left:
73
- stack.append(node.left)
74
- return visited""",
75
- ]
76
-
77
-
78
- def load_models(
79
- original_model_name: str, distilled_model_path: str
80
- ) -> tuple[tuple[SentenceTransformer | PreTrainedModel, str], StaticModel]:
81
- """Load both the original and distilled models."""
82
- logger.info(f"Loading original model: {original_model_name}")
83
-
84
- try:
85
- # Try to load as a sentence transformer first
86
- original_model = SentenceTransformer(original_model_name)
87
- model_type = "sentence_transformer"
88
- except (ValueError, OSError, ImportError) as e:
89
- # If that fails, try loading as a Hugging Face transformer
90
- logger.info(f"Failed to load as SentenceTransformer: {e}")
91
- AutoTokenizer.from_pretrained(original_model_name)
92
- original_model = AutoModel.from_pretrained(original_model_name)
93
- model_type = "huggingface"
94
-
95
- logger.info(f"Loading distilled model from: {distilled_model_path}")
96
- distilled_model = StaticModel.from_pretrained(distilled_model_path)
97
-
98
- return (original_model, model_type), distilled_model
99
-
100
-
101
- def measure_memory_usage(model: SentenceTransformer | PreTrainedModel | StaticModel) -> float:
102
- """Measure memory usage of a model in MB."""
103
- gc.collect()
104
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
105
-
106
- process = psutil.Process(os.getpid())
107
- memory_before = process.memory_info().rss / (1024 * 1024) # MB
108
-
109
- # Force model to allocate memory if it hasn't already
110
- if isinstance(model, StaticModel | SentenceTransformer):
111
- _ = model.encode(["Test"])
112
- else:
113
- # For HF models, we need to handle differently
114
- pass
115
-
116
- gc.collect()
117
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
118
-
119
- process = psutil.Process(os.getpid())
120
- memory_after = process.memory_info().rss / (1024 * 1024) # MB
121
-
122
- return memory_after - memory_before
123
-
124
-
125
- def compute_embeddings(
126
- original_model: SentenceTransformer | PreTrainedModel,
127
- original_model_type: str,
128
- distilled_model: StaticModel,
129
- texts: list[str],
130
- original_model_name: str = "unknown",
131
- ) -> tuple[np.ndarray, np.ndarray]:
132
- """Compute embeddings using both models."""
133
- # Original model embeddings
134
- if original_model_type == "sentence_transformer":
135
- # Type narrowing: we know it's a SentenceTransformer here
136
- sentence_model = cast("SentenceTransformer", original_model)
137
- original_embeddings = sentence_model.encode(texts)
138
- else:
139
- # Type narrowing: we know it's a PreTrainedModel here
140
- auto_model = original_model # AutoModel.from_pretrained returns a PreTrainedModel instance
141
-
142
- # For HF models, we need more custom code
143
- # Simple mean pooling function for HF models
144
- def mean_pooling(model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
145
- token_embeddings = model_output
146
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
147
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
148
- input_mask_expanded.sum(1), min=1e-9
149
- )
150
-
151
- # Get model name for tokenizer
152
- model_name = getattr(auto_model.config, "name_or_path", original_model_name)
153
- tokenizer = AutoTokenizer.from_pretrained(model_name)
154
- encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
155
-
156
- with torch.no_grad():
157
- model_output = auto_model(**encoded_input)
158
- original_embeddings = mean_pooling(model_output.last_hidden_state, encoded_input["attention_mask"]).numpy()
159
-
160
- # Distilled model embeddings
161
- distilled_embeddings = distilled_model.encode(texts)
162
-
163
- return original_embeddings, distilled_embeddings
164
-
165
-
166
- def measure_inference_speed(
167
- model: SentenceTransformer | PreTrainedModel | StaticModel, model_type: str, texts: list[str], n_runs: int = 5
168
- ) -> float:
169
- """Measure inference speed in texts/second."""
170
- # Warmup
171
- if model_type in {"sentence_transformer", "static_model"}:
172
- # Type narrowing: we know it has encode method here
173
- encode_model = cast("SentenceTransformer | StaticModel", model)
174
- _ = encode_model.encode(texts[:1])
175
- else:
176
- # Type narrowing: we know it's a PreTrainedModel here
177
- auto_model = cast("PreTrainedModel", model)
178
- # Warmup for HF models
179
- model_name = getattr(auto_model.config, "name_or_path", "unknown")
180
- tokenizer = AutoTokenizer.from_pretrained(model_name)
181
- encoded_input = tokenizer(texts[:1], padding=True, truncation=True, return_tensors="pt")
182
- with torch.no_grad():
183
- _ = auto_model(**encoded_input)
184
-
185
- # Measure speed
186
- start_time = time.time()
187
-
188
- if model_type in {"sentence_transformer", "static_model"}:
189
- # Type narrowing: we know it has encode method here
190
- encode_model = cast("SentenceTransformer | StaticModel", model)
191
- for _ in range(n_runs):
192
- _ = encode_model.encode(texts)
193
- else:
194
- # Type narrowing: we know it's a PreTrainedModel here
195
- auto_model = cast("PreTrainedModel", model)
196
- # For HF models
197
- model_name = getattr(auto_model.config, "name_or_path", "unknown")
198
- tokenizer = AutoTokenizer.from_pretrained(model_name)
199
- for _ in range(n_runs):
200
- encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
201
- with torch.no_grad():
202
- _ = auto_model(**encoded_input)
203
-
204
- total_time = time.time() - start_time
205
- return (len(texts) * n_runs) / total_time
206
-
207
-
208
- def compute_cosine_similarity(embeddings1: np.ndarray, embeddings2: np.ndarray) -> np.ndarray:
209
- """Compute cosine similarity between embeddings, handling different dimensions.
210
-
211
- For embeddings with different dimensions, we compute similarity by comparing
212
- how they rank the same texts (semantically equivalent).
213
- """
214
- # Ensure embeddings1 and embeddings2 are 2D arrays with shapes (n_samples, n_features)
215
- if embeddings1.ndim == 1:
216
- embeddings1 = embeddings1.reshape(1, -1)
217
- if embeddings2.ndim == 1:
218
- embeddings2 = embeddings2.reshape(1, -1)
219
-
220
- # Check and transpose if needed to ensure samples are in rows
221
- if embeddings2.shape[0] != len(SAMPLE_TEXTS) and embeddings2.shape[1] == len(SAMPLE_TEXTS):
222
- embeddings2 = embeddings2.T
223
-
224
- logger.info(f"Embeddings shapes: original={embeddings1.shape}, distilled={embeddings2.shape}")
225
-
226
- # If dimensions differ, we compute similarity matrix based on how each model ranks text pairs
227
- # This is a form of semantic similarity evaluation rather than direct vector comparison
228
- similarity_matrix = np.zeros((len(SAMPLE_TEXTS), len(SAMPLE_TEXTS)))
229
-
230
- # Compute similarity matrices within each embedding space
231
- sim1 = cosine_similarity(embeddings1)
232
- sim2 = cosine_similarity(embeddings2)
233
-
234
- # The similarity between samples i and j is the correlation between how they rank other samples
235
- for i in range(len(SAMPLE_TEXTS)):
236
- for j in range(len(SAMPLE_TEXTS)):
237
- # For diagonal elements (same sample), use a direct measure of how similar
238
- # the two models rank that sample against all others
239
- if i == j:
240
- # Pearson correlation between the rankings (excluding self-comparison)
241
- rankings1 = np.delete(sim1[i], i)
242
- rankings2 = np.delete(sim2[i], i)
243
- # Higher correlation means the models agree on the semantic similarity
244
- similarity_matrix[i, j] = np.corrcoef(rankings1, rankings2)[0, 1]
245
- else:
246
- # For off-diagonal elements, compare how similarly both models relate samples i and j
247
- similarity_matrix[i, j] = 1 - abs(sim1[i, j] - sim2[i, j])
248
-
249
- return similarity_matrix
250
-
251
-
252
- def format_size(size_bytes: float) -> str:
253
- """Format size in bytes to human-readable format."""
254
- for unit in ["B", "KB", "MB", "GB"]:
255
- if size_bytes < BYTES_PER_KB:
256
- return f"{size_bytes:.2f} {unit}"
257
- size_bytes /= BYTES_PER_KB
258
- return f"{size_bytes:.2f} TB"
259
-
260
-
261
- def plot_comparison(results: dict[str, Any], output_dir: str) -> None:
262
- """Generate comparison plots and save them."""
263
- output_path = Path(output_dir)
264
- output_path.mkdir(exist_ok=True, parents=True)
265
-
266
- # Speed comparison
267
- plt.figure(figsize=(10, 6))
268
- models = ["Original", "Distilled"]
269
- speeds = [results["original_speed"], results["distilled_speed"]]
270
- plt.bar(models, speeds, color=["#1f77b4", "#ff7f0e"])
271
- plt.ylabel("Texts per second")
272
- plt.title("Inference Speed Comparison")
273
- plt.savefig(output_path / "speed_comparison.png", dpi=300, bbox_inches="tight")
274
-
275
- # Memory comparison
276
- plt.figure(figsize=(10, 6))
277
- memories = [results["original_memory"], results["distilled_memory"]]
278
- plt.bar(models, memories, color=["#1f77b4", "#ff7f0e"])
279
- plt.ylabel("Memory Usage (MB)")
280
- plt.title("Memory Usage Comparison")
281
- plt.savefig(output_path / "memory_comparison.png", dpi=300, bbox_inches="tight")
282
-
283
- # Size comparison
284
- plt.figure(figsize=(10, 6))
285
- sizes = [results["original_size"], results["distilled_size"]]
286
- plt.bar(models, sizes, color=["#1f77b4", "#ff7f0e"])
287
- plt.ylabel("Model Size (MB)")
288
- plt.title("Model Size Comparison")
289
- plt.savefig(output_path / "size_comparison.png", dpi=300, bbox_inches="tight")
290
-
291
- # Similarity matrix heatmap
292
- plt.figure(figsize=(8, 6))
293
- plt.imshow(results["similarity_matrix"], cmap="viridis", interpolation="nearest")
294
- plt.colorbar(label="Cosine Similarity")
295
- plt.title("Embedding Similarity Between Original and Distilled Models")
296
- plt.xticks([])
297
- plt.yticks(
298
- range(len(SAMPLE_TEXTS)),
299
- [t[:TEXT_TRUNCATE_LENGTH] + "..." if len(t) > TEXT_TRUNCATE_LENGTH else t for t in SAMPLE_TEXTS],
300
- )
301
- plt.savefig(output_path / "similarity_matrix.png", dpi=300, bbox_inches="tight")
302
-
303
-
304
- def evaluate_models(original_model_name: str, distilled_model_path: str, output_dir: str) -> dict[str, Any]:
305
- """Evaluate the original and distilled models."""
306
- # Load models
307
- (original_model, original_model_type), distilled_model = load_models(original_model_name, distilled_model_path)
308
-
309
- # Measure model sizes
310
- if isinstance(original_model, SentenceTransformer):
311
- # For SentenceTransformer, get parameters from all modules
312
- total_params = 0
313
- for module in original_model.modules():
314
- if hasattr(module, "parameters"):
315
- for param in module.parameters():
316
- total_params += param.numel()
317
- original_model_size = total_params * 4 / (1024 * 1024) # MB (assuming float32)
318
- else:
319
- # For PreTrainedModel
320
- auto_model = original_model # AutoModel.from_pretrained returns a PreTrainedModel instance
321
- original_model_size = sum(p.numel() * 4 for p in auto_model.parameters()) / (
322
- 1024 * 1024
323
- ) # MB (assuming float32)
324
-
325
- # Calculate distilled model size - only count actual model files
326
- model_files = ["model.safetensors", "config.json", "modules.json", "tokenizer.json"]
327
- distilled_model_size = 0.0
328
- for file_name in model_files:
329
- file_path = Path(distilled_model_path) / file_name
330
- if file_path.exists():
331
- distilled_model_size += file_path.stat().st_size
332
- distilled_model_size = distilled_model_size / (1024 * 1024) # Convert to MB
333
-
334
- # Measure memory usage
335
- original_memory = measure_memory_usage(original_model)
336
- distilled_memory = measure_memory_usage(distilled_model)
337
-
338
- # Compute embeddings
339
- original_embeddings, distilled_embeddings = compute_embeddings(
340
- original_model, original_model_type, distilled_model, SAMPLE_TEXTS, original_model_name
341
- )
342
-
343
- # Compute similarity between embeddings
344
- similarity_matrix = compute_cosine_similarity(original_embeddings, distilled_embeddings)
345
- similarity_diagonal = np.diag(similarity_matrix)
346
- avg_similarity = np.mean(similarity_diagonal)
347
-
348
- # Measure inference speed
349
- original_speed = measure_inference_speed(original_model, original_model_type, SAMPLE_TEXTS, n_runs=5)
350
- distilled_speed = measure_inference_speed(distilled_model, "static_model", SAMPLE_TEXTS, n_runs=5)
351
-
352
- # Collect results
353
- results = {
354
- "original_size": original_model_size,
355
- "distilled_size": distilled_model_size,
356
- "original_memory": original_memory,
357
- "distilled_memory": distilled_memory,
358
- "similarity_matrix": similarity_matrix,
359
- "avg_similarity": avg_similarity,
360
- "original_speed": original_speed,
361
- "distilled_speed": distilled_speed,
362
- "speed_improvement": distilled_speed / original_speed if original_speed > 0 else float("inf"),
363
- "size_reduction": original_model_size / distilled_model_size if distilled_model_size > 0 else float("inf"),
364
- "memory_reduction": original_memory / distilled_memory if distilled_memory > 0 else float("inf"),
365
- }
366
-
367
- # Generate plots
368
- plot_comparison(results, output_dir)
369
-
370
- # Print results
371
- separator = "=" * 50
372
- logger.info("\n%s", separator)
373
- logger.info("Model Evaluation Results")
374
- logger.info("%s", separator)
375
- logger.info(f"Original Model Size: {results['original_size']:.2f} MB")
376
- logger.info(f"Distilled Model Size: {results['distilled_size']:.2f} MB")
377
- logger.info(f"Size Reduction Factor: {results['size_reduction']:.2f}x")
378
- logger.info("\n")
379
- logger.info(f"Original Model Memory: {results['original_memory']:.2f} MB")
380
- logger.info(f"Distilled Model Memory: {results['distilled_memory']:.2f} MB")
381
- logger.info(f"Memory Reduction Factor: {results['memory_reduction']:.2f}x")
382
- logger.info("\n")
383
- logger.info(f"Original Model Speed: {results['original_speed']:.2f} texts/second")
384
- logger.info(f"Distilled Model Speed: {results['distilled_speed']:.2f} texts/second")
385
- logger.info(f"Speed Improvement Factor: {results['speed_improvement']:.2f}x")
386
- logger.info("\n")
387
- logger.info(f"Average Embedding Similarity: {results['avg_similarity']:.4f}")
388
- logger.info("%s", separator)
389
-
390
- return results
391
-
392
-
393
- def main() -> None:
394
- """Run the evaluation process."""
395
- parser = argparse.ArgumentParser(description="Evaluate the distilled model against the original")
396
- parser.add_argument("--original_model", default=ORIGINAL_MODEL, help="Original model name or path")
397
- parser.add_argument("--distilled_model", default=DISTILLED_MODEL, help="Path to the distilled model")
398
- parser.add_argument("--output_dir", default=OUTPUT_DIR, help="Directory to save evaluation results")
399
-
400
- args = parser.parse_args()
401
-
402
- # Validate configuration
403
- if not args.distilled_model:
404
- logger.error("Distilled model path must be provided")
405
- logger.error("Use --distilled_model to specify the path or set DISTILLED_MODEL constant")
406
- return
407
-
408
- # Create output directory
409
- output_dir = Path(args.output_dir)
410
- output_dir.mkdir(parents=True, exist_ok=True)
411
-
412
- # Run evaluation
413
- try:
414
- evaluate_models(args.original_model, args.distilled_model, args.output_dir)
415
- logger.info(f"Evaluation completed. Results saved to {args.output_dir}")
416
- except Exception:
417
- logger.exception("Error during evaluation")
418
- raise
419
-
420
-
421
- if __name__ == "__main__":
422
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml CHANGED
@@ -1,30 +1,57 @@
 
 
 
 
1
  [project]
2
  name = "gte-qwen2-7b-instruct-m2v"
3
  version = "0.1.0"
4
- description = "Add your description here"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
 
 
 
8
  "datasets>=3.6.0",
9
- "evaluation",
 
 
 
 
10
  "iso639>=0.1.4",
 
11
  "lightning>=2.5.1.post0",
12
  "matplotlib>=3.10.3",
13
  "model2vec[train]>=0.5.0",
14
  "mteb>=1.14.15",
15
  "numpy>=1.26.4",
 
16
  "psutil>=7.0.0",
 
 
17
  "scikit-learn>=1.6.1",
 
18
  "sentence-transformers>=4.1.0",
 
 
 
 
19
  "torch>=2.7.0",
 
20
  ]
21
 
 
 
 
22
  [dependency-groups]
23
  dev = [
24
  "mypy>=1.15.0",
25
  "ruff>=0.11.6",
26
  ]
27
 
 
 
 
28
  [tool.mypy]
29
  exclude = [
30
  ".git",
@@ -79,6 +106,14 @@ ignore = [
79
  "E101", # Indentation contains mixed spaces and tabs
80
  "W191", # indentation contains tabs
81
  "D206", # indent with spaces, not tabs
 
 
 
 
 
 
 
 
82
  ]
83
 
84
  [tool.ruff.lint.mccabe]
@@ -97,6 +132,3 @@ quote-style = "double"
97
  indent-style = "tab"
98
  skip-magic-trailing-comma = false
99
  line-ending = "auto"
100
-
101
- [tool.uv.sources]
102
- evaluation = { git = "https://github.com/MinishLab/evaluation.git", rev = "main" }
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
  [project]
6
  name = "gte-qwen2-7b-instruct-m2v"
7
  version = "0.1.0"
8
+ description = "Model2Vec distillation pipeline for gte-Qwen2-7B-instruct"
9
  readme = "README.md"
10
  requires-python = ">=3.12"
11
  dependencies = [
12
+ "accelerate>=1.7.0",
13
+ "beam-client>=0.2.155",
14
+ "boto3>=1.38.23",
15
  "datasets>=3.6.0",
16
+ "dotenv>=0.9.9",
17
+ "editables>=0.5",
18
+ "einops>=0.8.1",
19
+ "flash-attn>=2.7.4.post1",
20
+ "hatchling>=1.27.0",
21
  "iso639>=0.1.4",
22
+ "kaleido==1.0.0rc13",
23
  "lightning>=2.5.1.post0",
24
  "matplotlib>=3.10.3",
25
  "model2vec[train]>=0.5.0",
26
  "mteb>=1.14.15",
27
  "numpy>=1.26.4",
28
+ "plotly>=6.1.1",
29
  "psutil>=7.0.0",
30
+ "pydantic>=2.11.5",
31
+ "requests>=2.32.3",
32
  "scikit-learn>=1.6.1",
33
+ "seaborn>=0.13.2",
34
  "sentence-transformers>=4.1.0",
35
+ "setuptools>=80.8.0",
36
+ "smart-open[s3]>=7.1.0",
37
+ "statsmodels>=0.14.4",
38
+ "tokenlearn>=0.2.0",
39
  "torch>=2.7.0",
40
+ "typer>=0.16.0",
41
  ]
42
 
43
+ [project.scripts]
44
+ distiller = "distiller.__main__:app"
45
+
46
  [dependency-groups]
47
  dev = [
48
  "mypy>=1.15.0",
49
  "ruff>=0.11.6",
50
  ]
51
 
52
+ [tool.hatch.build.targets.wheel]
53
+ packages = ["src/distiller"]
54
+
55
  [tool.mypy]
56
  exclude = [
57
  ".git",
 
106
  "E101", # Indentation contains mixed spaces and tabs
107
  "W191", # indentation contains tabs
108
  "D206", # indent with spaces, not tabs
109
+ "PD901", # Avoid using the generic variable name `df` for DataFrames
110
+ "ANN401", # Dynamically typed expressions (typing.Any) are disallowed
111
+ "D103", # Missing docstring in public function
112
+ "BLE001", # Do not catch blind exception: `Exception`
113
+ "T201", # Use `logger.info` instead of `print`
114
+ "E501", # Line too long
115
+ "PLR2004",
116
+ "RUF001",
117
  ]
118
 
119
  [tool.ruff.lint.mccabe]
 
132
  indent-style = "tab"
133
  skip-magic-trailing-comma = false
134
  line-ending = "auto"
 
 
 
src/distiller/distill.py CHANGED
@@ -20,23 +20,22 @@ Usage:
20
  distiller distill [--use-beam] [--train] # Basic distillation or with training
21
  """
22
 
 
23
  import json
24
  import logging
 
25
  import time
26
  from pathlib import Path
27
  from typing import Annotated, Any
28
 
29
- import numpy as np
30
  import torch
31
  import typer
32
- from beam import Volume, function
33
  from datasets import load_dataset
34
  from model2vec.distill import distill
35
- from model2vec.train.base import FinetunableStaticModel, TextDataset
36
  from sentence_transformers import SentenceTransformer
37
- from sklearn.model_selection import train_test_split
38
- from torch import nn, optim
39
 
 
40
  from .beam_utils import (
41
  BeamCheckpointManager,
42
  create_beam_utilities,
@@ -46,16 +45,17 @@ from .beam_utils import (
46
  upload_model_to_beam,
47
  )
48
  from .config import (
49
- BEAM_ENV_SETTINGS,
50
- GPU_NAME,
51
- IMAGE,
52
  codesearchnet_config,
53
  directories,
54
  distillation_config,
 
55
  get_volume_config,
56
  languages_config,
57
  )
58
 
 
 
 
59
  # =============================================================================
60
  # CONFIGURATION
61
  # =============================================================================
@@ -70,6 +70,75 @@ logger = logging.getLogger(__name__)
70
  # Teacher models for distillation
71
  DEFAULT_TEACHER_MODELS = list(distillation_config.code_teacher_models)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # =============================================================================
74
  # UTILITY FUNCTIONS
75
  # =============================================================================
@@ -106,13 +175,11 @@ def get_current_config_hash(enable_training: bool) -> str:
106
  }
107
 
108
  if enable_training:
109
- config_params.update(
110
- {
111
- "training_epochs": distillation_config.training_epochs,
112
- "learning_rate": distillation_config.learning_rate,
113
- "max_samples": distillation_config.max_training_samples,
114
- }
115
  )
 
116
 
117
  config_str = str(sorted(config_params.items()))
118
  return hashlib.md5(config_str.encode()).hexdigest()[:12] # noqa: S324
@@ -345,13 +412,10 @@ def simple_distillation(
345
 
346
 
347
  def load_codesearchnet_dataset(
348
- max_samples: int | None = None,
349
  checkpoint_manager: BeamCheckpointManager | None = None,
350
  ) -> list[str]:
351
- """Load and format the CodeSearchNet dataset for training with balanced language distribution."""
352
- if max_samples is None:
353
- max_samples = int(distillation_config.max_training_samples)
354
-
355
  logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
356
  logger.info(f"Limiting to {max_samples} samples for training efficiency")
357
  logger.info(f"Languages: {', '.join(languages_config.all)}")
@@ -542,7 +606,7 @@ def generate_teacher_embeddings(
542
  # Generate embeddings from scratch
543
  logger.info("Generating fresh teacher embeddings...")
544
 
545
- batch_size = int(distillation_config.teacher_model_config.get("batch_size", 16))
546
  embeddings_list = []
547
 
548
  for i in range(0, len(texts), batch_size):
@@ -614,146 +678,351 @@ def generate_teacher_embeddings(
614
  return teacher_embeddings
615
 
616
 
617
- def advanced_training(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
  student_model: Any,
619
  teacher_model: SentenceTransformer,
620
- checkpoint_manager: BeamCheckpointManager | None = None,
621
  ) -> Any:
622
- """Perform advanced code specialization training."""
623
- logger.info("🎓 Starting advanced code specialization training...")
624
-
625
- # Load CodeSearchNet training data
626
- training_texts = load_codesearchnet_dataset(checkpoint_manager=checkpoint_manager)
627
-
628
- if not training_texts:
629
- logger.warning("No training data available, skipping advanced training")
630
- return student_model
631
-
632
- # Generate teacher embeddings
633
- teacher_embeddings = generate_teacher_embeddings(teacher_model, training_texts, checkpoint_manager)
634
-
635
- # Create trainable model
636
- student_embedding_dim = student_model.embedding.shape[1]
637
- teacher_embedding_dim = teacher_embeddings.shape[1]
638
-
639
- # Project teacher embeddings if needed
640
- if teacher_embedding_dim != student_embedding_dim:
641
- from sklearn.decomposition import PCA
642
-
643
- logger.info("Performing PCA projection for dimension matching...")
644
- pca = PCA(n_components=student_embedding_dim)
645
- teacher_embeddings_np = teacher_embeddings.cpu().numpy().astype(np.float64)
646
- teacher_embeddings_projected = pca.fit_transform(teacher_embeddings_np)
647
- teacher_embeddings = torch.tensor(teacher_embeddings_projected.astype(np.float32), dtype=torch.float32)
648
-
649
- # Create trainable model
650
- trainable_model = FinetunableStaticModel.from_static_model(
651
- model=student_model,
652
- out_dim=student_embedding_dim,
653
- )
654
- trainable_model = trainable_model.float()
655
-
656
- # Tokenize texts
657
- tokenized_texts = []
658
- for text in training_texts:
659
- tokens = trainable_model.tokenize([text])
660
- if tokens.shape[1] > 0:
661
- tokenized_texts.append(tokens[0].tolist())
662
-
663
- # Prepare training data
664
- targets = teacher_embeddings[: len(tokenized_texts)].to(torch.float32)
665
- train_texts, val_texts, train_targets, val_targets = train_test_split(
666
- tokenized_texts, targets, test_size=0.2, random_state=42
667
- )
668
-
669
- # Training setup
670
- train_dataset = TextDataset(train_texts, train_targets)
671
- val_dataset = TextDataset(val_texts, val_targets)
672
-
673
- optimizer = optim.Adam(trainable_model.parameters(), lr=float(distillation_config.learning_rate))
674
- mse_loss = nn.MSELoss()
675
-
676
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
677
- trainable_model = trainable_model.to(device)
678
-
679
- batch_size = int(distillation_config.batch_size)
680
- epochs = int(distillation_config.training_epochs)
681
-
682
- # Training loop
683
- for epoch in range(epochs):
684
- trainable_model.train()
685
 
686
  try:
687
- train_loader = train_dataset.to_dataloader(shuffle=True, batch_size=batch_size)
688
- epoch_loss = 0.0
689
- num_batches = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690
 
691
- for _batch_idx, (tokens, targets_batch) in enumerate(train_loader):
692
- batch_tokens = tokens.to(device)
693
- batch_targets = targets_batch.to(device).to(torch.float32)
 
 
 
 
 
 
 
 
 
 
 
694
 
695
- optimizer.zero_grad()
696
- _, student_embeddings = trainable_model(batch_tokens)
697
- student_embeddings = student_embeddings.to(torch.float32)
 
 
698
 
699
- loss = mse_loss(student_embeddings, batch_targets)
700
- loss.backward()
701
- optimizer.step()
702
 
703
- epoch_loss += loss.item()
704
- num_batches += 1
 
705
 
706
- except torch.cuda.OutOfMemoryError:
707
- logger.warning(f"Training OOM with batch size {batch_size}, reducing...")
708
- batch_size = max(1, batch_size // 2)
709
- torch.cuda.empty_cache()
710
- continue
711
 
712
- avg_train_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
 
713
 
714
- # Validation
715
- trainable_model.eval()
716
- val_loader = val_dataset.to_dataloader(shuffle=False, batch_size=batch_size)
717
- val_loss = 0.0
718
- val_batches = 0
 
 
 
 
 
 
 
719
 
720
- with torch.no_grad():
721
- for tokens, targets_batch in val_loader:
722
- batch_tokens = tokens.to(device)
723
- batch_targets = targets_batch.to(device).to(torch.float32)
724
 
725
- _, student_embeddings = trainable_model(batch_tokens)
726
- student_embeddings = student_embeddings.to(torch.float32)
 
 
 
 
727
 
728
- loss = mse_loss(student_embeddings, batch_targets)
729
- val_loss += loss.item()
730
- val_batches += 1
 
 
731
 
732
- avg_val_loss = val_loss / val_batches if val_batches > 0 else 0.0
733
 
734
- logger.info(f"Epoch {epoch + 1}/{epochs} - Train: {avg_train_loss:.6f}, Val: {avg_val_loss:.6f}")
 
 
 
 
735
 
736
- # Save checkpoint
737
- if checkpoint_manager:
738
- checkpoint_data = {
739
- "config_hash": get_current_config_hash(enable_training=True),
740
- "stage": "training",
741
- "step": epoch + 1,
742
- "timestamp": time.time(),
743
- "data": {
744
- "model_state": trainable_model.state_dict(),
745
- "optimizer_state": optimizer.state_dict(),
746
- "train_loss": avg_train_loss,
747
- "val_loss": avg_val_loss,
748
- },
749
- }
750
- checkpoint_manager.save_checkpoint("training", checkpoint_data, epoch + 1)
751
 
752
- # Convert back to static model
753
- refined_model = trainable_model.to_static_model()
754
- logger.info("✅ Advanced training completed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
 
756
- return refined_model
 
 
 
 
 
 
 
 
757
 
758
 
759
  def distill_single_teacher(
@@ -884,15 +1153,15 @@ def distill_single_teacher(
884
 
885
  # Step 3: Handle final model creation
886
  if enable_training and base_model is not None:
887
- # Perform advanced training
888
- logger.info(f"🎓 Starting advanced training for {teacher_name}")
889
 
890
  # Load teacher model for training
891
  device = "cuda" if torch.cuda.is_available() else "cpu"
892
- teacher_st_model = SentenceTransformer(teacher_model, device=device, trust_remote_code=True)
893
 
894
- # Perform advanced training
895
- final_model = advanced_training(base_model, teacher_st_model, checkpoint_mgr)
896
 
897
  # Save final model
898
  final_dir.mkdir(parents=True, exist_ok=True)
@@ -1031,14 +1300,7 @@ def run_local_distillation(
1031
  return results_summary
1032
 
1033
 
1034
- @function(
1035
- gpu=GPU_NAME,
1036
- volumes=[Volume(name=VOLUME_CONFIG.name, mount_path=VOLUME_CONFIG.mount_path)],
1037
- image=IMAGE,
1038
- secrets=["HF_ACCESS_TOKEN"],
1039
- env=BEAM_ENV_SETTINGS,
1040
- timeout=3600 * 12, # 12 hours
1041
- )
1042
  def _beam_distill_models(
1043
  teacher_models: list[str] | None = None,
1044
  enable_training: bool = False,
@@ -1194,7 +1456,8 @@ def main(
1194
  ) -> None:
1195
  """Unified distillation command with optional training."""
1196
  logger.info("🚀 Starting unified Model2Vec distillation workflow")
1197
- logger.info(f"🎓 Training mode: {'Advanced (CodeSearchNet fine-tuning)' if train else 'Basic distillation only'}")
 
1198
  logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
1199
 
1200
  # Use default models if none specified
@@ -1355,7 +1618,6 @@ def salesforce_model_distillation(
1355
 
1356
  try:
1357
  import torch
1358
- from sentence_transformers import SentenceTransformer
1359
  from transformers import AutoModel, AutoTokenizer
1360
 
1361
  # Enhanced custom model loading for Salesforce models
@@ -1395,9 +1657,8 @@ def salesforce_model_distillation(
1395
 
1396
  # Method 2: Try SentenceTransformer with specific settings
1397
  logger.info("🔄 Falling back to SentenceTransformer method...")
1398
- sentence_model = SentenceTransformer(
1399
  teacher_model,
1400
- trust_remote_code=True,
1401
  device="cpu", # Force CPU loading first
1402
  )
1403
 
@@ -1470,7 +1731,6 @@ def baai_bge_model_distillation(
1470
 
1471
  try:
1472
  import torch
1473
- from sentence_transformers import SentenceTransformer
1474
  from transformers import AutoModel, AutoTokenizer
1475
 
1476
  logger.info("🔧 Loading BAAI model with tokenizer workaround...")
@@ -1481,7 +1741,7 @@ def baai_bge_model_distillation(
1481
  # Method 1: Try SentenceTransformer first (often handles tokenizer issues better)
1482
  try:
1483
  logger.info("🔄 Attempting with SentenceTransformer wrapper...")
1484
- sentence_model = SentenceTransformer(teacher_model, trust_remote_code=True)
1485
 
1486
  # Extract components
1487
  model = sentence_model[0].auto_model
 
20
  distiller distill [--use-beam] [--train] # Basic distillation or with training
21
  """
22
 
23
+ import importlib.util
24
  import json
25
  import logging
26
+ import os
27
  import time
28
  from pathlib import Path
29
  from typing import Annotated, Any
30
 
 
31
  import torch
32
  import typer
33
+ from beam import function
34
  from datasets import load_dataset
35
  from model2vec.distill import distill
 
36
  from sentence_transformers import SentenceTransformer
 
 
37
 
38
+ # Try to import flash_attn to check if it's available
39
  from .beam_utils import (
40
  BeamCheckpointManager,
41
  create_beam_utilities,
 
45
  upload_model_to_beam,
46
  )
47
  from .config import (
 
 
 
48
  codesearchnet_config,
49
  directories,
50
  distillation_config,
51
+ get_distillation_function_kwargs,
52
  get_volume_config,
53
  languages_config,
54
  )
55
 
56
+ # Check if flash_attn is available and compatible
57
+ FLASH_ATTN_AVAILABLE = importlib.util.find_spec("flash_attn") is not None
58
+
59
  # =============================================================================
60
  # CONFIGURATION
61
  # =============================================================================
 
70
  # Teacher models for distillation
71
  DEFAULT_TEACHER_MODELS = list(distillation_config.code_teacher_models)
72
 
73
+ # =============================================================================
74
+ # FLASH ATTENTION UTILITIES
75
+ # =============================================================================
76
+
77
+
78
+ def configure_flash_attention() -> dict[str, Any]:
79
+ """Configure flash attention settings and return model kwargs."""
80
+ model_kwargs: dict[str, Any] = {}
81
+
82
+ if not FLASH_ATTN_AVAILABLE:
83
+ logger.info("⚠️ Flash attention not available - using standard attention")
84
+ return model_kwargs
85
+
86
+ # Set environment variables for flash attention
87
+ os.environ["FLASH_ATTENTION_FORCE_USE"] = "1"
88
+ # Disable torch compile for flash attention compatibility
89
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
90
+ # Enable flash attention in transformers
91
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
92
+
93
+ # Check if we're on a compatible GPU
94
+ try:
95
+ if torch.cuda.is_available():
96
+ device_capability = torch.cuda.get_device_capability()
97
+ # Flash attention requires compute capability >= 7.5 (Turing, Ampere, Ada, Hopper)
98
+ if device_capability[0] >= 7 and (device_capability[0] > 7 or device_capability[1] >= 5):
99
+ logger.info("✅ Flash attention enabled - compatible GPU detected")
100
+ model_kwargs.update(
101
+ {
102
+ "model_kwargs": {
103
+ "attn_implementation": "flash_attention_2",
104
+ "torch_dtype": torch.float16, # Flash attention works best with fp16
105
+ "use_flash_attention_2": True,
106
+ "_attn_implementation": "flash_attention_2", # Alternative key for some models
107
+ }
108
+ }
109
+ )
110
+ else:
111
+ logger.info(f"⚠️ GPU compute capability {device_capability} < 7.5 - flash attention disabled")
112
+ else:
113
+ logger.info("⚠️ No CUDA available - flash attention disabled")
114
+ except Exception as e:
115
+ logger.warning(f"⚠️ Failed to check GPU compatibility: {e} - flash attention disabled")
116
+
117
+ return model_kwargs
118
+
119
+
120
+ def load_model_with_flash_attention(model_path: str, device: str = "auto") -> SentenceTransformer:
121
+ """Load a SentenceTransformer model with flash attention if available."""
122
+ flash_kwargs = configure_flash_attention()
123
+
124
+ try:
125
+ # Try loading with flash attention first
126
+ if flash_kwargs and "model_kwargs" in flash_kwargs:
127
+ logger.info(f"🚀 Loading model with flash attention: {Path(model_path).name}")
128
+ model = SentenceTransformer(model_path, device=device, trust_remote_code=True, **flash_kwargs)
129
+ logger.info("✅ Model loaded successfully with flash attention")
130
+ return model
131
+ except Exception as e:
132
+ logger.warning(f"⚠️ Failed to load with flash attention: {e}")
133
+ logger.info("🔄 Falling back to standard attention")
134
+
135
+ # Fallback to standard loading
136
+ logger.info(f"📂 Loading model with standard attention: {Path(model_path).name}")
137
+ model = SentenceTransformer(model_path, device=device, trust_remote_code=True)
138
+ logger.info("✅ Model loaded successfully with standard attention")
139
+ return model
140
+
141
+
142
  # =============================================================================
143
  # UTILITY FUNCTIONS
144
  # =============================================================================
 
175
  }
176
 
177
  if enable_training:
178
+ # Add a simple hash of tokenlearn parameters for config validation
179
+ tokenlearn_hash = hash(
180
+ f"{distillation_config.tokenlearn_dataset}_{distillation_config.tokenlearn_dataset_name}_{distillation_config.tokenlearn_text_key}"
 
 
 
181
  )
182
+ config_params["tokenlearn_hash"] = float(abs(tokenlearn_hash) % 1000000) # Convert to float for consistency
183
 
184
  config_str = str(sorted(config_params.items()))
185
  return hashlib.md5(config_str.encode()).hexdigest()[:12] # noqa: S324
 
412
 
413
 
414
  def load_codesearchnet_dataset(
415
+ max_samples: int = 50000,
416
  checkpoint_manager: BeamCheckpointManager | None = None,
417
  ) -> list[str]:
418
+ """Load and format the CodeSearchNet dataset for token frequency computation."""
 
 
 
419
  logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
420
  logger.info(f"Limiting to {max_samples} samples for training efficiency")
421
  logger.info(f"Languages: {', '.join(languages_config.all)}")
 
606
  # Generate embeddings from scratch
607
  logger.info("Generating fresh teacher embeddings...")
608
 
609
+ batch_size = 16 # Fixed batch size for teacher embedding generation
610
  embeddings_list = []
611
 
612
  for i in range(0, len(texts), batch_size):
 
678
  return teacher_embeddings
679
 
680
 
681
+ def compute_token_frequencies_for_sif(
682
+ teacher_model: SentenceTransformer,
683
+ features_dir: Path,
684
+ ) -> None:
685
+ """
686
+ Compute token frequencies from the training corpus for SIF weighting.
687
+
688
+ This follows the POTION approach for post-training re-regularization.
689
+ """
690
+ import json
691
+ from collections import Counter
692
+
693
+ logger.info("📊 Computing token frequencies for SIF weighting...")
694
+
695
+ try:
696
+ # Load CodeSearchNet dataset to compute frequencies (limited sample for efficiency)
697
+ dataset_texts = load_codesearchnet_dataset(max_samples=10000)
698
+
699
+ logger.info(f"📊 Computing frequencies on {len(dataset_texts)} texts...")
700
+
701
+ # Tokenize all texts and count token frequencies
702
+ tokenizer = teacher_model.tokenizer
703
+ token_counts: Counter[int] = Counter()
704
+
705
+ # Process in batches to avoid memory issues
706
+ batch_size = 100
707
+ for i in range(0, len(dataset_texts), batch_size):
708
+ batch_texts = dataset_texts[i : i + batch_size]
709
+
710
+ for text in batch_texts:
711
+ # Tokenize the text
712
+ tokens = tokenizer.encode(text, add_special_tokens=False)
713
+ token_counts.update(tokens)
714
+
715
+ if i % (batch_size * 10) == 0:
716
+ logger.info(f" Processed {i + len(batch_texts)}/{len(dataset_texts)} texts...")
717
+
718
+ # Convert to frequencies (token_id -> count)
719
+ token_frequencies = dict(token_counts)
720
+
721
+ # Save token frequencies to features directory for post-training regularization
722
+ freq_file = features_dir / "token_frequencies.json"
723
+ with freq_file.open("w") as f:
724
+ json.dump(token_frequencies, f, indent=2)
725
+
726
+ logger.info(f"✅ Token frequencies saved to {freq_file}")
727
+ logger.info(f"📊 Total unique tokens: {len(token_frequencies)}")
728
+ logger.info(f"📊 Total token occurrences: {sum(token_frequencies.values())}")
729
+
730
+ except Exception as e:
731
+ logger.warning(f"⚠️ Failed to compute token frequencies: {e}")
732
+ logger.warning("⚠️ Post-training re-regularization will use default Zipf weighting")
733
+
734
+
735
+ def apply_post_training_regularization(
736
+ model: Any,
737
+ features_dir: Path,
738
+ pca_dims: int = 256,
739
+ ) -> Any:
740
+ """
741
+ Apply post-training re-regularization following the POTION approach.
742
+
743
+ This includes:
744
+ 1. Token frequency weighting using corpus frequencies
745
+ 2. PCA application
746
+ 3. SIF weighting using formula: w = 1e-3 / (1e-3 + proba)
747
+ """
748
+ import json
749
+
750
+ import numpy as np
751
+ from sklearn.decomposition import PCA
752
+
753
+ logger.info("🔧 Starting post-training re-regularization (POTION Step 4)")
754
+
755
+ # Step 4a: Load token frequencies from the training corpus
756
+ logger.info("📊 Computing token frequencies from training corpus...")
757
+
758
+ # Try to load token frequencies from features directory
759
+ freq_file = features_dir / "token_frequencies.json"
760
+
761
+ if freq_file.exists():
762
+ with freq_file.open("r") as f:
763
+ token_frequencies = json.load(f)
764
+ logger.info(f"✅ Loaded token frequencies from {freq_file}")
765
+ else:
766
+ logger.warning("⚠️ Token frequencies not found - using default Zipf weighting")
767
+ # Fallback to basic frequency estimation based on rank
768
+ vocab_size = model.embedding.shape[0]
769
+ token_frequencies = {str(i): 1.0 / (i + 1) for i in range(vocab_size)}
770
+
771
+ # Step 4b: Apply PCA to the embeddings
772
+ logger.info(f"🔄 Applying PCA with {pca_dims} dimensions...")
773
+
774
+ # Get current embeddings
775
+ embeddings = model.embedding.cpu().numpy().astype(np.float64)
776
+ original_shape = embeddings.shape
777
+ logger.info(f"Original embedding shape: {original_shape}")
778
+
779
+ # Apply PCA if dimensions don't match
780
+ if original_shape[1] != pca_dims:
781
+ pca = PCA(n_components=pca_dims, random_state=42)
782
+ embeddings_pca = pca.fit_transform(embeddings)
783
+ logger.info(f"PCA applied: {original_shape} → {embeddings_pca.shape}")
784
+
785
+ # Explained variance ratio
786
+ explained_var = pca.explained_variance_ratio_.sum()
787
+ logger.info(f"PCA explained variance ratio: {explained_var:.4f}")
788
+ else:
789
+ embeddings_pca = embeddings
790
+ logger.info("PCA dimensions match - no PCA transformation needed")
791
+
792
+ # Step 4c: Apply SIF weighting using corpus frequencies
793
+ logger.info("⚖️ Applying SIF weighting based on token frequencies...")
794
+
795
+ # Convert token frequencies to probabilities
796
+ total_tokens = sum(token_frequencies.values())
797
+ token_probs = {token: freq / total_tokens for token, freq in token_frequencies.items()}
798
+
799
+ # Apply SIF weighting: w = 1e-3 / (1e-3 + proba)
800
+ sif_coefficient = 1e-3 # Standard SIF coefficient
801
+
802
+ for i in range(embeddings_pca.shape[0]):
803
+ token_id = str(i)
804
+ prob = token_probs[token_id] if token_id in token_probs else 1.0 / len(token_probs)
805
+
806
+ # Apply SIF weighting formula
807
+ sif_weight = sif_coefficient / (sif_coefficient + prob)
808
+ embeddings_pca[i] *= sif_weight
809
+
810
+ logger.info("✅ SIF weighting applied successfully")
811
+
812
+ # Step 4d: Create new model with re-regularized embeddings
813
+ logger.info("📦 Creating final model with re-regularized embeddings...")
814
+
815
+ # Convert back to float32 numpy array
816
+ final_embeddings = embeddings_pca.astype(np.float32)
817
+
818
+ # Create new model with updated embeddings
819
+ from model2vec.model import StaticModel
820
+
821
+ # Save tokenizer and config from original model
822
+ tokenizer = model.tokenizer
823
+ config = model.config
824
+
825
+ # Create new model with re-regularized embeddings
826
+ final_model = StaticModel(vectors=final_embeddings, tokenizer=tokenizer, config=config)
827
+
828
+ logger.info("✅ Post-training re-regularization completed successfully")
829
+ logger.info(f"Final model embedding shape: {final_model.embedding.shape}")
830
+
831
+ return final_model
832
+
833
+
834
+ def tokenlearn_training(
835
  student_model: Any,
836
  teacher_model: SentenceTransformer,
837
+ checkpoint_manager: BeamCheckpointManager | None = None, # noqa: ARG001
838
  ) -> Any:
839
+ """
840
+ Perform tokenlearn training following the official POTION approach.
841
+
842
+ This follows the 4-step process:
843
+ 1. Model2Vec distillation (already done - student_model)
844
+ 2. Sentence transformer inference (create features)
845
+ 3. Tokenlearn training
846
+ 4. Post-training re-regularization (PCA + SIF weighting)
847
+ """
848
+ import subprocess
849
+ import tempfile
850
+ from pathlib import Path
851
+
852
+ logger.info("🧪 Starting tokenlearn training (POTION approach)...")
853
+
854
+ # Create temporary directories for tokenlearn workflow
855
+ with tempfile.TemporaryDirectory() as temp_dir:
856
+ temp_path = Path(temp_dir)
857
+ features_dir = temp_path / "features"
858
+ model_dir = temp_path / "base_model"
859
+ trained_dir = temp_path / "trained_model"
860
+
861
+ features_dir.mkdir(exist_ok=True)
862
+ model_dir.mkdir(exist_ok=True)
863
+ trained_dir.mkdir(exist_ok=True)
864
+
865
+ # Save the base distilled model for tokenlearn
866
+ student_model.save_pretrained(str(model_dir))
867
+ logger.info(f"💾 Saved base model to {model_dir}")
868
+
869
+ # Step 2: Create features using tokenlearn-featurize
870
+ logger.info("🔍 Step 2: Creating features using sentence transformer...")
871
+
872
+ # Get teacher model name/path for tokenlearn
873
+ teacher_model_name = getattr(teacher_model, "model_name", None)
874
+ if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: # noqa: SLF001
875
+ # Try to extract from the first module if it's a SentenceTransformer
876
+ # _modules is a dict-like container, get the first module by iterating
877
+ first_module = next(iter(teacher_model._modules.values())) # noqa: SLF001
878
+ if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
879
+ teacher_model_name = first_module.auto_model.name_or_path
880
+
881
+ if not teacher_model_name:
882
+ logger.warning("⚠️ Could not determine teacher model name, using fallback")
883
+ teacher_model_name = "BAAI/bge-base-en-v1.5" # Fallback to a common model
884
+
885
+ logger.info(f"📊 Using teacher model: {teacher_model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886
 
887
  try:
888
+ # Use configured dataset for code specialization
889
+ featurize_cmd = [
890
+ "python",
891
+ "-m",
892
+ "tokenlearn.featurize",
893
+ "--model-name",
894
+ str(teacher_model_name),
895
+ "--output-dir",
896
+ str(features_dir),
897
+ "--dataset-path",
898
+ str(distillation_config.tokenlearn_dataset),
899
+ "--dataset-name",
900
+ str(distillation_config.tokenlearn_dataset_name),
901
+ "--dataset-split",
902
+ "train",
903
+ "--key",
904
+ str(distillation_config.tokenlearn_text_key), # Use configured text field
905
+ "--batch-size",
906
+ "1024", # Optimized batch size for A100-40G
907
+ ]
908
 
909
+ logger.info("🔄 Running tokenlearn featurization...")
910
+ logger.info(
911
+ f"📊 Dataset: {distillation_config.tokenlearn_dataset} (config: {distillation_config.tokenlearn_dataset_name})"
912
+ )
913
+ logger.info(f"📝 Text field: {distillation_config.tokenlearn_text_key}")
914
+ logger.info(f"Command: {' '.join(featurize_cmd)}")
915
+ print(f"\n🔄 Executing: {' '.join(featurize_cmd)}\n")
916
+
917
+ result = subprocess.run( # noqa: S603
918
+ featurize_cmd,
919
+ text=True,
920
+ timeout=distillation_config.tokenlearn_timeout_featurize,
921
+ check=False,
922
+ )
923
 
924
+ if result.returncode != 0:
925
+ logger.error(f"❌ Featurization failed with return code: {result.returncode}")
926
+ logger.error("💥 Tokenlearn featurization is required for training - cannot proceed")
927
+ msg = f"Tokenlearn featurization failed with return code: {result.returncode}"
928
+ raise RuntimeError(msg)
929
 
930
+ logger.info("✅ Featurization completed successfully")
 
 
931
 
932
+ # Generate token frequencies for post-training re-regularization
933
+ logger.info("📊 Computing token frequencies for SIF weighting...")
934
+ compute_token_frequencies_for_sif(teacher_model, features_dir)
935
 
936
+ except Exception as e:
937
+ logger.exception("💥 Tokenlearn featurization failed")
938
+ logger.exception("💥 Tokenlearn featurization is required for training - cannot proceed")
939
+ msg = f"Tokenlearn featurization failed: {e}"
940
+ raise RuntimeError(msg) from e
941
 
942
+ # Step 3: Train using tokenlearn-train
943
+ logger.info("🎓 Step 3: Training using tokenlearn...")
944
 
945
+ try:
946
+ train_cmd = [
947
+ "python",
948
+ "-m",
949
+ "tokenlearn.train",
950
+ "--model-name",
951
+ str(teacher_model_name),
952
+ "--data-path",
953
+ str(features_dir),
954
+ "--save-path",
955
+ str(trained_dir),
956
+ ]
957
 
958
+ logger.info("🔄 Running tokenlearn training...")
959
+ logger.info(f"Command: {' '.join(train_cmd)}")
960
+ print(f"\n🎓 Executing: {' '.join(train_cmd)}\n")
 
961
 
962
+ result = subprocess.run( # noqa: S603
963
+ train_cmd,
964
+ text=True,
965
+ timeout=distillation_config.tokenlearn_timeout_train,
966
+ check=False,
967
+ )
968
 
969
+ if result.returncode != 0:
970
+ logger.error(f"❌ Tokenlearn training failed with return code: {result.returncode}")
971
+ logger.error("💥 Tokenlearn training is required - cannot proceed")
972
+ msg = f"Tokenlearn training failed with return code: {result.returncode}"
973
+ raise RuntimeError(msg)
974
 
975
+ logger.info("✅ Tokenlearn training completed successfully")
976
 
977
+ except Exception as e:
978
+ logger.exception("💥 Tokenlearn training failed")
979
+ logger.exception("💥 Tokenlearn training is required - cannot proceed")
980
+ msg = f"Tokenlearn training failed: {e}"
981
+ raise RuntimeError(msg) from e
982
 
983
+ # Step 4: Load the trained model and apply post-training re-regularization
984
+ logger.info("📦 Step 4: Loading trained model and applying post-training re-regularization...")
 
 
 
 
 
 
 
 
 
 
 
 
 
985
 
986
+ try:
987
+ from model2vec.model import StaticModel
988
+
989
+ # Load the trained model from tokenlearn
990
+ trained_model_path = trained_dir / "model"
991
+ if not trained_model_path.exists():
992
+ # Try alternative paths
993
+ possible_paths = [
994
+ trained_dir / "model_weighted",
995
+ trained_dir,
996
+ ]
997
+
998
+ for path in possible_paths:
999
+ if path.exists() and any(path.glob("*.json")):
1000
+ trained_model_path = path
1001
+ break
1002
+ else:
1003
+ logger.error(f"❌ Could not find trained model in {trained_dir}")
1004
+ msg = f"Tokenlearn training failed - no model found in {trained_dir}"
1005
+ raise RuntimeError(msg)
1006
+
1007
+ # Load the model before re-regularization
1008
+ logger.info("🔄 Loading model from tokenlearn training...")
1009
+ trained_model = StaticModel.from_pretrained(str(trained_model_path))
1010
+
1011
+ # Apply post-training re-regularization (POTION Step 4)
1012
+ logger.info("🔧 Applying post-training re-regularization (PCA + SIF weighting)...")
1013
+ final_model = apply_post_training_regularization(
1014
+ trained_model, features_dir, pca_dims=distillation_config.optimal_pca_dims
1015
+ )
1016
 
1017
+ logger.info("✅ Tokenlearn training pipeline with post-training re-regularization completed successfully")
1018
+
1019
+ return final_model
1020
+
1021
+ except Exception as e:
1022
+ logger.exception("💥 Failed to load tokenlearn trained model")
1023
+ logger.exception("💥 Cannot load trained model - training failed")
1024
+ msg = f"Failed to load tokenlearn trained model: {e}"
1025
+ raise RuntimeError(msg) from e
1026
 
1027
 
1028
  def distill_single_teacher(
 
1153
 
1154
  # Step 3: Handle final model creation
1155
  if enable_training and base_model is not None:
1156
+ # Perform tokenlearn training (POTION approach)
1157
+ logger.info(f"🧪 Starting tokenlearn training for {teacher_name}")
1158
 
1159
  # Load teacher model for training
1160
  device = "cuda" if torch.cuda.is_available() else "cpu"
1161
+ teacher_st_model = load_model_with_flash_attention(teacher_model, device)
1162
 
1163
+ # Perform tokenlearn training (POTION approach)
1164
+ final_model = tokenlearn_training(base_model, teacher_st_model, checkpoint_mgr)
1165
 
1166
  # Save final model
1167
  final_dir.mkdir(parents=True, exist_ok=True)
 
1300
  return results_summary
1301
 
1302
 
1303
+ @function(**get_distillation_function_kwargs())
 
 
 
 
 
 
 
1304
  def _beam_distill_models(
1305
  teacher_models: list[str] | None = None,
1306
  enable_training: bool = False,
 
1456
  ) -> None:
1457
  """Unified distillation command with optional training."""
1458
  logger.info("🚀 Starting unified Model2Vec distillation workflow")
1459
+
1460
+ logger.info(f"🎓 Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
1461
  logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
1462
 
1463
  # Use default models if none specified
 
1618
 
1619
  try:
1620
  import torch
 
1621
  from transformers import AutoModel, AutoTokenizer
1622
 
1623
  # Enhanced custom model loading for Salesforce models
 
1657
 
1658
  # Method 2: Try SentenceTransformer with specific settings
1659
  logger.info("🔄 Falling back to SentenceTransformer method...")
1660
+ sentence_model = load_model_with_flash_attention(
1661
  teacher_model,
 
1662
  device="cpu", # Force CPU loading first
1663
  )
1664
 
 
1731
 
1732
  try:
1733
  import torch
 
1734
  from transformers import AutoModel, AutoTokenizer
1735
 
1736
  logger.info("🔧 Loading BAAI model with tokenizer workaround...")
 
1741
  # Method 1: Try SentenceTransformer first (often handles tokenizer issues better)
1742
  try:
1743
  logger.info("🔄 Attempting with SentenceTransformer wrapper...")
1744
+ sentence_model = load_model_with_flash_attention(teacher_model)
1745
 
1746
  # Extract components
1747
  model = sentence_model[0].auto_model
src/distiller/evaluate.py CHANGED
@@ -17,8 +17,12 @@ Usage:
17
  distiller evaluate [--use-beam] [--skip-benchmark] # Run evaluation locally or on Beam
18
  """
19
 
 
 
 
20
  import json
21
  import logging
 
22
  import time
23
  import traceback
24
  from pathlib import Path
@@ -29,7 +33,7 @@ import pandas as pd
29
  import psutil
30
  import torch
31
  import typer
32
- from beam import Volume, function
33
  from datasets import Dataset, load_dataset
34
  from sentence_transformers import SentenceTransformer
35
  from sklearn.metrics.pairwise import cosine_similarity
@@ -37,24 +41,25 @@ from tqdm import tqdm
37
 
38
  from .beam_utils import download_specific_evaluation_file
39
  from .config import (
40
- BEAM_ENV_SETTINGS,
41
  DEFAULT_EVALUATION_MODELS,
42
- GPU_NAME,
43
- IMAGE,
44
  codesearchnet_config,
45
  directories,
 
46
  get_safe_model_name,
47
  get_volume_config,
48
  languages_config,
49
  )
50
 
 
 
 
51
  logger = logging.getLogger(__name__)
52
 
53
  # =============================================================================
54
  # EVALUATION CONFIGURATION
55
  # =============================================================================
56
 
57
- BATCH_SIZE = 32
58
  LOCAL_EVALUATION_DIR = directories.evaluation_results
59
  LOCAL_BENCHMARK_DIR = directories.benchmark_results
60
  LOCAL_MODELS_DIR = directories.final
@@ -92,6 +97,200 @@ def complex_algorithm(data, config):
92
  }
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  class PerformanceBenchmark:
96
  """Comprehensive performance benchmarking for embedding models."""
97
 
@@ -109,7 +308,7 @@ class PerformanceBenchmark:
109
  start_time = time.time()
110
 
111
  try:
112
- self.model = SentenceTransformer(self.model_path, device=self.device, trust_remote_code=True)
113
  load_time = time.time() - start_time
114
 
115
  logger.info(f"✅ Model loaded in {load_time:.2f}s on {self.device}")
@@ -321,7 +520,7 @@ class PerformanceBenchmark:
321
  logger.info(f" 📊 Testing on {device.upper()}")
322
 
323
  try:
324
- model = SentenceTransformer(self.model_path, device=device, trust_remote_code=True)
325
 
326
  # Warmup
327
  _ = model.encode(test_texts[:4], convert_to_tensor=False)
@@ -401,36 +600,71 @@ class CodeSearchNetEvaluator:
401
  """Load the embedding model."""
402
  logger.info(f"Loading model from {self.model_path}")
403
  try:
404
- self.model = SentenceTransformer(self.model_path, trust_remote_code=True)
405
  logger.info(f"Successfully loaded model: {self.model_name}")
406
  except Exception:
407
  logger.exception(f"Failed to load model from {self.model_path}")
408
  raise
409
 
410
  def encode_texts(self, texts: list[str], desc: str = "Encoding") -> np.ndarray:
411
- """Encode texts into embeddings with batching."""
412
  if self.model is None:
413
  msg = "Model not loaded"
414
  raise RuntimeError(msg)
415
 
416
  embeddings = []
417
- for i in tqdm(range(0, len(texts), BATCH_SIZE), desc=desc):
418
- batch = texts[i : i + BATCH_SIZE]
419
- batch_embeddings = self.model.encode(batch, convert_to_tensor=False, normalize_embeddings=True)
420
- embeddings.append(batch_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
  return np.vstack(embeddings)
423
 
424
- def evaluate_language(self, language: str, max_queries: int = 1000) -> dict[str, Any]:
425
  """Evaluate on a specific programming language."""
426
  logger.info(f"Evaluating on {language} language (max {max_queries} queries)")
427
 
428
  try:
429
- # Load test split for the language
 
430
  dataset = load_dataset(
431
  codesearchnet_config.dataset_name,
432
  language,
433
- split="test",
434
  trust_remote_code=True,
435
  )
436
 
@@ -438,17 +672,17 @@ class CodeSearchNetEvaluator:
438
  logger.error(f"Unexpected dataset type for {language}: {type(dataset)}")
439
  return {}
440
 
441
- # Sample queries for evaluation
442
- if len(dataset) > max_queries:
443
- rng = np.random.default_rng(42)
444
- indices = rng.choice(len(dataset), max_queries, replace=False)
445
- dataset = dataset.select(indices)
446
 
447
- queries = []
448
- codes = []
449
- query_ids = []
450
 
 
451
  for i, example in enumerate(dataset):
 
 
 
452
  doc_string = example.get("func_documentation_string", "").strip()
453
  code_string = example.get("func_code_string", "").strip()
454
 
@@ -461,8 +695,23 @@ class CodeSearchNetEvaluator:
461
  logger.warning(f"No valid query-code pairs found for {language}")
462
  return {}
463
 
 
 
 
 
 
 
464
  logger.info(f"Found {len(queries)} valid query-code pairs for {language}")
465
 
 
 
 
 
 
 
 
 
 
466
  # Encode queries and codes
467
  start_time = time.time()
468
  query_embeddings = self.encode_texts(queries, f"Encoding {language} queries")
@@ -548,7 +797,7 @@ class CodeSearchNetEvaluator:
548
  return 0.0
549
 
550
  def evaluate_all_languages(
551
- self, max_queries_per_lang: int = 1000, languages: list[str] | None = None
552
  ) -> dict[str, Any]:
553
  """Evaluate on all specified languages."""
554
  eval_languages = languages or languages_config.all
@@ -613,7 +862,7 @@ class ComprehensiveModelEvaluator:
613
 
614
  def run_comprehensive_evaluation(
615
  self,
616
- max_queries_per_lang: int = 1000,
617
  languages: list[str] | None = None,
618
  skip_benchmark: bool = False,
619
  ) -> dict[str, Any]:
@@ -901,7 +1150,7 @@ def create_comparison_report(all_results: list[dict[str, Any]], output_dir: str
901
 
902
  def run_evaluation(
903
  models: list[str],
904
- max_queries: int = 1000,
905
  languages: list[str] | None = None,
906
  use_beam: bool = False,
907
  skip_benchmark: bool = False,
@@ -962,7 +1211,7 @@ def run_evaluation(
962
 
963
  def _run_local_evaluation(
964
  models: list[str],
965
- max_queries: int = 1000,
966
  languages: list[str] | None = None,
967
  skip_benchmark: bool = False,
968
  ) -> list[dict[str, Any]]:
@@ -993,30 +1242,52 @@ def _run_local_evaluation(
993
  return results
994
 
995
 
996
- @function(
997
- gpu=GPU_NAME,
998
- volumes=[Volume(name=VOLUME_CONFIG.name, mount_path=VOLUME_CONFIG.mount_path)],
999
- image=IMAGE,
1000
- secrets=["HF_ACCESS_TOKEN"],
1001
- env=BEAM_ENV_SETTINGS,
1002
- timeout=3600 * 8, # 8 hours for comprehensive evaluation
1003
- )
1004
  def _beam_evaluate_single_model(
1005
  model_path: str,
1006
- max_queries: int = 1000,
1007
  languages: list[str] | None = None,
1008
  skip_benchmark: bool = False,
1009
  ) -> dict[str, Any]:
1010
  """Beam function to comprehensively evaluate a single model."""
1011
- model_name = Path(model_path).name
 
 
 
 
1012
 
 
1013
  logger.info(f"🚀 Beam comprehensive evaluation starting for {model_name}")
1014
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1015
  try:
 
1016
  evaluator = ComprehensiveModelEvaluator(model_path, model_name)
 
 
 
1017
  results = evaluator.run_comprehensive_evaluation(max_queries, languages, skip_benchmark)
 
 
 
 
 
 
1018
 
1019
  # Save to Beam volume as single comprehensive file
 
1020
  volume_results_dir = Path(VOLUME_CONFIG.mount_path) / "evaluation_results"
1021
  volume_results_dir.mkdir(parents=True, exist_ok=True)
1022
 
@@ -1026,17 +1297,65 @@ def _beam_evaluate_single_model(
1026
  with result_file.open("w") as f:
1027
  json.dump(results, f, indent=2, default=str)
1028
 
1029
- logger.info(f"💾 Saved Beam comprehensive evaluation results for {model_name}")
 
 
1030
  return results
1031
 
1032
- except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1033
  logger.exception(f"❌ Beam comprehensive evaluation failed for {model_name}")
1034
- return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1035
 
1036
 
1037
  def _run_beam_evaluation(
1038
  models: list[str],
1039
- max_queries: int = 1000,
1040
  languages: list[str] | None = None,
1041
  skip_benchmark: bool = False,
1042
  ) -> list[dict[str, Any]]:
@@ -1054,6 +1373,13 @@ def _run_beam_evaluation(
1054
  result = _beam_evaluate_single_model.remote(model_path, max_queries, languages, skip_benchmark)
1055
 
1056
  if result:
 
 
 
 
 
 
 
1057
  # Download the comprehensive result file from Beam
1058
  success = download_specific_evaluation_file(
1059
  VOLUME_CONFIG.name,
@@ -1069,6 +1395,8 @@ def _run_beam_evaluation(
1069
  results.append(result)
1070
  else:
1071
  logger.warning(f"⚠️ Could not download results for {model_name}")
 
 
1072
 
1073
  except Exception:
1074
  logger.exception(f"❌ Beam comprehensive evaluation failed for {model_name}")
@@ -1086,7 +1414,7 @@ def main(
1086
  use_beam: bool = typer.Option(default=False, help="Use Beam for evaluation"),
1087
  skip_third_party: bool = typer.Option(default=False, help="Skip third-party models"),
1088
  skip_benchmark: bool = typer.Option(default=False, help="Skip performance benchmarking"),
1089
- max_queries: int = typer.Option(default=1000, help="Maximum queries per language"),
1090
  ) -> None:
1091
  """Main comprehensive evaluation function."""
1092
  logger.info("🚀 Starting comprehensive model evaluation (CodeSearchNet + Performance)")
 
17
  distiller evaluate [--use-beam] [--skip-benchmark] # Run evaluation locally or on Beam
18
  """
19
 
20
+ # Try to import flash_attn to check if it's available
21
+ import contextlib
22
+ import importlib.util
23
  import json
24
  import logging
25
+ import os
26
  import time
27
  import traceback
28
  from pathlib import Path
 
33
  import psutil
34
  import torch
35
  import typer
36
+ from beam import function
37
  from datasets import Dataset, load_dataset
38
  from sentence_transformers import SentenceTransformer
39
  from sklearn.metrics.pairwise import cosine_similarity
 
41
 
42
  from .beam_utils import download_specific_evaluation_file
43
  from .config import (
 
44
  DEFAULT_EVALUATION_MODELS,
 
 
45
  codesearchnet_config,
46
  directories,
47
+ get_evaluation_function_kwargs,
48
  get_safe_model_name,
49
  get_volume_config,
50
  languages_config,
51
  )
52
 
53
+ # Check if flash_attn is available and compatible
54
+ FLASH_ATTN_AVAILABLE = importlib.util.find_spec("flash_attn") is not None
55
+
56
  logger = logging.getLogger(__name__)
57
 
58
  # =============================================================================
59
  # EVALUATION CONFIGURATION
60
  # =============================================================================
61
 
62
+ BATCH_SIZE = 10
63
  LOCAL_EVALUATION_DIR = directories.evaluation_results
64
  LOCAL_BENCHMARK_DIR = directories.benchmark_results
65
  LOCAL_MODELS_DIR = directories.final
 
97
  }
98
 
99
 
100
+ def reset_cuda_state() -> None:
101
+ """Aggressively reset CUDA state after memory allocation errors."""
102
+ if not torch.cuda.is_available():
103
+ return
104
+
105
+ try:
106
+ # Clear all CUDA caches
107
+ torch.cuda.empty_cache()
108
+ torch.cuda.ipc_collect()
109
+ torch.cuda.reset_peak_memory_stats()
110
+
111
+ # Try to force garbage collection
112
+ import gc
113
+
114
+ gc.collect()
115
+
116
+ logger.info("🧹 CUDA state reset completed")
117
+ except Exception as e:
118
+ logger.warning(f"⚠️ Could not fully reset CUDA state: {e}")
119
+
120
+
121
+ def configure_flash_attention() -> dict[str, Any]:
122
+ """Configure flash attention settings and return model kwargs."""
123
+ model_kwargs: dict[str, Any] = {}
124
+
125
+ if not FLASH_ATTN_AVAILABLE:
126
+ logger.info("⚠️ Flash attention not available - using standard attention")
127
+ return model_kwargs
128
+
129
+ # Set environment variables for flash attention and CUDA memory management
130
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
131
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
132
+
133
+ # Check if we're on a compatible GPU
134
+ try:
135
+ if torch.cuda.is_available():
136
+ device_capability = torch.cuda.get_device_capability()
137
+ # Flash attention requires compute capability >= 7.5 (Turing, Ampere, Ada, Hopper)
138
+ if device_capability[0] >= 7 and (device_capability[0] > 7 or device_capability[1] >= 5):
139
+ logger.info("✅ Flash attention available - compatible GPU detected")
140
+ # For SentenceTransformer, we'll use environment variables to enable flash attention
141
+ os.environ["TRANSFORMERS_FLASH_ATTENTION"] = "1"
142
+ else:
143
+ logger.info(f"⚠️ GPU compute capability {device_capability} < 7.5 - flash attention disabled")
144
+ else:
145
+ logger.info("⚠️ No CUDA available - flash attention disabled")
146
+ except Exception as e:
147
+ logger.warning(f"⚠️ Failed to check GPU compatibility: {e} - flash attention disabled")
148
+
149
+ return model_kwargs
150
+
151
+
152
+ def load_model_with_flash_attention(model_path: str, device: str = "auto") -> SentenceTransformer:
153
+ """Load a SentenceTransformer model with flash attention if available."""
154
+ # Convert "auto" device to actual device
155
+ target_device = "cuda" if device == "auto" and torch.cuda.is_available() else device
156
+ if device == "auto" and not torch.cuda.is_available():
157
+ target_device = "cpu"
158
+
159
+ # Configure flash attention via environment variables
160
+ configure_flash_attention()
161
+
162
+ # Load model with standard SentenceTransformer initialization
163
+ logger.info(f"📂 Loading model: {Path(model_path).name}")
164
+
165
+ try:
166
+ # Try loading directly to target device first
167
+ model = SentenceTransformer(model_path, device=target_device, trust_remote_code=True)
168
+ logger.info(f"✅ Model loaded successfully on {target_device}")
169
+ return model
170
+ except (torch.OutOfMemoryError, RuntimeError) as oom_error:
171
+ # Handle both torch.OutOfMemoryError and RuntimeError (CUDA driver errors)
172
+ is_oom = isinstance(oom_error, torch.OutOfMemoryError) or "out of memory" in str(oom_error).lower()
173
+
174
+ if is_oom and target_device != "cpu":
175
+ logger.warning(f"⚠️ OOM loading directly to {target_device}, trying CPU first: {oom_error}")
176
+ try:
177
+ # Clear CUDA cache more aggressively after OOM
178
+ reset_cuda_state()
179
+
180
+ logger.info("🔄 Loading model on CPU first, then trying to move to GPU...")
181
+ model = SentenceTransformer(model_path, device="cpu", trust_remote_code=True)
182
+ logger.info("📦 Model loaded on CPU, attempting GPU transfer...")
183
+
184
+ # Try moving to GPU with additional error handling
185
+ try:
186
+ model = model.to(target_device)
187
+ logger.info(f"✅ Model successfully moved to {target_device}")
188
+ return model
189
+ except (RuntimeError, AssertionError) as gpu_move_error:
190
+ # Handle PyTorch internal errors and CUDA allocator issues
191
+ logger.warning(f"⚠️ GPU transfer failed: {gpu_move_error}")
192
+ if "INTERNAL ASSERT FAILED" in str(gpu_move_error) or "handles_" in str(gpu_move_error):
193
+ logger.warning("🔧 Detected CUDA allocator corruption, resetting and staying on CPU")
194
+ # Try to reset CUDA context
195
+ reset_cuda_state()
196
+ else:
197
+ # Re-raise unexpected GPU transfer errors
198
+ raise
199
+
200
+ logger.info("✅ Model will remain on CPU due to GPU memory issues")
201
+ return model
202
+
203
+ except Exception as fallback_error:
204
+ logger.warning(f"⚠️ CPU fallback failed: {fallback_error}, loading fresh on CPU")
205
+ # Clear any remaining CUDA state
206
+ reset_cuda_state()
207
+
208
+ model = SentenceTransformer(model_path, device="cpu", trust_remote_code=True)
209
+ logger.info("✅ Model loaded on CPU (all GPU attempts failed)")
210
+ return model
211
+ else:
212
+ # Re-raise if not OOM or already on CPU
213
+ raise
214
+ except ValueError as e:
215
+ if "'MaxSim' is not a valid SimilarityFunction" in str(e):
216
+ logger.warning(f"⚠️ Model {Path(model_path).name} uses unsupported MaxSim similarity function")
217
+ logger.info("🔧 Attempting workaround by loading with custom config...")
218
+
219
+ # Try loading with similarity function override
220
+ try:
221
+ # Load model components manually and override similarity function
222
+ import json
223
+ import tempfile
224
+ from pathlib import Path as PathLib
225
+
226
+ # Create temporary directory for modified config
227
+ with tempfile.TemporaryDirectory() as temp_dir:
228
+ temp_path = PathLib(temp_dir) / "temp_model"
229
+
230
+ # Download/copy model files
231
+ if model_path.startswith("http") or ("/" in model_path and not PathLib(model_path).exists()):
232
+ # It's a HuggingFace model ID
233
+ from huggingface_hub import snapshot_download
234
+
235
+ snapshot_download(model_path, local_dir=temp_path, ignore_patterns=["*.bin"])
236
+ else:
237
+ # It's a local path
238
+ import shutil
239
+
240
+ shutil.copytree(model_path, temp_path)
241
+
242
+ # Modify config to use supported similarity function
243
+ config_path = temp_path / "config_sentence_transformers.json"
244
+ if config_path.exists():
245
+ with config_path.open() as f:
246
+ config = json.load(f)
247
+
248
+ # Override similarity function to 'cosine' (supported)
249
+ if "similarity_fn_name" in config:
250
+ logger.info(
251
+ f"🔧 Changing similarity function from '{config['similarity_fn_name']}' to 'cosine'"
252
+ )
253
+ config["similarity_fn_name"] = "cosine"
254
+
255
+ with config_path.open("w") as f:
256
+ json.dump(config, f, indent=2)
257
+
258
+ # Load model with modified config
259
+ model = SentenceTransformer(str(temp_path), device=device, trust_remote_code=True)
260
+ logger.info("✅ Model loaded successfully with similarity function workaround")
261
+ return model
262
+
263
+ except Exception as workaround_error:
264
+ logger.warning(f"⚠️ Similarity function workaround failed: {workaround_error}")
265
+ logger.info("🔧 Attempting direct model component loading...")
266
+
267
+ # Last resort: try loading model components directly
268
+ try:
269
+ from sentence_transformers.models import Pooling, Transformer
270
+
271
+ # Load model components manually
272
+ logger.info("🔄 Loading model components directly...")
273
+
274
+ # Create SentenceTransformer components using model path
275
+ transformer = Transformer(model_path)
276
+ pooling = Pooling(transformer.get_word_embedding_dimension())
277
+
278
+ # Create SentenceTransformer with manual components
279
+ model = SentenceTransformer(modules=[transformer, pooling], device=device)
280
+ logger.info("✅ Model loaded successfully with direct component loading")
281
+ return model
282
+
283
+ except Exception as direct_error:
284
+ logger.warning(f"⚠️ Direct component loading failed: {direct_error}")
285
+ logger.exception(f"❌ All loading methods failed for {Path(model_path).name}")
286
+ raise e from direct_error
287
+ else:
288
+ raise
289
+ except Exception:
290
+ logger.exception(f"❌ Failed to load model {Path(model_path).name}")
291
+ raise
292
+
293
+
294
  class PerformanceBenchmark:
295
  """Comprehensive performance benchmarking for embedding models."""
296
 
 
308
  start_time = time.time()
309
 
310
  try:
311
+ self.model = load_model_with_flash_attention(self.model_path, self.device)
312
  load_time = time.time() - start_time
313
 
314
  logger.info(f"✅ Model loaded in {load_time:.2f}s on {self.device}")
 
520
  logger.info(f" 📊 Testing on {device.upper()}")
521
 
522
  try:
523
+ model = load_model_with_flash_attention(self.model_path, device)
524
 
525
  # Warmup
526
  _ = model.encode(test_texts[:4], convert_to_tensor=False)
 
600
  """Load the embedding model."""
601
  logger.info(f"Loading model from {self.model_path}")
602
  try:
603
+ self.model = load_model_with_flash_attention(self.model_path)
604
  logger.info(f"Successfully loaded model: {self.model_name}")
605
  except Exception:
606
  logger.exception(f"Failed to load model from {self.model_path}")
607
  raise
608
 
609
  def encode_texts(self, texts: list[str], desc: str = "Encoding") -> np.ndarray:
610
+ """Encode texts into embeddings with batching and memory management."""
611
  if self.model is None:
612
  msg = "Model not loaded"
613
  raise RuntimeError(msg)
614
 
615
  embeddings = []
616
+ # Use smaller batch size to avoid OOM
617
+ effective_batch_size = min(BATCH_SIZE, 5) # Limit to 5 for large models
618
+
619
+ for i in tqdm(range(0, len(texts), effective_batch_size), desc=desc):
620
+ batch = texts[i : i + effective_batch_size]
621
+
622
+ try:
623
+ batch_embeddings = self.model.encode(batch, convert_to_tensor=False, normalize_embeddings=True)
624
+ embeddings.append(batch_embeddings)
625
+
626
+ # Clear CUDA cache periodically to prevent memory buildup
627
+ if torch.cuda.is_available() and i > 0 and i % (effective_batch_size * 4) == 0:
628
+ torch.cuda.empty_cache()
629
+
630
+ except (torch.OutOfMemoryError, RuntimeError) as e:
631
+ # Handle both torch.OutOfMemoryError and RuntimeError (CUDA driver errors)
632
+ is_oom = isinstance(e, torch.OutOfMemoryError) or "out of memory" in str(e).lower()
633
+
634
+ if is_oom:
635
+ logger.warning(
636
+ f"⚠️ OOM during encoding batch {i // effective_batch_size + 1}, trying smaller batch..."
637
+ )
638
+ # Try encoding one at a time
639
+ for single_text in batch:
640
+ try:
641
+ single_embedding = self.model.encode(
642
+ [single_text], convert_to_tensor=False, normalize_embeddings=True
643
+ )
644
+ embeddings.append(single_embedding)
645
+ if torch.cuda.is_available():
646
+ torch.cuda.empty_cache()
647
+ except (torch.OutOfMemoryError, RuntimeError) as single_e:
648
+ if isinstance(single_e, torch.OutOfMemoryError) or "out of memory" in str(single_e).lower():
649
+ logger.exception("❌ Cannot encode even single text, model too large for GPU")
650
+ raise
651
+ raise
652
+ else:
653
+ raise
654
 
655
  return np.vstack(embeddings)
656
 
657
+ def evaluate_language(self, language: str, max_queries: int = 100) -> dict[str, Any]:
658
  """Evaluate on a specific programming language."""
659
  logger.info(f"Evaluating on {language} language (max {max_queries} queries)")
660
 
661
  try:
662
+ # Load ONLY test split for the language with streaming to avoid loading full dataset
663
+ logger.info(f"📥 Loading test split for {language}...")
664
  dataset = load_dataset(
665
  codesearchnet_config.dataset_name,
666
  language,
667
+ split=f"test[:{max_queries * 10}]", # Load 10x more than needed to ensure we get enough valid pairs
668
  trust_remote_code=True,
669
  )
670
 
 
672
  logger.error(f"Unexpected dataset type for {language}: {type(dataset)}")
673
  return {}
674
 
675
+ logger.info(f"📊 Loaded {len(dataset)} examples from {language} test split")
 
 
 
 
676
 
677
+ queries: list[str] = []
678
+ codes: list[str] = []
679
+ query_ids: list[str] = []
680
 
681
+ # Process examples and stop once we have enough valid pairs
682
  for i, example in enumerate(dataset):
683
+ if len(queries) >= max_queries: # Stop once we have enough
684
+ break
685
+
686
  doc_string = example.get("func_documentation_string", "").strip()
687
  code_string = example.get("func_code_string", "").strip()
688
 
 
695
  logger.warning(f"No valid query-code pairs found for {language}")
696
  return {}
697
 
698
+ # Truncate to exactly max_queries if we have more
699
+ if len(queries) > max_queries:
700
+ queries = queries[:max_queries]
701
+ codes = codes[:max_queries]
702
+ query_ids = query_ids[:max_queries]
703
+
704
  logger.info(f"Found {len(queries)} valid query-code pairs for {language}")
705
 
706
+ # Check available memory before encoding
707
+ if torch.cuda.is_available():
708
+ free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()
709
+ free_gb = free_memory / (1024**3)
710
+ logger.info(f"💾 Available GPU memory before encoding: {free_gb:.1f} GB")
711
+ if free_gb < 2.0: # Less than 2GB free
712
+ logger.warning(f"⚠️ Low GPU memory ({free_gb:.1f} GB), using conservative encoding")
713
+ torch.cuda.empty_cache()
714
+
715
  # Encode queries and codes
716
  start_time = time.time()
717
  query_embeddings = self.encode_texts(queries, f"Encoding {language} queries")
 
797
  return 0.0
798
 
799
  def evaluate_all_languages(
800
+ self, max_queries_per_lang: int = 100, languages: list[str] | None = None
801
  ) -> dict[str, Any]:
802
  """Evaluate on all specified languages."""
803
  eval_languages = languages or languages_config.all
 
862
 
863
  def run_comprehensive_evaluation(
864
  self,
865
+ max_queries_per_lang: int = 100,
866
  languages: list[str] | None = None,
867
  skip_benchmark: bool = False,
868
  ) -> dict[str, Any]:
 
1150
 
1151
  def run_evaluation(
1152
  models: list[str],
1153
+ max_queries: int = 100,
1154
  languages: list[str] | None = None,
1155
  use_beam: bool = False,
1156
  skip_benchmark: bool = False,
 
1211
 
1212
  def _run_local_evaluation(
1213
  models: list[str],
1214
+ max_queries: int = 100,
1215
  languages: list[str] | None = None,
1216
  skip_benchmark: bool = False,
1217
  ) -> list[dict[str, Any]]:
 
1242
  return results
1243
 
1244
 
1245
+ @function(**get_evaluation_function_kwargs())
 
 
 
 
 
 
 
1246
  def _beam_evaluate_single_model(
1247
  model_path: str,
1248
+ max_queries: int = 100,
1249
  languages: list[str] | None = None,
1250
  skip_benchmark: bool = False,
1251
  ) -> dict[str, Any]:
1252
  """Beam function to comprehensively evaluate a single model."""
1253
+ # Set CUDA memory settings BEFORE any CUDA operations
1254
+
1255
+ import os
1256
+
1257
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
1258
 
1259
+ model_name = Path(model_path).name
1260
  logger.info(f"🚀 Beam comprehensive evaluation starting for {model_name}")
1261
 
1262
+ # Clear CUDA cache if available
1263
+ try:
1264
+ import torch
1265
+
1266
+ if torch.cuda.is_available():
1267
+ torch.cuda.empty_cache()
1268
+ torch.cuda.reset_peak_memory_stats()
1269
+ logger.info(
1270
+ f"🧹 Cleared CUDA cache. Available memory: {torch.cuda.get_device_properties(0).total_memory // (1024**3)} GB"
1271
+ )
1272
+ except Exception as e:
1273
+ logger.warning(f"⚠️ Could not clear CUDA cache: {e}")
1274
+
1275
  try:
1276
+ logger.info("🔧 Creating ComprehensiveModelEvaluator...")
1277
  evaluator = ComprehensiveModelEvaluator(model_path, model_name)
1278
+ logger.info("✅ ComprehensiveModelEvaluator created successfully")
1279
+
1280
+ logger.info("🚀 Starting comprehensive evaluation...")
1281
  results = evaluator.run_comprehensive_evaluation(max_queries, languages, skip_benchmark)
1282
+ logger.info("✅ Comprehensive evaluation completed")
1283
+
1284
+ # Validate results
1285
+ if not results or "model_name" not in results:
1286
+ logger.error(f"❌ Invalid evaluation results for {model_name}: {results}")
1287
+ return {"error": "Invalid evaluation results", "model_name": model_name}
1288
 
1289
  # Save to Beam volume as single comprehensive file
1290
+ logger.info("💾 Saving results to Beam volume...")
1291
  volume_results_dir = Path(VOLUME_CONFIG.mount_path) / "evaluation_results"
1292
  volume_results_dir.mkdir(parents=True, exist_ok=True)
1293
 
 
1297
  with result_file.open("w") as f:
1298
  json.dump(results, f, indent=2, default=str)
1299
 
1300
+ logger.info(f"💾 Saved Beam comprehensive evaluation results for {model_name} to {result_file}")
1301
+ logger.info(f"🎯 Final results summary: {len(results.get('codesearch_languages', {}))} languages evaluated")
1302
+
1303
  return results
1304
 
1305
+ except (torch.OutOfMemoryError, RuntimeError, AssertionError) as e:
1306
+ # Handle CUDA errors including OOM, driver errors, and PyTorch internal assertion failures
1307
+ is_oom = isinstance(e, torch.OutOfMemoryError) or "out of memory" in str(e).lower()
1308
+ is_cuda_error = is_oom or "cuda" in str(e).lower() or "INTERNAL ASSERT FAILED" in str(e) or "handles_" in str(e)
1309
+
1310
+ if is_cuda_error:
1311
+ error_type = "CUDA OOM" if is_oom else "CUDA Error"
1312
+ logger.exception(f"❌ {error_type} during evaluation of {model_name}")
1313
+
1314
+ # Try to clear memory and reset CUDA state more aggressively
1315
+ with contextlib.suppress(Exception):
1316
+ reset_cuda_state()
1317
+
1318
+ return {
1319
+ "error": f"{error_type}: {e!s}",
1320
+ "model_name": model_name,
1321
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
1322
+ "evaluation_failed": True,
1323
+ "oom": is_oom,
1324
+ "cuda_error": True,
1325
+ }
1326
+ # Re-raise if not a CUDA-related error
1327
+ raise
1328
+ except Exception as e:
1329
  logger.exception(f"❌ Beam comprehensive evaluation failed for {model_name}")
1330
+ # Return error info in a structured way
1331
+ error_result = {
1332
+ "error": str(e),
1333
+ "model_name": model_name,
1334
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
1335
+ "evaluation_failed": True,
1336
+ }
1337
+
1338
+ # Try to save error result to volume
1339
+ try:
1340
+ volume_results_dir = Path(VOLUME_CONFIG.mount_path) / "evaluation_results"
1341
+ volume_results_dir.mkdir(parents=True, exist_ok=True)
1342
+
1343
+ safe_model_name = get_safe_model_name(model_name)
1344
+ error_file = volume_results_dir / f"error_eval_{safe_model_name}.json"
1345
+
1346
+ with error_file.open("w") as f:
1347
+ json.dump(error_result, f, indent=2, default=str)
1348
+
1349
+ logger.info(f"💾 Saved error info to {error_file}")
1350
+ except Exception:
1351
+ logger.exception("❌ Could not save error info")
1352
+
1353
+ return error_result
1354
 
1355
 
1356
  def _run_beam_evaluation(
1357
  models: list[str],
1358
+ max_queries: int = 100,
1359
  languages: list[str] | None = None,
1360
  skip_benchmark: bool = False,
1361
  ) -> list[dict[str, Any]]:
 
1373
  result = _beam_evaluate_single_model.remote(model_path, max_queries, languages, skip_benchmark)
1374
 
1375
  if result:
1376
+ # Check if this is an error result
1377
+ if result.get("evaluation_failed", False):
1378
+ logger.error(f"❌ Beam evaluation failed for {model_name}: {result.get('error', 'Unknown error')}")
1379
+ if result.get("oom", False):
1380
+ logger.error("💥 Out of memory error - model may be too large for available GPU")
1381
+ continue
1382
+
1383
  # Download the comprehensive result file from Beam
1384
  success = download_specific_evaluation_file(
1385
  VOLUME_CONFIG.name,
 
1395
  results.append(result)
1396
  else:
1397
  logger.warning(f"⚠️ Could not download results for {model_name}")
1398
+ else:
1399
+ logger.warning(f"⚠️ No result returned for {model_name}")
1400
 
1401
  except Exception:
1402
  logger.exception(f"❌ Beam comprehensive evaluation failed for {model_name}")
 
1414
  use_beam: bool = typer.Option(default=False, help="Use Beam for evaluation"),
1415
  skip_third_party: bool = typer.Option(default=False, help="Skip third-party models"),
1416
  skip_benchmark: bool = typer.Option(default=False, help="Skip performance benchmarking"),
1417
+ max_queries: int = typer.Option(default=100, help="Maximum queries per language"),
1418
  ) -> None:
1419
  """Main comprehensive evaluation function."""
1420
  logger.info("🚀 Starting comprehensive model evaluation (CodeSearchNet + Performance)")
train_code_classification.py DELETED
@@ -1,365 +0,0 @@
1
- #!/usr/bin/env python
2
- """
3
- Script to train a code classification model using CodeSearchNet dataset with Model2Vec.
4
-
5
- This script performs the following operations:
6
- 1. Downloads the Alibaba-NLP/gte-Qwen2-7B-instruct model
7
- 2. Optionally distills it using Model2Vec to create a smaller, faster static model
8
- 3. Trains a programming language classifier on CodeSearchNet dataset
9
- 4. Evaluates the classifier and saves the trained model
10
-
11
- Based on the official CodeSearchNet dataset: https://github.com/github/CodeSearchNet
12
- """
13
-
14
- import json
15
- import logging
16
- import re
17
- import time
18
- from pathlib import Path
19
- from time import perf_counter
20
- from typing import Any, cast
21
-
22
- from datasets import Dataset, DatasetDict, load_dataset # type: ignore [import]
23
- from model2vec.distill import distill
24
- from model2vec.train import StaticModelForClassification
25
-
26
- # =============================================================================
27
- # CONFIGURATION CONSTANTS
28
- # =============================================================================
29
-
30
- # Model Configuration
31
- MODEL_NAME = "Alibaba-NLP/gte-Qwen2-7B-instruct" # Source model to distill
32
- OUTPUT_DIR = "." # Directory to save the trained model
33
-
34
- # Distillation Configuration
35
- SKIP_DISTILLATION = False # Set to True to skip distillation and use existing model
36
- DISTILLED_MODEL_PATH = "." # Path to existing distilled model (if skipping distillation)
37
- PCA_DIMS = 256 # Dimensions for PCA reduction (smaller = faster but less accurate)
38
-
39
- # Dataset Configuration
40
- DATASET_NAME = "code-search-net/code_search_net" # CodeSearchNet dataset
41
- CLASSIFICATION_TASK = "language" # Task: classify programming language
42
- MAX_SAMPLES_PER_LANGUAGE = 5000 # Limit samples per language for balanced training
43
- MIN_CODE_LENGTH = 50 # Minimum code length in characters
44
- MAX_CODE_LENGTH = 2000 # Maximum code length in characters (for memory efficiency)
45
-
46
- # Text processing constants
47
- MAX_COMMENT_LENGTH = 200 # Maximum length for comment lines before truncation
48
-
49
- # Training Configuration
50
- MAX_EPOCHS = 30 # Maximum number of training epochs
51
- PATIENCE = 5 # Early stopping patience
52
- BATCH_SIZE = 32 # Training batch size
53
- LEARNING_RATE = 1e-3 # Learning rate
54
-
55
- # Saving Configuration
56
- SAVE_PIPELINE = True # Save as scikit-learn compatible pipeline
57
- SAVE_TO_HUB = False # Whether to push the model to HuggingFace Hub
58
- HUB_MODEL_ID = None # Model ID for HuggingFace Hub (if saving to hub)
59
-
60
- # =============================================================================
61
-
62
- # Configure logging
63
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
64
- logger = logging.getLogger(__name__)
65
-
66
-
67
- def clean_code_text(code: str) -> str:
68
- """Clean and normalize code text for better classification."""
69
- if not code:
70
- return ""
71
-
72
- # Remove excessive whitespace while preserving structure
73
- code = re.sub(r"\n\s*\n\s*\n", "\n\n", code) # Remove multiple empty lines
74
- code = re.sub(r" +", " ", code) # Replace multiple spaces with single space
75
-
76
- # Remove very long comments that might bias classification
77
- lines = code.split("\n")
78
- cleaned_lines = []
79
- for original_line in lines:
80
- line = original_line
81
- # Keep comment lines but limit their length
82
- if line.strip().startswith(("#", "//", "/*", "*", "--")) and len(line) > MAX_COMMENT_LENGTH:
83
- line = line[:MAX_COMMENT_LENGTH] + "..."
84
- cleaned_lines.append(line)
85
-
86
- return "\n".join(cleaned_lines)
87
-
88
-
89
- def load_codesearchnet_dataset() -> tuple[Dataset, Dataset, str, str]:
90
- """Load and preprocess the CodeSearchNet dataset for programming language classification."""
91
- logger.info("Loading CodeSearchNet dataset...")
92
-
93
- try:
94
- # Load the dataset with trust_remote_code=True
95
- logger.info("Downloading and loading CodeSearchNet data...")
96
- ds = cast(
97
- "DatasetDict",
98
- load_dataset(
99
- DATASET_NAME,
100
- trust_remote_code=True,
101
- # Load a reasonable sample for training
102
- ),
103
- )
104
-
105
- logger.info(f"Available splits: {list(ds.keys())}")
106
-
107
- # Use train/test splits if available, otherwise split the data
108
- if "train" in ds and "test" in ds:
109
- train_dataset = ds["train"]
110
- test_dataset = ds["test"]
111
- elif "train" in ds:
112
- # Split the train set
113
- split_ds = ds["train"].train_test_split(test_size=0.2, seed=42)
114
- train_dataset = split_ds["train"]
115
- test_dataset = split_ds["test"]
116
- else:
117
- # Use all data and split
118
- all_data = ds[next(iter(ds.keys()))]
119
- split_ds = all_data.train_test_split(test_size=0.2, seed=42)
120
- train_dataset = split_ds["train"]
121
- test_dataset = split_ds["test"]
122
-
123
- logger.info(f"Raw dataset sizes - Train: {len(train_dataset)}, Test: {len(test_dataset)}")
124
-
125
- # Filter and preprocess the data
126
- def filter_and_clean(dataset: Dataset) -> Dataset:
127
- # Filter examples with valid code and language
128
- filtered = dataset.filter(
129
- lambda x: (
130
- x["func_code_string"] is not None
131
- and x["language"] is not None
132
- and len(x["func_code_string"]) >= MIN_CODE_LENGTH
133
- and len(x["func_code_string"]) <= MAX_CODE_LENGTH
134
- and x["language"] in ["python", "java", "javascript", "go", "php", "ruby"]
135
- )
136
- )
137
-
138
- # Balance the dataset by limiting samples per language
139
- if len(filtered) > MAX_SAMPLES_PER_LANGUAGE * 6: # 6 languages
140
- # Group by language and sample
141
- language_samples: dict[str, list[dict[str, Any]]] = {}
142
- for example in filtered:
143
- lang = example["language"]
144
- if lang not in language_samples:
145
- language_samples[lang] = []
146
- if len(language_samples[lang]) < MAX_SAMPLES_PER_LANGUAGE:
147
- language_samples[lang].append(example)
148
-
149
- # Combine all samples
150
- balanced_examples = []
151
- for lang_examples in language_samples.values():
152
- balanced_examples.extend(lang_examples)
153
-
154
- # Convert back to dataset format
155
- if balanced_examples:
156
- filtered = Dataset.from_list(balanced_examples)
157
-
158
- # Clean the code text
159
- def clean_example(example: dict[str, Any]) -> dict[str, Any]:
160
- example["func_code_string"] = clean_code_text(example["func_code_string"])
161
- return example
162
-
163
- return filtered.map(clean_example)
164
-
165
- train_dataset = filter_and_clean(train_dataset)
166
- test_dataset = filter_and_clean(test_dataset)
167
-
168
- logger.info(f"Filtered dataset sizes - Train: {len(train_dataset)}, Test: {len(test_dataset)}")
169
-
170
- # Show language distribution
171
- if len(train_dataset) > 0:
172
- from collections import Counter
173
-
174
- train_lang_dist = Counter(train_dataset["language"])
175
- test_lang_dist = Counter(test_dataset["language"])
176
- logger.info(f"Training language distribution: {dict(train_lang_dist)}")
177
- logger.info(f"Test language distribution: {dict(test_lang_dist)}")
178
-
179
- return train_dataset, test_dataset, "func_code_string", "language"
180
-
181
- except Exception:
182
- logger.exception("Error loading CodeSearchNet dataset")
183
- raise
184
-
185
-
186
- def main() -> None:
187
- """Run the code classification training pipeline."""
188
- # Create output directory if it doesn't exist
189
- output_dir = Path(OUTPUT_DIR)
190
- output_dir.mkdir(parents=True, exist_ok=True)
191
-
192
- logger.info(f"Starting CodeSearchNet code classification pipeline for {MODEL_NAME}")
193
- logger.info(f"Classification task: {CLASSIFICATION_TASK}")
194
- logger.info(f"Trained model will be saved to {output_dir}")
195
-
196
- # Record start time for benchmarking
197
- total_start_time = time.time()
198
-
199
- try:
200
- # Step 1: Get the static model (either distill or load existing)
201
- static_model = None
202
-
203
- if SKIP_DISTILLATION:
204
- if DISTILLED_MODEL_PATH:
205
- logger.info(f"Loading existing distilled model from {DISTILLED_MODEL_PATH}")
206
- # Note: We'll create the classifier from pretrained instead
207
- else:
208
- logger.error("DISTILLED_MODEL_PATH must be specified when SKIP_DISTILLATION is True")
209
- return
210
- else:
211
- logger.info("Starting Model2Vec distillation...")
212
- distill_start_time = time.time()
213
-
214
- static_model = distill(
215
- model_name=MODEL_NAME,
216
- pca_dims=PCA_DIMS,
217
- )
218
-
219
- distill_time = time.time() - distill_start_time
220
- logger.info(f"Distillation completed in {distill_time:.2f} seconds")
221
-
222
- # Step 2: Create the classifier
223
- logger.info("Creating classifier...")
224
-
225
- if static_model is not None:
226
- # From a distilled model
227
- classifier = StaticModelForClassification.from_static_model(model=static_model)
228
- else:
229
- # From a pre-trained model path
230
- classifier = StaticModelForClassification.from_pretrained(model_name=DISTILLED_MODEL_PATH)
231
-
232
- # Step 3: Load the CodeSearchNet dataset
233
- train_dataset, test_dataset, text_column, label_column = load_codesearchnet_dataset()
234
-
235
- if len(train_dataset) == 0 or len(test_dataset) == 0:
236
- logger.error("No valid data found after filtering. Please check dataset configuration.")
237
- return
238
-
239
- logger.info(f"Training dataset size: {len(train_dataset)}")
240
- logger.info(f"Test dataset size: {len(test_dataset)}")
241
-
242
- # Get unique languages for reference
243
- unique_languages = sorted(set(train_dataset[label_column]))
244
- logger.info(f"Programming languages to classify: {unique_languages}")
245
-
246
- # Step 4: Train the classifier
247
- logger.info("Starting training...")
248
- train_start_time = perf_counter()
249
-
250
- classifier = classifier.fit(
251
- train_dataset[text_column],
252
- train_dataset[label_column],
253
- max_epochs=MAX_EPOCHS,
254
- batch_size=BATCH_SIZE,
255
- learning_rate=LEARNING_RATE,
256
- early_stopping_patience=PATIENCE,
257
- )
258
-
259
- train_time = perf_counter() - train_start_time
260
- logger.info(f"Training completed in {int(train_time)} seconds")
261
-
262
- # Step 5: Evaluate the classifier
263
- logger.info("Evaluating classifier...")
264
- eval_start_time = perf_counter()
265
-
266
- classification_report = classifier.evaluate(test_dataset[text_column], test_dataset[label_column])
267
-
268
- eval_time = perf_counter() - eval_start_time
269
- logger.info(f"Evaluation completed in {int(eval_time * 1000)} milliseconds")
270
- logger.info(f"Classification results:\n{classification_report}")
271
-
272
- # Step 6: Test with some examples
273
- logger.info("Testing with sample code snippets...")
274
-
275
- # Test examples for different languages
276
- test_examples = [
277
- 'def hello_world():\n print("Hello, World!")\n return True', # Python
278
- (
279
- "public class HelloWorld {\n"
280
- " public static void main(String[] args) {\n"
281
- ' System.out.println("Hello, World!");\n'
282
- " }\n"
283
- "}"
284
- ), # Java
285
- 'function helloWorld() {\n console.log("Hello, World!");\n return true;\n}', # JavaScript
286
- 'package main\n\nimport "fmt"\n\nfunc main() {\n fmt.Println("Hello, World!")\n}', # Go
287
- '<?php\nfunction hello_world() {\n echo "Hello, World!";\n return true;\n}\n?>', # PHP
288
- 'def hello_world\n puts "Hello, World!"\n true\nend', # Ruby
289
- ]
290
-
291
- predictions = classifier.predict(test_examples)
292
- for i, (code, pred) in enumerate(zip(test_examples, predictions, strict=False)):
293
- logger.info(f"Example {i + 1}: {pred}")
294
- logger.info(f"Code snippet: {code[:100]}...")
295
-
296
- # Step 7: Benchmark inference speed
297
- logger.info("Benchmarking inference speed...")
298
- inference_start_time = perf_counter()
299
- _ = classifier.predict(test_dataset[text_column][:100]) # Test on first 100 samples
300
- inference_time = perf_counter() - inference_start_time
301
- logger.info(f"Inference took {int(inference_time * 1000)} milliseconds for 100 code snippets on CPU")
302
-
303
- # Step 8: Save the model
304
- if SAVE_PIPELINE:
305
- logger.info("Converting to scikit-learn pipeline...")
306
- pipeline = classifier.to_pipeline()
307
-
308
- # Save locally
309
- pipeline_path = output_dir / "pipeline"
310
- pipeline.save_pretrained(str(pipeline_path))
311
- logger.info(f"Pipeline saved to {pipeline_path}")
312
-
313
- # Save additional metadata
314
- metadata = {
315
- "model_name": MODEL_NAME,
316
- "dataset": DATASET_NAME,
317
- "task": "programming_language_classification",
318
- "languages": unique_languages,
319
- "pca_dims": PCA_DIMS,
320
- "train_samples": len(train_dataset),
321
- "test_samples": len(test_dataset),
322
- }
323
-
324
- metadata_path = output_dir / "metadata.json"
325
- with metadata_path.open("w") as f:
326
- json.dump(metadata, f, indent=2)
327
- logger.info("Metadata saved to metadata.json")
328
-
329
- # Push to hub if requested
330
- if SAVE_TO_HUB and HUB_MODEL_ID:
331
- logger.info(f"Pushing pipeline to HuggingFace Hub as {HUB_MODEL_ID}")
332
- pipeline.push_to_hub(HUB_MODEL_ID)
333
- else:
334
- # Save the classifier directly
335
- classifier_path = output_dir / "classifier"
336
- classifier_path.mkdir(exist_ok=True)
337
-
338
- # Note: StaticModelForClassification might not have save_pretrained
339
- # We'll save the underlying static model and create instructions
340
- if static_model is not None:
341
- static_model.save_pretrained(str(classifier_path / "static_model"))
342
-
343
- logger.info(f"Classifier components saved to {classifier_path}")
344
-
345
- # Summary
346
- total_time = time.time() - total_start_time
347
- logger.info("=" * 60)
348
- logger.info("CODE CLASSIFICATION TRAINING COMPLETED SUCCESSFULLY!")
349
- logger.info(f"Total time: {total_time:.2f} seconds")
350
- if not SKIP_DISTILLATION:
351
- logger.info(f"Distillation time: {distill_time:.2f} seconds")
352
- logger.info(f"Training time: {int(train_time)} seconds")
353
- logger.info(f"Dataset: {DATASET_NAME}")
354
- logger.info("Task: Programming Language Classification")
355
- logger.info(f"Languages: {', '.join(unique_languages)}")
356
- logger.info(f"Model saved to: {output_dir}")
357
- logger.info("=" * 60)
358
-
359
- except Exception:
360
- logger.exception("Error during code classification training pipeline")
361
- raise
362
-
363
-
364
- if __name__ == "__main__":
365
- main()