Sarthak commited on
Commit
454e47c
·
1 Parent(s): ea0b2a0

feat: overhaul distiller package with unified CLI, enhanced evaluation, and modular structure

Browse files
src/distiller/__init__.py CHANGED
@@ -1,7 +1,67 @@
1
- """Model2Vec Distillation Pipeline for gte-Qwen2-7B-instruct."""
 
2
 
3
- __version__ = "0.1.0"
 
 
 
4
 
5
- from .distill import beam_code_distillation, code_specialized_distillation
 
 
 
 
 
6
 
7
- __all__ = ["beam_code_distillation", "code_specialized_distillation"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Distiller package for code-specialized embedding model distillation and evaluation.
3
 
4
+ This package provides a complete pipeline for:
5
+ 1. Distilling code-specialized embedding models using Model2Vec
6
+ 2. Comprehensive evaluation including CodeSearchNet and performance benchmarks
7
+ 3. Analysis and reporting of model performance
8
 
9
+ Main modules:
10
+ - distill: Model2Vec distillation with optional advanced training
11
+ - evaluate: Comprehensive evaluation (CodeSearchNet + performance benchmarks)
12
+ - analyze: Analysis and reporting tools
13
+ - config: Centralized configuration management
14
+ - beam_utils: Beam cloud utilities for distributed processing
15
 
16
+ Usage:
17
+ from distiller import distill, evaluate, analyze
18
+ """
19
+
20
+ from . import analyze, config, distill, evaluate
21
+ from .analyze import CodeSearchNetAnalyzer
22
+ from .config import (
23
+ BEAM_ENV_SETTINGS,
24
+ DEFAULT_EVALUATION_MODELS,
25
+ GPU_NAME,
26
+ IMAGE,
27
+ codesearchnet_config,
28
+ directories,
29
+ distillation_config,
30
+ get_volume_config,
31
+ languages_config,
32
+ )
33
+ from .distill import (
34
+ run_beam_distillation,
35
+ run_local_distillation,
36
+ )
37
+ from .evaluate import (
38
+ CodeSearchNetEvaluator,
39
+ ComprehensiveModelEvaluator,
40
+ run_evaluation,
41
+ )
42
+
43
+ __all__ = [
44
+ # Configuration
45
+ "BEAM_ENV_SETTINGS",
46
+ "DEFAULT_EVALUATION_MODELS",
47
+ "GPU_NAME",
48
+ "IMAGE",
49
+ # Main classes
50
+ "CodeSearchNetAnalyzer",
51
+ "CodeSearchNetEvaluator",
52
+ "ComprehensiveModelEvaluator",
53
+ # Modules
54
+ "analyze",
55
+ "codesearchnet_config",
56
+ "config",
57
+ "directories",
58
+ "distill",
59
+ "distillation_config",
60
+ "evaluate",
61
+ "get_volume_config",
62
+ "languages_config",
63
+ # Main functions
64
+ "run_beam_distillation",
65
+ "run_evaluation",
66
+ "run_local_distillation",
67
+ ]
src/distiller/__main__.py CHANGED
@@ -1,183 +1,57 @@
1
  """Main entry point for the distiller package."""
2
 
3
- import argparse
4
- import sys
5
 
 
6
 
7
- def main() -> None:
8
- """Main entry point for the distiller package."""
9
- parser = argparse.ArgumentParser(description="Model2Vec Code-Specialized Distillation Pipeline")
10
- subparsers = parser.add_subparsers(dest="command", help="Available commands")
 
11
 
12
- # Distillation command
13
- distill_parser = subparsers.add_parser("distill", help="Run code-specialized model distillation")
14
- distill_parser.add_argument("--model", default="Alibaba-NLP/gte-Qwen2-7B-instruct", help="Model to distill")
15
- distill_parser.add_argument("--output-dir", default="gte_qwen2_m2v_code", help="Output directory")
16
- distill_parser.add_argument("--pca-dims", type=int, default=512, help="PCA dimensions")
17
- distill_parser.add_argument("--max-samples", type=int, default=50000, help="Max CodeSearchNet samples")
18
- distill_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud GPU distillation")
19
 
20
- # Simplified distillation command
21
- simple_parser = subparsers.add_parser("distill-simple", help="Run simplified Model2Vec distillation (local)")
22
- simple_parser.add_argument(
23
- "--teacher", default="sentence-transformers/all-MiniLM-L6-v2", help="Teacher model to distill from"
24
- )
25
- simple_parser.add_argument("--output-dir", default="gte_qwen2_m2v_code_simplified", help="Output directory")
26
- simple_parser.add_argument("--pca-dims", type=int, default=256, help="PCA dimensions")
 
 
27
 
28
- # CodeSearchNet evaluation command
29
- evaluate_parser = subparsers.add_parser("evaluate", help="Run CodeSearchNet evaluation on all default models")
30
- evaluate_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud evaluation")
31
 
32
- # CodeSearchNet evaluation command (simplified models only)
33
- evaluate_simple_parser = subparsers.add_parser(
34
- "evaluate-simple", help="Run CodeSearchNet evaluation on simplified models only"
35
- )
36
- evaluate_simple_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud evaluation")
37
 
38
- # Analysis command
39
- analysis_parser = subparsers.add_parser("analyze", help="Generate CodeSearchNet analysis report")
40
- analysis_parser.add_argument("--results-dir", default="code_evaluation_results", help="Results directory")
41
- analysis_parser.add_argument("--results-file", help="Single results file to analyze")
42
- analysis_parser.add_argument("--model-name", default="gte_qwen2_m2v_code", help="Model name for report")
43
- analysis_parser.add_argument("--output", default="README.md", help="Output report file")
44
- analysis_parser.add_argument("--export-csv", help="Export comparison results to CSV")
45
- analysis_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud analysis")
 
46
 
47
- # Sync command
48
- sync_parser = subparsers.add_parser("sync", help="Download files from Beam volume to local directory")
49
- sync_parser.add_argument("--model-files", action="store_true", help="Download final model files")
50
- sync_parser.add_argument(
51
- "--analysis-files",
52
- action="store_true",
53
- help="Download analysis reports and charts",
54
- )
55
- sync_parser.add_argument("--all", action="store_true", help="Download all generated files")
56
- sync_parser.add_argument("--output-dir", default=".", help="Local output directory")
57
 
58
- # Benchmark command
59
- benchmark_parser = subparsers.add_parser("benchmark", help="Run performance benchmarking on all default models")
60
- benchmark_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud benchmarking")
61
 
62
- # Benchmark command (simplified models only)
63
- benchmark_simple_parser = subparsers.add_parser(
64
- "benchmark-simple", help="Run performance benchmarking on simplified models only"
65
- )
66
- benchmark_simple_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud benchmarking")
 
 
 
 
67
 
68
- args = parser.parse_args()
69
-
70
- if args.command == "distill":
71
- from .distill_simplified import run_local_distillation, beam_distill_all_teachers
72
-
73
- if args.use_beam:
74
- # Run on Beam
75
- print("Running comprehensive teacher model distillation on Beam...")
76
- results = beam_distill_all_teachers()
77
- else:
78
- # Run locally
79
- print("Running comprehensive teacher model distillation locally...")
80
- results = run_local_distillation()
81
-
82
- print(f"✅ Distillation complete! Created {results['total_successful']} models")
83
- print("📁 Models location: ./code_model2vec/final/")
84
- print("\n✅ Created models:")
85
- for model_name in results["successful_models"]:
86
- model_info = results["all_results"][model_name]
87
- print(f" • {model_name} (from {model_info['teacher_model']})")
88
-
89
- elif args.command == "distill-simple":
90
- from .distill_simplified import run_local_distillation
91
-
92
- # Run simplified distillation for all teacher models locally
93
- print("Running comprehensive teacher model distillation locally...")
94
- results = run_local_distillation()
95
- print(f"✅ Distillation complete! Created {results['total_successful']} models")
96
- print("📁 Models location: ./code_model2vec/final/")
97
- print("\n✅ Created models:")
98
- for model_name in results["successful_models"]:
99
- model_info = results["all_results"][model_name]
100
- print(f" • {model_name} (from {model_info['teacher_model']})")
101
-
102
- elif args.command == "evaluate":
103
- from .evaluate import main as evaluate_main, run_local_evaluation
104
-
105
- if args.use_beam:
106
- # Run on Beam with all default models
107
- print("Running comprehensive evaluation on Beam...")
108
- evaluate_main()
109
- else:
110
- # Run locally with all default models
111
- print("Running comprehensive evaluation locally...")
112
- run_local_evaluation()
113
-
114
- elif args.command == "evaluate-simple":
115
- from .evaluate import evaluate_simplified_only, run_local_evaluation_simplified
116
-
117
- if args.use_beam:
118
- # Run on Beam with simplified models only
119
- print("Running simplified model evaluation on Beam...")
120
- evaluate_simplified_only()
121
- else:
122
- # Run locally with simplified models only
123
- print("Running simplified model evaluation locally...")
124
- run_local_evaluation_simplified()
125
-
126
- elif args.command == "analyze":
127
- from .analyze import main as analyze_main
128
-
129
- # Run locally - Override sys.argv to pass arguments to the analyze script
130
- sys.argv = ["analyze.py"]
131
- if args.results_dir != "code_evaluation_results":
132
- sys.argv.extend(["--results-dir", args.results_dir])
133
- if args.results_file:
134
- sys.argv.extend(["--results-file", args.results_file])
135
- if args.model_name != "gte_qwen2_m2v_code":
136
- sys.argv.extend(["--model-name", args.model_name])
137
- if args.output != "README.md":
138
- sys.argv.extend(["--output", args.output])
139
- if args.export_csv:
140
- sys.argv.extend(["--export-csv", args.export_csv])
141
- analyze_main()
142
-
143
- elif args.command == "sync":
144
- from .sync import sync_files
145
-
146
- # Run locally
147
- sync_files(
148
- model_files=args.model_files,
149
- analysis_files=args.analysis_files,
150
- all_files=args.all,
151
- output_dir=args.output_dir,
152
- )
153
-
154
- elif args.command == "benchmark":
155
- from .benchmark import main as benchmark_main, run_local_benchmark
156
-
157
- if args.use_beam:
158
- # Run on Beam with all default models
159
- print("Running comprehensive benchmarking on Beam...")
160
- benchmark_main()
161
- else:
162
- # Run locally with all default models
163
- print("Running comprehensive benchmarking locally...")
164
- run_local_benchmark()
165
-
166
- elif args.command == "benchmark-simple":
167
- from .benchmark import benchmark_simplified_only, run_local_benchmark_simplified
168
-
169
- if args.use_beam:
170
- # Run on Beam with simplified models only
171
- print("Running simplified model benchmarking on Beam...")
172
- benchmark_simplified_only()
173
- else:
174
- # Run locally with simplified models only
175
- print("Running simplified model benchmarking locally...")
176
- run_local_benchmark_simplified()
177
-
178
- else:
179
- parser.print_help()
180
 
181
 
182
  if __name__ == "__main__":
183
- main()
 
1
  """Main entry point for the distiller package."""
2
 
3
+ from typing import Annotated
 
4
 
5
+ import typer
6
 
7
+ app = typer.Typer(
8
+ help="Model2Vec Code-Specialized Distillation Pipeline",
9
+ no_args_is_help=True,
10
+ context_settings={"help_option_names": ["-h", "--help"]},
11
+ )
12
 
 
 
 
 
 
 
 
13
 
14
+ @app.command()
15
+ def distill(
16
+ use_beam: Annotated[bool, typer.Option(help="Use Beam for distillation")] = False,
17
+ train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False,
18
+ teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None,
19
+ pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None,
20
+ ) -> None:
21
+ """Run unified Model2Vec distillation with optional training."""
22
+ from .distill import main as distill_main
23
 
24
+ # Call the distill main function with arguments
25
+ distill_main(use_beam, train, teacher_models, pca_dims)
 
26
 
 
 
 
 
 
27
 
28
+ @app.command()
29
+ def evaluate(
30
+ use_beam: Annotated[bool, typer.Option(help="Use Beam for evaluation")] = False,
31
+ skip_third_party: Annotated[bool, typer.Option(help="Skip third-party models")] = False,
32
+ skip_benchmark: Annotated[bool, typer.Option(help="Skip performance benchmarking")] = False,
33
+ max_queries: Annotated[int, typer.Option(help="Maximum queries per language")] = 1000,
34
+ ) -> None:
35
+ """Run CodeSearchNet evaluation on models."""
36
+ from .evaluate import main as evaluate_main
37
 
38
+ # Call the evaluate main function with arguments
39
+ evaluate_main(use_beam, skip_third_party, skip_benchmark, max_queries)
 
 
 
 
 
 
 
 
40
 
 
 
 
41
 
42
+ @app.command()
43
+ def analyze(
44
+ results_dir: Annotated[str | None, typer.Option(help="Results directory")] = None,
45
+ model_name: Annotated[str, typer.Option(help="Model name for analysis")] = "gte_qwen2_m2v_code (Ours)",
46
+ output: Annotated[str, typer.Option(help="Output report file")] = "REPORT.md",
47
+ export_csv: Annotated[str | None, typer.Option(help="Export results to CSV")] = None,
48
+ ) -> None:
49
+ """Generate comprehensive analysis reports."""
50
+ from .analyze import main as analyze_main
51
 
52
+ # Call the analyze main function with arguments
53
+ analyze_main(results_dir or "code_model2vec/evaluation_results", model_name, output, export_csv)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  if __name__ == "__main__":
57
+ app()
src/distiller/analyze.py CHANGED
@@ -23,7 +23,6 @@ Usage:
23
  distiller analyze --results-dir evaluation_results
24
  """
25
 
26
- import argparse
27
  import json
28
  import logging
29
  import time
@@ -35,6 +34,8 @@ import numpy as np
35
  import pandas as pd
36
  import seaborn as sns
37
 
 
 
38
  # Optional Plotly import with fallback
39
  PLOTLY_AVAILABLE = True
40
  try:
@@ -65,48 +66,140 @@ OUTPUT_DIR = Path("analysis_results")
65
  IMAGES_DIR = Path("analysis_charts")
66
  REPORT_FILE = Path("REPORT.md") # Changed from README.md
67
 
68
- # Local directories for results - updated for new structure
69
- DEFAULT_EVALUATION_DIR = "code_model2vec/evaluation_results"
70
- DEFAULT_BENCHMARK_DIR = "code_model2vec/benchmark_results"
71
 
72
  # CodeSearchNet Languages
73
  CODE_LANGUAGES = ["python", "javascript", "java", "php", "ruby", "go"]
74
 
75
  # Model name mapping from the default models in evaluate.py and benchmark.py
76
  MODEL_NAME_MAPPING = {
77
- # File names to display names
78
- "gte_qwen2_m2v_code": "gte_qwen2_m2v_code (Ours)",
79
- "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
80
- "codebert-base": "microsoft/codebert-base",
81
- "graphcodebert-base": "microsoft/graphcodebert-base",
82
- "CodeBERTa-small-v1": "huggingface/CodeBERTa-small-v1",
83
- "all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
84
- "all-MiniLM-L12-v2": "sentence-transformers/all-MiniLM-L12-v2",
85
- "potion-base-8M": "minishlab/potion-base-8M",
86
- "potion-retrieval-32M": "minishlab/potion-retrieval-32M",
87
- "codet5-base": "Salesforce/codet5-base",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  }
89
 
90
- # Reverse mapping for lookups
91
- DISPLAY_NAME_TO_FILE = {v: k for k, v in MODEL_NAME_MAPPING.items()}
92
 
93
  # Peer models for comparison (code-specialized models)
94
  PEER_MODELS = {
95
- "sentence-transformers/all-MiniLM-L6-v2": {"overall_ndcg": 0.25, "type": "General"},
96
- "microsoft/codebert-base": {"overall_ndcg": 0.32, "type": "Code-Specific"},
97
- "microsoft/graphcodebert-base": {"overall_ndcg": 0.35, "type": "Code-Specific"},
98
- "huggingface/CodeBERTa-small-v1": {"overall_ndcg": 0.28, "type": "Code-Specific"},
99
- "sentence-transformers/all-mpnet-base-v2": {"overall_ndcg": 0.27, "type": "General"},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  }
101
 
102
  # Model specifications for efficiency analysis
103
  MODEL_SPECS = {
104
- "sentence-transformers/all-MiniLM-L6-v2": {"parameters": 22.7, "size_mb": 90},
105
- "microsoft/codebert-base": {"parameters": 125.0, "size_mb": 500},
106
- "microsoft/graphcodebert-base": {"parameters": 125.0, "size_mb": 500},
107
- "huggingface/CodeBERTa-small-v1": {"parameters": 84.0, "size_mb": 340},
108
- "sentence-transformers/all-mpnet-base-v2": {"parameters": 109.0, "size_mb": 440},
109
- "Alibaba-NLP/gte-Qwen2-7B-instruct": {"parameters": 7000.0, "size_mb": 13000},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  }
111
 
112
  # Distilled model specifications
@@ -134,13 +227,12 @@ def setup_directories(base_path: Path | None = None) -> tuple[Path, Path, Path]:
134
  images_dir = base_path / "analysis_results" / "charts"
135
  reports_dir = base_path / "analysis_results" / "reports"
136
  else:
137
- output_dir = OUTPUT_DIR
138
- images_dir = IMAGES_DIR
139
- reports_dir = OUTPUT_DIR / "reports"
140
 
141
- output_dir.mkdir(parents=True, exist_ok=True)
142
  images_dir.mkdir(parents=True, exist_ok=True)
143
- reports_dir.mkdir(parents=True, exist_ok=True)
144
 
145
  return output_dir, images_dir, reports_dir
146
 
@@ -152,17 +244,94 @@ def extract_model_name_from_filename(filename: str) -> str:
152
 
153
  # Check if it's in our mapping
154
  if name in MODEL_NAME_MAPPING:
155
- return MODEL_NAME_MAPPING[name]
156
 
157
  # Try to find partial matches
158
- for file_key, display_name in MODEL_NAME_MAPPING.items():
159
  if file_key in name or name in file_key:
160
- return display_name
161
 
162
  # If no mapping found, return the cleaned name
163
  return name
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  class CodeSearchNetAnalyzer:
167
  """Analyzer for CodeSearchNet evaluation results and performance benchmarks."""
168
 
@@ -182,33 +351,43 @@ class CodeSearchNetAnalyzer:
182
  self.benchmark_df: pd.DataFrame | None = None
183
 
184
  def load_benchmark_results(self) -> None:
185
- """Load benchmark results from local directory."""
186
- logger.info("📊 Loading benchmark results...")
187
 
188
- if not self.benchmark_dir.exists():
189
- logger.warning(f"Benchmark directory not found: {self.benchmark_dir}")
190
  return
191
 
192
- logger.info(f"🔍 Searching for benchmark files in: {self.benchmark_dir}")
193
- benchmark_files = list(self.benchmark_dir.glob("benchmark_*.json"))
194
- logger.info(f"📁 Found {len(benchmark_files)} benchmark files")
195
 
196
- for benchmark_file_path in benchmark_files:
 
 
 
 
 
 
 
 
 
197
  try:
198
- logger.info(f"📖 Loading: {benchmark_file_path.name}")
199
- with benchmark_file_path.open() as f:
200
  data = json.load(f)
 
201
  if data is not None:
202
- # Update model name with proper mapping
203
- original_name = data.get("model_name", "Unknown")
204
- mapped_name = extract_model_name_from_filename(benchmark_file_path.stem)
205
- data["model_name"] = mapped_name
206
- data["original_model_name"] = original_name
207
-
208
- self.benchmark_results.append(data)
209
- logger.info(f"✅ Successfully loaded: {mapped_name}")
 
 
210
  except (json.JSONDecodeError, KeyError) as e:
211
- logger.warning(f"❌ Failed to load {benchmark_file_path}: {e}")
212
 
213
  logger.info(f"📊 Total benchmark results loaded: {len(self.benchmark_results)}")
214
  if self.benchmark_results:
@@ -217,6 +396,34 @@ class CodeSearchNetAnalyzer:
217
 
218
  self._create_benchmark_dataframe()
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  def _create_benchmark_dataframe(self) -> None:
221
  """Create benchmark comparison DataFrame from results."""
222
  if not self.benchmark_results:
@@ -263,10 +470,10 @@ class CodeSearchNetAnalyzer:
263
  )
264
 
265
  # CPU vs GPU comparison
266
- for device in ["cpu", "cuda"]:
267
- if device in cpu_vs_gpu and "error" not in cpu_vs_gpu[device]:
268
  device_key = f"{device.upper()}_TextsPerSec"
269
- row[device_key] = cpu_vs_gpu[device].get("texts_per_second", 0)
270
 
271
  benchmark_data.append(row)
272
 
@@ -281,23 +488,31 @@ class CodeSearchNetAnalyzer:
281
  return
282
 
283
  logger.info(f"🔍 Searching for evaluation files in: {self.results_dir}")
284
- json_files = list(self.results_dir.glob("codesearchnet_eval_*.json"))
285
- logger.info(f"📁 Found {len(json_files)} evaluation files")
286
 
287
- for json_file in json_files:
 
 
 
 
 
 
 
 
 
288
  try:
289
  logger.info(f"📖 Loading: {json_file.name}")
290
  with json_file.open() as f:
291
  data = json.load(f)
292
  if data is not None:
293
- # Update model name with proper mapping
294
- original_name = data.get("model_name", "Unknown")
295
- mapped_name = extract_model_name_from_filename(json_file.stem)
296
- data["model_name"] = mapped_name
297
- data["original_model_name"] = original_name
298
-
299
- self.results.append(data)
300
- logger.info(f"✅ Successfully loaded: {mapped_name}")
 
301
  except (json.JSONDecodeError, KeyError) as e:
302
  logger.warning(f"❌ Failed to load {json_file}: {e}")
303
 
@@ -311,6 +526,32 @@ class CodeSearchNetAnalyzer:
311
  # Also load benchmark results
312
  self.load_benchmark_results()
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  def _create_comparison_dataframe(self) -> None:
315
  """Create comparison DataFrame from results."""
316
  if not self.results:
@@ -453,7 +694,7 @@ class CodeSearchNetAnalyzer:
453
  if cpu_vs_gpu:
454
  print("🖥️ CPU vs GPU:")
455
  for device, metrics in cpu_vs_gpu.items():
456
- if "error" not in metrics:
457
  print(f" {device.upper()}: {metrics.get('texts_per_second', 0):.1f} texts/sec")
458
 
459
  # Memory efficiency
@@ -978,7 +1219,7 @@ class CodeSearchNetAnalyzer:
978
  # Safe conversion to float for pandas values
979
  score_value = pd.to_numeric(current_model_score, errors="coerce")
980
  scores.append(float(score_value) if not pd.isna(score_value) else 0.0)
981
- params.append(float(MODEL_SPECS[model_key].get("parameters", 100)))
982
  is_user_model.append(False)
983
 
984
  if not models:
@@ -1098,7 +1339,7 @@ class CodeSearchNetAnalyzer:
1098
 
1099
  # Create visualizations
1100
  logger.info("Generating visualizations...")
1101
- setup_directories()
1102
 
1103
  self.create_performance_radar_chart(main_model_name, language_scores)
1104
  comparison_chart = self.plot_model_comparison()
@@ -1163,21 +1404,14 @@ This report presents a comprehensive analysis of Model2Vec distillation experime
1163
  overall_metrics = result.get("overall", {})
1164
 
1165
  # Extract teacher model name from model name
1166
- teacher = "Unknown"
1167
- if "all_MiniLM_L6_v2" in model_display:
1168
- teacher = "all-MiniLM-L6-v2"
1169
- elif "codebert_base" in model_display:
1170
- teacher = "codebert-base"
1171
- elif "graphcodebert_base" in model_display:
1172
- teacher = "graphcodebert-base"
1173
- elif "gte_Qwen2_7B_instruct" in model_display:
1174
- teacher = "gte-Qwen2-7B-instruct"
1175
- elif "all_mpnet_base_v2" in model_display:
1176
- teacher = "all-mpnet-base-v2"
1177
 
1178
  status = "🥇 Best" if rank == 1 else "🥈 2nd" if rank == 2 else "🥉 3rd" if rank == 3 else f"#{rank}"
1179
 
1180
- report += f"| {model_display} | {teacher} | {overall_metrics.get('ndcg@10', 0):.4f} | {overall_metrics.get('mrr', 0):.4f} | {overall_metrics.get('recall@5', 0):.4f} | {status} |\n"
 
 
 
1181
 
1182
  report += """
1183
 
@@ -1215,19 +1449,12 @@ This report presents a comprehensive analysis of Model2Vec distillation experime
1215
  report += "### Individual Model Performance by Language\n\n"
1216
  for chart_model_name, chart_path in individual_radar_charts.items():
1217
  # Extract teacher name for cleaner display
1218
- teacher = "Unknown"
1219
- if "all_MiniLM_L6_v2" in chart_model_name:
1220
- teacher = "all-MiniLM-L6-v2"
1221
- elif "codebert_base" in chart_model_name:
1222
- teacher = "codebert-base"
1223
- elif "graphcodebert_base" in chart_model_name:
1224
- teacher = "graphcodebert-base"
1225
- elif "gte_Qwen2_7B_instruct" in chart_model_name:
1226
- teacher = "gte-Qwen2-7B-instruct"
1227
- elif "all_mpnet_base_v2" in chart_model_name:
1228
- teacher = "all-mpnet-base-v2"
1229
-
1230
- report += f"#### {chart_model_name} (Teacher: {teacher})\n\n"
1231
  report += f"![{chart_model_name} Radar Chart]({chart_path})\n\n"
1232
 
1233
  report += f"""
@@ -1324,7 +1551,7 @@ This report presents a comprehensive analysis of Model2Vec distillation experime
1324
 
1325
  if language_scores:
1326
  report += "| Language | Best Model Performance | Average Performance | Language Difficulty |\n"
1327
- report += "|----------|------------------------|--------------------|--------------------||\n"
1328
 
1329
  for lang in sorted(language_scores.keys()):
1330
  # Find best performance for this language across all models
@@ -1358,16 +1585,8 @@ Based on the evaluation results across all simplified distillation models:
1358
  model_name = result["model_name"]
1359
  score = result.get("overall", {}).get("ndcg@10", 0)
1360
 
1361
- if "all_MiniLM_L6_v2" in model_name:
1362
- teacher_performance["all-MiniLM-L6-v2"] = score
1363
- elif "codebert_base" in model_name:
1364
- teacher_performance["codebert-base"] = score
1365
- elif "graphcodebert_base" in model_name:
1366
- teacher_performance["graphcodebert-base"] = score
1367
- elif "gte_Qwen2_7B_instruct" in model_name:
1368
- teacher_performance["gte-Qwen2-7B-instruct"] = score
1369
- elif "all_mpnet_base_v2" in model_name:
1370
- teacher_performance["all-mpnet-base-v2"] = score
1371
 
1372
  if teacher_performance:
1373
  best_teacher = max(teacher_performance.items(), key=lambda x: x[1])
@@ -1397,11 +1616,20 @@ Based on the evaluation results across all simplified distillation models:
1397
  - **Evaluation**: Retrieval of correct code for each documentation query
1398
 
1399
  ### Teacher Models Tested
1400
- - sentence-transformers/all-MiniLM-L6-v2 (proven baseline)
1401
- - microsoft/codebert-base (code-specialized)
1402
- - microsoft/graphcodebert-base (graph-aware code model)
1403
- - Alibaba-NLP/gte-Qwen2-7B-instruct (large instruction model)
1404
- - sentence-transformers/all-mpnet-base-v2 (general purpose)
 
 
 
 
 
 
 
 
 
1405
 
1406
  ### Distillation Method
1407
  - **Technique**: Model2Vec static embedding generation
@@ -1424,31 +1652,27 @@ Based on the evaluation results across all simplified distillation models:
1424
  logger.info(f"Results exported to {output_file}")
1425
 
1426
 
1427
- def main() -> None:
 
 
 
 
 
1428
  """Main analysis function."""
1429
- parser = argparse.ArgumentParser(description="Analyze CodeSearchNet evaluation results and performance benchmarks")
1430
- parser.add_argument("--results-dir", default=DEFAULT_EVALUATION_DIR, help="Evaluation results directory")
1431
- parser.add_argument("--benchmark-dir", default=DEFAULT_BENCHMARK_DIR, help="Benchmark results directory")
1432
- parser.add_argument("--model-name", default="gte_qwen2_m2v_code (Ours)", help="Model name for report")
1433
- parser.add_argument("--output", default="REPORT.md", help="Output report file")
1434
- parser.add_argument("--export-csv", help="Export comparison results to CSV")
1435
-
1436
- args = parser.parse_args()
1437
-
1438
- logger.info("Starting CodeSearchNet Analysis with Benchmark Integration")
1439
  logger.info("=" * 60)
1440
 
1441
  # Setup output directories
1442
  output_dir, images_dir, reports_dir = setup_directories()
1443
 
1444
- # Initialize analyzer with local directories
1445
  analyzer = CodeSearchNetAnalyzer(
1446
- results_dir=args.results_dir,
1447
- benchmark_dir=args.benchmark_dir,
1448
  images_dir=images_dir,
1449
  )
1450
 
1451
- # Load results (this will also load benchmark results)
1452
  analyzer.load_results()
1453
 
1454
  if not analyzer.results:
@@ -1463,33 +1687,32 @@ def main() -> None:
1463
  if analyzer.benchmark_results:
1464
  analyzer.analyze_benchmark_performance()
1465
  else:
1466
- logger.warning("No benchmark results found. Run benchmark.py first for complete analysis.")
1467
 
1468
  # Generate comprehensive report with benchmark integration
1469
- logger.info("Generating comprehensive report with benchmark data...")
1470
- report = analyzer.generate_comprehensive_report(args.model_name)
1471
 
1472
  # Save report
1473
- report_path = Path(args.output)
1474
  with report_path.open("w") as f:
1475
  f.write(report)
1476
 
1477
  # Export CSV if requested
1478
- if args.export_csv:
1479
- analyzer.export_results(args.export_csv)
1480
 
1481
  # Export benchmark CSV if available
1482
  if analyzer.benchmark_df is not None and not analyzer.benchmark_df.empty:
1483
- benchmark_csv = report_path.parent / f"{args.model_name}_benchmark_comparison.csv"
1484
  analyzer.benchmark_df.to_csv(benchmark_csv, index=False)
1485
  logger.info(f"📊 Benchmark comparison saved to: {benchmark_csv}")
1486
 
1487
- logger.info("✅ CodeSearchNet analysis with benchmarks complete!")
1488
  logger.info(f"📊 Report saved to: {report_path}")
1489
  logger.info(f"🖼️ Charts saved to: {images_dir}")
 
1490
 
1491
 
1492
  if __name__ == "__main__":
1493
- import argparse
1494
-
1495
  main()
 
23
  distiller analyze --results-dir evaluation_results
24
  """
25
 
 
26
  import json
27
  import logging
28
  import time
 
34
  import pandas as pd
35
  import seaborn as sns
36
 
37
+ from .config import directories
38
+
39
  # Optional Plotly import with fallback
40
  PLOTLY_AVAILABLE = True
41
  try:
 
66
  IMAGES_DIR = Path("analysis_charts")
67
  REPORT_FILE = Path("REPORT.md") # Changed from README.md
68
 
69
+ # Local directories for results - using standardized directories from config
70
+ DEFAULT_EVALUATION_DIR = directories.evaluation_results
71
+ DEFAULT_BENCHMARK_DIR = directories.benchmark_results
72
 
73
  # CodeSearchNet Languages
74
  CODE_LANGUAGES = ["python", "javascript", "java", "php", "ruby", "go"]
75
 
76
  # Model name mapping from the default models in evaluate.py and benchmark.py
77
  MODEL_NAME_MAPPING = {
78
+ # File names to display names and HuggingFace links
79
+ "all-MiniLM-L6-v2": {
80
+ "name": "sentence-transformers/all-MiniLM-L6-v2",
81
+ "link": "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
82
+ },
83
+ "all-mpnet-base-v2": {
84
+ "name": "sentence-transformers/all-mpnet-base-v2",
85
+ "link": "https://huggingface.co/sentence-transformers/all-mpnet-base-v2",
86
+ },
87
+ "paraphrase-MiniLM-L6-v2": {
88
+ "name": "sentence-transformers/paraphrase-MiniLM-L6-v2",
89
+ "link": "https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2",
90
+ },
91
+ "codebert-base": {"name": "microsoft/codebert-base", "link": "https://huggingface.co/microsoft/codebert-base"},
92
+ "graphcodebert-base": {
93
+ "name": "microsoft/graphcodebert-base",
94
+ "link": "https://huggingface.co/microsoft/graphcodebert-base",
95
+ },
96
+ "CodeBERTa-small-v1": {
97
+ "name": "huggingface/CodeBERTa-small-v1",
98
+ "link": "https://huggingface.co/huggingface/CodeBERTa-small-v1",
99
+ },
100
+ "all-MiniLM-L12-v2": {
101
+ "name": "sentence-transformers/all-MiniLM-L12-v2",
102
+ "link": "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2",
103
+ },
104
+ "potion-base-8M": {"name": "minishlab/potion-base-8M", "link": "https://huggingface.co/minishlab/potion-base-8M"},
105
+ "potion-retrieval-32M": {
106
+ "name": "minishlab/potion-retrieval-32M",
107
+ "link": "https://huggingface.co/minishlab/potion-retrieval-32M",
108
+ },
109
+ "codet5-base": {"name": "Salesforce/codet5-base", "link": "https://huggingface.co/Salesforce/codet5-base"},
110
+ "gte-Qwen2-1.5B-instruct": {
111
+ "name": "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
112
+ "link": "https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct",
113
+ },
114
+ "bge-m3": {"name": "BAAI/bge-m3", "link": "https://huggingface.co/BAAI/bge-m3"},
115
+ "jina-embeddings-v3": {
116
+ "name": "jinaai/jina-embeddings-v3",
117
+ "link": "https://huggingface.co/jinaai/jina-embeddings-v3",
118
+ },
119
+ "nomic-embed-text-v2-moe": {
120
+ "name": "nomic-ai/nomic-embed-text-v2-moe",
121
+ "link": "https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe",
122
+ },
123
+ "Qodo-Embed-1-1.5B": {"name": "Qodo/Qodo-Embed-1-1.5B", "link": "https://huggingface.co/Qodo/Qodo-Embed-1-1.5B"},
124
+ "Reason-ModernColBERT": {
125
+ "name": "lightonai/Reason-ModernColBERT",
126
+ "link": "https://huggingface.co/lightonai/Reason-ModernColBERT",
127
+ },
128
+ "Linq-Embed-Mistral": {
129
+ "name": "Linq-AI-Research/Linq-Embed-Mistral",
130
+ "link": "https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral",
131
+ },
132
+ "bge-code-v1": {"name": "BAAI/bge-code-v1", "link": "https://huggingface.co/BAAI/bge-code-v1"},
133
+ "SFR-Embedding-Code-2B_R": {
134
+ "name": "Salesforce/SFR-Embedding-Code-2B_R",
135
+ "link": "https://huggingface.co/Salesforce/SFR-Embedding-Code-2B_R",
136
+ },
137
  }
138
 
139
+ # Reverse mapping for lookups - using just the names
140
+ DISPLAY_NAME_TO_FILE = {v["name"]: k for k, v in MODEL_NAME_MAPPING.items()}
141
 
142
  # Peer models for comparison (code-specialized models)
143
  PEER_MODELS = {
144
+ "sentence-transformers/all-MiniLM-L6-v2": {
145
+ "overall_ndcg": 0.25,
146
+ "type": "General",
147
+ "link": "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
148
+ },
149
+ "microsoft/codebert-base": {
150
+ "overall_ndcg": 0.32,
151
+ "type": "Code-Specific",
152
+ "link": "https://huggingface.co/microsoft/codebert-base",
153
+ },
154
+ "microsoft/graphcodebert-base": {
155
+ "overall_ndcg": 0.35,
156
+ "type": "Code-Specific",
157
+ "link": "https://huggingface.co/microsoft/graphcodebert-base",
158
+ },
159
+ "huggingface/CodeBERTa-small-v1": {
160
+ "overall_ndcg": 0.28,
161
+ "type": "Code-Specific",
162
+ "link": "https://huggingface.co/huggingface/CodeBERTa-small-v1",
163
+ },
164
+ "sentence-transformers/all-mpnet-base-v2": {
165
+ "overall_ndcg": 0.27,
166
+ "type": "General",
167
+ "link": "https://huggingface.co/sentence-transformers/all-mpnet-base-v2",
168
+ },
169
  }
170
 
171
  # Model specifications for efficiency analysis
172
  MODEL_SPECS = {
173
+ "sentence-transformers/all-MiniLM-L6-v2": {
174
+ "parameters": 22.7,
175
+ "size_mb": 90,
176
+ "link": "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
177
+ },
178
+ "microsoft/codebert-base": {
179
+ "parameters": 125.0,
180
+ "size_mb": 500,
181
+ "link": "https://huggingface.co/microsoft/codebert-base",
182
+ },
183
+ "microsoft/graphcodebert-base": {
184
+ "parameters": 125.0,
185
+ "size_mb": 500,
186
+ "link": "https://huggingface.co/microsoft/graphcodebert-base",
187
+ },
188
+ "huggingface/CodeBERTa-small-v1": {
189
+ "parameters": 84.0,
190
+ "size_mb": 340,
191
+ "link": "https://huggingface.co/huggingface/CodeBERTa-small-v1",
192
+ },
193
+ "sentence-transformers/all-mpnet-base-v2": {
194
+ "parameters": 109.0,
195
+ "size_mb": 440,
196
+ "link": "https://huggingface.co/sentence-transformers/all-mpnet-base-v2",
197
+ },
198
+ "Alibaba-NLP/gte-Qwen2-1.5B-instruct": {
199
+ "parameters": 1500.0,
200
+ "size_mb": 3000,
201
+ "link": "https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct",
202
+ },
203
  }
204
 
205
  # Distilled model specifications
 
227
  images_dir = base_path / "analysis_results" / "charts"
228
  reports_dir = base_path / "analysis_results" / "reports"
229
  else:
230
+ output_dir = Path() # Use current directory
231
+ images_dir = IMAGES_DIR # Use analysis_charts
232
+ reports_dir = Path() # Use current directory for reports
233
 
234
+ # Only create directories that we actually use
235
  images_dir.mkdir(parents=True, exist_ok=True)
 
236
 
237
  return output_dir, images_dir, reports_dir
238
 
 
244
 
245
  # Check if it's in our mapping
246
  if name in MODEL_NAME_MAPPING:
247
+ return MODEL_NAME_MAPPING[name]["name"]
248
 
249
  # Try to find partial matches
250
+ for file_key, model_info in MODEL_NAME_MAPPING.items():
251
  if file_key in name or name in file_key:
252
+ return model_info["name"]
253
 
254
  # If no mapping found, return the cleaned name
255
  return name
256
 
257
 
258
+ def get_model_link(model_name: str) -> str:
259
+ """Get HuggingFace link for a model."""
260
+ # First try direct lookup by file key
261
+ for model_info in MODEL_NAME_MAPPING.values():
262
+ if model_info["name"] == model_name:
263
+ return model_info["link"]
264
+
265
+ # Try partial matches
266
+ for model_info in MODEL_NAME_MAPPING.values():
267
+ if model_name.lower() in model_info["name"].lower() or model_info["name"].lower() in model_name.lower():
268
+ return model_info["link"]
269
+
270
+ # If no mapping found, construct link from model name
271
+ if "/" in model_name:
272
+ return f"https://huggingface.co/{model_name}"
273
+ return ""
274
+
275
+
276
+ def format_model_with_link(model_name: str) -> str:
277
+ """Format model name with markdown link."""
278
+ link = get_model_link(model_name)
279
+ if link:
280
+ return f"[{model_name}]({link})"
281
+ return model_name
282
+
283
+
284
+ def get_teacher_model_info(model_display_name: str) -> tuple[str, str]:
285
+ """Extract teacher model name and link from distilled model display name."""
286
+ # Mapping from model display patterns to teacher models
287
+ teacher_mapping = {
288
+ "all_MiniLM_L6_v2": (
289
+ "sentence-transformers/all-MiniLM-L6-v2",
290
+ "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
291
+ ),
292
+ "all_mpnet_base_v2": (
293
+ "sentence-transformers/all-mpnet-base-v2",
294
+ "https://huggingface.co/sentence-transformers/all-mpnet-base-v2",
295
+ ),
296
+ "paraphrase_MiniLM_L6_v2": (
297
+ "sentence-transformers/paraphrase-MiniLM-L6-v2",
298
+ "https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2",
299
+ ),
300
+ "codebert_base": ("microsoft/codebert-base", "https://huggingface.co/microsoft/codebert-base"),
301
+ "graphcodebert_base": ("microsoft/graphcodebert-base", "https://huggingface.co/microsoft/graphcodebert-base"),
302
+ "gte_Qwen2_1.5B_instruct": (
303
+ "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
304
+ "https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct",
305
+ ),
306
+ "bge_m3": ("BAAI/bge-m3", "https://huggingface.co/BAAI/bge-m3"),
307
+ "jina_embeddings_v3": ("jinaai/jina-embeddings-v3", "https://huggingface.co/jinaai/jina-embeddings-v3"),
308
+ "nomic_embed_text_v2_moe": (
309
+ "nomic-ai/nomic-embed-text-v2-moe",
310
+ "https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe",
311
+ ),
312
+ "Qodo_Embed_1_1.5B": ("Qodo/Qodo-Embed-1-1.5B", "https://huggingface.co/Qodo/Qodo-Embed-1-1.5B"),
313
+ "Reason_ModernColBERT": (
314
+ "lightonai/Reason-ModernColBERT",
315
+ "https://huggingface.co/lightonai/Reason-ModernColBERT",
316
+ ),
317
+ "Linq_Embed_Mistral": (
318
+ "Linq-AI-Research/Linq-Embed-Mistral",
319
+ "https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral",
320
+ ),
321
+ "bge_code_v1": ("BAAI/bge-code-v1", "https://huggingface.co/BAAI/bge-code-v1"),
322
+ "SFR_Embedding_Code_2B_R": (
323
+ "Salesforce/SFR-Embedding-Code-2B_R",
324
+ "https://huggingface.co/Salesforce/SFR-Embedding-Code-2B_R",
325
+ ),
326
+ }
327
+
328
+ for pattern, (teacher_name, teacher_link) in teacher_mapping.items():
329
+ if pattern in model_display_name:
330
+ return teacher_name, teacher_link
331
+
332
+ return "Unknown", ""
333
+
334
+
335
  class CodeSearchNetAnalyzer:
336
  """Analyzer for CodeSearchNet evaluation results and performance benchmarks."""
337
 
 
351
  self.benchmark_df: pd.DataFrame | None = None
352
 
353
  def load_benchmark_results(self) -> None:
354
+ """Load benchmark results from comprehensive evaluation files."""
355
+ logger.info("📊 Loading benchmark results from comprehensive evaluations...")
356
 
357
+ if not self.results_dir.exists():
358
+ logger.warning(f"Evaluation directory not found: {self.results_dir}")
359
  return
360
 
361
+ logger.info(f"🔍 Searching for comprehensive evaluation files in: {self.results_dir}")
 
 
362
 
363
+ # Look for both new comprehensive format and legacy formats
364
+ comprehensive_files = list(self.results_dir.glob("comprehensive_eval_*.json"))
365
+ legacy_files = list(self.results_dir.glob("codesearchnet_eval_*.json"))
366
+
367
+ all_files = comprehensive_files + legacy_files
368
+ logger.info(
369
+ f"📁 Found {len(all_files)} evaluation files ({len(comprehensive_files)} comprehensive, {len(legacy_files)} legacy)"
370
+ )
371
+
372
+ for eval_file_path in all_files:
373
  try:
374
+ logger.info(f"📖 Loading: {eval_file_path.name}")
375
+ with eval_file_path.open() as f:
376
  data = json.load(f)
377
+
378
  if data is not None:
379
+ if not isinstance(data, dict):
380
+ logger.warning(f"⚠️ Skipping {eval_file_path.name} (not a dict)")
381
+ continue
382
+
383
+ # Extract benchmark data if available
384
+ benchmark_data = self._extract_benchmark_data(data, eval_file_path)
385
+ if benchmark_data:
386
+ self.benchmark_results.append(benchmark_data)
387
+ logger.info(f"✅ Successfully loaded benchmark data: {benchmark_data['model_name']}")
388
+
389
  except (json.JSONDecodeError, KeyError) as e:
390
+ logger.warning(f"❌ Failed to load {eval_file_path}: {e}")
391
 
392
  logger.info(f"📊 Total benchmark results loaded: {len(self.benchmark_results)}")
393
  if self.benchmark_results:
 
396
 
397
  self._create_benchmark_dataframe()
398
 
399
+ def _extract_benchmark_data(self, data: dict, file_path: Path) -> dict[str, Any] | None:
400
+ """Extract benchmark data from comprehensive evaluation results."""
401
+ # Check if this evaluation contains benchmark data
402
+ if data.get("benchmark_skipped", False):
403
+ return None
404
+
405
+ # Check for benchmark fields
406
+ if not any(key in data for key in ["size_metrics", "speed_benchmarks", "memory_benchmarks", "cpu_vs_gpu"]):
407
+ return None
408
+
409
+ # Extract model name
410
+ original_name = data.get("model_name") or "Unknown"
411
+ mapped_name = extract_model_name_from_filename(
412
+ file_path.stem.replace("comprehensive_eval_", "").replace("codesearchnet_eval_", "")
413
+ )
414
+
415
+ # Create benchmark result structure
416
+ result: dict[str, Any] = {
417
+ "model_name": mapped_name,
418
+ "original_model_name": original_name,
419
+ "size_metrics": data.get("size_metrics", {}),
420
+ "speed_benchmarks": data.get("speed_benchmarks", {}),
421
+ "memory_benchmarks": data.get("memory_benchmarks", {}),
422
+ "cpu_vs_gpu": data.get("cpu_vs_gpu", {}),
423
+ }
424
+
425
+ return result
426
+
427
  def _create_benchmark_dataframe(self) -> None:
428
  """Create benchmark comparison DataFrame from results."""
429
  if not self.benchmark_results:
 
470
  )
471
 
472
  # CPU vs GPU comparison
473
+ for device, metrics in cpu_vs_gpu.items():
474
+ if isinstance(metrics, dict) and "error" not in metrics:
475
  device_key = f"{device.upper()}_TextsPerSec"
476
+ row[device_key] = metrics.get("texts_per_second", 0)
477
 
478
  benchmark_data.append(row)
479
 
 
488
  return
489
 
490
  logger.info(f"🔍 Searching for evaluation files in: {self.results_dir}")
 
 
491
 
492
+ # Look for both new comprehensive format and legacy formats
493
+ comprehensive_files = list(self.results_dir.glob("comprehensive_eval_*.json"))
494
+ legacy_files = list(self.results_dir.glob("codesearchnet_eval_*.json"))
495
+
496
+ all_files = comprehensive_files + legacy_files
497
+ logger.info(
498
+ f"📁 Found {len(all_files)} evaluation files ({len(comprehensive_files)} comprehensive, {len(legacy_files)} legacy)"
499
+ )
500
+
501
+ for json_file in all_files:
502
  try:
503
  logger.info(f"📖 Loading: {json_file.name}")
504
  with json_file.open() as f:
505
  data = json.load(f)
506
  if data is not None:
507
+ if not isinstance(data, dict):
508
+ logger.warning(f"⚠️ Skipping {json_file.name} (not a dict)")
509
+ continue
510
+
511
+ # Normalize data format for analysis
512
+ normalized_data = self._normalize_evaluation_data(data, json_file)
513
+ self.results.append(normalized_data)
514
+ logger.info(f"✅ Successfully loaded: {normalized_data['model_name']}")
515
+
516
  except (json.JSONDecodeError, KeyError) as e:
517
  logger.warning(f"❌ Failed to load {json_file}: {e}")
518
 
 
526
  # Also load benchmark results
527
  self.load_benchmark_results()
528
 
529
+ def _normalize_evaluation_data(self, data: dict, file_path: Path) -> dict[str, Any]:
530
+ """Normalize evaluation data to consistent format for analysis."""
531
+ # Extract model name
532
+ original_name = data.get("model_name", "Unknown")
533
+ file_stem = file_path.stem.replace("comprehensive_eval_", "").replace("codesearchnet_eval_", "")
534
+ mapped_name = extract_model_name_from_filename(file_stem)
535
+
536
+ # Handle comprehensive format (new)
537
+ if "codesearch_overall" in data and "codesearch_languages" in data:
538
+ result = {
539
+ "model_name": mapped_name,
540
+ "original_model_name": original_name,
541
+ "overall": data.get("codesearch_overall", {}),
542
+ "languages": data.get("codesearch_languages", {}),
543
+ }
544
+ # Handle legacy format (old codesearchnet_eval files)
545
+ else:
546
+ result = {
547
+ "model_name": mapped_name,
548
+ "original_model_name": original_name,
549
+ "overall": data.get("overall", {}),
550
+ "languages": data.get("languages", {}),
551
+ }
552
+
553
+ return result
554
+
555
  def _create_comparison_dataframe(self) -> None:
556
  """Create comparison DataFrame from results."""
557
  if not self.results:
 
694
  if cpu_vs_gpu:
695
  print("🖥️ CPU vs GPU:")
696
  for device, metrics in cpu_vs_gpu.items():
697
+ if isinstance(metrics, dict) and "error" not in metrics:
698
  print(f" {device.upper()}: {metrics.get('texts_per_second', 0):.1f} texts/sec")
699
 
700
  # Memory efficiency
 
1219
  # Safe conversion to float for pandas values
1220
  score_value = pd.to_numeric(current_model_score, errors="coerce")
1221
  scores.append(float(score_value) if not pd.isna(score_value) else 0.0)
1222
+ params.append(float(MODEL_SPECS[model_key].get("parameters", 100.0)))
1223
  is_user_model.append(False)
1224
 
1225
  if not models:
 
1339
 
1340
  # Create visualizations
1341
  logger.info("Generating visualizations...")
1342
+ output_dir, images_dir, reports_dir = setup_directories()
1343
 
1344
  self.create_performance_radar_chart(main_model_name, language_scores)
1345
  comparison_chart = self.plot_model_comparison()
 
1404
  overall_metrics = result.get("overall", {})
1405
 
1406
  # Extract teacher model name from model name
1407
+ teacher_name, teacher_link = get_teacher_model_info(model_display)
 
 
 
 
 
 
 
 
 
 
1408
 
1409
  status = "🥇 Best" if rank == 1 else "🥈 2nd" if rank == 2 else "🥉 3rd" if rank == 3 else f"#{rank}"
1410
 
1411
+ # Use linked teacher name if available
1412
+ teacher_display = f"[{teacher_name}]({teacher_link})" if teacher_link else teacher_name
1413
+
1414
+ report += f"| {model_display} | {teacher_display} | {overall_metrics.get('ndcg@10', 0):.4f} | {overall_metrics.get('mrr', 0):.4f} | {overall_metrics.get('recall@5', 0):.4f} | {status} |\n"
1415
 
1416
  report += """
1417
 
 
1449
  report += "### Individual Model Performance by Language\n\n"
1450
  for chart_model_name, chart_path in individual_radar_charts.items():
1451
  # Extract teacher name for cleaner display
1452
+ teacher_name, teacher_link = get_teacher_model_info(chart_model_name)
1453
+
1454
+ # Use linked teacher name if available
1455
+ teacher_display = f"[{teacher_name}]({teacher_link})" if teacher_link else teacher_name
1456
+
1457
+ report += f"#### {chart_model_name} (Teacher: {teacher_display})\n\n"
 
 
 
 
 
 
 
1458
  report += f"![{chart_model_name} Radar Chart]({chart_path})\n\n"
1459
 
1460
  report += f"""
 
1551
 
1552
  if language_scores:
1553
  report += "| Language | Best Model Performance | Average Performance | Language Difficulty |\n"
1554
+ report += "|----------|------------------------|--------------------|--------------------|\n"
1555
 
1556
  for lang in sorted(language_scores.keys()):
1557
  # Find best performance for this language across all models
 
1585
  model_name = result["model_name"]
1586
  score = result.get("overall", {}).get("ndcg@10", 0)
1587
 
1588
+ teacher_name, teacher_link = get_teacher_model_info(model_name)
1589
+ teacher_performance[teacher_name] = score
 
 
 
 
 
 
 
 
1590
 
1591
  if teacher_performance:
1592
  best_teacher = max(teacher_performance.items(), key=lambda x: x[1])
 
1616
  - **Evaluation**: Retrieval of correct code for each documentation query
1617
 
1618
  ### Teacher Models Tested
1619
+ - [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) (proven baseline)
1620
+ - [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) (general purpose)
1621
+ - [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) (paraphrase model)
1622
+ - [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base) (code-specialized)
1623
+ - [microsoft/graphcodebert-base](https://huggingface.co/microsoft/graphcodebert-base) (graph-aware code model)
1624
+ - [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) (instruction model)
1625
+ - [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) (multilingual model)
1626
+ - [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) (modern embedding model)
1627
+ - [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) (mixture of experts)
1628
+ - [Qodo/Qodo-Embed-1-1.5B](https://huggingface.co/Qodo/Qodo-Embed-1-1.5B) (code-specialized)
1629
+ - [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT) (ColBERT architecture)
1630
+ - [Linq-AI-Research/Linq-Embed-Mistral](https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral) (Mistral-based)
1631
+ - [BAAI/bge-code-v1](https://huggingface.co/BAAI/bge-code-v1) (code-specialized BGE)
1632
+ - [Salesforce/SFR-Embedding-Code-2B_R](https://huggingface.co/Salesforce/SFR-Embedding-Code-2B_R) (large code model)
1633
 
1634
  ### Distillation Method
1635
  - **Technique**: Model2Vec static embedding generation
 
1652
  logger.info(f"Results exported to {output_file}")
1653
 
1654
 
1655
+ def main(
1656
+ results_dir: str = DEFAULT_EVALUATION_DIR,
1657
+ model_name: str = "code_model2vec_distilled_models",
1658
+ output: str = "REPORT.md",
1659
+ export_csv: str | None = None,
1660
+ ) -> None:
1661
  """Main analysis function."""
1662
+ logger.info("Starting CodeSearchNet Analysis with Integrated Benchmarks")
 
 
 
 
 
 
 
 
 
1663
  logger.info("=" * 60)
1664
 
1665
  # Setup output directories
1666
  output_dir, images_dir, reports_dir = setup_directories()
1667
 
1668
+ # Initialize analyzer with results directory (benchmarks are integrated)
1669
  analyzer = CodeSearchNetAnalyzer(
1670
+ results_dir=results_dir,
1671
+ benchmark_dir=None, # No longer needed - benchmarks are in comprehensive files
1672
  images_dir=images_dir,
1673
  )
1674
 
1675
+ # Load results (this will also load benchmark data from comprehensive files)
1676
  analyzer.load_results()
1677
 
1678
  if not analyzer.results:
 
1687
  if analyzer.benchmark_results:
1688
  analyzer.analyze_benchmark_performance()
1689
  else:
1690
+ logger.warning("No benchmark results found. Models may have been evaluated with --skip-benchmark flag.")
1691
 
1692
  # Generate comprehensive report with benchmark integration
1693
+ logger.info("Generating comprehensive report with integrated benchmark data...")
1694
+ report = analyzer.generate_comprehensive_report(model_name)
1695
 
1696
  # Save report
1697
+ report_path = Path(output)
1698
  with report_path.open("w") as f:
1699
  f.write(report)
1700
 
1701
  # Export CSV if requested
1702
+ if export_csv:
1703
+ analyzer.export_results(export_csv)
1704
 
1705
  # Export benchmark CSV if available
1706
  if analyzer.benchmark_df is not None and not analyzer.benchmark_df.empty:
1707
+ benchmark_csv = report_path.parent / f"{model_name}_benchmark_comparison.csv"
1708
  analyzer.benchmark_df.to_csv(benchmark_csv, index=False)
1709
  logger.info(f"📊 Benchmark comparison saved to: {benchmark_csv}")
1710
 
1711
+ logger.info("✅ CodeSearchNet analysis with integrated benchmarks complete!")
1712
  logger.info(f"📊 Report saved to: {report_path}")
1713
  logger.info(f"🖼️ Charts saved to: {images_dir}")
1714
+ logger.info(f"💾 Source: Comprehensive evaluation files in {results_dir}")
1715
 
1716
 
1717
  if __name__ == "__main__":
 
 
1718
  main()
src/distiller/beam_utils.py CHANGED
@@ -16,6 +16,7 @@ Features:
16
  import json
17
  import logging
18
  import shutil
 
19
  import time
20
  from pathlib import Path
21
  from typing import Any
@@ -204,7 +205,7 @@ class BeamVolumeManager:
204
 
205
 
206
  class BeamCheckpointManager:
207
- """Manager for checkpoint operations on Beam volumes."""
208
 
209
  def __init__(self, volume_manager: BeamVolumeManager, checkpoint_prefix: str = "checkpoints") -> None:
210
  """
@@ -216,14 +217,21 @@ class BeamCheckpointManager:
216
  """
217
  self.volume = volume_manager
218
  self.checkpoint_prefix = checkpoint_prefix
219
- self.checkpoint_dir = self.volume.mount_path / checkpoint_prefix
220
- self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
221
 
222
  def save_checkpoint(self, stage: str, data: dict[str, Any], step: int = 0) -> bool:
223
- """Save checkpoint to volume."""
224
  try:
 
225
  checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
226
- checkpoint_path = self.checkpoint_dir / checkpoint_filename
227
 
228
  with checkpoint_path.open("w") as f:
229
  json.dump(data, f, indent=2, default=str)
@@ -236,10 +244,11 @@ class BeamCheckpointManager:
236
  return False
237
 
238
  def load_checkpoint(self, stage: str, step: int = 0) -> dict[str, Any] | None:
239
- """Load checkpoint from volume."""
240
  try:
 
241
  checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
242
- checkpoint_path = self.checkpoint_dir / checkpoint_filename
243
 
244
  if checkpoint_path.exists():
245
  with checkpoint_path.open("r") as f:
@@ -257,11 +266,13 @@ class BeamCheckpointManager:
257
  def get_latest_checkpoint(self, stage: str) -> tuple[int, dict[str, Any]] | None:
258
  """Get the latest checkpoint for a stage."""
259
  try:
 
 
260
  # Find checkpoint files for this stage
261
  pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
262
  stage_checkpoints: list[tuple[int, Path]] = []
263
 
264
- for checkpoint_file in self.checkpoint_dir.glob(pattern):
265
  try:
266
  # Extract step number from filename
267
  step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
@@ -290,11 +301,13 @@ class BeamCheckpointManager:
290
  def cleanup_old_checkpoints(self, stage: str, keep_latest: int = 3) -> list[str]:
291
  """Clean up old checkpoints, keeping only the latest N."""
292
  try:
 
 
293
  # Find checkpoint files for this stage
294
  pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
295
  stage_checkpoints: list[tuple[int, Path]] = []
296
 
297
- for checkpoint_file in self.checkpoint_dir.glob(pattern):
298
  try:
299
  step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
300
  step = int(step_str)
@@ -329,29 +342,55 @@ class BeamCheckpointManager:
329
  """List all checkpoints, optionally filtered by stage."""
330
  try:
331
  checkpoints: list[dict[str, Any]] = []
332
- pattern = f"{self.checkpoint_prefix}_*.json"
333
 
334
- for checkpoint_file in self.checkpoint_dir.glob(pattern):
335
- # Parse checkpoint info
336
- name_parts = checkpoint_file.stem.split("_")
337
- if len(name_parts) >= 4:
338
- checkpoint_stage = name_parts[1]
339
- try:
340
- step = int(name_parts[3])
341
- except ValueError:
342
- step = 0
 
 
 
343
 
344
- if stage is None or checkpoint_stage == stage:
345
  stat = checkpoint_file.stat()
346
  checkpoints.append(
347
  {
348
- "stage": checkpoint_stage,
349
  "step": step,
350
  "filename": checkpoint_file.name,
351
  "size": f"{stat.st_size / 1024:.1f}KB",
352
  "modified": time.ctime(stat.st_mtime),
353
  }
354
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
  return sorted(checkpoints, key=lambda x: (x["stage"], x["step"]))
357
 
@@ -747,6 +786,389 @@ def example_distillation_workflow() -> None:
747
  logger.info(f"Workspace info: {info}")
748
 
749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
  if __name__ == "__main__":
751
  # Example usage
752
  logging.basicConfig(level=logging.INFO)
 
16
  import json
17
  import logging
18
  import shutil
19
+ import subprocess
20
  import time
21
  from pathlib import Path
22
  from typing import Any
 
205
 
206
 
207
  class BeamCheckpointManager:
208
+ """Manager for checkpoint operations on Beam volumes with stage-based organization."""
209
 
210
  def __init__(self, volume_manager: BeamVolumeManager, checkpoint_prefix: str = "checkpoints") -> None:
211
  """
 
217
  """
218
  self.volume = volume_manager
219
  self.checkpoint_prefix = checkpoint_prefix
220
+ self.checkpoint_base_dir = self.volume.mount_path / checkpoint_prefix
221
+ self.checkpoint_base_dir.mkdir(parents=True, exist_ok=True)
222
+
223
+ def _get_stage_dir(self, stage: str) -> Path:
224
+ """Get stage-specific checkpoint directory."""
225
+ stage_dir = self.checkpoint_base_dir / stage
226
+ stage_dir.mkdir(parents=True, exist_ok=True)
227
+ return stage_dir
228
 
229
  def save_checkpoint(self, stage: str, data: dict[str, Any], step: int = 0) -> bool:
230
+ """Save checkpoint to volume in stage-specific directory."""
231
  try:
232
+ stage_dir = self._get_stage_dir(stage)
233
  checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
234
+ checkpoint_path = stage_dir / checkpoint_filename
235
 
236
  with checkpoint_path.open("w") as f:
237
  json.dump(data, f, indent=2, default=str)
 
244
  return False
245
 
246
  def load_checkpoint(self, stage: str, step: int = 0) -> dict[str, Any] | None:
247
+ """Load checkpoint from volume stage-specific directory."""
248
  try:
249
+ stage_dir = self._get_stage_dir(stage)
250
  checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
251
+ checkpoint_path = stage_dir / checkpoint_filename
252
 
253
  if checkpoint_path.exists():
254
  with checkpoint_path.open("r") as f:
 
266
  def get_latest_checkpoint(self, stage: str) -> tuple[int, dict[str, Any]] | None:
267
  """Get the latest checkpoint for a stage."""
268
  try:
269
+ stage_dir = self._get_stage_dir(stage)
270
+
271
  # Find checkpoint files for this stage
272
  pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
273
  stage_checkpoints: list[tuple[int, Path]] = []
274
 
275
+ for checkpoint_file in stage_dir.glob(pattern):
276
  try:
277
  # Extract step number from filename
278
  step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
 
301
  def cleanup_old_checkpoints(self, stage: str, keep_latest: int = 3) -> list[str]:
302
  """Clean up old checkpoints, keeping only the latest N."""
303
  try:
304
+ stage_dir = self._get_stage_dir(stage)
305
+
306
  # Find checkpoint files for this stage
307
  pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
308
  stage_checkpoints: list[tuple[int, Path]] = []
309
 
310
+ for checkpoint_file in stage_dir.glob(pattern):
311
  try:
312
  step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
313
  step = int(step_str)
 
342
  """List all checkpoints, optionally filtered by stage."""
343
  try:
344
  checkpoints: list[dict[str, Any]] = []
 
345
 
346
+ if stage:
347
+ # List checkpoints for specific stage
348
+ stage_dir = self._get_stage_dir(stage)
349
+ pattern = f"{self.checkpoint_prefix}_{stage}_*.json"
350
+
351
+ for checkpoint_file in stage_dir.glob(pattern):
352
+ name_parts = checkpoint_file.stem.split("_")
353
+ if len(name_parts) >= 4:
354
+ try:
355
+ step = int(name_parts[3])
356
+ except ValueError:
357
+ step = 0
358
 
 
359
  stat = checkpoint_file.stat()
360
  checkpoints.append(
361
  {
362
+ "stage": stage,
363
  "step": step,
364
  "filename": checkpoint_file.name,
365
  "size": f"{stat.st_size / 1024:.1f}KB",
366
  "modified": time.ctime(stat.st_mtime),
367
  }
368
  )
369
+ else:
370
+ # List checkpoints for all stages
371
+ for stage_dir in self.checkpoint_base_dir.iterdir():
372
+ if stage_dir.is_dir():
373
+ stage_name = stage_dir.name
374
+ pattern = f"{self.checkpoint_prefix}_{stage_name}_*.json"
375
+
376
+ for checkpoint_file in stage_dir.glob(pattern):
377
+ name_parts = checkpoint_file.stem.split("_")
378
+ if len(name_parts) >= 4:
379
+ try:
380
+ step = int(name_parts[3])
381
+ except ValueError:
382
+ step = 0
383
+
384
+ stat = checkpoint_file.stat()
385
+ checkpoints.append(
386
+ {
387
+ "stage": stage_name,
388
+ "step": step,
389
+ "filename": checkpoint_file.name,
390
+ "size": f"{stat.st_size / 1024:.1f}KB",
391
+ "modified": time.ctime(stat.st_mtime),
392
+ }
393
+ )
394
 
395
  return sorted(checkpoints, key=lambda x: (x["stage"], x["step"]))
396
 
 
786
  logger.info(f"Workspace info: {info}")
787
 
788
 
789
+ def download_evaluation_results_from_beam(
790
+ volume_name: str,
791
+ remote_results_dir: str = "evaluation_results",
792
+ local_results_dir: str = "code_model2vec/evaluation_results",
793
+ ) -> bool:
794
+ """
795
+ Download evaluation result files from Beam volume to local directory using beam cp.
796
+
797
+ Args:
798
+ volume_name: Name of the Beam volume
799
+ remote_results_dir: Directory path in the Beam volume containing results
800
+ local_results_dir: Local directory to download results to
801
+
802
+ Returns:
803
+ True if download successful, False otherwise
804
+ """
805
+ try:
806
+ local_path = Path(local_results_dir)
807
+ local_path.mkdir(parents=True, exist_ok=True)
808
+
809
+ # Use beam cp to download individual JSON files
810
+ remote_path = f"{volume_name}:{remote_results_dir}"
811
+
812
+ # First, list files in the remote directory
813
+ list_cmd = ["beam", "cp", "-r", "--list-only", remote_path]
814
+ try:
815
+ result = subprocess.run(list_cmd, capture_output=True, text=True, check=True) # noqa: S603
816
+ remote_files = [line.strip() for line in result.stdout.split("\n") if line.strip().endswith(".json")]
817
+ except subprocess.CalledProcessError:
818
+ logger.warning(f"Could not list files in {remote_path}")
819
+ remote_files = []
820
+
821
+ # Download each JSON file individually
822
+ downloaded_files = []
823
+ for file_name in remote_files:
824
+ if file_name.endswith(".json"):
825
+ remote_file_path = f"{volume_name}:{remote_results_dir}/{file_name}"
826
+ local_file_path = local_path / file_name
827
+
828
+ try:
829
+ download_cmd = ["beam", "cp", remote_file_path, str(local_file_path)]
830
+ subprocess.run(download_cmd, check=True, capture_output=True) # noqa: S603
831
+ downloaded_files.append(file_name)
832
+ logger.info(f"📥 Downloaded: {file_name}")
833
+
834
+ # Delete the file from Beam volume after successful download
835
+ delete_cmd = ["beam", "rm", remote_file_path]
836
+ try:
837
+ subprocess.run(delete_cmd, check=True, capture_output=True) # noqa: S603
838
+ logger.info(f"🗑️ Deleted from volume: {file_name}")
839
+ except subprocess.CalledProcessError as e:
840
+ logger.warning(f"⚠️ Could not delete {file_name} from volume: {e}")
841
+
842
+ except subprocess.CalledProcessError as e:
843
+ logger.warning(f"⚠️ Failed to download {file_name}: {e}")
844
+
845
+ if downloaded_files:
846
+ logger.info(f"✅ Downloaded {len(downloaded_files)} evaluation result files")
847
+ return True
848
+ logger.info("ℹ️ No new evaluation files to download")
849
+ return True
850
+
851
+ except Exception:
852
+ logger.exception("❌ Error downloading evaluation results from Beam")
853
+ return False
854
+
855
+
856
+ def download_specific_evaluation_file(
857
+ volume_name: str,
858
+ model_name: str,
859
+ remote_results_dir: str = "evaluation_results",
860
+ local_results_dir: str = "code_model2vec/evaluation_results",
861
+ file_prefix: str = "codesearchnet_eval",
862
+ ) -> bool:
863
+ """
864
+ Download a specific evaluation or benchmark result file from Beam volume.
865
+
866
+ Args:
867
+ volume_name: Name of the Beam volume
868
+ model_name: Name of the model whose results to download
869
+ remote_results_dir: Directory path in the Beam volume containing results
870
+ local_results_dir: Local directory to download results to
871
+ file_prefix: Prefix for the file (e.g., 'codesearchnet_eval', 'benchmark')
872
+
873
+ Returns:
874
+ True if download successful, False otherwise
875
+ """
876
+ try:
877
+ local_path = Path(local_results_dir)
878
+ local_path.mkdir(parents=True, exist_ok=True)
879
+
880
+ # Generate filename following the pattern
881
+ safe_model_name = model_name.replace("/", "_")
882
+ filename = f"{file_prefix}_{safe_model_name}.json"
883
+
884
+ remote_file_path = f"{volume_name}:{remote_results_dir}/{filename}"
885
+ local_file_path = local_path / filename
886
+
887
+ # Download the specific file
888
+ download_cmd = ["beam", "cp", remote_file_path, str(local_file_path)]
889
+ subprocess.run(download_cmd, check=True, capture_output=True) # noqa: S603
890
+
891
+ logger.info(f"📥 Downloaded {file_prefix} results for {model_name}")
892
+
893
+ # Delete the file from Beam volume after successful download
894
+ delete_cmd = ["beam", "rm", remote_file_path]
895
+ try:
896
+ subprocess.run(delete_cmd, check=True, capture_output=True) # noqa: S603
897
+ logger.info(f"🗑️ Deleted {file_prefix} results for {model_name} from volume")
898
+ except subprocess.CalledProcessError as e:
899
+ logger.warning(f"⚠️ Could not delete {filename} from volume: {e}")
900
+
901
+ return True
902
+
903
+ except subprocess.CalledProcessError:
904
+ logger.warning(f"⚠️ No {file_prefix} results found for {model_name} on Beam")
905
+ return False
906
+ except Exception:
907
+ logger.exception(f"❌ Error downloading {file_prefix} results for {model_name}")
908
+ return False
909
+
910
+
911
+ def download_model_from_beam(
912
+ volume_name: str,
913
+ model_name: str,
914
+ local_dir: str,
915
+ ) -> bool:
916
+ """
917
+ Download a model from Beam volume to local directory.
918
+
919
+ Args:
920
+ volume_name: Name of the Beam volume
921
+ model_name: Name of the model to download
922
+ local_dir: Local directory to download model to
923
+
924
+ Returns:
925
+ True if download successful, False otherwise
926
+ """
927
+ try:
928
+ local_path = Path(local_dir)
929
+ local_path.mkdir(parents=True, exist_ok=True)
930
+
931
+ # Use beam cp to download the model directory
932
+ remote_path = f"{volume_name}:models/{model_name}"
933
+ local_model_path = local_path / model_name
934
+
935
+ download_cmd = ["beam", "cp", "-r", remote_path, str(local_model_path)]
936
+ subprocess.run(download_cmd, check=True, capture_output=True) # noqa: S603
937
+
938
+ logger.info(f"📥 Downloaded model {model_name} from Beam to {local_dir}")
939
+ return True
940
+
941
+ except subprocess.CalledProcessError as e:
942
+ logger.warning(f"⚠️ Failed to download model {model_name} from Beam: {e}")
943
+ return False
944
+ except Exception:
945
+ logger.exception(f"❌ Error downloading model {model_name} from Beam")
946
+ return False
947
+
948
+
949
+ def upload_model_to_beam(
950
+ volume_name: str,
951
+ model_name: str,
952
+ local_dir: str,
953
+ ) -> bool:
954
+ """
955
+ Upload a model from local directory to Beam volume.
956
+
957
+ Args:
958
+ volume_name: Name of the Beam volume
959
+ model_name: Name for the model on Beam
960
+ local_dir: Local directory containing the model
961
+
962
+ Returns:
963
+ True if upload successful, False otherwise
964
+ """
965
+ try:
966
+ local_path = Path(local_dir)
967
+ if not local_path.exists():
968
+ logger.error(f"❌ Local model directory does not exist: {local_dir}")
969
+ return False
970
+
971
+ # Use beam cp to upload the model directory
972
+ remote_path = f"{volume_name}:models/{model_name}"
973
+
974
+ upload_cmd = ["beam", "cp", "-r", str(local_path), remote_path]
975
+ subprocess.run(upload_cmd, check=True, capture_output=True) # noqa: S603
976
+
977
+ logger.info(f"📤 Uploaded model {model_name} to Beam from {local_dir}")
978
+ return True
979
+
980
+ except subprocess.CalledProcessError as e:
981
+ logger.warning(f"⚠️ Failed to upload model {model_name} to Beam: {e}")
982
+ return False
983
+ except Exception:
984
+ logger.exception(f"❌ Error uploading model {model_name} to Beam")
985
+ return False
986
+
987
+
988
+ def download_checkpoints_from_beam(
989
+ volume_name: str,
990
+ stage: str | None = None,
991
+ remote_checkpoints_dir: str = "checkpoints",
992
+ local_checkpoints_dir: str = "code_model2vec/checkpoints",
993
+ ) -> bool:
994
+ """
995
+ Download checkpoint files from Beam volume to local directory.
996
+
997
+ Args:
998
+ volume_name: Name of the Beam volume
999
+ stage: Specific stage to download (e.g., 'distillation', 'training'), or None for all
1000
+ remote_checkpoints_dir: Directory path in the Beam volume containing checkpoints
1001
+ local_checkpoints_dir: Local directory to download checkpoints to
1002
+
1003
+ Returns:
1004
+ True if download successful, False otherwise
1005
+ """
1006
+ try:
1007
+ local_path = Path(local_checkpoints_dir)
1008
+ local_path.mkdir(parents=True, exist_ok=True)
1009
+
1010
+ # Build the pattern for files to download
1011
+ if stage:
1012
+ local_stage_dir = local_path / stage
1013
+ local_stage_dir.mkdir(parents=True, exist_ok=True)
1014
+ else:
1015
+ pass
1016
+
1017
+ # Use beam cp to download checkpoint files
1018
+ remote_path = f"{volume_name}:{remote_checkpoints_dir}"
1019
+
1020
+ # First, try to list files
1021
+ list_cmd = ["beam", "cp", "-r", "--list-only", remote_path]
1022
+ try:
1023
+ result = subprocess.run(list_cmd, capture_output=True, text=True, check=True) # noqa: S603
1024
+ remote_files = [
1025
+ line.strip()
1026
+ for line in result.stdout.split("\n")
1027
+ if line.strip().endswith(".json") and "checkpoints_" in line.strip()
1028
+ ]
1029
+ except subprocess.CalledProcessError:
1030
+ logger.warning(f"Could not list checkpoint files in {remote_path}")
1031
+ remote_files = []
1032
+
1033
+ # Filter by stage if specified
1034
+ if stage:
1035
+ remote_files = [f for f in remote_files if f"checkpoints_{stage}_" in f]
1036
+
1037
+ # Download each checkpoint file
1038
+ downloaded_files = []
1039
+ for file_name in remote_files:
1040
+ remote_file_path = f"{volume_name}:{remote_checkpoints_dir}/{file_name}"
1041
+
1042
+ # Determine local subdirectory based on checkpoint stage
1043
+ file_stage = file_name.split("_")[1] if "_" in file_name else "unknown"
1044
+ local_stage_dir = local_path / file_stage
1045
+ local_stage_dir.mkdir(parents=True, exist_ok=True)
1046
+ local_file_path = local_stage_dir / file_name
1047
+
1048
+ try:
1049
+ download_cmd = ["beam", "cp", remote_file_path, str(local_file_path)]
1050
+ subprocess.run(download_cmd, check=True, capture_output=True) # noqa: S603
1051
+ downloaded_files.append(file_name)
1052
+ logger.info(f"📥 Downloaded checkpoint: {file_name}")
1053
+
1054
+ except subprocess.CalledProcessError as e:
1055
+ logger.warning(f"⚠️ Failed to download checkpoint {file_name}: {e}")
1056
+
1057
+ if downloaded_files:
1058
+ logger.info(f"✅ Downloaded {len(downloaded_files)} checkpoint files")
1059
+ return True
1060
+ logger.info("ℹ️ No new checkpoint files to download")
1061
+ return True
1062
+
1063
+ except Exception:
1064
+ logger.exception("❌ Error downloading checkpoints from Beam")
1065
+ return False
1066
+
1067
+
1068
+ def upload_checkpoints_to_beam(
1069
+ volume_name: str,
1070
+ stage: str | None = None,
1071
+ local_checkpoints_dir: str = "code_model2vec/checkpoints",
1072
+ remote_checkpoints_dir: str = "checkpoints",
1073
+ ) -> bool:
1074
+ """
1075
+ Upload checkpoint files from local directory to Beam volume.
1076
+
1077
+ Args:
1078
+ volume_name: Name of the Beam volume
1079
+ stage: Specific stage to upload (e.g., 'distillation', 'training'), or None for all
1080
+ local_checkpoints_dir: Local directory containing checkpoints
1081
+ remote_checkpoints_dir: Directory path in the Beam volume to store checkpoints
1082
+
1083
+ Returns:
1084
+ True if upload successful, False otherwise
1085
+ """
1086
+ try:
1087
+ local_path = Path(local_checkpoints_dir)
1088
+ if not local_path.exists():
1089
+ logger.warning(f"⚠️ Local checkpoints directory does not exist: {local_checkpoints_dir}")
1090
+ return True # Not an error - no checkpoints to upload
1091
+
1092
+ # Find checkpoint files to upload
1093
+ if stage:
1094
+ # Look in the stage subdirectory
1095
+ stage_dir = local_path / stage
1096
+ checkpoint_files = list(stage_dir.glob(f"checkpoints_{stage}_*.json")) if stage_dir.exists() else []
1097
+ else:
1098
+ # Look for all checkpoint files in all subdirectories
1099
+ checkpoint_files = []
1100
+ for subdir in local_path.iterdir():
1101
+ if subdir.is_dir():
1102
+ checkpoint_files.extend(subdir.glob("checkpoints_*.json"))
1103
+
1104
+ if not checkpoint_files:
1105
+ logger.info(f"ℹ️ No checkpoint files found to upload for stage: {stage or 'all'}")
1106
+ return True
1107
+
1108
+ # Upload each checkpoint file
1109
+ uploaded_files = []
1110
+ for checkpoint_file in checkpoint_files:
1111
+ remote_file_path = f"{volume_name}:{remote_checkpoints_dir}/{checkpoint_file.name}"
1112
+
1113
+ try:
1114
+ upload_cmd = ["beam", "cp", str(checkpoint_file), remote_file_path]
1115
+ subprocess.run(upload_cmd, check=True, capture_output=True) # noqa: S603
1116
+ uploaded_files.append(checkpoint_file.name)
1117
+ logger.info(f"📤 Uploaded checkpoint: {checkpoint_file.name}")
1118
+
1119
+ except subprocess.CalledProcessError as e:
1120
+ logger.warning(f"⚠️ Failed to upload checkpoint {checkpoint_file.name}: {e}")
1121
+
1122
+ if uploaded_files:
1123
+ logger.info(f"✅ Uploaded {len(uploaded_files)} checkpoint files")
1124
+ return True
1125
+ return False
1126
+
1127
+ except Exception:
1128
+ logger.exception("❌ Error uploading checkpoints to Beam")
1129
+ return False
1130
+
1131
+
1132
+ def sync_checkpoints_from_beam(
1133
+ volume_name: str,
1134
+ stage: str,
1135
+ local_checkpoints_dir: str = "code_model2vec/checkpoints",
1136
+ ) -> bool:
1137
+ """
1138
+ Sync specific stage checkpoints from Beam to local directory.
1139
+
1140
+ Args:
1141
+ volume_name: Name of the Beam volume
1142
+ stage: Stage to sync (e.g., 'distillation', 'training')
1143
+ local_checkpoints_dir: Local directory for checkpoints
1144
+
1145
+ Returns:
1146
+ True if sync successful, False otherwise
1147
+ """
1148
+ logger.info(f"🔄 Syncing {stage} checkpoints from Beam...")
1149
+ return download_checkpoints_from_beam(volume_name, stage, "checkpoints", local_checkpoints_dir)
1150
+
1151
+
1152
+ def sync_checkpoints_to_beam(
1153
+ volume_name: str,
1154
+ stage: str,
1155
+ local_checkpoints_dir: str = "code_model2vec/checkpoints",
1156
+ ) -> bool:
1157
+ """
1158
+ Sync specific stage checkpoints from local directory to Beam.
1159
+
1160
+ Args:
1161
+ volume_name: Name of the Beam volume
1162
+ stage: Stage to sync (e.g., 'distillation', 'training')
1163
+ local_checkpoints_dir: Local directory containing checkpoints
1164
+
1165
+ Returns:
1166
+ True if sync successful, False otherwise
1167
+ """
1168
+ logger.info(f"🔄 Syncing {stage} checkpoints to Beam...")
1169
+ return upload_checkpoints_to_beam(volume_name, stage, local_checkpoints_dir, "checkpoints")
1170
+
1171
+
1172
  if __name__ == "__main__":
1173
  # Example usage
1174
  logging.basicConfig(level=logging.INFO)
src/distiller/benchmark.py DELETED
@@ -1,1181 +0,0 @@
1
- """
2
- Operational Performance Benchmarking for Embedding Models.
3
-
4
- This module benchmarks embedding models on operational metrics like:
5
- - Inference speed (latency and throughput)
6
- - Memory efficiency (RAM and GPU usage)
7
- - Model size and storage requirements
8
- - Scalability with batch size
9
- - CPU vs GPU performance
10
- """
11
-
12
- import gc
13
- import json
14
- import logging
15
- import os
16
- import time
17
- from pathlib import Path
18
- from typing import Any
19
-
20
- import pandas as pd
21
- import psutil
22
- import torch
23
- from beam import GpuType, Image, Volume, function
24
- from sentence_transformers import SentenceTransformer
25
-
26
- from .beam_utils import (
27
- BeamCheckpointManager,
28
- BeamEvaluationManager,
29
- create_beam_utilities,
30
- )
31
-
32
- logger = logging.getLogger(__name__)
33
-
34
- # =============================================================================
35
- # BEAM CONFIGURATION
36
- # =============================================================================
37
-
38
- GPU_NAME = GpuType.A100_40
39
- VOLUME_NAME = "gte_qwen2_m2v_code" # Same volume as distill.py and evaluate.py
40
- VOLUME_PATH = "./gte_qwen2_m2v_code" # Same mount path as distill.py and evaluate.py
41
- BENCHMARK_RESULTS_DIR = "benchmark_results" # Subdirectory within volume
42
- BENCHMARK_CACHE_DIR = "benchmark_cache" # Cache for models
43
-
44
- IMAGE = Image(python_version="python3.12").add_python_packages(
45
- [
46
- "torch>=2.7.0",
47
- "transformers>=4.40.0",
48
- "datasets>=3.2.0",
49
- "sentence-transformers>=4.1.0",
50
- "model2vec[train]>=0.5.0",
51
- "numpy>=1.26.4",
52
- "scikit-learn>=1.6.1",
53
- "pandas>=2.0.0",
54
- "tqdm>=4.65.0",
55
- "psutil>=5.9.0",
56
- ]
57
- )
58
-
59
- # =============================================================================
60
- # CONFIGURATION
61
- # =============================================================================
62
-
63
- DEFAULT_OUTPUT_DIR = "benchmark_results" # Local fallback directory
64
-
65
- # Default models to benchmark (can be overridden via command line)
66
- DEFAULT_BENCHMARK_MODELS = [
67
- # Your distilled model (local files in Beam volume root)
68
- "gte_qwen2_m2v_code", # This will be resolved to VOLUME_PATH in Beam
69
- # Established Code Models
70
- "sentence-transformers/all-MiniLM-L6-v2",
71
- "microsoft/codebert-base",
72
- "microsoft/graphcodebert-base",
73
- "huggingface/CodeBERTa-small-v1",
74
- "sentence-transformers/all-mpnet-base-v2",
75
- "sentence-transformers/all-MiniLM-L12-v2",
76
- # Model2Vec & Efficiency Models (Direct Competitors)
77
- "minishlab/potion-base-8M",
78
- "minishlab/potion-retrieval-32M",
79
- # Small Transformer-Based Code Models
80
- "Salesforce/codet5-base",
81
- ]
82
-
83
- # =============================================================================
84
- # CHECKPOINT CONFIGURATION
85
- # =============================================================================
86
-
87
- # Prevent conflicts with other modules by using unique prefixes
88
- BENCHMARK_CHECKPOINT_PREFIX = "benchmark_checkpoints"
89
- MODEL_CACHE_PREFIX = "model_cache"
90
-
91
- # Sample texts for benchmarking (various lengths)
92
- BENCHMARK_TEXTS = {
93
- "short": [
94
- "def add(a, b): return a + b",
95
- "function multiply(x, y) { return x * y; }",
96
- "class Calculator { public int subtract(int a, int b) { return a - b; } }",
97
- ]
98
- * 100, # 300 short texts
99
- "medium": [
100
- "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
101
- "function quickSort(arr) {\n if (arr.length <= 1) return arr;\n const pivot = arr[arr.length - 1];\n const left = [], right = [];\n for (let i = 0; i < arr.length - 1; i++) {\n if (arr[i] < pivot) left.push(arr[i]);\n else right.push(arr[i]);\n }\n return [...quickSort(left), pivot, ...quickSort(right)];\n}",
102
- ]
103
- * 50, # 100 medium texts
104
- "long": [
105
- """
106
- def complex_algorithm(data, config):
107
- '''
108
- Complex data processing algorithm with multiple steps.
109
-
110
- Args:
111
- data: Input data structure
112
- config: Configuration parameters
113
-
114
- Returns:
115
- Processed results
116
- '''
117
- results = []
118
-
119
- # Step 1: Data validation
120
- if not isinstance(data, (list, tuple)):
121
- raise ValueError("Data must be list or tuple")
122
-
123
- # Step 2: Preprocessing
124
- processed_data = []
125
- for item in data:
126
- if config.get('normalize', False):
127
- item = normalize_item(item)
128
- if config.get('filter', False):
129
- if not filter_item(item, config['filter_criteria']):
130
- continue
131
- processed_data.append(item)
132
-
133
- # Step 3: Main processing
134
- for item in processed_data:
135
- result = process_item(item, config)
136
- if result is not None:
137
- results.append(result)
138
-
139
- # Step 4: Post-processing
140
- if config.get('sort', False):
141
- results.sort(key=lambda x: x.get('score', 0), reverse=True)
142
-
143
- return results
144
- """.strip(),
145
- ]
146
- * 20, # 20 long texts
147
- }
148
-
149
-
150
- class PerformanceBenchmark:
151
- """Comprehensive performance benchmarking for embedding models."""
152
-
153
- def __init__(
154
- self,
155
- model_path: str,
156
- model_name: str | None = None,
157
- checkpoint_manager: BeamCheckpointManager | None = None,
158
- eval_manager: BeamEvaluationManager | None = None,
159
- ) -> None:
160
- """Initialize benchmarker with model and optional Beam utilities."""
161
- self.model_path = model_path
162
- self.model_name = model_name or Path(model_path).name
163
- self.model: SentenceTransformer | None = None
164
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
165
- self.results: dict[str, Any] = {}
166
- self.checkpoint_manager = checkpoint_manager
167
- self.eval_manager = eval_manager
168
-
169
- def load_model(self) -> None:
170
- """Load the embedding model."""
171
- logger.info(f"Loading model from {self.model_path}")
172
- start_time = time.time()
173
-
174
- try:
175
- self.model = SentenceTransformer(self.model_path, device=self.device, trust_remote_code=True)
176
- load_time = time.time() - start_time
177
-
178
- logger.info(f"✅ Model loaded in {load_time:.2f}s on {self.device}")
179
- self.results["model_load_time"] = load_time
180
-
181
- except Exception:
182
- logger.exception("❌ Failed to load model")
183
- raise
184
-
185
- def measure_model_size(self) -> dict[str, float]:
186
- """Measure model size metrics."""
187
- logger.info("📏 Measuring model size...")
188
-
189
- size_metrics = {}
190
-
191
- # Disk size - handle both local paths and HuggingFace models
192
- try:
193
- if Path(self.model_path).is_dir():
194
- # Local directory - calculate size of model files only
195
- model_extensions = {".safetensors", ".bin", ".json", ".txt", ".tokenizer"}
196
- total_size = 0
197
- model_dir = Path(self.model_path)
198
-
199
- for file_path in model_dir.rglob("*"):
200
- if file_path.is_file() and (
201
- file_path.suffix.lower() in model_extensions
202
- or file_path.name.lower() in {"config.json", "tokenizer.json", "modules.json", "README.md"}
203
- ):
204
- total_size += file_path.stat().st_size
205
-
206
- size_metrics["disk_size_mb"] = total_size / (1024 * 1024)
207
- elif Path(self.model_path).is_file():
208
- # Single file
209
- total_size = Path(self.model_path).stat().st_size
210
- size_metrics["disk_size_mb"] = total_size / (1024 * 1024)
211
- else:
212
- # HuggingFace model - estimate from cache if available
213
- from transformers import AutoConfig
214
-
215
- try:
216
- config = AutoConfig.from_pretrained(self.model_path)
217
- # Estimate size based on parameters (rough approximation)
218
- if hasattr(config, "hidden_size") and hasattr(config, "num_hidden_layers"):
219
- # Rough estimation for transformer models
220
- estimated_params = config.hidden_size * config.num_hidden_layers * 1000 # Very rough
221
- size_metrics["disk_size_mb"] = estimated_params * 4 / (1024 * 1024) # 4 bytes per float32
222
- else:
223
- size_metrics["disk_size_mb"] = 0 # Unknown
224
- except Exception:
225
- logger.warning(f"Could not determine disk size for HuggingFace model: {self.model_path}")
226
- size_metrics["disk_size_mb"] = 0 # Unknown
227
- except Exception as e:
228
- logger.warning(f"Could not determine disk size: {e}")
229
- size_metrics["disk_size_mb"] = 0
230
-
231
- # Model parameters (if accessible)
232
- try:
233
- if self.model is not None and hasattr(self.model, "modules"):
234
- total_params = sum(p.numel() for p in self.model.parameters())
235
- size_metrics["parameters_millions"] = total_params / 1_000_000
236
-
237
- # Try to get embedding dimension from model config
238
- try:
239
- # Use the public modules() method instead of private _modules
240
- modules = list(self.model.modules())
241
- if len(modules) > 1: # modules[0] is usually the entire model, modules[1] is first submodule
242
- first_module = modules[1]
243
- if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "config"):
244
- config = first_module.auto_model.config
245
- if hasattr(config, "hidden_size"):
246
- size_metrics["embedding_dim"] = config.hidden_size
247
- elif hasattr(config, "model_dim"):
248
- size_metrics["embedding_dim"] = config.model_dim
249
- except Exception as e:
250
- logger.debug(
251
- f"Could not extract embedding dimension from model config: {e}"
252
- ) # Silently continue if this method fails
253
-
254
- # For Model2Vec static models
255
- elif self.model is not None and hasattr(self.model, "embedding"):
256
- # Handle both tensor and numpy array embeddings
257
- embedding = self.model.embedding
258
- if hasattr(embedding, "shape"):
259
- vocab_size, embedding_dim = embedding.shape # type: ignore[misc]
260
- total_params = vocab_size * embedding_dim
261
- size_metrics["parameters_millions"] = total_params / 1_000_000
262
- size_metrics["vocab_size"] = vocab_size
263
- size_metrics["embedding_dim"] = embedding_dim
264
- else:
265
- logger.warning("Could not determine embedding shape for Model2Vec model")
266
-
267
- # Alternative method: get embedding dimension from a test encoding
268
- if "embedding_dim" not in size_metrics and self.model is not None:
269
- try:
270
- test_embedding = self.model.encode(["test"], convert_to_tensor=False)
271
- if hasattr(test_embedding, "shape") and len(test_embedding.shape) >= 2:
272
- size_metrics["embedding_dim"] = test_embedding.shape[1]
273
- elif (
274
- isinstance(test_embedding, (list, tuple))
275
- and len(test_embedding) > 0
276
- and hasattr(test_embedding[0], "__len__")
277
- ):
278
- size_metrics["embedding_dim"] = len(test_embedding[0])
279
- except Exception as e:
280
- logger.warning(f"Could not determine embedding dimension: {e}")
281
-
282
- except Exception as e:
283
- logger.warning(f"Could not determine parameter count: {e}")
284
-
285
- # Memory footprint
286
- if self.device == "cuda" and torch.cuda.is_available():
287
- torch.cuda.empty_cache()
288
- size_metrics["gpu_memory_mb"] = torch.cuda.memory_allocated() / (1024 * 1024)
289
-
290
- # RAM usage (approximate)
291
- process = psutil.Process(os.getpid())
292
- size_metrics["ram_usage_mb"] = process.memory_info().rss / (1024 * 1024)
293
-
294
- self.results["size_metrics"] = size_metrics
295
- return size_metrics
296
-
297
- def benchmark_inference_speed(self, batch_sizes: list[int] | None = None) -> dict[str, Any]:
298
- """Benchmark inference speed across different batch sizes."""
299
- if batch_sizes is None:
300
- batch_sizes = [1, 8, 16, 32, 64, 128]
301
- logger.info("⚡ Benchmarking inference speed...")
302
-
303
- if self.model is None:
304
- self.load_model()
305
-
306
- if self.model is None:
307
- msg = "Failed to load model"
308
- raise RuntimeError(msg)
309
-
310
- speed_results = {}
311
- text_lengths = ["short", "medium", "long"]
312
-
313
- for text_length in text_lengths:
314
- logger.info(f" 📝 Testing {text_length} texts...")
315
- texts = BENCHMARK_TEXTS[text_length]
316
-
317
- length_results = {}
318
-
319
- for batch_size in batch_sizes:
320
- if batch_size > len(texts):
321
- continue
322
-
323
- logger.info(f" 🔄 Batch size: {batch_size}")
324
-
325
- # Prepare batch
326
- batch_texts = texts[:batch_size]
327
-
328
- # Warmup
329
- if self.device == "cuda":
330
- torch.cuda.synchronize()
331
- _ = self.model.encode(batch_texts[: min(2, batch_size)], convert_to_tensor=False)
332
-
333
- # Clear cache
334
- if self.device == "cuda":
335
- torch.cuda.empty_cache()
336
- torch.cuda.synchronize()
337
-
338
- # Measure inference time
339
- start_time = time.perf_counter()
340
-
341
- embeddings = self.model.encode(batch_texts, convert_to_tensor=False, show_progress_bar=False)
342
-
343
- if self.device == "cuda":
344
- torch.cuda.synchronize()
345
-
346
- end_time = time.perf_counter()
347
-
348
- # Calculate metrics
349
- total_time = end_time - start_time
350
- time_per_text = total_time / batch_size
351
- texts_per_second = batch_size / total_time
352
-
353
- # Estimate tokens (rough approximation)
354
- avg_tokens = sum(len(text.split()) for text in batch_texts) / batch_size
355
- total_tokens = avg_tokens * batch_size
356
- tokens_per_second = total_tokens / total_time
357
-
358
- length_results[f"batch_{batch_size}"] = {
359
- "total_time_ms": total_time * 1000,
360
- "time_per_text_ms": time_per_text * 1000,
361
- "texts_per_second": texts_per_second,
362
- "tokens_per_second": tokens_per_second,
363
- "avg_tokens_per_text": avg_tokens,
364
- "embedding_shape": embeddings.shape
365
- if hasattr(embeddings, "shape")
366
- else f"({len(embeddings)}, {len(embeddings[0]) if embeddings else 0})",
367
- }
368
-
369
- speed_results[text_length] = length_results
370
-
371
- self.results["speed_benchmarks"] = speed_results
372
- return speed_results
373
-
374
- def benchmark_memory_scaling(self, batch_sizes: list[int] | None = None) -> dict[str, Any]:
375
- """Benchmark memory usage across batch sizes."""
376
- if batch_sizes is None:
377
- batch_sizes = [1, 8, 16, 32, 64, 128, 256]
378
- logger.info("💾 Benchmarking memory scaling...")
379
-
380
- if self.model is None:
381
- self.load_model()
382
-
383
- if self.model is None:
384
- msg = "Failed to load model"
385
- raise RuntimeError(msg)
386
-
387
- memory_results: dict[str, Any] = {}
388
- texts = BENCHMARK_TEXTS["medium"]
389
-
390
- baseline_memory = 0
391
- if self.device == "cuda":
392
- torch.cuda.empty_cache()
393
- baseline_memory = torch.cuda.memory_allocated()
394
-
395
- for batch_size in batch_sizes:
396
- if batch_size > len(texts):
397
- continue
398
-
399
- logger.info(f" 📊 Testing batch size: {batch_size}")
400
-
401
- # Clear cache
402
- if self.device == "cuda":
403
- torch.cuda.empty_cache()
404
- gc.collect()
405
-
406
- batch_texts = texts[:batch_size]
407
-
408
- # Measure memory before
409
- if self.device == "cuda":
410
- torch.cuda.memory_allocated()
411
-
412
- # Run inference
413
- try:
414
- embeddings = self.model.encode(
415
- batch_texts,
416
- convert_to_tensor=self.device == "cuda",
417
- show_progress_bar=False,
418
- )
419
-
420
- # Measure memory after
421
- memory_after = 0
422
- if self.device == "cuda":
423
- memory_after = torch.cuda.max_memory_allocated()
424
- torch.cuda.reset_peak_memory_stats()
425
-
426
- memory_used_mb = (memory_after - baseline_memory) / (1024 * 1024)
427
- memory_per_text_mb = memory_used_mb / batch_size if batch_size > 0 else 0
428
-
429
- memory_results[f"batch_{batch_size}"] = {
430
- "memory_used_mb": memory_used_mb,
431
- "memory_per_text_mb": memory_per_text_mb,
432
- "baseline_memory_mb": baseline_memory / (1024 * 1024),
433
- "peak_memory_mb": memory_after / (1024 * 1024),
434
- }
435
-
436
- # Clean up
437
- del embeddings
438
-
439
- except torch.cuda.OutOfMemoryError:
440
- logger.warning(f"❌ OOM at batch size {batch_size}")
441
- memory_results[f"batch_{batch_size}"] = {"oom": True}
442
- break
443
- except Exception as e:
444
- logger.warning(f"❌ Error at batch size {batch_size}: {e}")
445
- memory_results[f"batch_{batch_size}"] = {"error": str(e)}
446
-
447
- self.results["memory_benchmarks"] = memory_results
448
- return memory_results
449
-
450
- def benchmark_cpu_vs_gpu(self) -> dict[str, Any]:
451
- """Compare CPU vs GPU performance."""
452
- logger.info("🖥️ Benchmarking CPU vs GPU performance...")
453
-
454
- comparison_results = {}
455
- test_texts = BENCHMARK_TEXTS["medium"][:32] # Fixed batch size
456
-
457
- devices = ["cpu"]
458
- if torch.cuda.is_available():
459
- devices.append("cuda")
460
-
461
- for device in devices:
462
- logger.info(f" 🔄 Testing on {device}")
463
-
464
- # Load model on device
465
- try:
466
- model = SentenceTransformer(self.model_path, device=device)
467
-
468
- # Warmup
469
- _ = model.encode(test_texts[:2], convert_to_tensor=False)
470
-
471
- # Benchmark
472
- start_time = time.perf_counter()
473
- embeddings = model.encode(test_texts, convert_to_tensor=False, show_progress_bar=False)
474
- end_time = time.perf_counter()
475
-
476
- total_time = end_time - start_time
477
-
478
- comparison_results[device] = {
479
- "total_time_ms": total_time * 1000,
480
- "texts_per_second": len(test_texts) / total_time,
481
- "time_per_text_ms": (total_time / len(test_texts)) * 1000,
482
- "embedding_shape": embeddings.shape
483
- if hasattr(embeddings, "shape")
484
- else f"({len(embeddings)}, {len(embeddings[0]) if embeddings else 0})",
485
- }
486
-
487
- del model
488
- if device == "cuda":
489
- torch.cuda.empty_cache()
490
-
491
- except Exception as e:
492
- logger.warning(f"❌ Failed on {device}: {e}")
493
- comparison_results[device] = {"error": str(e)}
494
-
495
- self.results["cpu_vs_gpu"] = comparison_results
496
- return comparison_results
497
-
498
- def run_comprehensive_benchmark(self) -> dict[str, Any]:
499
- """Run all benchmarks and return comprehensive results."""
500
- logger.info(f"🚀 Starting comprehensive benchmark for {self.model_name}")
501
-
502
- # Load model
503
- self.load_model()
504
-
505
- # Run all benchmarks
506
- self.measure_model_size()
507
- self.benchmark_inference_speed()
508
- self.benchmark_memory_scaling()
509
- self.benchmark_cpu_vs_gpu()
510
-
511
- # Add metadata
512
- self.results["model_name"] = self.model_name
513
- self.results["model_path"] = self.model_path
514
- self.results["device"] = self.device
515
- self.results["torch_version"] = torch.__version__
516
- self.results["cuda_available"] = torch.cuda.is_available()
517
-
518
- if torch.cuda.is_available():
519
- self.results["gpu_name"] = torch.cuda.get_device_name(0)
520
- self.results["gpu_memory_gb"] = torch.cuda.get_device_properties(0).total_memory / (1024**3)
521
-
522
- # System info
523
- self.results["cpu_count"] = psutil.cpu_count()
524
- self.results["ram_gb"] = psutil.virtual_memory().total / (1024**3)
525
-
526
- logger.info("✅ Comprehensive benchmark completed!")
527
- return self.results
528
-
529
- def save_results(self, output_file: str) -> None:
530
- """Save benchmark results to JSON file."""
531
- output_path = Path(output_file)
532
- output_path.parent.mkdir(parents=True, exist_ok=True)
533
-
534
- with output_path.open("w") as f:
535
- json.dump(self.results, f, indent=2, default=str)
536
-
537
- logger.info(f"📄 Results saved to {output_path}")
538
-
539
- def print_summary(self) -> None:
540
- """Print a summary of benchmark results."""
541
- if not self.results:
542
- logger.warning("No results to summarize")
543
- return
544
-
545
- print(f"\n{'=' * 60}")
546
- print(f"Performance Benchmark Summary: {self.model_name}")
547
- print(f"{'=' * 60}")
548
-
549
- # Model size
550
- if "size_metrics" in self.results:
551
- size = self.results["size_metrics"]
552
- print("\n📏 Model Size:")
553
- print(f" Disk Size: {size.get('disk_size_mb', 0):.1f} MB")
554
- if "parameters_millions" in size:
555
- print(f" Parameters: {size['parameters_millions']:.1f}M")
556
- if "embedding_dim" in size:
557
- print(f" Embedding Dim: {size['embedding_dim']}")
558
-
559
- # Speed summary
560
- if "speed_benchmarks" in self.results:
561
- speed = self.results["speed_benchmarks"]
562
- print("\n⚡ Speed (medium texts, batch 32):")
563
- if "medium" in speed and "batch_32" in speed["medium"]:
564
- batch_32 = speed["medium"]["batch_32"]
565
- print(f" Throughput: {batch_32['texts_per_second']:.1f} texts/sec")
566
- print(f" Latency: {batch_32['time_per_text_ms']:.1f} ms/text")
567
- print(f" Token Speed: {batch_32['tokens_per_second']:.0f} tokens/sec")
568
-
569
- # CPU vs GPU
570
- if "cpu_vs_gpu" in self.results:
571
- comparison = self.results["cpu_vs_gpu"]
572
- print("\n🖥️ CPU vs GPU:")
573
- for device, metrics in comparison.items():
574
- if "error" not in metrics:
575
- print(f" {device.upper()}: {metrics['texts_per_second']:.1f} texts/sec")
576
-
577
- print()
578
-
579
-
580
- def run_benchmark(
581
- model_path: str | list[str],
582
- model_name: str | None = None,
583
- output: str = "benchmark_results.json",
584
- quick: bool = False,
585
- compare_models: list[str] | None = None,
586
- ) -> None:
587
- """Run benchmark for one or multiple models with comparison."""
588
- # Handle both single model and multiple models
589
- models_to_benchmark = [model_path] if isinstance(model_path, str) else model_path
590
-
591
- if compare_models:
592
- models_to_benchmark.extend(compare_models)
593
-
594
- all_results = []
595
-
596
- for i, model in enumerate(models_to_benchmark):
597
- current_model_name = model_name if i == 0 else Path(model).name
598
-
599
- print(f"\n{'=' * 60}")
600
- print(f"Benchmarking Model {i + 1}/{len(models_to_benchmark)}: {current_model_name}")
601
- print(f"{'=' * 60}")
602
-
603
- try:
604
- benchmarker = PerformanceBenchmark(model, current_model_name)
605
-
606
- if quick:
607
- # Quick benchmark
608
- benchmarker.load_model()
609
- benchmarker.measure_model_size()
610
- benchmarker.benchmark_inference_speed([1, 16, 32])
611
- else:
612
- # Comprehensive benchmark
613
- benchmarker.run_comprehensive_benchmark()
614
-
615
- all_results.append(benchmarker.results)
616
- benchmarker.print_summary()
617
-
618
- except Exception:
619
- logger.exception(f"❌ Failed to benchmark {current_model_name}")
620
- continue
621
-
622
- # Save individual results
623
- output_dir = Path(output).parent if Path(output).suffix else Path(output)
624
- output_dir.mkdir(parents=True, exist_ok=True)
625
-
626
- for results in all_results:
627
- model_name_safe = "".join(c for c in results["model_name"] if c.isalnum() or c in ("-", "_", "."))
628
- output_path = output_dir / f"benchmark_{model_name_safe}.json"
629
-
630
- with output_path.open("w") as f:
631
- json.dump(results, f, indent=2, default=str)
632
-
633
- logger.info(f"📄 Results saved to {output_path}")
634
-
635
- # Create comparison if multiple models
636
- if len(all_results) > 1:
637
- create_benchmark_comparison(all_results, str(output_dir / "benchmark_comparison.json"))
638
-
639
- print(f"\n✅ Benchmark complete! Results saved to {output_dir}")
640
-
641
-
642
- def create_benchmark_comparison(all_results: list[dict[str, Any]], output_path: str) -> None:
643
- """Create a comparison report for multiple benchmark results."""
644
- print(f"\n{'=' * 80}")
645
- print("Performance Benchmark Comparison")
646
- print(f"{'=' * 80}")
647
-
648
- comparison_data = []
649
-
650
- for results in all_results:
651
- model_name = results.get("model_name", "Unknown")
652
- size_metrics = results.get("size_metrics", {})
653
- speed_benchmarks = results.get("speed_benchmarks", {})
654
- cpu_vs_gpu = results.get("cpu_vs_gpu", {})
655
-
656
- # Extract key metrics
657
- row = {
658
- "Model": model_name,
659
- "Disk Size (MB)": size_metrics.get("disk_size_mb", 0),
660
- "Parameters (M)": size_metrics.get("parameters_millions", 0),
661
- "Embedding Dim": size_metrics.get("embedding_dim", 0),
662
- }
663
-
664
- # Speed metrics (medium texts, batch 32)
665
- if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
666
- batch_32 = speed_benchmarks["medium"]["batch_32"]
667
- row.update(
668
- {
669
- "Throughput (texts/sec)": batch_32.get("texts_per_second", 0),
670
- "Latency (ms/text)": batch_32.get("time_per_text_ms", 0),
671
- "Token Speed (tokens/sec)": batch_32.get("tokens_per_second", 0),
672
- }
673
- )
674
-
675
- # CPU vs GPU comparison
676
- for device in ["cpu", "cuda"]:
677
- if device in cpu_vs_gpu and "error" not in cpu_vs_gpu[device]:
678
- row[f"{device.upper()} Speed (texts/sec)"] = cpu_vs_gpu[device].get("texts_per_second", 0)
679
-
680
- comparison_data.append(row)
681
-
682
- # Create DataFrame and save
683
- df = pd.DataFrame(comparison_data)
684
-
685
- # Sort by throughput (descending)
686
- if "Throughput (texts/sec)" in df.columns:
687
- df = df.sort_values("Throughput (texts/sec)", ascending=False)
688
-
689
- # Print comparison table
690
- print(df.to_string(index=False, float_format="%.2f"))
691
-
692
- # Save comparison results
693
- comparison_summary = {
694
- "comparison_table": df.to_dict(orient="records"),
695
- "summary": {
696
- "fastest_model": df.iloc[0]["Model"] if len(df) > 0 else None,
697
- "smallest_model": df.loc[df["Disk Size (MB)"].idxmin()]["Model"] if len(df) > 0 else None,
698
- "most_efficient": df.loc[df["Throughput (texts/sec)"].idxmax()]["Model"]
699
- if "Throughput (texts/sec)" in df.columns and len(df) > 0
700
- else None,
701
- },
702
- "timestamp": time.time(),
703
- }
704
-
705
- with Path(output_path).open("w") as f:
706
- json.dump(comparison_summary, f, indent=2, default=str)
707
-
708
- print(f"\n📊 Comparison saved to {output_path}")
709
-
710
-
711
- def save_benchmark_results(
712
- results: dict[str, Any],
713
- output_dir: str,
714
- model_name: str,
715
- volume_results_dir: Path | None = None,
716
- ) -> None:
717
- """Save benchmark results to JSON file with Beam volume support."""
718
- # Save to Beam volume if available
719
- if volume_results_dir:
720
- volume_output_path = volume_results_dir / f"benchmark_{model_name}.json"
721
- try:
722
- with volume_output_path.open("w") as f:
723
- json.dump(results, f, indent=2, default=str)
724
- logger.info(f"💾 Results saved to Beam volume: {volume_output_path}")
725
- except Exception as e:
726
- logger.warning(f"⚠️ Failed to save to Beam volume: {e}")
727
-
728
- # Always save local backup
729
- output_path = Path(output_dir)
730
- output_path.mkdir(parents=True, exist_ok=True)
731
-
732
- # Clean model name for filename
733
- safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
734
- filename = f"benchmark_{safe_name}.json"
735
- filepath = output_path / filename
736
-
737
- with filepath.open("w") as f:
738
- json.dump(results, f, indent=2, default=str)
739
-
740
- logger.info(f"📄 Local backup saved to {filepath}")
741
-
742
-
743
- def beam_benchmark_models(
744
- models: list[str],
745
- quick: bool = False,
746
- output_dir: str = DEFAULT_OUTPUT_DIR,
747
- volume_name: str = VOLUME_NAME,
748
- mount_path: str = VOLUME_PATH,
749
- ) -> list[dict[str, Any]]:
750
- """Main benchmarking function for Beam execution with checkpoint support."""
751
- logger.info("🚀 Starting Beam-powered performance benchmarking")
752
- logger.info(f"📊 Benchmarking {len(models)} models")
753
-
754
- # Initialize Beam utilities
755
- volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(volume_name, mount_path)
756
-
757
- # Create benchmark results directory in volume
758
- results_dir = Path(mount_path) / BENCHMARK_RESULTS_DIR
759
- results_dir.mkdir(parents=True, exist_ok=True)
760
-
761
- logger.info(f"📁 Using Beam volume: {volume_name} at {mount_path}")
762
- logger.info(f"💾 Benchmark results directory: {results_dir}")
763
-
764
- all_results = []
765
- skipped_models = []
766
-
767
- for model_path in models:
768
- model_name = Path(model_path).name if model_path != str(Path(mount_path)) else "gte_qwen2_m2v_code"
769
-
770
- # Check if this model has already been benchmarked (except for trained model)
771
- is_trained_model = model_path == str(Path(mount_path)) or model_name == "gte_qwen2_m2v_code"
772
-
773
- if not is_trained_model:
774
- # Check for existing benchmark results
775
- existing_result_file = results_dir / f"benchmark_{model_name}.json"
776
- if existing_result_file.exists():
777
- logger.info(f"✅ Model {model_name} already benchmarked - loading existing results")
778
- try:
779
- with existing_result_file.open("r") as f:
780
- existing_results = json.load(f)
781
- all_results.append(existing_results)
782
- skipped_models.append(model_name)
783
- continue
784
- except Exception as e:
785
- logger.warning(f"⚠️ Failed to load existing results for {model_name}: {e}")
786
- # Continue with benchmarking if loading fails
787
-
788
- logger.info(f"\n{'=' * 60}")
789
- logger.info(f"🔍 Benchmarking model: {model_name}")
790
- logger.info(f"📂 Path: {model_path}")
791
- if is_trained_model:
792
- logger.info("🎯 Trained model - always re-benchmark")
793
- logger.info(f"{'=' * 60}")
794
-
795
- try:
796
- # Distinguish between local paths and HuggingFace model names
797
- is_huggingface_model = (
798
- "/" in model_path and not model_path.startswith("/") and not Path(model_path).exists()
799
- )
800
-
801
- if is_huggingface_model:
802
- # This is a HuggingFace model name - pass directly to benchmarker
803
- logger.info(f"📥 Loading HuggingFace model: {model_path}")
804
- benchmarker = PerformanceBenchmark(
805
- model_path,
806
- model_name,
807
- checkpoint_manager=checkpoint_mgr,
808
- eval_manager=eval_mgr,
809
- )
810
- else:
811
- # This is a local path - check if it exists in Beam volume
812
- actual_model_path = model_path # Default to original path
813
- if not Path(model_path).exists() and not model_path.startswith("/"):
814
- # Try to load from Beam volume
815
- local_model_path = Path(mount_path) / model_name
816
- logger.info(f"🔍 Trying to load {model_name} from Beam volume: {local_model_path}")
817
- if local_model_path.exists():
818
- actual_model_path = str(local_model_path)
819
- logger.info(f"✅ Found model in Beam volume: {actual_model_path}")
820
- else:
821
- # Try in root of volume (for your trained model)
822
- root_model_path = Path(mount_path)
823
- if (root_model_path / "config.json").exists():
824
- actual_model_path = str(root_model_path)
825
- logger.info(f"✅ Found model in Beam volume root: {actual_model_path}")
826
- else:
827
- logger.warning(f"⚠️ Model not found locally or in Beam volume: {model_name}")
828
- continue
829
-
830
- benchmarker = PerformanceBenchmark(
831
- actual_model_path,
832
- model_name,
833
- checkpoint_manager=checkpoint_mgr,
834
- eval_manager=eval_mgr,
835
- )
836
-
837
- # Run benchmarking
838
- if quick:
839
- # Quick benchmark
840
- benchmarker.load_model()
841
- benchmarker.measure_model_size()
842
- benchmarker.benchmark_inference_speed([1, 16, 32])
843
- else:
844
- # Comprehensive benchmark
845
- benchmarker.run_comprehensive_benchmark()
846
-
847
- # Save results with Beam support
848
- save_benchmark_results(benchmarker.results, output_dir, model_name, results_dir)
849
-
850
- # Print summary
851
- benchmarker.print_summary()
852
-
853
- all_results.append(benchmarker.results)
854
-
855
- except Exception:
856
- logger.exception(f"❌ Failed to benchmark {model_name}")
857
- continue
858
-
859
- # Create comparison report in Beam volume
860
- if len(all_results) > 1:
861
- comparison_dir = results_dir / "comparisons"
862
- comparison_dir.mkdir(parents=True, exist_ok=True)
863
- create_benchmark_comparison(all_results, str(comparison_dir / "benchmark_comparison.json"))
864
- logger.info(f"📊 Comparison report saved to Beam volume: {comparison_dir}")
865
-
866
- # Log summary of what was done
867
- newly_benchmarked = len(all_results) - len(skipped_models)
868
- logger.info("\n✅ Beam benchmarking complete!")
869
- logger.info(f"📊 Newly benchmarked: {newly_benchmarked} models")
870
- logger.info(f"⏭️ Skipped (already done): {len(skipped_models)} models")
871
- logger.info(f"📁 Total results: {len(all_results)} models")
872
- logger.info(f"💾 Results available in Beam volume: {volume_name}")
873
-
874
- if skipped_models:
875
- logger.info(f"⏭️ Skipped models: {', '.join(skipped_models)}")
876
-
877
- return all_results
878
-
879
-
880
- @function(
881
- gpu=GPU_NAME,
882
- volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
883
- image=IMAGE,
884
- secrets=["HF_ACCESS_TOKEN"],
885
- env={
886
- "TOKENIZERS_PARALLELISM": "false",
887
- "CUDA_LAUNCH_BLOCKING": "0",
888
- },
889
- timeout=3600 * 4, # 4 hours for benchmarking all models
890
- )
891
- def main() -> None:
892
- """Main benchmarking function - runs all default models on Beam."""
893
- logger.info("🚀 Starting comprehensive performance benchmarking on Beam")
894
-
895
- # Use default models but replace the local model path with Beam volume path
896
- models = DEFAULT_BENCHMARK_MODELS.copy()
897
-
898
- # Replace "gte_qwen2_m2v_code" with actual Beam volume path
899
- for i, model in enumerate(models):
900
- if model == "gte_qwen2_m2v_code":
901
- models[i] = str(Path(VOLUME_PATH)) # Use the Beam volume root
902
- logger.info(f"🎯 Using trained model from Beam volume: {models[i]}")
903
-
904
- # Discover simplified distillation models
905
- logger.info("🔍 Discovering simplified distillation models...")
906
- discovered_models = discover_simplified_models(".")
907
-
908
- # Add discovered models
909
- if discovered_models:
910
- logger.info(f"✅ Found {len(discovered_models)} simplified models:")
911
- for model_path in discovered_models:
912
- models.append(model_path)
913
- logger.info(f" 📁 {model_path}")
914
- else:
915
- logger.warning("⚠️ No simplified distillation models found")
916
-
917
- logger.info(f"📊 Benchmarking {len(models)} models:")
918
- for i, model in enumerate(models, 1):
919
- logger.info(f" {i}. {model}")
920
-
921
- logger.info("\n💡 Checkpoint Info:")
922
- logger.info(" - Already benchmarked models will be skipped")
923
- logger.info(" - Your trained model will always be re-benchmarked")
924
- logger.info(" - Results are saved persistently to Beam volume")
925
-
926
- # Run comprehensive benchmark using Beam utilities
927
- results = beam_benchmark_models(
928
- models=models,
929
- quick=True, # Use quick benchmark for efficiency
930
- output_dir=str(Path(VOLUME_PATH) / BENCHMARK_RESULTS_DIR),
931
- volume_name=VOLUME_NAME,
932
- mount_path=VOLUME_PATH,
933
- )
934
-
935
- # Print final summary
936
- print("\n🎯 Benchmarking Summary:")
937
- print(f"📊 Total models processed: {len(results)}")
938
- print(f"💾 Results saved to Beam volume: {VOLUME_NAME}")
939
- print(f"📁 Directory: {BENCHMARK_RESULTS_DIR}")
940
- print("\n🔍 To view analysis:")
941
- print(" beam run src.distiller.analyze:beam_analysis")
942
- print("\n📈 To run benchmarks again:")
943
- print(" distiller benchmark (will skip already completed models)")
944
-
945
-
946
- def discover_simplified_models(base_path: str = ".") -> list[str]:
947
- """
948
- Discover all simplified distillation models in the correct directory.
949
-
950
- Looks for directories matching the pattern: ./code_model2vec/final/code_model2vec_*
951
- """
952
- discovered_models: list[str] = []
953
-
954
- # Look in the correct location where distill_simplified.py saves models
955
- models_dir = Path(base_path) / "code_model2vec" / "final"
956
-
957
- if not models_dir.exists():
958
- logger.warning(f"Models directory not found: {models_dir}")
959
- return discovered_models
960
-
961
- # Look for simplified model directories with the updated pattern
962
- pattern = "code_model2vec_*"
963
- for model_dir in models_dir.glob(pattern):
964
- if model_dir.is_dir() and (model_dir / "config.json").exists():
965
- discovered_models.append(str(model_dir))
966
- logger.info(f"🔍 Discovered simplified model: {model_dir}")
967
-
968
- # Sort alphabetically for consistent ordering
969
- discovered_models.sort()
970
-
971
- return discovered_models
972
-
973
-
974
- @function(
975
- gpu=GPU_NAME,
976
- volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
977
- image=IMAGE,
978
- secrets=["HF_ACCESS_TOKEN"],
979
- env={
980
- "TOKENIZERS_PARALLELISM": "false",
981
- "CUDA_LAUNCH_BLOCKING": "0",
982
- },
983
- timeout=3600 * 3, # 3 hours for simplified models only
984
- )
985
- def benchmark_simplified_only() -> None:
986
- """Benchmark only simplified distillation models, skipping 3rd party models."""
987
- logger.info("🚀 Starting simplified distillation models benchmarking on Beam")
988
- logger.info("⏭️ Skipping 3rd party models - benchmarking only simplified distillation models")
989
-
990
- # Discover simplified distillation models
991
- logger.info("🔍 Discovering simplified distillation models...")
992
- discovered_models = discover_simplified_models(".")
993
-
994
- if not discovered_models:
995
- logger.error("❌ No simplified distillation models found! Run distill-simple first.")
996
- return
997
-
998
- logger.info(f"✅ Found {len(discovered_models)} simplified models:")
999
- for model_path in discovered_models:
1000
- logger.info(f" 📁 {model_path}")
1001
-
1002
- logger.info("\n💡 Checkpoint Info:")
1003
- logger.info(" - Already benchmarked models will be skipped")
1004
- logger.info(" - Results are saved persistently to Beam volume")
1005
-
1006
- # Run comprehensive benchmark using Beam utilities
1007
- results = beam_benchmark_models(
1008
- models=discovered_models,
1009
- quick=True, # Use quick benchmark for efficiency
1010
- output_dir=str(Path(VOLUME_PATH) / BENCHMARK_RESULTS_DIR),
1011
- volume_name=VOLUME_NAME,
1012
- mount_path=VOLUME_PATH,
1013
- )
1014
-
1015
- # Print final summary
1016
- print("\n🎯 Simplified Benchmarking Summary:")
1017
- print(f"📊 Total simplified models processed: {len(results)}")
1018
- print(f"💾 Results saved to Beam volume: {VOLUME_NAME}")
1019
- print(f"📁 Directory: {BENCHMARK_RESULTS_DIR}")
1020
- print("⏭️ 3rd party models were skipped")
1021
- print("\n🔍 To view analysis:")
1022
- print(" distiller analyze")
1023
- print("\n📈 To run full benchmarks (including 3rd party):")
1024
- print(" distiller benchmark")
1025
-
1026
-
1027
- def run_local_benchmark(
1028
- models: list[str] | None = None,
1029
- quick: bool = False,
1030
- output_dir: str = DEFAULT_OUTPUT_DIR,
1031
- ) -> list[dict[str, Any]]:
1032
- """Main benchmarking function for local execution without Beam utilities."""
1033
- logger.info("🖥️ Running performance benchmarking locally")
1034
-
1035
- if models is None:
1036
- models = DEFAULT_BENCHMARK_MODELS.copy()
1037
-
1038
- # Replace "gte_qwen2_m2v_code" with a reasonable local path
1039
- for i, model in enumerate(models):
1040
- if model == "gte_qwen2_m2v_code":
1041
- # Look for local trained model
1042
- local_model_paths = [
1043
- "./gte_qwen2_m2v_code",
1044
- "./models/gte_qwen2_m2v_code",
1045
- "./output/gte_qwen2_m2v_code",
1046
- ]
1047
- found = False
1048
- for local_path in local_model_paths:
1049
- if Path(local_path).exists():
1050
- models[i] = local_path
1051
- logger.info(f"🎯 Found local trained model: {local_path}")
1052
- found = True
1053
- break
1054
- if not found:
1055
- logger.warning("⚠️ Local trained model not found, skipping")
1056
- models.pop(i)
1057
-
1058
- # Discover simplified distillation models
1059
- logger.info("🔍 Discovering simplified distillation models...")
1060
- discovered_models = discover_simplified_models(".")
1061
-
1062
- # Add discovered models
1063
- if discovered_models:
1064
- logger.info(f"✅ Found {len(discovered_models)} simplified models:")
1065
- for model_path in discovered_models:
1066
- models.append(model_path)
1067
- logger.info(f" 📁 {model_path}")
1068
- else:
1069
- logger.warning("⚠️ No simplified distillation models found")
1070
-
1071
- logger.info(f"📊 Benchmarking {len(models)} models")
1072
- logger.info(f"📁 Using local output directory: {output_dir}")
1073
-
1074
- # Create local output directory
1075
- output_path = Path(output_dir)
1076
- output_path.mkdir(parents=True, exist_ok=True)
1077
-
1078
- all_results = []
1079
- skipped_models = []
1080
-
1081
- for model_path in models:
1082
- model_name = Path(model_path).name
1083
-
1084
- # Check for existing benchmark results locally
1085
- safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
1086
- result_file = output_path / f"benchmark_{safe_name}.json"
1087
-
1088
- if result_file.exists():
1089
- logger.info(f"✅ Model {model_name} already benchmarked - loading existing results")
1090
- try:
1091
- with result_file.open("r") as f:
1092
- existing_results = json.load(f)
1093
- all_results.append(existing_results)
1094
- skipped_models.append(model_name)
1095
- continue
1096
- except Exception as e:
1097
- logger.warning(f"⚠️ Failed to load existing results for {model_name}: {e}")
1098
-
1099
- logger.info(f"\n{'=' * 60}")
1100
- logger.info(f"🔍 Benchmarking model: {model_name}")
1101
- logger.info(f"📂 Path: {model_path}")
1102
- logger.info(f"{'=' * 60}")
1103
-
1104
- try:
1105
- # Create benchmarker without Beam utilities
1106
- benchmarker = PerformanceBenchmark(
1107
- model_path,
1108
- model_name,
1109
- checkpoint_manager=None, # No checkpointing for local benchmarking
1110
- eval_manager=None,
1111
- )
1112
-
1113
- # Run benchmarking
1114
- if quick:
1115
- # Quick benchmark
1116
- benchmarker.load_model()
1117
- benchmarker.measure_model_size()
1118
- benchmarker.benchmark_inference_speed([1, 16, 32])
1119
- else:
1120
- # Comprehensive benchmark
1121
- benchmarker.run_comprehensive_benchmark()
1122
-
1123
- # Save results locally only
1124
- save_benchmark_results(benchmarker.results, output_dir, model_name, volume_results_dir=None)
1125
-
1126
- # Print summary
1127
- benchmarker.print_summary()
1128
-
1129
- all_results.append(benchmarker.results)
1130
-
1131
- except Exception:
1132
- logger.exception(f"❌ Failed to benchmark {model_name}")
1133
- continue
1134
-
1135
- # Create comparison report locally
1136
- if len(all_results) > 1:
1137
- create_benchmark_comparison(all_results, str(output_path / "benchmark_comparison.json"))
1138
- logger.info(f"📊 Comparison report saved locally: {output_dir}")
1139
-
1140
- # Log summary
1141
- newly_benchmarked = len(all_results) - len(skipped_models)
1142
- logger.info("\n✅ Local benchmarking complete!")
1143
- logger.info(f"📊 Newly benchmarked: {newly_benchmarked} models")
1144
- logger.info(f"⏭️ Skipped (already done): {len(skipped_models)} models")
1145
- logger.info(f"📁 Total results: {len(all_results)} models")
1146
- logger.info(f"💾 Results available locally: {output_dir}")
1147
-
1148
- if skipped_models:
1149
- logger.info(f"⏭️ Skipped models: {', '.join(skipped_models)}")
1150
-
1151
- return all_results
1152
-
1153
-
1154
- def run_local_benchmark_simplified(
1155
- quick: bool = False,
1156
- output_dir: str = DEFAULT_OUTPUT_DIR,
1157
- ) -> list[dict[str, Any]]:
1158
- """Local benchmarking function for simplified models only."""
1159
- logger.info("🖥️ Running simplified model benchmarking locally")
1160
-
1161
- # Discover simplified distillation models only
1162
- logger.info("🔍 Discovering simplified distillation models...")
1163
- discovered_models = discover_simplified_models(".")
1164
-
1165
- if not discovered_models:
1166
- logger.error("❌ No simplified distillation models found! Run 'distiller distill-simple' first.")
1167
- return []
1168
-
1169
- logger.info(f"✅ Found {len(discovered_models)} simplified models:")
1170
- for model_path in discovered_models:
1171
- logger.info(f" 📁 {model_path}")
1172
-
1173
- return run_local_benchmark(
1174
- models=discovered_models,
1175
- quick=quick,
1176
- output_dir=output_dir,
1177
- )
1178
-
1179
-
1180
- if __name__ == "__main__":
1181
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/distiller/config.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared configuration for the distiller package.
3
+
4
+ This module centralizes all configuration constants, default values, and common
5
+ settings used across distillation, evaluation, and benchmarking modules.
6
+ """
7
+
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ from beam import GpuType, Image
13
+ from pydantic import BaseModel
14
+
15
+ # =============================================================================
16
+ # LOGGING CONFIGURATION
17
+ # =============================================================================
18
+
19
+
20
+ def setup_logging(level: int = logging.INFO) -> None:
21
+ """Set up consistent logging across the package."""
22
+ log_dir = Path("logs")
23
+ log_dir.mkdir(parents=True, exist_ok=True)
24
+ log_path = log_dir / "distiller.log"
25
+ logging.basicConfig(
26
+ level=level,
27
+ format="%(asctime)s - %(levelname)s - %(message)s",
28
+ handlers=[logging.StreamHandler(), logging.FileHandler(log_path, mode="a")],
29
+ )
30
+
31
+
32
+ # =============================================================================
33
+ # BEAM CLOUD CONFIGURATION
34
+ # =============================================================================
35
+
36
+ # Beam execution settings
37
+ GPU_NAME = GpuType.A100_40
38
+
39
+
40
+ # Volume configurations for different workflows
41
+ class VolumeConfig(BaseModel):
42
+ """Volume configuration container."""
43
+
44
+ name: str
45
+ mount_path: str
46
+ description: str = ""
47
+
48
+
49
+ # Define volume configurations - code_model2vec is the primary volume for all workflows
50
+ VOLUMES: dict[str, VolumeConfig] = {
51
+ "primary": VolumeConfig(
52
+ name="code_model2vec",
53
+ mount_path="./code_model2vec",
54
+ description="Primary volume for all distillation models, evaluations, benchmarks, and checkpoints",
55
+ ),
56
+ # Legacy volume name mapping for backwards compatibility
57
+ "simplified": VolumeConfig(
58
+ name="code_model2vec",
59
+ mount_path="./code_model2vec",
60
+ description="Primary volume for all distillation models, evaluations, benchmarks, and checkpoints",
61
+ ),
62
+ }
63
+
64
+ # Default volume name for all workflows
65
+ DEFAULT_VOLUME = "primary"
66
+
67
+ # Beam environment settings
68
+ BEAM_ENV_SETTINGS: dict[str, str] = {
69
+ "TOKENIZERS_PARALLELISM": "false",
70
+ "CUDA_LAUNCH_BLOCKING": "0",
71
+ "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,max_split_size_mb:512",
72
+ "TORCH_CUDNN_V8_API_ENABLED": "1",
73
+ }
74
+
75
+ # Common Python packages for Beam images
76
+ COMMON_PACKAGES: list[str] = [
77
+ "torch>=2.7.0",
78
+ "transformers>=4.40.0",
79
+ "datasets>=3.2.0",
80
+ "sentence-transformers>=4.1.0",
81
+ "model2vec[train]>=0.5.0",
82
+ "numpy>=1.26.4",
83
+ "scikit-learn>=1.6.1",
84
+ "pandas>=2.0.0",
85
+ "tqdm>=4.65.0",
86
+ "plotly>=5.0.0",
87
+ "matplotlib>=3.7.0",
88
+ "seaborn>=0.12.0",
89
+ ]
90
+
91
+ # Create common Beam image
92
+ IMAGE = Image(python_version="python3.12").add_python_packages(COMMON_PACKAGES)
93
+
94
+ # =============================================================================
95
+ # MODEL CONFIGURATION
96
+ # =============================================================================
97
+
98
+ # Teacher model configurations
99
+ TEACHER_MODELS: list[str] = [
100
+ "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
101
+ "BAAI/bge-m3",
102
+ "jinaai/jina-embeddings-v3",
103
+ "lightonai/Reason-ModernColBERT",
104
+ "Linq-AI-Research/Linq-Embed-Mistral",
105
+ "microsoft/codebert-base",
106
+ "microsoft/graphcodebert-base",
107
+ "nomic-ai/nomic-embed-text-v2-moe",
108
+ "Qodo/Qodo-Embed-1-1.5B",
109
+ "sentence-transformers/all-MiniLM-L6-v2",
110
+ "sentence-transformers/all-mpnet-base-v2",
111
+ "sentence-transformers/paraphrase-MiniLM-L6-v2",
112
+ "nomic-ai/nomic-embed-code",
113
+ "nomic-ai/CodeRankEmbed",
114
+ ]
115
+
116
+ # Default evaluation models for comparison
117
+ DEFAULT_EVALUATION_MODELS: list[str] = [
118
+ "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
119
+ "BAAI/bge-m3",
120
+ "huggingface/CodeBERTa-small-v1",
121
+ "jinaai/jina-embeddings-v3",
122
+ "lightonai/Reason-ModernColBERT",
123
+ "Linq-AI-Research/Linq-Embed-Mistral",
124
+ "microsoft/codebert-base",
125
+ "microsoft/graphcodebert-base",
126
+ "minishlab/potion-base-8M",
127
+ "minishlab/potion-retrieval-32M",
128
+ "nomic-ai/nomic-embed-text-v2-moe",
129
+ "Qodo/Qodo-Embed-1-1.5B",
130
+ "Salesforce/codet5-base",
131
+ "sentence-transformers/all-MiniLM-L12-v2",
132
+ "sentence-transformers/all-MiniLM-L6-v2",
133
+ "sentence-transformers/all-mpnet-base-v2",
134
+ "sentence-transformers/paraphrase-MiniLM-L6-v2",
135
+ "nvidia/NV-Embed-v2",
136
+ "nomic-ai/nomic-embed-code",
137
+ "nomic-ai/CodeRankEmbed",
138
+ ]
139
+
140
+
141
+ # Model2Vec distillation parameters
142
+ class DistillationConfig(BaseModel):
143
+ """Configuration for Model2Vec distillation parameters."""
144
+
145
+ # Teacher models for distillation
146
+ code_teacher_models: list[str] = TEACHER_MODELS
147
+
148
+ # Basic distillation parameters
149
+ optimal_pca_dims: int = 256
150
+ sif_coefficient: float = 1e-3
151
+ apply_zipf: bool = True
152
+
153
+ # Training parameters (used when --train flag is enabled)
154
+ training_epochs: int = 2
155
+ learning_rate: float = 1e-4
156
+ batch_size: int = 32
157
+ max_training_samples: int = 50000
158
+ teacher_model_config: dict[str, Any] = {}
159
+
160
+
161
+ distillation_config = DistillationConfig()
162
+
163
+
164
+ # =============================================================================
165
+ # DATASET CONFIGURATION
166
+ # =============================================================================
167
+
168
+
169
+ # Add a LanguagesConfig Pydantic model
170
+ class LanguagesConfig(BaseModel):
171
+ """Configuration for languages used in evaluation."""
172
+
173
+ all: list[str] = [
174
+ "python",
175
+ "java",
176
+ "javascript",
177
+ "php",
178
+ "ruby",
179
+ "go",
180
+ ]
181
+
182
+
183
+ languages_config = LanguagesConfig()
184
+
185
+
186
+ # Update CodeSearchNetConfig to use languages_config.all as the default for evaluation_languages
187
+ class CodeSearchNetConfig(BaseModel):
188
+ """Configuration for CodeSearchNet evaluation settings."""
189
+
190
+ dataset_name: str = "code_search_net"
191
+ evaluation_languages: list[str] = languages_config.all
192
+ max_queries_per_language: int = 1000
193
+ similarity_threshold: float = 0.7
194
+ evaluation_metrics: list[str] = ["ndcg@1", "ndcg@5", "ndcg@10", "mrr", "recall@1", "recall@5", "recall@10"]
195
+
196
+
197
+ codesearchnet_config = CodeSearchNetConfig()
198
+
199
+ # Training dataset configurations
200
+ TRAINING_DATASETS: dict[str, str] = {
201
+ "codesearchnet": "sentence-transformers/codesearchnet",
202
+ "code_search_net": "code_search_net",
203
+ }
204
+
205
+ # =============================================================================
206
+ # OUTPUT DIRECTORY CONFIGURATION
207
+ # =============================================================================
208
+
209
+
210
+ # Standardized directory structure within code_model2vec
211
+ class StandardDirectories(BaseModel):
212
+ """Standardized directory structure for code_model2vec workspace."""
213
+
214
+ # Root directory
215
+ root: str = "code_model2vec"
216
+
217
+ # Model directories
218
+ base: str = "code_model2vec/base" # Basic distilled models
219
+ final: str = "code_model2vec/final" # Final trained models
220
+ models: str = "code_model2vec/models" # Legacy/alternative models location
221
+
222
+ # Results directories
223
+ evaluation_results: str = "code_model2vec/evaluation_results"
224
+ benchmark_results: str = "code_model2vec/benchmark_results"
225
+ analysis_results: str = "code_model2vec/analysis_results"
226
+
227
+ # Working directories
228
+ checkpoints: str = "code_model2vec/checkpoints"
229
+ cache: str = "code_model2vec/cache"
230
+ temp: str = "code_model2vec/temp"
231
+
232
+
233
+ # Create global instance
234
+ directories = StandardDirectories()
235
+
236
+
237
+ # Legacy OutputDirs for backwards compatibility
238
+ class OutputDirs(BaseModel):
239
+ """Base output directory structure for storing models, checkpoints, and results."""
240
+
241
+ base: str = "base"
242
+ models: str = "final"
243
+ checkpoints: str = "checkpoints"
244
+ evaluation_results: str = "evaluation_results"
245
+ benchmark_results: str = "benchmark_results"
246
+ analysis_results: str = "analysis_results"
247
+ cache: str = "cache"
248
+
249
+
250
+ output_dirs = OutputDirs()
251
+
252
+
253
+ # File naming patterns
254
+ class FilenamePatterns(BaseModel):
255
+ """File naming patterns for evaluation, benchmark, checkpoint, and model files."""
256
+
257
+ evaluation: str = "codesearchnet_eval_{model_name}.json"
258
+ bencmark: str = "benchmark_{model_name}.json"
259
+ checkpoint: str = "checkpoints_{stage}_step_{step}.json"
260
+ model: str = "{teacher_model}_{dims}d"
261
+
262
+
263
+ filename_patterns = FilenamePatterns()
264
+
265
+ # =============================================================================
266
+ # ANALYSIS AND VISUALIZATION
267
+ # =============================================================================
268
+
269
+
270
+ # Chart configuration
271
+ class ChartConfig(BaseModel):
272
+ """Chart configuration for analysis and visualization."""
273
+
274
+ figsize: tuple[int, int] = (12, 8)
275
+ dpi: int = 300
276
+ style: str = "whitegrid"
277
+ color_palette: str = "Set2"
278
+ save_formats: list[str] = ["png", "pdf"]
279
+
280
+
281
+ chart_config = ChartConfig()
282
+
283
+
284
+ # Performance thresholds for analysis
285
+ class PerformanceThresholds(BaseModel):
286
+ """Performance thresholds for analysis results."""
287
+
288
+ excellent: float = 0.7
289
+ good: float = 0.5
290
+ fair: float = 0.3
291
+ pour: float = 0.1
292
+
293
+
294
+ performance_thresholds = PerformanceThresholds()
295
+
296
+ # =============================================================================
297
+ # HELPER FUNCTIONS
298
+ # =============================================================================
299
+
300
+
301
+ def get_volume_config() -> VolumeConfig:
302
+ """Get volume configuration for any workflow - always returns the primary code_model2vec volume."""
303
+ return VOLUMES["primary"]
304
+
305
+
306
+ def get_output_path(base_path: str | Path, output_type: str) -> Path:
307
+ """Get standardized output path for different types of outputs."""
308
+ base = Path(base_path)
309
+ if hasattr(output_dirs, output_type):
310
+ return base / getattr(output_dirs, output_type)
311
+ return base / output_type
312
+
313
+
314
+ def get_standard_directory(dir_type: str) -> str:
315
+ """Get standardized directory path for any directory type."""
316
+ if hasattr(directories, dir_type):
317
+ return getattr(directories, dir_type)
318
+ # Default to relative path within code_model2vec
319
+ return f"code_model2vec/{dir_type}"
320
+
321
+
322
+ def ensure_checkpoint_directory(stage: str) -> str:
323
+ """Ensure checkpoint directory exists for a specific stage and return the path."""
324
+ checkpoint_dir = f"{directories.checkpoints}/{stage}"
325
+ Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
326
+ return checkpoint_dir
327
+
328
+
329
+ def format_filename(pattern_key: str, **kwargs: Any) -> str:
330
+ """Format filename using predefined patterns."""
331
+ if hasattr(filename_patterns, pattern_key):
332
+ return getattr(filename_patterns, pattern_key).format(**kwargs)
333
+ msg = f"Unknown filename pattern: {pattern_key}"
334
+ raise ValueError(msg)
335
+
336
+
337
+ def get_safe_model_name(model_name: str) -> str:
338
+ """Convert model name to filesystem-safe name."""
339
+ return "".join(c for c in model_name if c.isalnum() or c in ("-", "_", ".")).replace("/", "_")
src/distiller/distill.py CHANGED
@@ -1,31 +1,35 @@
1
  """
2
- Code-Specialized Model2Vec Distillation Script with Checkpoint Support.
3
 
4
- This script implements a focused approach for creating code-specialized embeddings
5
- using Model2Vec distillation with one additional training round on code-specific tasks.
6
 
7
  Features:
8
- - Incremental checkpoint saving
9
- - Resume from previous progress
10
- - Persistent storage of embeddings and models
11
- - Robust error handling and recovery
12
- - Smart checkpoint validation for parameter compatibility
13
-
14
- Approach:
15
- 1. Basic Model2Vec distillation with optimized parameters
16
- 2. Single code specialization round using sentence-transformers/codesearchnet dataset
 
 
 
 
17
  """
18
 
19
  import json
20
  import logging
21
- import os
22
  import time
23
  from pathlib import Path
24
- from typing import Any
25
 
26
  import numpy as np
27
  import torch
28
- from beam import GpuType, Image, Volume, function
 
29
  from datasets import load_dataset
30
  from model2vec.distill import distill
31
  from model2vec.train.base import FinetunableStaticModel, TextDataset
@@ -35,239 +39,474 @@ from torch import nn, optim
35
 
36
  from .beam_utils import (
37
  BeamCheckpointManager,
38
- BeamModelManager,
39
  create_beam_utilities,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
41
 
42
  # =============================================================================
43
- # CODE-FOCUSED CONFIGURATION
44
  # =============================================================================
45
 
46
- # Model Configuration
47
- MODEL_NAME = "Alibaba-NLP/gte-Qwen2-7B-instruct"
48
- OUTPUT_DIR = "gte_qwen2_m2v_code"
49
- CHECKPOINT_DIR = "gte_qwen2_m2v_code/checkpoints"
50
-
51
- # Code-optimized parameters
52
- PCA_DIMS = 512 # Higher dims for code complexity
53
- TRAINING_EPOCHS = 2
54
- LEARNING_RATE = 1e-4
55
- BATCH_SIZE = 32
56
- REGULARIZATION_WEIGHT = 0.01
57
-
58
- # CodeSearchNet dataset configuration
59
- CODESEARCHNET_DATASET = "sentence-transformers/codesearchnet"
60
- MAX_TRAINING_SAMPLES = 50000 # Limit for manageable training time
61
-
62
- # Checkpoint configuration
63
- CHECKPOINT_INTERVAL = 1000 # Save every N samples
64
- EMBEDDINGS_BATCH_SIZE = 100 # Save embeddings in smaller batches
65
-
66
- # OPTIMIZED TEACHER MODEL CONFIGURATION FOR 40GB VRAM
67
- TEACHER_MODEL_CONFIG: dict[str, Any] = {
68
- "batch_size": 12, # Slightly reduced due to float32 memory usage
69
- "precision": "float32", # Use float32 for quality preservation
70
- "max_seq_length": 8192, # Reduce from 32k default for better performance
71
- "device_map": "auto", # Automatic device placement
72
- "torch_dtype": torch.float32, # Use float32 for quality preservation
73
- "trust_remote_code": True,
74
- "use_flash_attention": True, # Try to enable flash attention if available
75
- "attn_implementation": "flash_attention_2", # Use flash attention 2 if available
76
- }
77
 
78
  # =============================================================================
79
- # BEAM CONFIGURATION
80
  # =============================================================================
81
 
82
- GPU_NAME = GpuType.A100_40
83
- VOLUME_NAME = "gte_qwen2_m2v_code"
84
- VOLUME_PATH = "./gte_qwen2_m2v_code"
85
- IMAGE = Image(python_version="python3.12").add_python_packages(
86
- [
87
- "torch>=2.7.0", # Install torch first
88
- "transformers>=4.40.0", # Latest transformers with flash attention support
89
- "accelerate>=1.7.0",
90
- "datasets>=3.2.0",
91
- "model2vec[train]>=0.5.0",
92
- "numpy>=1.26.4",
93
- "scikit-learn>=1.6.1",
94
- "sentence-transformers>=4.1.0",
95
- ]
96
- )
97
 
98
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
99
- logger = logging.getLogger(__name__)
 
 
 
100
 
 
 
 
 
 
101
 
102
- def get_current_config_hash() -> str:
 
 
 
 
 
 
 
103
  """Generate a hash of current configuration parameters for checkpoint validation."""
104
  import hashlib
105
 
106
  config_params = {
107
- "model_name": MODEL_NAME,
108
- "pca_dims": PCA_DIMS,
109
- "precision": TEACHER_MODEL_CONFIG["precision"],
110
- "torch_dtype": str(TEACHER_MODEL_CONFIG["torch_dtype"]),
111
- "max_samples": MAX_TRAINING_SAMPLES,
112
- "codesearchnet_dataset": CODESEARCHNET_DATASET,
113
  }
114
 
 
 
 
 
 
 
 
 
 
115
  config_str = str(sorted(config_params.items()))
116
  return hashlib.md5(config_str.encode()).hexdigest()[:12] # noqa: S324
117
 
118
 
119
- def validate_checkpoint_compatibility(checkpoint_data: dict[str, Any]) -> bool:
120
- """
121
- Validate if checkpoint is compatible with current configuration.
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- Args:
124
- checkpoint_data: Checkpoint data dictionary
 
125
 
126
- Returns:
127
- True if compatible, False otherwise
128
- """
129
- current_hash = get_current_config_hash()
130
- checkpoint_hash = checkpoint_data.get("config_hash", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- if checkpoint_hash != current_hash:
133
- logger.warning(f"Configuration mismatch: current={current_hash}, checkpoint={checkpoint_hash}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  return False
135
 
136
- # Additional validation checks
137
- checkpoint_config = checkpoint_data.get("config", {})
138
 
139
- # Check critical parameters
140
- if checkpoint_config.get("pca_dims") != PCA_DIMS:
141
- logger.warning(f"PCA dimensions mismatch: current={PCA_DIMS}, checkpoint={checkpoint_config.get('pca_dims')}")
 
 
 
 
142
  return False
143
 
144
- if checkpoint_config.get("precision") != TEACHER_MODEL_CONFIG["precision"]:
145
- logger.warning(
146
- f"Precision mismatch: current={TEACHER_MODEL_CONFIG['precision']}, checkpoint={checkpoint_config.get('precision')}"
147
- )
 
 
 
 
 
 
 
148
  return False
149
 
150
- if checkpoint_config.get("max_samples") != MAX_TRAINING_SAMPLES:
151
- logger.warning(
152
- f"Max samples mismatch: current={MAX_TRAINING_SAMPLES}, checkpoint={checkpoint_config.get('max_samples')}"
153
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  return False
155
 
156
- logger.info("✅ Checkpoint configuration is compatible")
157
- return True
 
 
 
 
 
 
158
 
159
 
160
- def create_checkpoint_data(stage: str, data: dict[str, Any], step: int = 0) -> dict[str, Any]:
 
 
 
 
 
161
  """
162
- Create checkpoint data with configuration metadata.
163
 
164
  Args:
165
- stage: Checkpoint stage name
166
- data: Core checkpoint data
167
- step: Step number
 
168
 
169
  Returns:
170
- Enhanced checkpoint data with configuration
171
  """
172
- return {
173
- "config_hash": get_current_config_hash(),
174
- "config": {
175
- "model_name": MODEL_NAME,
176
- "pca_dims": PCA_DIMS,
177
- "precision": TEACHER_MODEL_CONFIG["precision"],
178
- "torch_dtype": str(TEACHER_MODEL_CONFIG["torch_dtype"]),
179
- "max_samples": MAX_TRAINING_SAMPLES,
180
- "codesearchnet_dataset": CODESEARCHNET_DATASET,
181
- },
182
- "stage": stage,
183
- "step": step,
184
- "timestamp": time.time(),
185
- "data": data,
186
- }
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- def load_codesearchnet_dataset_with_resume(
190
- max_samples: int = MAX_TRAINING_SAMPLES,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  checkpoint_manager: BeamCheckpointManager | None = None,
192
  ) -> list[str]:
193
- """Load and format the sentence-transformers/codesearchnet dataset with resume capability."""
194
- logger.info(f"Loading CodeSearchNet dataset from {CODESEARCHNET_DATASET}")
 
 
 
195
  logger.info(f"Limiting to {max_samples} samples for training efficiency")
 
 
 
 
 
196
 
197
- # Check for existing dataset checkpoint with validation
198
  if checkpoint_manager:
199
  checkpoint_data = checkpoint_manager.load_checkpoint("dataset", 0)
200
  if checkpoint_data:
201
- if validate_checkpoint_compatibility(checkpoint_data):
202
- texts = checkpoint_data.get("data", {}).get("texts", [])
203
- if len(texts) >= max_samples:
204
- logger.info(f"✅ Resumed dataset loading: {len(texts)} texts from checkpoint")
205
- return texts[:max_samples]
206
- logger.info(f"📋 Partial dataset found: {len(texts)} texts, continuing from there")
207
- start_from = len(texts)
208
- else:
209
- logger.warning("🔄 Incompatible dataset checkpoint found, starting fresh")
210
- # Clean up incompatible checkpoint
211
- checkpoint_manager.cleanup_old_checkpoints("dataset", keep_latest=0)
212
- texts = []
213
- start_from = 0
214
- else:
215
- texts = []
216
- start_from = 0
217
- else:
218
- texts = []
219
- start_from = 0
220
 
221
  try:
222
- # Load the dataset
223
- dataset = load_dataset(CODESEARCHNET_DATASET, split="train", streaming=True)
 
 
224
 
225
- # Skip to where we left off
226
- dataset_iter = iter(dataset)
227
- for _ in range(start_from):
228
- try:
229
- next(dataset_iter)
230
- except StopIteration:
231
- break
232
 
233
- for i, example in enumerate(dataset_iter, start=start_from):
234
- if len(texts) >= max_samples:
 
 
 
 
235
  break
236
 
237
- comment = example.get("comment", "").strip()
238
- code = example.get("code", "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
- if comment and code and len(comment) > 10 and len(code) > 50:
241
- # Format as comment-code pair for training
242
- text = f"Comment: {comment}\nCode:\n{code}"
243
 
244
- # Ensure reasonable length
245
- if len(text) <= 2048: # Reasonable limit for embedding models
246
- texts.append(text)
247
 
248
- # Save checkpoint periodically
249
- if checkpoint_manager and (i + 1) % CHECKPOINT_INTERVAL == 0:
250
- checkpoint_data = create_checkpoint_data("dataset", {"texts": texts}, 0)
251
- checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0)
252
- logger.info(f"💾 Saved dataset checkpoint: {len(texts)} texts collected")
253
 
254
- if (i + 1) % 10000 == 0:
255
- logger.info(f"Processed {i + 1} examples, collected {len(texts)} valid pairs")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
  # Final checkpoint save
258
  if checkpoint_manager:
259
- checkpoint_data = create_checkpoint_data("dataset", {"texts": texts}, 0)
 
 
 
 
 
 
260
  checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0)
261
 
262
- logger.info(f"Successfully loaded {len(texts)} code-comment pairs from CodeSearchNet")
263
- return texts
264
 
265
  except Exception:
266
  logger.exception("Error loading CodeSearchNet dataset")
267
  return texts # Return what we have so far
268
 
269
 
270
- def generate_teacher_embeddings_with_checkpoints(
271
  teacher_model: SentenceTransformer,
272
  texts: list[str],
273
  checkpoint_manager: BeamCheckpointManager | None = None,
@@ -275,13 +514,11 @@ def generate_teacher_embeddings_with_checkpoints(
275
  """Generate teacher embeddings for code training with checkpoint support."""
276
  logger.info(f"Generating teacher embeddings for {len(texts)} texts...")
277
 
278
- # Check for existing embeddings checkpoint using torch.save format
279
- final_embeddings = None
280
-
281
  if checkpoint_manager:
282
- # Try to load complete embeddings tensor directly
283
- embeddings_path = Path(VOLUME_PATH) / "embeddings_cache.pt"
284
- config_path = Path(VOLUME_PATH) / "embeddings_config.json"
285
 
286
  if embeddings_path.exists() and config_path.exists():
287
  try:
@@ -289,118 +526,78 @@ def generate_teacher_embeddings_with_checkpoints(
289
  with config_path.open("r") as f:
290
  config_data = json.load(f)
291
 
292
- # Create a dummy checkpoint data structure for validation
293
- checkpoint_data = {
294
- "config_hash": config_data.get("config_hash"),
295
- "config": config_data.get("config", {}),
296
- }
297
-
298
- if validate_checkpoint_compatibility(checkpoint_data):
299
  # Load the embeddings tensor
300
  final_embeddings = torch.load(embeddings_path, map_location="cpu")
301
  num_expected = config_data.get("num_texts", len(texts))
302
 
303
  if final_embeddings.shape[0] >= num_expected:
304
- logger.info(
305
- f"✅ Loaded complete embeddings from cache ({final_embeddings.shape[0]} embeddings)"
306
- )
307
- return final_embeddings[: len(texts)] # Return only the needed amount
308
- logger.info(
309
- f"⚠️ Cached embeddings incomplete ({final_embeddings.shape[0]}/{num_expected}), regenerating"
310
- )
311
- final_embeddings = None
312
- else:
313
- logger.warning("🔄 Incompatible embeddings cache found, regenerating")
314
- final_embeddings = None
315
  except Exception as e:
316
  logger.warning(f"Failed to load embeddings cache: {e}, regenerating...")
317
- final_embeddings = None
318
-
319
- # If we have complete embeddings, return them
320
- if final_embeddings is not None:
321
- return final_embeddings
322
 
323
  # Generate embeddings from scratch
324
  logger.info("Generating fresh teacher embeddings...")
325
 
326
- # Use optimized batch size for large models with proper type casting
327
- batch_size_raw = TEACHER_MODEL_CONFIG["batch_size"]
328
- current_batch_size: int = batch_size_raw if isinstance(batch_size_raw, int) else 16
329
- logger.info(f"Using optimized batch size: {current_batch_size} for 40GB VRAM (7B model)")
330
-
331
  embeddings_list = []
332
 
333
- for i in range(0, len(texts), current_batch_size):
334
- batch_texts = texts[i : i + current_batch_size]
335
 
336
  try:
337
- # Use optimized encoding with convert_to_tensor=True for efficiency
338
  batch_embeddings = teacher_model.encode(
339
  batch_texts,
340
  convert_to_tensor=True,
341
- batch_size=current_batch_size,
342
- show_progress_bar=False, # Reduce overhead
343
- normalize_embeddings=True, # Pre-normalize for efficiency
344
  )
345
  embeddings_list.append(batch_embeddings)
346
 
347
- if i % (current_batch_size * 10) == 0:
348
  logger.info(f"Generated embeddings for {i + len(batch_texts)}/{len(texts)} texts")
349
 
350
  except torch.cuda.OutOfMemoryError:
351
- logger.warning(
352
- f"GPU OOM with batch size {current_batch_size}, reducing to {max(1, current_batch_size // 2)}"
353
- )
354
-
355
- # Clear cache and reduce batch size
356
- if torch.cuda.is_available():
357
- torch.cuda.empty_cache()
358
-
359
- current_batch_size = max(1, current_batch_size // 2)
360
 
361
  # Retry with smaller batch size
362
- batch_texts = texts[i : i + current_batch_size]
363
  batch_embeddings = teacher_model.encode(
364
  batch_texts,
365
  convert_to_tensor=True,
366
- batch_size=current_batch_size,
367
  show_progress_bar=False,
368
  normalize_embeddings=True,
369
  )
370
  embeddings_list.append(batch_embeddings)
371
 
372
- logger.info(f"Successfully processed batch with reduced size {current_batch_size}")
373
-
374
- # Combine all embeddings and force fp32 precision
375
  teacher_embeddings = torch.cat(embeddings_list, dim=0)
376
 
377
- # Ensure teacher embeddings are in fp32 for maximum quality
378
  if teacher_embeddings.dtype != torch.float32:
379
- logger.info(f"Converting teacher embeddings from {teacher_embeddings.dtype} to fp32")
380
  teacher_embeddings = teacher_embeddings.to(torch.float32)
381
 
382
  logger.info(f"Generated {teacher_embeddings.shape[0]} teacher embeddings in {teacher_embeddings.dtype}")
383
 
384
- # Save embeddings cache using torch.save for future runs
385
  if checkpoint_manager:
386
  try:
387
- embeddings_path = Path(VOLUME_PATH) / "embeddings_cache.pt"
388
- config_path = Path(VOLUME_PATH) / "embeddings_config.json"
 
389
 
390
  # Save embeddings tensor
391
  torch.save(teacher_embeddings, embeddings_path)
392
 
393
  # Save configuration
394
  config_data = {
395
- "config_hash": get_current_config_hash(),
396
- "config": {
397
- "model_name": MODEL_NAME,
398
- "pca_dims": PCA_DIMS,
399
- "precision": TEACHER_MODEL_CONFIG["precision"],
400
- "torch_dtype": str(TEACHER_MODEL_CONFIG["torch_dtype"]),
401
- "max_samples": MAX_TRAINING_SAMPLES,
402
- "codesearchnet_dataset": CODESEARCHNET_DATASET,
403
- },
404
  "num_texts": len(texts),
405
  "embedding_shape": list(teacher_embeddings.shape),
406
  "timestamp": time.time(),
@@ -417,890 +614,953 @@ def generate_teacher_embeddings_with_checkpoints(
417
  return teacher_embeddings
418
 
419
 
420
- def refine_with_code_training(
421
  student_model: Any,
422
- training_texts: list[str],
423
- teacher_embeddings: torch.Tensor,
424
- epochs: int = 2,
425
  checkpoint_manager: BeamCheckpointManager | None = None,
426
- model_manager: BeamModelManager | None = None,
427
  ) -> Any:
428
- """Refine the student model with code-specific training."""
429
- logger.info(f"Starting code specialization training for {epochs} epochs...")
430
 
431
- # Validate input parameters
432
- if student_model is None:
433
- logger.error("student_model is None - cannot proceed with code training")
434
- msg = "student_model cannot be None"
435
- raise ValueError(msg)
436
 
437
- if not hasattr(student_model, "embedding"):
438
- logger.error(f"student_model of type {type(student_model)} does not have 'embedding' attribute")
439
- msg = f"student_model must have 'embedding' attribute, got {type(student_model)}"
440
- raise ValueError(msg)
441
 
442
- logger.info(f"Student model type: {type(student_model)}")
443
- logger.info(f"Student model embedding shape: {student_model.embedding.shape}")
444
 
445
- try:
446
- # Force fp32 precision throughout for maximum quality
447
- target_dtype = torch.float32
448
- logger.info("🎯 Enforcing fp32 precision throughout for maximum quality")
449
 
450
- # Detect student model dtype for logging purposes
451
- student_dtype = student_model.embedding.dtype
452
- logger.info(f"Student model original embedding dtype: {student_dtype}")
453
 
454
- # Force teacher embeddings to fp32 if not already
455
- if teacher_embeddings.dtype != target_dtype:
456
- logger.info(f"Converting teacher embeddings from {teacher_embeddings.dtype} to {target_dtype}")
457
- teacher_embeddings = teacher_embeddings.to(target_dtype)
 
458
 
459
- # Get dimensions
460
- student_embedding_dim = student_model.embedding.shape[1]
461
- teacher_embedding_dim = teacher_embeddings.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
 
463
- logger.info(f"Student dims: {student_embedding_dim}, Teacher dims: {teacher_embedding_dim}")
 
 
464
 
465
- # Project teacher embeddings if needed with high-precision PCA
466
- if teacher_embedding_dim != student_embedding_dim:
467
- from sklearn.decomposition import PCA
468
 
469
- logger.info("Performing high-precision PCA projection for quality preservation...")
470
- pca = PCA(n_components=student_embedding_dim)
471
 
472
- # Use float64 for PCA computation to maximize precision
473
- teacher_embeddings_np = teacher_embeddings.cpu().numpy().astype(np.float64)
474
- teacher_embeddings_projected = pca.fit_transform(teacher_embeddings_np)
475
 
476
- # Convert back to fp32 (always use fp32, never fp16)
477
- teacher_embeddings = torch.tensor(
478
- teacher_embeddings_projected.astype(np.float32),
479
- dtype=target_dtype,
480
- )
481
- logger.info(f"PCA projection completed: {teacher_embeddings.shape} with dtype {target_dtype}")
482
- logger.info(
483
- f"PCA preserved variance ratio: {pca.explained_variance_ratio_[:5].sum():.4f} (first 5 components)"
484
- )
485
 
486
- # Create trainable model
487
- trainable_model = FinetunableStaticModel.from_static_model(
488
- model=student_model,
489
- out_dim=student_embedding_dim,
490
- )
491
 
492
- # Force ALL model parameters to fp32 to ensure no precision loss
493
- trainable_model = trainable_model.float()
 
494
 
495
- # Additional explicit conversion of embedding weights to fp32
496
- if hasattr(trainable_model, "embeddings") and hasattr(trainable_model.embeddings, "weight"):
497
- trainable_model.embeddings.weight.data = trainable_model.embeddings.weight.data.to(target_dtype)
498
 
499
- # Verify final model dtype after model2vec patch fix
500
- actual_model_dtype = None
501
- for param in trainable_model.parameters():
502
- actual_model_dtype = param.dtype
503
- break
504
 
505
- logger.info(f"Model parameter dtype: {actual_model_dtype}")
506
- logger.info(f"Embedding weight dtype: {trainable_model.embeddings.weight.dtype}")
507
 
508
- # Ensure teacher embeddings are definitely in fp32
509
- teacher_embeddings = teacher_embeddings.to(target_dtype)
510
- logger.info(f"Final teacher embeddings dtype: {teacher_embeddings.dtype}")
511
- logger.info(f"Final model parameter dtype: {actual_model_dtype}")
 
512
 
513
- # Verify we're using fp32 throughout
514
- if teacher_embeddings.dtype != target_dtype:
515
- logger.warning(f"⚠️ Teacher embeddings not in {target_dtype}: {teacher_embeddings.dtype}")
516
- if actual_model_dtype != target_dtype:
517
- logger.warning(f"⚠️ Model parameters not in {target_dtype}: {actual_model_dtype}")
518
 
519
- logger.info("✅ Confirmed fp32 precision throughout the training pipeline")
 
 
 
 
520
 
521
- # Tokenize texts
522
- tokenized_texts = []
523
- for text in training_texts:
524
- tokens = trainable_model.tokenize([text])
525
- if tokens.shape[1] > 0:
526
- tokenized_texts.append(tokens[0].tolist())
527
 
528
- # Prepare training data with explicit fp32 casting
529
- targets = teacher_embeddings[: len(tokenized_texts)]
530
 
531
- # Force targets to fp32 to maintain maximum precision
532
- targets = targets.to(target_dtype)
533
- logger.info(f"Cast targets to fp32: {targets.dtype}")
534
 
535
- train_texts, val_texts, train_targets, val_targets = train_test_split(
536
- tokenized_texts, targets, test_size=0.2, random_state=42
537
- )
538
 
539
- logger.info(f"Train targets dtype: {train_targets.dtype}")
540
- logger.info(f"Val targets dtype: {val_targets.dtype}")
541
 
542
- # Training setup
543
- train_dataset = TextDataset(train_texts, train_targets)
544
- val_dataset = TextDataset(val_texts, val_targets)
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
- optimizer = optim.Adam(trainable_model.parameters(), lr=LEARNING_RATE)
547
- mse_loss = nn.MSELoss()
 
548
 
549
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
550
 
551
- try:
552
- trainable_model = trainable_model.to(device)
553
- logger.info(f"Training on {device}")
554
- except torch.cuda.OutOfMemoryError:
555
- logger.warning("GPU OOM loading training model, using CPU")
556
- device = torch.device("cpu")
557
- trainable_model = trainable_model.to(device)
558
- if torch.cuda.is_available():
559
- torch.cuda.empty_cache()
560
 
561
- # Adaptive batch size for training
562
- adaptive_batch_size = BATCH_SIZE
 
 
 
 
 
 
563
 
564
- # Quality monitoring: compute embedding similarity before training
565
- logger.info("🔍 Quality monitoring: Computing pre-training teacher-student similarity...")
566
- trainable_model.eval()
567
- with torch.no_grad():
568
- # Take a small sample of texts for quality measurement
569
- sample_texts = training_texts[: min(5, len(training_texts))]
570
- sample_tokens = trainable_model.tokenize(sample_texts)
571
- sample_tokens = sample_tokens.to(device)
572
-
573
- _, student_embeddings_before = trainable_model(sample_tokens)
574
- sample_teacher_embeddings = targets[: len(sample_texts)].to(device)
575
-
576
- # Compute average cosine similarity
577
- similarities_before = []
578
- for i in range(len(sample_texts)):
579
- sim = torch.cosine_similarity(
580
- student_embeddings_before[i].unsqueeze(0),
581
- sample_teacher_embeddings[i].unsqueeze(0),
582
- ).item()
583
- similarities_before.append(sim)
584
-
585
- avg_similarity_before = np.mean(similarities_before)
586
- logger.info(f"📊 Pre-training average teacher-student similarity: {avg_similarity_before:.4f}")
587
-
588
- # Training loop with validation
589
- for epoch in range(epochs):
590
- # Training phase
591
- trainable_model.train()
592
-
593
- # Try with current batch size, reduce if OOM
594
- train_successful = False
595
- while not train_successful and adaptive_batch_size >= 1:
596
- try:
597
- train_loader = train_dataset.to_dataloader(shuffle=True, batch_size=adaptive_batch_size)
598
-
599
- epoch_loss = 0.0
600
- num_batches = 0
601
-
602
- for batch_idx, (tokens, targets_batch) in enumerate(train_loader):
603
- batch_tokens = tokens.to(device)
604
- batch_targets = targets_batch.to(device)
605
-
606
- optimizer.zero_grad()
607
- _, student_embeddings = trainable_model(batch_tokens)
608
-
609
- # Debug dtype information on first batch
610
- if batch_idx == 0:
611
- logger.info(
612
- f"Batch {batch_idx}: tokens shape {batch_tokens.shape}, dtype {batch_tokens.dtype}"
613
- )
614
- logger.info(
615
- f"Batch {batch_idx}: targets shape {batch_targets.shape}, dtype {batch_targets.dtype}"
616
- )
617
- logger.info(
618
- f"Batch {batch_idx}: student_embeddings shape {student_embeddings.shape}, dtype {student_embeddings.dtype}"
619
- )
620
-
621
- # Force both tensors to fp32 to avoid any precision loss
622
- if student_embeddings.dtype != target_dtype:
623
- logger.warning(
624
- f"Student embeddings not in fp32: {student_embeddings.dtype}, converting to fp32"
625
- )
626
- student_embeddings = student_embeddings.to(target_dtype)
627
- if batch_targets.dtype != target_dtype:
628
- logger.info(f"Converting targets from {batch_targets.dtype} to fp32")
629
- batch_targets = batch_targets.to(target_dtype)
630
-
631
- try:
632
- loss = mse_loss(student_embeddings, batch_targets)
633
- loss.backward()
634
- optimizer.step()
635
- except RuntimeError as e:
636
- if "expected scalar type" in str(e):
637
- logger.exception("Dtype mismatch error occurred:")
638
- logger.exception(
639
- f"student_embeddings: {student_embeddings.shape}, {student_embeddings.dtype}"
640
- )
641
- logger.exception(f"batch_targets: {batch_targets.shape}, {batch_targets.dtype}")
642
- logger.exception(
643
- f"MSE loss input dtypes: {student_embeddings.dtype} vs {batch_targets.dtype}"
644
- )
645
- # Force explicit casting to fp32 for maximum precision
646
- batch_targets = batch_targets.to(target_dtype)
647
- student_embeddings = student_embeddings.to(target_dtype)
648
- logger.info("Emergency dtype fix: forced both to fp32")
649
- loss = mse_loss(student_embeddings, batch_targets)
650
- loss.backward()
651
- optimizer.step()
652
- else:
653
- raise
654
-
655
- epoch_loss += loss.item()
656
- num_batches += 1
657
-
658
- # Save training checkpoint periodically
659
- if checkpoint_manager and batch_idx % 100 == 0:
660
- training_state = {
661
- "epoch": epoch,
662
- "batch": batch_idx,
663
- "model_state": trainable_model.state_dict(),
664
- "optimizer_state": optimizer.state_dict(),
665
- "loss": epoch_loss / max(1, num_batches),
666
- }
667
- checkpoint_data = create_checkpoint_data("training", training_state, epoch)
668
- checkpoint_manager.save_checkpoint("training", checkpoint_data, epoch)
669
-
670
- train_successful = True
671
-
672
- except torch.cuda.OutOfMemoryError:
673
- logger.warning(
674
- f"Training OOM with batch size {adaptive_batch_size}, reducing to {adaptive_batch_size // 2}"
675
- )
676
- adaptive_batch_size = max(1, adaptive_batch_size // 2)
677
- if torch.cuda.is_available():
678
- torch.cuda.empty_cache()
679
 
680
- if not train_successful:
681
- logger.error("Unable to train even with batch size 1, skipping training")
682
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
 
684
- avg_train_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
685
 
686
- # Validation phase
687
- trainable_model.eval()
688
- val_loader = val_dataset.to_dataloader(shuffle=False, batch_size=adaptive_batch_size)
689
- val_loss = 0.0
690
- val_batches = 0
 
 
 
691
 
692
- with torch.no_grad():
693
- for tokens, targets_batch in val_loader:
694
- batch_tokens = tokens.to(device)
695
- batch_targets = targets_batch.to(device)
 
 
 
 
 
 
 
 
696
 
697
- _, student_embeddings = trainable_model(batch_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698
 
699
- # Force both tensors to fp32 to avoid any precision loss in validation
700
- if student_embeddings.dtype != target_dtype:
701
- student_embeddings = student_embeddings.to(target_dtype)
702
- if batch_targets.dtype != target_dtype:
703
- batch_targets = batch_targets.to(target_dtype)
 
 
704
 
705
- loss = mse_loss(student_embeddings, batch_targets)
706
- val_loss += loss.item()
707
- val_batches += 1
708
 
709
- avg_val_loss = val_loss / val_batches if val_batches > 0 else 0.0
 
 
710
 
711
- logger.info(
712
- f"Epoch {epoch + 1}/{epochs} - Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}, Batch Size: {adaptive_batch_size}"
713
- )
714
 
715
- # Save epoch checkpoint
716
- if checkpoint_manager:
717
- epoch_state = {
718
- "epoch": epoch + 1,
719
- "model_state": trainable_model.state_dict(),
720
- "optimizer_state": optimizer.state_dict(),
721
- "train_loss": avg_train_loss,
722
- "val_loss": avg_val_loss,
 
 
 
 
 
723
  }
724
- checkpoint_data = create_checkpoint_data("epoch", epoch_state, epoch + 1)
725
- checkpoint_manager.save_checkpoint("epoch", checkpoint_data, epoch + 1)
726
-
727
- # Quality monitoring: compute embedding similarity after training
728
- logger.info("🔍 Quality monitoring: Computing post-training teacher-student similarity...")
729
- trainable_model.eval()
730
- with torch.no_grad():
731
- # Use the same sample texts as before
732
- sample_texts = training_texts[: min(5, len(training_texts))]
733
- sample_tokens = trainable_model.tokenize(sample_texts)
734
- sample_tokens = sample_tokens.to(device)
735
-
736
- _, student_embeddings_after = trainable_model(sample_tokens)
737
- sample_teacher_embeddings = targets[: len(sample_texts)].to(device)
738
-
739
- # Compute average cosine similarity
740
- similarities_after = []
741
- for i in range(len(sample_texts)):
742
- sim = torch.cosine_similarity(
743
- student_embeddings_after[i].unsqueeze(0),
744
- sample_teacher_embeddings[i].unsqueeze(0),
745
- ).item()
746
- similarities_after.append(sim)
747
-
748
- avg_similarity_after = np.mean(similarities_after)
749
- logger.info(f"📊 Post-training average teacher-student similarity: {avg_similarity_after:.4f}")
750
-
751
- # Quality assessment
752
- quality_change = avg_similarity_after - avg_similarity_before
753
- logger.info(f"📈 Quality change: {quality_change:+.4f}")
754
-
755
- if abs(quality_change) < 0.01:
756
- logger.info("✅ Quality well preserved during training!")
757
- elif quality_change > 0:
758
- logger.info("✅ Quality improved during training!")
759
- else:
760
- logger.warning(f"⚠️ Quality degraded by {abs(quality_change):.4f} during training")
761
 
762
- # Convert back to static model
763
- refined_model = trainable_model.to_static_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
764
 
765
- # Save final refined model to beam volume
766
- if model_manager:
767
- # Save to temporary local directory first
768
- temp_refined_path = Path("./temp_refined_save")
769
- temp_refined_path.mkdir(exist_ok=True)
770
- refined_model.save_pretrained(str(temp_refined_path))
771
 
772
- # Upload to beam volume
773
- model_manager.save_model("refined_model", str(temp_refined_path))
 
774
 
775
- # Clean up temp directory
776
- import shutil
 
 
 
 
 
 
 
777
 
778
- shutil.rmtree(temp_refined_path, ignore_errors=True)
 
 
 
 
 
 
 
 
 
779
 
780
- logger.info("💾 Saved refined model to beam volume")
781
 
782
- logger.info("Code specialization training completed")
783
- return refined_model
 
 
 
 
 
 
 
784
 
785
  except Exception as e:
786
- logger.warning(f"Code training failed: {e}")
787
- return student_model
 
 
 
 
 
788
 
789
 
790
- def apply_regularization(model: Any, weight: float = 0.01) -> Any:
791
- """Apply light regularization with overflow protection."""
792
- # Validate input
793
- if model is None:
794
- logger.error("Cannot apply regularization: model is None")
795
- msg = "model cannot be None"
796
- raise ValueError(msg)
797
 
798
- if not hasattr(model, "embedding"):
799
- logger.error(f"Cannot apply regularization: model of type {type(model)} does not have 'embedding' attribute")
800
- msg = f"model must have 'embedding' attribute, got {type(model)}"
801
- raise ValueError(msg)
802
 
803
- logger.info(f"Applying regularization to model of type: {type(model)}")
 
 
 
 
 
 
 
804
 
805
- try:
806
- embeddings = model.embedding.copy()
807
 
808
- # Check for extreme values and clip if necessary
809
- max_val = np.abs(embeddings).max()
810
- if max_val > 1e6: # Clip extremely large values
811
- logger.warning(f"Large embedding values detected (max: {max_val:.2e}), clipping to prevent overflow")
812
- embeddings = np.clip(embeddings, -1e6, 1e6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
 
814
- # Apply regularization
815
- regularized_embeddings = embeddings * (1.0 - weight)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816
 
817
- # Stable normalization to prevent overflow
818
- norms = np.linalg.norm(regularized_embeddings, axis=1, keepdims=True)
 
 
 
819
 
820
- # Handle zero norms and potential overflow
821
- norms = np.where(norms == 0, 1, norms)
822
- norms = np.where(norms > 1e6, 1e6, norms) # Prevent extremely large norms
823
 
824
- regularized_embeddings = regularized_embeddings / norms
825
 
826
- # Create new model
827
- from model2vec.model import StaticModel
828
 
829
- regularized_model = StaticModel(
830
- vectors=regularized_embeddings,
831
- tokenizer=model.tokenizer,
832
- config=model.config,
833
- base_model_name=model.base_model_name,
834
- language=model.language,
835
- normalize=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836
  )
837
 
838
- logger.info("Regularization applied successfully")
839
- return regularized_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
840
 
841
  except Exception as e:
842
- logger.warning(f"Regularization failed: {e}")
843
- return model
 
 
 
 
 
 
 
 
844
 
845
 
846
- def load_teacher_model_with_cache(
847
- model_name: str,
848
- output_dir: str,
849
- device: str = "cuda",
850
- resume: bool = True,
851
- ) -> SentenceTransformer:
852
- """Load teacher model with local caching to avoid re-downloading."""
853
- cache_dir = Path(output_dir) / "teacher_model_cache"
854
-
855
- # Check if cached model exists
856
- if resume and cache_dir.exists():
857
- try:
858
- logger.info(f"Loading cached teacher model from {cache_dir}")
859
- teacher_model = SentenceTransformer(str(cache_dir), device=device)
860
 
861
- # Set optimized sequence length
862
- max_seq_len = TEACHER_MODEL_CONFIG.get("max_seq_length", 8192)
863
- if isinstance(max_seq_len, int):
864
- teacher_model.max_seq_length = max_seq_len
865
 
866
- logger.info("Successfully loaded cached teacher model")
867
- return teacher_model
868
- except Exception as e:
869
- logger.warning(f"Failed to load cached teacher model: {e}")
870
- logger.info("Will download fresh model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871
 
872
- # Download and cache the model
873
- logger.info(f"Downloading teacher model {model_name} (this may take a while)")
874
 
875
- # Prepare model kwargs with flash attention
876
- model_kwargs = {
877
- "torch_dtype": TEACHER_MODEL_CONFIG["torch_dtype"],
878
- "device_map": TEACHER_MODEL_CONFIG["device_map"],
 
 
 
 
 
 
 
879
  }
880
 
881
- # Try to add flash attention if available
882
- if TEACHER_MODEL_CONFIG.get("use_flash_attention", False):
883
- try:
884
- model_kwargs["attn_implementation"] = TEACHER_MODEL_CONFIG["attn_implementation"]
885
- logger.info("Flash Attention 2 enabled")
886
- except Exception as e:
887
- logger.warning(f"Flash Attention not available, using default attention: {e}")
888
 
889
- try:
890
- teacher_model = SentenceTransformer(
891
- model_name,
892
- device=device,
893
- trust_remote_code=bool(TEACHER_MODEL_CONFIG["trust_remote_code"]),
894
- model_kwargs=model_kwargs,
895
- )
896
- except ImportError as e:
897
- if "flash_attn" in str(e):
898
- logger.warning("Flash Attention 2 not available, falling back to default attention")
899
- # Remove flash attention from model_kwargs and retry
900
- model_kwargs_fallback = {k: v for k, v in model_kwargs.items() if k != "attn_implementation"}
901
- teacher_model = SentenceTransformer(
902
- model_name,
903
- device=device,
904
- trust_remote_code=bool(TEACHER_MODEL_CONFIG["trust_remote_code"]),
905
- model_kwargs=model_kwargs_fallback,
906
- )
907
- else:
908
- raise
909
 
910
- # Set optimized sequence length
911
- max_seq_len = TEACHER_MODEL_CONFIG.get("max_seq_length", 8192)
912
- if isinstance(max_seq_len, int):
913
- teacher_model.max_seq_length = max_seq_len
914
- logger.info(f"Set max_seq_length to {max_seq_len} for better performance")
915
 
916
- # Cache the model for future use
 
 
917
  try:
918
- cache_dir.mkdir(parents=True, exist_ok=True)
919
- teacher_model.save(str(cache_dir))
920
- logger.info(f"Cached teacher model to {cache_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
921
  except Exception as e:
922
- logger.warning(f"Failed to cache teacher model: {e}")
923
- # Continue without caching
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
924
 
925
- return teacher_model
926
 
927
 
928
- def code_specialized_distillation(
929
- model_name: str = MODEL_NAME,
930
- output_dir: str = OUTPUT_DIR,
931
- pca_dims: int = PCA_DIMS,
932
- max_samples: int = MAX_TRAINING_SAMPLES,
933
- resume: bool = True,
934
  ) -> Any:
935
- """Main code-specialized distillation function using CodeSearchNet dataset with checkpoint support."""
 
 
 
936
  output_path = Path(output_dir)
937
  output_path.mkdir(parents=True, exist_ok=True)
938
 
939
- # Initialize Beam utilities
940
- volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(VOLUME_NAME, VOLUME_PATH)
941
-
942
- logger.info(f"Starting code-specialized distillation of {model_name}")
943
- logger.info(f"Using CodeSearchNet dataset: {CODESEARCHNET_DATASET}")
944
- logger.info(f"Resume mode: {resume}")
945
-
946
- # GPU Diagnostics
947
- logger.info("=== GPU DIAGNOSTICS ===")
948
- logger.info(f"CUDA available: {torch.cuda.is_available()}")
949
- if torch.cuda.is_available():
950
- logger.info(f"CUDA version: {torch.version.cuda}")
951
- logger.info(f"GPU count: {torch.cuda.device_count()}")
952
- for i in range(torch.cuda.device_count()):
953
- gpu_name = torch.cuda.get_device_name(i)
954
- gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
955
- logger.info(f"GPU {i}: {gpu_name} ({gpu_memory:.1f} GB)")
956
-
957
- # Current GPU memory
958
- current_device = torch.cuda.current_device()
959
- allocated = torch.cuda.memory_allocated(current_device) / 1024**3
960
- total = torch.cuda.get_device_properties(current_device).total_memory / 1024**3
961
- logger.info(f"Current GPU {current_device}: {allocated:.2f}GB allocated, {total:.1f}GB total")
962
- else:
963
- logger.warning("CUDA not available - will use CPU (much slower)")
964
- logger.info("======================")
965
 
966
  start_time = time.time()
967
 
968
- # Step 1: Basic Model2Vec distillation with checkpoint support
969
- logger.info("Step 1: Basic Model2Vec distillation...")
 
 
970
 
971
- # Check for existing distilled model in beam volume
972
- m2v_model = None
973
- if resume:
974
- # Check if model files exist directly in the volume root
975
- try:
976
- # Try to load from the volume root where the model was successfully saved
977
- volume_root_path = Path(VOLUME_PATH)
978
- if (volume_root_path / "config.json").exists() and (volume_root_path / "model.safetensors").exists():
979
- logger.info("✅ Found existing model files in volume root")
980
- from model2vec.model import StaticModel
981
 
982
- m2v_model = StaticModel.from_pretrained(str(volume_root_path))
983
- logger.info("✅ Successfully loaded existing distilled model from volume")
984
- else:
985
- logger.info("No existing model files found in volume root")
986
- except Exception as e:
987
- logger.warning(f"Failed to load existing model from volume: {e}")
988
- m2v_model = None
989
 
990
- if m2v_model is None:
991
- # Clear GPU cache before starting
992
- if torch.cuda.is_available():
993
- torch.cuda.empty_cache()
994
- current_device = torch.cuda.current_device()
995
- allocated = torch.cuda.memory_allocated(current_device) / 1024**3
996
- total = torch.cuda.get_device_properties(current_device).total_memory / 1024**3
997
- logger.info(f"GPU memory before distillation: {allocated:.2f}GB allocated / {total:.1f}GB total")
998
- else:
999
- logger.info("Using CPU for distillation")
1000
 
1001
- try:
1002
- m2v_model = distill(
1003
- model_name=model_name,
1004
- pca_dims=pca_dims,
1005
- apply_zipf=None,
1006
- sif_coefficient=1e-4,
1007
  trust_remote_code=True,
 
 
1008
  )
1009
- logger.info("Basic distillation completed with preserved precision")
1010
 
1011
- # Validate the distilled model
1012
- if m2v_model is None:
1013
- msg = "Distillation returned None - this should not happen"
1014
- raise ValueError(msg) from None
1015
-
1016
- logger.info(f"Distilled model type: {type(m2v_model)}")
1017
- logger.info(f"Distilled model has embedding attribute: {hasattr(m2v_model, 'embedding')}")
 
1018
 
1019
- # Save the base distilled model - DISABLED due to recursive directory bug
1020
- # model_mgr.save_model("base_distilled_model", str(output_path))
1021
 
1022
- except torch.cuda.OutOfMemoryError:
1023
- logger.warning("GPU OOM during distillation, clearing cache and retrying...")
1024
- torch.cuda.empty_cache()
1025
 
1026
- # Force CPU-only distillation if GPU fails
1027
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
1028
 
1029
- logger.info("Retrying distillation on CPU...")
1030
- m2v_model = distill(
1031
- model_name=model_name,
1032
- pca_dims=pca_dims,
1033
- apply_zipf=None,
1034
- sif_coefficient=1e-4,
1035
  trust_remote_code=True,
 
1036
  )
1037
- logger.info("Basic distillation completed on CPU")
1038
 
1039
- # Validate the distilled model
1040
- if m2v_model is None:
1041
- msg = "CPU distillation returned None - this should not happen"
1042
- raise ValueError(msg) from None
1043
 
1044
- logger.info(f"CPU distilled model type: {type(m2v_model)}")
1045
- logger.info(f"CPU distilled model has embedding attribute: {hasattr(m2v_model, 'embedding')}")
 
1046
 
1047
- # Save the base distilled model - DISABLED due to recursive directory bug
1048
- # model_mgr.save_model("base_distilled_model", str(output_path))
1049
 
1050
- except Exception:
1051
- logger.exception("Distillation failed with error")
1052
- raise
1053
 
1054
- # Validate m2v_model before proceeding
1055
- if m2v_model is None:
1056
- msg = "m2v_model is None after distillation step - cannot proceed"
1057
- raise ValueError(msg)
 
 
 
1058
 
1059
- # Step 2: Load CodeSearchNet training data with resume
1060
- logger.info("Step 2: Loading CodeSearchNet training data...")
1061
- code_texts = load_codesearchnet_dataset_with_resume(max_samples, checkpoint_mgr)
1062
 
1063
- if not code_texts:
1064
- logger.warning("No code training data available, skipping code specialization")
1065
- else:
1066
- logger.info("Step 3: Code specialization training...")
1067
-
1068
- # Check for existing refined model
1069
- if resume:
1070
- # Check if refined model exists in beam volume
1071
- models = model_mgr.list_models()
1072
- refined_model_exists = any(model["name"] == "refined_model" for model in models)
1073
-
1074
- if refined_model_exists:
1075
- # Download model to local path for loading
1076
- temp_model_path = Path("./temp_refined_model")
1077
- if model_mgr.load_model("refined_model", temp_model_path):
1078
- try:
1079
- from model2vec.model import StaticModel
1080
-
1081
- refined_model = StaticModel.from_pretrained(str(temp_model_path / "refined_model"))
1082
- logger.info("✅ Resumed from existing refined model")
1083
- m2v_model = refined_model
1084
- # Clean up temp directory
1085
- import shutil
1086
-
1087
- shutil.rmtree(temp_model_path, ignore_errors=True)
1088
- except Exception as e:
1089
- logger.warning(f"Failed to load existing refined model: {e}")
1090
- refined_model = None
1091
- # Clean up temp directory
1092
- import shutil
1093
-
1094
- shutil.rmtree(temp_model_path, ignore_errors=True)
1095
- else:
1096
- refined_model = None
1097
- else:
1098
- refined_model = None
1099
-
1100
- if refined_model is None:
1101
- # Load teacher model with memory management
1102
- try:
1103
- device = "cuda" if torch.cuda.is_available() else "cpu"
1104
- logger.info(f"Loading teacher model on {device} with optimized settings")
1105
- logger.info(
1106
- f"Using precision: {TEACHER_MODEL_CONFIG['precision']}, batch_size: {TEACHER_MODEL_CONFIG['batch_size']}"
1107
- )
1108
- logger.info("Attempting to enable Flash Attention 2 for maximum performance")
1109
 
1110
- teacher_model = load_teacher_model_with_cache(model_name, output_dir, device=device, resume=resume)
 
 
 
 
1111
 
1112
- # Generate teacher embeddings with checkpoints
1113
- teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
1114
- teacher_model, code_texts, checkpoint_mgr
1115
- )
1116
 
1117
- # Refine with code training
1118
- m2v_model = refine_with_code_training(
1119
- m2v_model,
1120
- code_texts,
1121
- teacher_embeddings,
1122
- epochs=TRAINING_EPOCHS,
1123
- checkpoint_manager=checkpoint_mgr,
1124
- model_manager=model_mgr,
1125
- )
1126
 
1127
- del teacher_model
1128
- if torch.cuda.is_available():
1129
- torch.cuda.empty_cache()
1130
-
1131
- except torch.cuda.OutOfMemoryError:
1132
- logger.warning("GPU OOM during code training, falling back to CPU...")
1133
-
1134
- if torch.cuda.is_available():
1135
- torch.cuda.empty_cache()
1136
-
1137
- # Force CPU for teacher model with optimized settings (no flash attention on CPU)
1138
- try:
1139
- teacher_model = load_teacher_model_with_cache(
1140
- model_name, output_dir, device="cpu", resume=resume
1141
- )
1142
- except ImportError as e:
1143
- if "flash_attn" in str(e):
1144
- logger.warning("Flash Attention 2 not available on CPU, using default attention")
1145
- # Fallback without any special attention implementation
1146
- teacher_model = load_teacher_model_with_cache(
1147
- model_name, output_dir, device="cpu", resume=resume
1148
- )
1149
- else:
1150
- raise
1151
-
1152
- # Generate teacher embeddings on CPU with checkpoints
1153
- teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
1154
- teacher_model, code_texts, checkpoint_mgr
1155
- )
1156
 
1157
- # Refine with code training on CPU
1158
- m2v_model = refine_with_code_training(
1159
- m2v_model,
1160
- code_texts,
1161
- teacher_embeddings,
1162
- epochs=TRAINING_EPOCHS,
1163
- checkpoint_manager=checkpoint_mgr,
1164
- model_manager=model_mgr,
1165
- )
1166
 
1167
- del teacher_model
1168
- else:
1169
- # Fresh training without resume
1170
- try:
1171
- device = "cuda" if torch.cuda.is_available() else "cpu"
1172
- logger.info(f"Loading teacher model on {device} with optimized settings")
1173
- logger.info(
1174
- f"Using precision: {TEACHER_MODEL_CONFIG['precision']}, batch_size: {TEACHER_MODEL_CONFIG['batch_size']}"
1175
- )
1176
- logger.info("Attempting to enable Flash Attention 2 for maximum performance")
1177
 
1178
- teacher_model = load_teacher_model_with_cache(model_name, output_dir, device=device, resume=resume)
 
 
 
 
 
 
 
1179
 
1180
- # Generate teacher embeddings with checkpoints
1181
- teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
1182
- teacher_model, code_texts, checkpoint_mgr
1183
- )
1184
 
1185
- # Refine with code training
1186
- m2v_model = refine_with_code_training(
1187
- m2v_model,
1188
- code_texts,
1189
- teacher_embeddings,
1190
- epochs=TRAINING_EPOCHS,
1191
- checkpoint_manager=checkpoint_mgr,
1192
- model_manager=model_mgr,
1193
- )
1194
 
1195
- del teacher_model
1196
- if torch.cuda.is_available():
1197
- torch.cuda.empty_cache()
1198
 
1199
- except torch.cuda.OutOfMemoryError:
1200
- logger.warning("GPU OOM during code training, falling back to CPU...")
 
 
1201
 
1202
- if torch.cuda.is_available():
1203
- torch.cuda.empty_cache()
1204
 
1205
- # Force CPU for teacher model with optimized settings (no flash attention on CPU)
1206
- try:
1207
- teacher_model = load_teacher_model_with_cache(model_name, output_dir, device="cpu", resume=resume)
1208
- except ImportError as e:
1209
- if "flash_attn" in str(e):
1210
- logger.warning("Flash Attention 2 not available on CPU, using default attention")
1211
- # Fallback without any special attention implementation
1212
- teacher_model = load_teacher_model_with_cache(
1213
- model_name, output_dir, device="cpu", resume=resume
1214
- )
1215
- else:
1216
- raise
1217
-
1218
- # Generate teacher embeddings on CPU with checkpoints
1219
- teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
1220
- teacher_model, code_texts, checkpoint_mgr
1221
- )
1222
 
1223
- # Refine with code training on CPU
1224
- m2v_model = refine_with_code_training(
1225
- m2v_model,
1226
- code_texts,
1227
- teacher_embeddings,
1228
- epochs=TRAINING_EPOCHS,
1229
- checkpoint_manager=checkpoint_mgr,
1230
- model_manager=model_mgr,
1231
- )
1232
 
1233
- del teacher_model
 
 
1234
 
1235
- # Step 4: Light regularization
1236
- logger.info("Step 4: Applying regularization...")
1237
- m2v_model = apply_regularization(m2v_model, REGULARIZATION_WEIGHT)
 
1238
 
1239
- # Save final model
1240
- logger.info("Saving code-specialized model...")
1241
 
1242
- # Final validation before saving
1243
- if m2v_model is None:
1244
- msg = "Cannot save model: m2v_model is None"
1245
- raise ValueError(msg)
1246
 
1247
- if not hasattr(m2v_model, "save_pretrained"):
1248
- msg = f"Cannot save model: m2v_model of type {type(m2v_model)} does not have save_pretrained method"
1249
- raise ValueError(msg)
1250
 
1251
- logger.info(f"Final model type: {type(m2v_model)}")
1252
- logger.info(f"Final model has embedding attribute: {hasattr(m2v_model, 'embedding')}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1253
 
1254
- m2v_model.save_pretrained(str(output_path))
1255
 
1256
- # Save final model to beam volume as well - DISABLED due to recursive directory bug
1257
- # model_mgr.save_model("final_model", str(output_path))
 
1258
 
1259
- total_time = time.time() - start_time
1260
- logger.info(f"Code-specialized distillation completed in {total_time:.2f} seconds")
 
 
 
1261
 
1262
- return m2v_model
 
1263
 
 
 
 
 
 
 
1264
 
1265
- @function(
1266
- gpu=GPU_NAME,
1267
- volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
1268
- image=IMAGE,
1269
- secrets=["HF_ACCESS_TOKEN"],
1270
- env={
1271
- "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,max_split_size_mb:512",
1272
- "TOKENIZERS_PARALLELISM": "false",
1273
- "CUDA_LAUNCH_BLOCKING": "0", # Allow async CUDA operations
1274
- "TORCH_CUDNN_V8_API_ENABLED": "1", # Enable optimized cuDNN
1275
- "OMP_NUM_THREADS": "8", # Limit CPU threads for better GPU utilization
1276
- },
1277
- timeout=3600 * 12, # 12 hours
1278
- )
1279
- def beam_code_distillation(
1280
- model_name: str = MODEL_NAME,
1281
- output_dir: str = OUTPUT_DIR,
1282
- pca_dims: int = PCA_DIMS,
1283
- max_samples: int = MAX_TRAINING_SAMPLES,
1284
- resume: bool = True,
1285
- ) -> Any:
1286
- # Apply all patches from the patches directory
1287
- try:
1288
- from .patch_utils import apply_all_patches
1289
 
1290
- logger.info("Applying all patches from patches directory...")
1291
- patches_applied = apply_all_patches()
1292
- logger.info(f"Successfully applied {patches_applied} patches")
1293
- except Exception as e:
1294
- logger.warning(f"Failed to apply patches: {e}. Continuing without patches.")
1295
-
1296
- return code_specialized_distillation(
1297
- model_name=model_name,
1298
- output_dir=output_dir,
1299
- pca_dims=pca_dims,
1300
- max_samples=max_samples,
1301
- resume=resume,
1302
- )
1303
 
1304
 
1305
  if __name__ == "__main__":
1306
- code_specialized_distillation()
 
1
  """
2
+ Unified Code-Specialized Model2Vec Distillation Script.
3
 
4
+ This script provides a unified approach for creating code-specialized embeddings
5
+ using Model2Vec distillation with optional code-specific training.
6
 
7
  Features:
8
+ - Basic distillation (default): Simple Model2Vec distillation
9
+ - Advanced training (--train flag): Additional CodeSearchNet fine-tuning
10
+ - Checkpoint support with Beam sync utilities
11
+ - Multi-teacher model processing
12
+ - Smart resume capabilities
13
+ - Hierarchical storage: base → final
14
+
15
+ Directory Structure:
16
+ - code_model2vec/base: Basic distilled models (first step)
17
+ - code_model2vec/final: Final models (copied from base or after training)
18
+
19
+ Usage:
20
+ distiller distill [--use-beam] [--train] # Basic distillation or with training
21
  """
22
 
23
  import json
24
  import logging
 
25
  import time
26
  from pathlib import Path
27
+ from typing import Annotated, Any
28
 
29
  import numpy as np
30
  import torch
31
+ import typer
32
+ from beam import Volume, function
33
  from datasets import load_dataset
34
  from model2vec.distill import distill
35
  from model2vec.train.base import FinetunableStaticModel, TextDataset
 
39
 
40
  from .beam_utils import (
41
  BeamCheckpointManager,
 
42
  create_beam_utilities,
43
+ download_model_from_beam,
44
+ sync_checkpoints_from_beam,
45
+ sync_checkpoints_to_beam,
46
+ upload_model_to_beam,
47
+ )
48
+ from .config import (
49
+ BEAM_ENV_SETTINGS,
50
+ GPU_NAME,
51
+ IMAGE,
52
+ codesearchnet_config,
53
+ directories,
54
+ distillation_config,
55
+ get_volume_config,
56
+ languages_config,
57
  )
58
 
59
  # =============================================================================
60
+ # CONFIGURATION
61
  # =============================================================================
62
 
63
+ VOLUME_CONFIG = get_volume_config()
64
+ LOCAL_BASE_DIR = directories.base
65
+ LOCAL_FINAL_DIR = directories.final
66
+
67
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
68
+ logger = logging.getLogger(__name__)
69
+
70
+ # Teacher models for distillation
71
+ DEFAULT_TEACHER_MODELS = list(distillation_config.code_teacher_models)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # =============================================================================
74
+ # UTILITY FUNCTIONS
75
  # =============================================================================
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ def apply_local_patches() -> bool:
79
+ """Apply patches locally without requiring Beam utilities."""
80
+ try:
81
+ try:
82
+ from .patch_utils import apply_all_patches
83
 
84
+ patches_applied = apply_all_patches()
85
+ logger.info(f"Successfully applied {patches_applied} patches via patch_utils")
86
+ return True
87
+ except ImportError:
88
+ logger.warning("patch_utils not available, trying direct patching")
89
 
90
+ return False
91
+
92
+ except Exception as e:
93
+ logger.warning(f"Failed to apply patches: {e}")
94
+ return False
95
+
96
+
97
+ def get_current_config_hash(enable_training: bool) -> str:
98
  """Generate a hash of current configuration parameters for checkpoint validation."""
99
  import hashlib
100
 
101
  config_params = {
102
+ "pca_dims": distillation_config.optimal_pca_dims,
103
+ "sif_coefficient": distillation_config.sif_coefficient,
104
+ "apply_zipf": distillation_config.apply_zipf,
105
+ "enable_training": enable_training,
 
 
106
  }
107
 
108
+ if enable_training:
109
+ config_params.update(
110
+ {
111
+ "training_epochs": distillation_config.training_epochs,
112
+ "learning_rate": distillation_config.learning_rate,
113
+ "max_samples": distillation_config.max_training_samples,
114
+ }
115
+ )
116
+
117
  config_str = str(sorted(config_params.items()))
118
  return hashlib.md5(config_str.encode()).hexdigest()[:12] # noqa: S324
119
 
120
 
121
+ def check_existing_base_model(teacher_name: str) -> str | None:
122
+ """Check if base distilled model already exists locally."""
123
+ base_dir = Path(LOCAL_BASE_DIR)
124
+ model_dir = base_dir / f"code_model2vec_{teacher_name}"
125
+
126
+ if model_dir.exists():
127
+ # Check for essential model files
128
+ has_config = (model_dir / "config.json").exists()
129
+ has_model_file = any(
130
+ [
131
+ (model_dir / "model.safetensors").exists(),
132
+ (model_dir / "model.bin").exists(),
133
+ (model_dir / "pytorch_model.bin").exists(),
134
+ ]
135
+ )
136
 
137
+ if has_config and has_model_file:
138
+ logger.info(f"✅ Found existing base model: {teacher_name}")
139
+ return str(model_dir)
140
 
141
+ return None
142
+
143
+
144
+ def check_existing_final_model(teacher_name: str, enable_training: bool = False) -> str | None:
145
+ """Check if final model already exists locally."""
146
+ final_dir = Path(LOCAL_FINAL_DIR)
147
+
148
+ # Add suffix for trained models
149
+ model_name = f"code_model2vec_{teacher_name}"
150
+ if enable_training:
151
+ model_name += "_fine_tuned"
152
+ model_dir = final_dir / model_name
153
+
154
+ if model_dir.exists():
155
+ # Check for essential model files
156
+ has_config = (model_dir / "config.json").exists()
157
+ has_model_file = any(
158
+ [
159
+ (model_dir / "model.safetensors").exists(),
160
+ (model_dir / "model.bin").exists(),
161
+ (model_dir / "pytorch_model.bin").exists(),
162
+ ]
163
+ )
164
+
165
+ if has_config and has_model_file:
166
+ logger.info(f"✅ Found existing final model: {teacher_name}{'_fine_tuned' if enable_training else ''}")
167
+ return str(model_dir)
168
 
169
+ return None
170
+
171
+
172
+ def copy_base_to_final(teacher_name: str, enable_training: bool = False) -> bool:
173
+ """Copy base model to final directory."""
174
+ import shutil
175
+
176
+ base_path = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}"
177
+
178
+ # Add suffix for trained models
179
+ final_model_name = f"code_model2vec_{teacher_name}"
180
+ if enable_training:
181
+ final_model_name += "_fine_tuned"
182
+ final_path = Path(LOCAL_FINAL_DIR) / final_model_name
183
+
184
+ try:
185
+ final_path.parent.mkdir(parents=True, exist_ok=True)
186
+ if final_path.exists():
187
+ shutil.rmtree(final_path)
188
+ shutil.copytree(base_path, final_path)
189
+ logger.info(f"📁 Copied {teacher_name} from base to final{'_fine_tuned' if enable_training else ''}")
190
+ return True
191
+ except Exception:
192
+ logger.exception(f"❌ Failed to copy {teacher_name} to final{'_fine_tuned' if enable_training else ''}")
193
  return False
194
 
 
 
195
 
196
+ def sync_model_from_beam(
197
+ teacher_name: str,
198
+ target_dir: str,
199
+ use_beam_utilities: bool = False,
200
+ ) -> bool:
201
+ """Sync model from Beam volume to local directory."""
202
+ if not use_beam_utilities:
203
  return False
204
 
205
+ try:
206
+ target_path = Path(target_dir)
207
+ target_path.mkdir(parents=True, exist_ok=True)
208
+
209
+ beam_model_name = f"{teacher_name}_model"
210
+ success = download_model_from_beam(VOLUME_CONFIG.name, beam_model_name, str(target_path))
211
+
212
+ if success:
213
+ logger.info(f"📥 Synced {teacher_name} from Beam to {target_dir}")
214
+ return True
215
+ logger.warning(f"⚠️ Failed to sync {teacher_name} from Beam")
216
  return False
217
 
218
+ except Exception as e:
219
+ logger.warning(f"Failed to sync {teacher_name} from Beam: {e}")
220
+ return False
221
+
222
+
223
+ def sync_model_to_beam(
224
+ teacher_name: str,
225
+ source_dir: str,
226
+ use_beam_utilities: bool = False,
227
+ ) -> bool:
228
+ """Sync model from local directory to Beam volume."""
229
+ if not use_beam_utilities:
230
+ return False
231
+
232
+ try:
233
+ beam_model_name = f"{teacher_name}_model"
234
+ success = upload_model_to_beam(VOLUME_CONFIG.name, beam_model_name, source_dir)
235
+
236
+ if success:
237
+ logger.info(f"📤 Synced {teacher_name} to Beam from {source_dir}")
238
+ return True
239
+ logger.warning(f"⚠️ Failed to sync {teacher_name} to Beam")
240
  return False
241
 
242
+ except Exception as e:
243
+ logger.warning(f"Failed to sync {teacher_name} to Beam: {e}")
244
+ return False
245
+
246
+
247
+ # =============================================================================
248
+ # DISTILLATION FUNCTIONS
249
+ # =============================================================================
250
 
251
 
252
+ def simple_distillation(
253
+ teacher_model: str,
254
+ output_dir: str,
255
+ pca_dims: int | None = None,
256
+ retry_with_cache_clear: bool = False,
257
+ ) -> Any:
258
  """
259
+ Perform simple Model2Vec distillation without additional training.
260
 
261
  Args:
262
+ teacher_model: Name of teacher model
263
+ output_dir: Output directory for the distilled model
264
+ pca_dims: PCA dimensions (uses config default if None)
265
+ retry_with_cache_clear: Whether this is a retry after clearing cache
266
 
267
  Returns:
268
+ Distilled model or None if failed
269
  """
270
+ if pca_dims is None:
271
+ pca_dims = int(distillation_config.optimal_pca_dims)
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
+ output_path = Path(output_dir)
274
+ output_path.mkdir(parents=True, exist_ok=True)
275
+
276
+ retry_suffix = " (retry after cache clear)" if retry_with_cache_clear else ""
277
+ logger.info(f"🔄 Simple distillation{retry_suffix}: {teacher_model} → {output_dir}")
278
+ logger.info(f"📊 PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}")
279
+
280
+ start_time = time.time()
281
+
282
+ try:
283
+ # Perform distillation with optimal parameters
284
+ model = distill(
285
+ model_name=teacher_model,
286
+ pca_dims=int(pca_dims),
287
+ apply_zipf=bool(distillation_config.apply_zipf),
288
+ sif_coefficient=float(distillation_config.sif_coefficient),
289
+ trust_remote_code=True,
290
+ )
291
+
292
+ logger.info("✅ Core distillation completed successfully")
293
+
294
+ # Save the model
295
+ model.save_pretrained(str(output_path))
296
+ logger.info(f"💾 Model saved to {output_path}")
297
+
298
+ # Log model info
299
+ logger.info(f"Model type: {type(model)}")
300
+ if hasattr(model, "embedding"):
301
+ logger.info(f"Embedding shape: {model.embedding.shape}")
302
+ logger.info(f"Embedding dtype: {model.embedding.dtype}")
303
 
304
+ total_time = time.time() - start_time
305
+ logger.info(f"🎉 Simple distillation completed in {total_time:.2f} seconds")
306
+ return model
307
+
308
+ except ValueError as e:
309
+ if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
310
+ logger.warning(f"⚠️ Token-vector mismatch with {teacher_model} - this is a Model2Vec library issue")
311
+ logger.warning(f"Error details: {e}")
312
+ logger.warning("💡 This model has incompatible tokenization. Skipping...")
313
+ return None
314
+ if "weight is on the meta device" in str(e):
315
+ logger.warning(f"⚠️ Device placement issue with {teacher_model} - model weights on meta device")
316
+ logger.warning(f"Error details: {e}")
317
+ logger.warning("💡 This model has device placement issues. Skipping...")
318
+ return None
319
+ raise
320
+ except AttributeError as e:
321
+ if "backend_tokenizer" in str(e):
322
+ logger.warning(f"⚠️ Tokenizer compatibility issue with {teacher_model}")
323
+ logger.warning(f"Error details: {e}")
324
+ logger.warning("💡 This model's tokenizer is incompatible with Model2Vec. Skipping...")
325
+ return None
326
+ raise
327
+ except FileNotFoundError as e:
328
+ if "transformers_modules" in str(e) or "xlm_padding.py" in str(e):
329
+ logger.warning(f"⚠️ Missing custom model files for {teacher_model}")
330
+ logger.warning(f"Error details: {e}")
331
+
332
+ # Try clearing cache and retrying once
333
+ if not retry_with_cache_clear:
334
+ logger.info("🔧 Attempting to clear cache and retry...")
335
+ if clear_model_cache(teacher_model):
336
+ logger.info("🔄 Retrying distillation after cache clear...")
337
+ return simple_distillation(teacher_model, output_dir, pca_dims, retry_with_cache_clear=True)
338
+
339
+ logger.warning("💡 This model has missing dependencies. Manual intervention may be required.")
340
+ return None
341
+ raise
342
+ except Exception:
343
+ logger.exception(f"❌ Simple distillation failed for {teacher_model}")
344
+ return None
345
+
346
+
347
+ def load_codesearchnet_dataset(
348
+ max_samples: int | None = None,
349
  checkpoint_manager: BeamCheckpointManager | None = None,
350
  ) -> list[str]:
351
+ """Load and format the CodeSearchNet dataset for training with balanced language distribution."""
352
+ if max_samples is None:
353
+ max_samples = int(distillation_config.max_training_samples)
354
+
355
+ logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
356
  logger.info(f"Limiting to {max_samples} samples for training efficiency")
357
+ logger.info(f"Languages: {', '.join(languages_config.all)}")
358
+
359
+ # Check for existing dataset checkpoint
360
+ texts = []
361
+ start_from = 0
362
 
 
363
  if checkpoint_manager:
364
  checkpoint_data = checkpoint_manager.load_checkpoint("dataset", 0)
365
  if checkpoint_data:
366
+ cached_texts = checkpoint_data.get("data", {}).get("texts", [])
367
+ if len(cached_texts) >= max_samples:
368
+ logger.info(f"✅ Resumed dataset loading: {len(cached_texts)} texts from checkpoint")
369
+ return cached_texts[:max_samples]
370
+ logger.info(f"📋 Partial dataset found: {len(cached_texts)} texts, continuing...")
371
+ texts = cached_texts
372
+ start_from = len(texts)
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  try:
375
+ # Calculate samples per language for balanced distribution
376
+ num_languages = len(languages_config.all)
377
+ samples_per_language = max_samples // num_languages
378
+ remaining_samples = max_samples % num_languages
379
 
380
+ logger.info(f"📊 Target distribution: {samples_per_language} samples per language")
381
+ if remaining_samples > 0:
382
+ logger.info(f"📊 Extra {remaining_samples} samples will be distributed to first languages")
 
 
 
 
383
 
384
+ # Load training data from each language separately for balanced distribution
385
+ language_texts: dict[str, list[str]] = {}
386
+ total_collected = len(texts)
387
+
388
+ for i, language in enumerate(languages_config.all):
389
+ if total_collected >= max_samples:
390
  break
391
 
392
+ logger.info(f"🔍 Loading {language} training data...")
393
+
394
+ # Determine how many samples to collect for this language
395
+ target_for_lang = samples_per_language
396
+ if i < remaining_samples: # Distribute extra samples to first languages
397
+ target_for_lang += 1
398
+
399
+ # Skip if we already have enough from this language
400
+ if language in language_texts and len(language_texts[language]) >= target_for_lang:
401
+ continue
402
+
403
+ try:
404
+ # Load training split for the specific language (same format as evaluate.py)
405
+ dataset = load_dataset(
406
+ codesearchnet_config.dataset_name,
407
+ language,
408
+ split="train",
409
+ trust_remote_code=True,
410
+ )
411
+
412
+ lang_texts: list[str] = []
413
+ processed_count = 0
414
 
415
+ for processed_count, example in enumerate(dataset, 1):
416
+ if len(lang_texts) >= target_for_lang:
417
+ break
418
 
419
+ # Use same field names as evaluate.py
420
+ doc_string = example.get("func_documentation_string", "").strip()
421
+ code_string = example.get("func_code_string", "").strip()
422
 
423
+ if doc_string and code_string and len(doc_string.split()) >= 3 and len(code_string) > 50:
424
+ # Format as documentation-code pair for training (same as evaluate.py)
425
+ text = f"Documentation: {doc_string}\nCode:\n{code_string}"
 
 
426
 
427
+ # Ensure reasonable length for embedding models
428
+ if len(text) <= 2048:
429
+ lang_texts.append(text)
430
+
431
+ if processed_count % 5000 == 0:
432
+ logger.info(f" {language}: processed {processed_count}, collected {len(lang_texts)}")
433
+
434
+ language_texts[language] = lang_texts
435
+ total_collected += len(lang_texts)
436
+ logger.info(f"✅ {language}: collected {len(lang_texts)} samples")
437
+
438
+ except Exception as e:
439
+ logger.warning(f"⚠️ Failed to load {language} data: {e}")
440
+ continue
441
+
442
+ # Combine all language texts in a balanced way
443
+ combined_texts = []
444
+
445
+ # Add existing texts first (from checkpoint)
446
+ if start_from > 0:
447
+ combined_texts = texts[:start_from]
448
+
449
+ # Interleave texts from different languages for better training distribution
450
+ max_lang_samples = max(len(lang_texts) for lang_texts in language_texts.values()) if language_texts else 0
451
+
452
+ for sample_idx in range(max_lang_samples):
453
+ for language in languages_config.all:
454
+ if len(combined_texts) >= max_samples:
455
+ break
456
+
457
+ if language in language_texts and sample_idx < len(language_texts[language]):
458
+ combined_texts.append(language_texts[language][sample_idx])
459
+
460
+ if len(combined_texts) >= max_samples:
461
+ break
462
+
463
+ # Truncate to exact max_samples
464
+ combined_texts = combined_texts[:max_samples]
465
+
466
+ # Log final distribution
467
+ logger.info("📊 Final dataset distribution:")
468
+ lang_counts: dict[str, int] = {}
469
+ for text in combined_texts:
470
+ # Simple heuristic to identify language from code patterns
471
+ if "def " in text and ":" in text:
472
+ lang_counts["python"] = lang_counts.get("python", 0) + 1
473
+ elif "function " in text and "{" in text:
474
+ lang_counts["javascript"] = lang_counts.get("javascript", 0) + 1
475
+ elif "public " in text and "class " in text:
476
+ lang_counts["java"] = lang_counts.get("java", 0) + 1
477
+ elif "<?php" in text or "$" in text:
478
+ lang_counts["php"] = lang_counts.get("php", 0) + 1
479
+ elif "func " in text and "end" in text:
480
+ lang_counts["ruby"] = lang_counts.get("ruby", 0) + 1
481
+ elif "func " in text and "}" in text:
482
+ lang_counts["go"] = lang_counts.get("go", 0) + 1
483
+ else:
484
+ lang_counts["other"] = lang_counts.get("other", 0) + 1
485
+
486
+ for lang, count in lang_counts.items():
487
+ percentage = (count / len(combined_texts)) * 100
488
+ logger.info(f" {lang}: {count} samples ({percentage:.1f}%)")
489
 
490
  # Final checkpoint save
491
  if checkpoint_manager:
492
+ checkpoint_data = {
493
+ "config_hash": get_current_config_hash(enable_training=True),
494
+ "stage": "dataset",
495
+ "step": 0,
496
+ "timestamp": time.time(),
497
+ "data": {"texts": combined_texts},
498
+ }
499
  checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0)
500
 
501
+ logger.info(f"Successfully loaded {len(combined_texts)} balanced code-documentation pairs from CodeSearchNet")
502
+ return combined_texts
503
 
504
  except Exception:
505
  logger.exception("Error loading CodeSearchNet dataset")
506
  return texts # Return what we have so far
507
 
508
 
509
+ def generate_teacher_embeddings(
510
  teacher_model: SentenceTransformer,
511
  texts: list[str],
512
  checkpoint_manager: BeamCheckpointManager | None = None,
 
514
  """Generate teacher embeddings for code training with checkpoint support."""
515
  logger.info(f"Generating teacher embeddings for {len(texts)} texts...")
516
 
517
+ # Check for existing embeddings checkpoint
 
 
518
  if checkpoint_manager:
519
+ volume_path = Path(VOLUME_CONFIG.mount_path)
520
+ embeddings_path = volume_path / "embeddings_cache.pt"
521
+ config_path = volume_path / "embeddings_config.json"
522
 
523
  if embeddings_path.exists() and config_path.exists():
524
  try:
 
526
  with config_path.open("r") as f:
527
  config_data = json.load(f)
528
 
529
+ current_hash = get_current_config_hash(enable_training=True)
530
+ if config_data.get("config_hash") == current_hash:
 
 
 
 
 
531
  # Load the embeddings tensor
532
  final_embeddings = torch.load(embeddings_path, map_location="cpu")
533
  num_expected = config_data.get("num_texts", len(texts))
534
 
535
  if final_embeddings.shape[0] >= num_expected:
536
+ logger.info(f"✅ Loaded embeddings from cache ({final_embeddings.shape[0]} embeddings)")
537
+ return final_embeddings[: len(texts)]
538
+
 
 
 
 
 
 
 
 
539
  except Exception as e:
540
  logger.warning(f"Failed to load embeddings cache: {e}, regenerating...")
 
 
 
 
 
541
 
542
  # Generate embeddings from scratch
543
  logger.info("Generating fresh teacher embeddings...")
544
 
545
+ batch_size = int(distillation_config.teacher_model_config.get("batch_size", 16))
 
 
 
 
546
  embeddings_list = []
547
 
548
+ for i in range(0, len(texts), batch_size):
549
+ batch_texts = texts[i : i + batch_size]
550
 
551
  try:
 
552
  batch_embeddings = teacher_model.encode(
553
  batch_texts,
554
  convert_to_tensor=True,
555
+ batch_size=batch_size,
556
+ show_progress_bar=False,
557
+ normalize_embeddings=True,
558
  )
559
  embeddings_list.append(batch_embeddings)
560
 
561
+ if i % (batch_size * 10) == 0:
562
  logger.info(f"Generated embeddings for {i + len(batch_texts)}/{len(texts)} texts")
563
 
564
  except torch.cuda.OutOfMemoryError:
565
+ logger.warning(f"GPU OOM with batch size {batch_size}, reducing...")
566
+ torch.cuda.empty_cache()
567
+ batch_size = max(1, batch_size // 2)
 
 
 
 
 
 
568
 
569
  # Retry with smaller batch size
 
570
  batch_embeddings = teacher_model.encode(
571
  batch_texts,
572
  convert_to_tensor=True,
573
+ batch_size=batch_size,
574
  show_progress_bar=False,
575
  normalize_embeddings=True,
576
  )
577
  embeddings_list.append(batch_embeddings)
578
 
579
+ # Combine all embeddings
 
 
580
  teacher_embeddings = torch.cat(embeddings_list, dim=0)
581
 
582
+ # Ensure fp32 precision
583
  if teacher_embeddings.dtype != torch.float32:
 
584
  teacher_embeddings = teacher_embeddings.to(torch.float32)
585
 
586
  logger.info(f"Generated {teacher_embeddings.shape[0]} teacher embeddings in {teacher_embeddings.dtype}")
587
 
588
+ # Save embeddings cache for future runs
589
  if checkpoint_manager:
590
  try:
591
+ volume_path = Path(VOLUME_CONFIG.mount_path)
592
+ embeddings_path = volume_path / "embeddings_cache.pt"
593
+ config_path = volume_path / "embeddings_config.json"
594
 
595
  # Save embeddings tensor
596
  torch.save(teacher_embeddings, embeddings_path)
597
 
598
  # Save configuration
599
  config_data = {
600
+ "config_hash": get_current_config_hash(enable_training=True),
 
 
 
 
 
 
 
 
601
  "num_texts": len(texts),
602
  "embedding_shape": list(teacher_embeddings.shape),
603
  "timestamp": time.time(),
 
614
  return teacher_embeddings
615
 
616
 
617
+ def advanced_training(
618
  student_model: Any,
619
+ teacher_model: SentenceTransformer,
 
 
620
  checkpoint_manager: BeamCheckpointManager | None = None,
 
621
  ) -> Any:
622
+ """Perform advanced code specialization training."""
623
+ logger.info("🎓 Starting advanced code specialization training...")
624
 
625
+ # Load CodeSearchNet training data
626
+ training_texts = load_codesearchnet_dataset(checkpoint_manager=checkpoint_manager)
 
 
 
627
 
628
+ if not training_texts:
629
+ logger.warning("No training data available, skipping advanced training")
630
+ return student_model
 
631
 
632
+ # Generate teacher embeddings
633
+ teacher_embeddings = generate_teacher_embeddings(teacher_model, training_texts, checkpoint_manager)
634
 
635
+ # Create trainable model
636
+ student_embedding_dim = student_model.embedding.shape[1]
637
+ teacher_embedding_dim = teacher_embeddings.shape[1]
 
638
 
639
+ # Project teacher embeddings if needed
640
+ if teacher_embedding_dim != student_embedding_dim:
641
+ from sklearn.decomposition import PCA
642
 
643
+ logger.info("Performing PCA projection for dimension matching...")
644
+ pca = PCA(n_components=student_embedding_dim)
645
+ teacher_embeddings_np = teacher_embeddings.cpu().numpy().astype(np.float64)
646
+ teacher_embeddings_projected = pca.fit_transform(teacher_embeddings_np)
647
+ teacher_embeddings = torch.tensor(teacher_embeddings_projected.astype(np.float32), dtype=torch.float32)
648
 
649
+ # Create trainable model
650
+ trainable_model = FinetunableStaticModel.from_static_model(
651
+ model=student_model,
652
+ out_dim=student_embedding_dim,
653
+ )
654
+ trainable_model = trainable_model.float()
655
+
656
+ # Tokenize texts
657
+ tokenized_texts = []
658
+ for text in training_texts:
659
+ tokens = trainable_model.tokenize([text])
660
+ if tokens.shape[1] > 0:
661
+ tokenized_texts.append(tokens[0].tolist())
662
+
663
+ # Prepare training data
664
+ targets = teacher_embeddings[: len(tokenized_texts)].to(torch.float32)
665
+ train_texts, val_texts, train_targets, val_targets = train_test_split(
666
+ tokenized_texts, targets, test_size=0.2, random_state=42
667
+ )
668
 
669
+ # Training setup
670
+ train_dataset = TextDataset(train_texts, train_targets)
671
+ val_dataset = TextDataset(val_texts, val_targets)
672
 
673
+ optimizer = optim.Adam(trainable_model.parameters(), lr=float(distillation_config.learning_rate))
674
+ mse_loss = nn.MSELoss()
 
675
 
676
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
677
+ trainable_model = trainable_model.to(device)
678
 
679
+ batch_size = int(distillation_config.batch_size)
680
+ epochs = int(distillation_config.training_epochs)
 
681
 
682
+ # Training loop
683
+ for epoch in range(epochs):
684
+ trainable_model.train()
 
 
 
 
 
 
685
 
686
+ try:
687
+ train_loader = train_dataset.to_dataloader(shuffle=True, batch_size=batch_size)
688
+ epoch_loss = 0.0
689
+ num_batches = 0
 
690
 
691
+ for _batch_idx, (tokens, targets_batch) in enumerate(train_loader):
692
+ batch_tokens = tokens.to(device)
693
+ batch_targets = targets_batch.to(device).to(torch.float32)
694
 
695
+ optimizer.zero_grad()
696
+ _, student_embeddings = trainable_model(batch_tokens)
697
+ student_embeddings = student_embeddings.to(torch.float32)
698
 
699
+ loss = mse_loss(student_embeddings, batch_targets)
700
+ loss.backward()
701
+ optimizer.step()
 
 
702
 
703
+ epoch_loss += loss.item()
704
+ num_batches += 1
705
 
706
+ except torch.cuda.OutOfMemoryError:
707
+ logger.warning(f"Training OOM with batch size {batch_size}, reducing...")
708
+ batch_size = max(1, batch_size // 2)
709
+ torch.cuda.empty_cache()
710
+ continue
711
 
712
+ avg_train_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
 
 
 
 
713
 
714
+ # Validation
715
+ trainable_model.eval()
716
+ val_loader = val_dataset.to_dataloader(shuffle=False, batch_size=batch_size)
717
+ val_loss = 0.0
718
+ val_batches = 0
719
 
720
+ with torch.no_grad():
721
+ for tokens, targets_batch in val_loader:
722
+ batch_tokens = tokens.to(device)
723
+ batch_targets = targets_batch.to(device).to(torch.float32)
 
 
724
 
725
+ _, student_embeddings = trainable_model(batch_tokens)
726
+ student_embeddings = student_embeddings.to(torch.float32)
727
 
728
+ loss = mse_loss(student_embeddings, batch_targets)
729
+ val_loss += loss.item()
730
+ val_batches += 1
731
 
732
+ avg_val_loss = val_loss / val_batches if val_batches > 0 else 0.0
 
 
733
 
734
+ logger.info(f"Epoch {epoch + 1}/{epochs} - Train: {avg_train_loss:.6f}, Val: {avg_val_loss:.6f}")
 
735
 
736
+ # Save checkpoint
737
+ if checkpoint_manager:
738
+ checkpoint_data = {
739
+ "config_hash": get_current_config_hash(enable_training=True),
740
+ "stage": "training",
741
+ "step": epoch + 1,
742
+ "timestamp": time.time(),
743
+ "data": {
744
+ "model_state": trainable_model.state_dict(),
745
+ "optimizer_state": optimizer.state_dict(),
746
+ "train_loss": avg_train_loss,
747
+ "val_loss": avg_val_loss,
748
+ },
749
+ }
750
+ checkpoint_manager.save_checkpoint("training", checkpoint_data, epoch + 1)
751
 
752
+ # Convert back to static model
753
+ refined_model = trainable_model.to_static_model()
754
+ logger.info("✅ Advanced training completed")
755
 
756
+ return refined_model
757
 
 
 
 
 
 
 
 
 
 
758
 
759
+ def distill_single_teacher(
760
+ teacher_model: str,
761
+ enable_training: bool = False,
762
+ use_beam_utilities: bool = False,
763
+ pca_dims: int | None = None,
764
+ ) -> dict[str, Any]:
765
+ """
766
+ Distill a single teacher model with optional training.
767
 
768
+ Args:
769
+ teacher_model: Name of teacher model
770
+ enable_training: Whether to enable advanced training
771
+ use_beam_utilities: Whether to use Beam utilities
772
+ pca_dims: PCA dimensions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
773
 
774
+ Returns:
775
+ Dictionary with distillation results
776
+ """
777
+ teacher_name = teacher_model.split("/")[-1].replace("-", "_")
778
+ base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}"
779
+
780
+ # Add suffix for trained models
781
+ final_model_name = f"code_model2vec_{teacher_name}"
782
+ if enable_training:
783
+ final_model_name += "_fine_tuned"
784
+ final_dir = Path(LOCAL_FINAL_DIR) / final_model_name
785
+
786
+ logger.info(f"\n{'=' * 60}")
787
+ logger.info(f"🔄 Processing teacher model: {teacher_model}")
788
+ logger.info(f"📁 Teacher name: {teacher_name}")
789
+ logger.info(f"🎓 Training enabled: {enable_training}")
790
+ logger.info(f"{'=' * 60}")
791
+
792
+ # Check model compatibility first
793
+ is_compatible, warning_msg = check_model_compatibility(teacher_model)
794
+ if not is_compatible:
795
+ logger.warning(f"⚠️ Known compatibility issue: {warning_msg}")
796
+ logger.info("🔧 Attempting distillation anyway, but may fail...")
797
+
798
+ # Try model-specific workarounds
799
+ workaround_type = try_model_workarounds(teacher_model)
800
+ # Don't skip if we have a workaround - we'll use it later
801
 
802
+ start_time = time.time()
803
 
804
+ # Initialize Beam utilities if requested
805
+ checkpoint_mgr = None
806
+ model_mgr = None
807
+ if use_beam_utilities:
808
+ try:
809
+ _, checkpoint_mgr, model_mgr, _ = create_beam_utilities(VOLUME_CONFIG.name, VOLUME_CONFIG.mount_path)
810
+ except Exception as e:
811
+ logger.warning(f"Failed to initialize Beam utilities: {e}")
812
 
813
+ try:
814
+ # Step 1: Check for existing final model
815
+ existing_final = check_existing_final_model(teacher_name, enable_training)
816
+ if existing_final:
817
+ logger.info(f"✅ Final model already exists: {teacher_name}{'_fine_tuned' if enable_training else ''}")
818
+ return {
819
+ "teacher_model": teacher_model,
820
+ "teacher_name": teacher_name,
821
+ "status": "skipped_existing_final",
822
+ "final_path": existing_final,
823
+ "distillation_time": 0.0,
824
+ }
825
 
826
+ # Step 1.5: Sync existing checkpoints from Beam if using Beam utilities
827
+ if use_beam_utilities and checkpoint_mgr:
828
+ logger.info(f"🔄 Syncing existing checkpoints for {teacher_name}...")
829
+ sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints)
830
+ if enable_training:
831
+ sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints)
832
+
833
+ # Step 2: Check for existing base model or create it
834
+ existing_base = check_existing_base_model(teacher_name)
835
+ base_model = None
836
+
837
+ if existing_base:
838
+ logger.info(f"✅ Found existing base model: {teacher_name}")
839
+ if enable_training:
840
+ # Load base model for training
841
+ from model2vec.model import StaticModel
842
 
843
+ base_model = StaticModel.from_pretrained(existing_base)
844
+ elif use_beam_utilities:
845
+ synced = sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities)
846
+ if synced:
847
+ existing_base = str(base_dir)
848
+ if enable_training:
849
+ from model2vec.model import StaticModel
850
 
851
+ base_model = StaticModel.from_pretrained(existing_base)
 
 
852
 
853
+ if not existing_base:
854
+ # Perform simple distillation to create base model
855
+ logger.info(f"🔄 Creating base model for {teacher_name}")
856
 
857
+ # Check if we need specialized distillation
858
+ workaround_type = try_model_workarounds(teacher_model)
 
859
 
860
+ if workaround_type == "salesforce":
861
+ base_model = salesforce_model_distillation(teacher_model, str(base_dir), pca_dims)
862
+ elif workaround_type == "baai":
863
+ base_model = baai_bge_model_distillation(teacher_model, str(base_dir), pca_dims)
864
+ else:
865
+ base_model = simple_distillation(teacher_model, str(base_dir), pca_dims)
866
+
867
+ if base_model is None:
868
+ return {
869
+ "teacher_model": teacher_model,
870
+ "teacher_name": teacher_name,
871
+ "status": "failed_base_distillation",
872
+ "error": "Simple distillation failed",
873
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874
 
875
+ # Sync base model and checkpoints to Beam
876
+ if use_beam_utilities:
877
+ sync_model_to_beam(teacher_name, str(base_dir), use_beam_utilities)
878
+ if checkpoint_mgr:
879
+ sync_checkpoints_to_beam(
880
+ VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints
881
+ )
882
+
883
+ existing_base = str(base_dir)
884
+
885
+ # Step 3: Handle final model creation
886
+ if enable_training and base_model is not None:
887
+ # Perform advanced training
888
+ logger.info(f"🎓 Starting advanced training for {teacher_name}")
889
+
890
+ # Load teacher model for training
891
+ device = "cuda" if torch.cuda.is_available() else "cpu"
892
+ teacher_st_model = SentenceTransformer(teacher_model, device=device, trust_remote_code=True)
893
 
894
+ # Perform advanced training
895
+ final_model = advanced_training(base_model, teacher_st_model, checkpoint_mgr)
 
 
 
 
896
 
897
+ # Save final model
898
+ final_dir.mkdir(parents=True, exist_ok=True)
899
+ final_model.save_pretrained(str(final_dir))
900
 
901
+ # Sync final model and training checkpoints to Beam
902
+ if use_beam_utilities:
903
+ sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities)
904
+ if checkpoint_mgr:
905
+ sync_checkpoints_to_beam(VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints)
906
+
907
+ del teacher_st_model
908
+ if torch.cuda.is_available():
909
+ torch.cuda.empty_cache()
910
 
911
+ else:
912
+ # Copy base to final (no training)
913
+ logger.info(f"📁 Copying base to final for {teacher_name}")
914
+ if not copy_base_to_final(teacher_name, enable_training):
915
+ return {
916
+ "teacher_model": teacher_model,
917
+ "teacher_name": teacher_name,
918
+ "status": "failed_copy_to_final",
919
+ "error": "Failed to copy base to final",
920
+ }
921
 
922
+ total_time = time.time() - start_time
923
 
924
+ return {
925
+ "teacher_model": teacher_model,
926
+ "teacher_name": teacher_name,
927
+ "status": "success",
928
+ "enable_training": enable_training,
929
+ "base_path": existing_base,
930
+ "final_path": str(final_dir),
931
+ "distillation_time": total_time,
932
+ }
933
 
934
  except Exception as e:
935
+ logger.exception(f" Failed to process {teacher_model}")
936
+ return {
937
+ "teacher_model": teacher_model,
938
+ "teacher_name": teacher_name,
939
+ "status": "failed",
940
+ "error": str(e),
941
+ }
942
 
943
 
944
+ # =============================================================================
945
+ # MAIN EXECUTION FUNCTIONS
946
+ # =============================================================================
 
 
 
 
947
 
 
 
 
 
948
 
949
+ def run_local_distillation(
950
+ teacher_models: list[str] | None = None,
951
+ enable_training: bool = False,
952
+ pca_dims: int | None = None,
953
+ clear_cache: bool = False,
954
+ ) -> dict[str, Any]:
955
+ """Run distillation locally."""
956
+ logger.info("🖥️ Running distillation locally")
957
 
958
+ if teacher_models is None:
959
+ teacher_models = DEFAULT_TEACHER_MODELS
960
 
961
+ # Apply patches
962
+ patch_success = apply_local_patches()
963
+ if patch_success:
964
+ logger.info(" Successfully applied patches")
965
+ else:
966
+ logger.warning("⚠️ Failed to apply patches - some models may fail")
967
+
968
+ results = {}
969
+ successful_models = []
970
+
971
+ logger.info("🚀 Starting distillation workflow")
972
+ logger.info(f"📊 Processing {len(teacher_models)} teacher models")
973
+ logger.info(f"🎓 Training enabled: {enable_training}")
974
+
975
+ # Use default models if none specified
976
+ models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS
977
+
978
+ logger.info(f"📊 Teacher models to process: {len(models_to_distill)}")
979
+ for i, model in enumerate(models_to_distill, 1):
980
+ logger.info(f" {i}. {model}")
981
+
982
+ # Clear cache for problematic models if requested
983
+ if clear_cache:
984
+ logger.info("🧹 Clearing cache for known problematic models...")
985
+ problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"]
986
+ for model in problematic_models:
987
+ if model in models_to_distill:
988
+ clear_model_cache(model)
989
+
990
+ for teacher_model in models_to_distill:
991
+ result = distill_single_teacher(
992
+ teacher_model=teacher_model,
993
+ enable_training=enable_training,
994
+ use_beam_utilities=False,
995
+ pca_dims=pca_dims,
996
+ )
997
 
998
+ teacher_name = result["teacher_name"]
999
+ results[teacher_name] = result
1000
+
1001
+ if result["status"] == "success" or result["status"].startswith("skipped"):
1002
+ successful_models.append(teacher_name)
1003
+
1004
+ # Summary
1005
+ logger.info("\n🏆 DISTILLATION WORKFLOW COMPLETE!")
1006
+ logger.info(f"📊 Successful models: {len(successful_models)}")
1007
+ logger.info(f"🎓 Training mode: {'Enabled' if enable_training else 'Basic distillation only'}")
1008
+
1009
+ for model_name in successful_models:
1010
+ result = results[model_name]
1011
+ logger.info(f"✅ {model_name}: {result['teacher_model']}")
1012
+
1013
+ # Save results summary
1014
+ results_summary = {
1015
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
1016
+ "enable_training": enable_training,
1017
+ "successful_models": successful_models,
1018
+ "all_results": results,
1019
+ "total_successful": len(successful_models),
1020
+ "total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
1021
+ }
1022
 
1023
+ # Save results to file
1024
+ results_file = Path(LOCAL_BASE_DIR).parent / "distillation_results.json"
1025
+ results_file.parent.mkdir(parents=True, exist_ok=True)
1026
+ with results_file.open("w") as f:
1027
+ json.dump(results_summary, f, indent=2)
1028
 
1029
+ logger.info(f"📊 Results summary saved to: {results_file}")
 
 
1030
 
1031
+ return results_summary
1032
 
 
 
1033
 
1034
+ @function(
1035
+ gpu=GPU_NAME,
1036
+ volumes=[Volume(name=VOLUME_CONFIG.name, mount_path=VOLUME_CONFIG.mount_path)],
1037
+ image=IMAGE,
1038
+ secrets=["HF_ACCESS_TOKEN"],
1039
+ env=BEAM_ENV_SETTINGS,
1040
+ timeout=3600 * 12, # 12 hours
1041
+ )
1042
+ def _beam_distill_models(
1043
+ teacher_models: list[str] | None = None,
1044
+ enable_training: bool = False,
1045
+ pca_dims: int | None = None,
1046
+ clear_cache: bool = False,
1047
+ ) -> dict[str, Any]:
1048
+ """Internal Beam function for distillation."""
1049
+ logger.info("☁️ Running distillation on Beam")
1050
+
1051
+ # Apply patches
1052
+ patch_success = apply_local_patches()
1053
+ if patch_success:
1054
+ logger.info("✅ Successfully applied patches")
1055
+ else:
1056
+ logger.warning("⚠️ Failed to apply patches - some models may fail")
1057
+
1058
+ if teacher_models is None:
1059
+ teacher_models = DEFAULT_TEACHER_MODELS
1060
+
1061
+ # Clear cache for problematic models if requested
1062
+ if clear_cache:
1063
+ logger.info("🧹 Clearing cache for known problematic models...")
1064
+ problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"]
1065
+ for model in problematic_models:
1066
+ if model in teacher_models:
1067
+ clear_model_cache(model)
1068
+
1069
+ results = {}
1070
+ successful_models = []
1071
+
1072
+ logger.info("🚀 Starting Beam distillation workflow")
1073
+ logger.info(f"📊 Processing {len(teacher_models)} teacher models")
1074
+ logger.info(f"🎓 Training enabled: {enable_training}")
1075
+
1076
+ # Use default models if none specified
1077
+ models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS
1078
+
1079
+ logger.info(f"📊 Teacher models to process: {len(models_to_distill)}")
1080
+ for i, model in enumerate(models_to_distill, 1):
1081
+ logger.info(f" {i}. {model}")
1082
+
1083
+ for teacher_model in models_to_distill:
1084
+ result = distill_single_teacher(
1085
+ teacher_model=teacher_model,
1086
+ enable_training=enable_training,
1087
+ use_beam_utilities=True,
1088
+ pca_dims=pca_dims,
1089
  )
1090
 
1091
+ teacher_name = result["teacher_name"]
1092
+ results[teacher_name] = result
1093
+
1094
+ if result["status"] == "success" or result["status"].startswith("skipped"):
1095
+ successful_models.append(teacher_name)
1096
+
1097
+ # Summary
1098
+ logger.info("\n🏆 BEAM DISTILLATION WORKFLOW COMPLETE!")
1099
+ logger.info(f"📊 Successful models: {len(successful_models)}")
1100
+
1101
+ # Save results to Beam volume
1102
+ volume_path = Path(VOLUME_CONFIG.mount_path)
1103
+ results_summary = {
1104
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
1105
+ "enable_training": enable_training,
1106
+ "successful_models": successful_models,
1107
+ "all_results": results,
1108
+ "total_successful": len(successful_models),
1109
+ "total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
1110
+ }
1111
+
1112
+ results_file = volume_path / "distillation_results.json"
1113
+ with results_file.open("w") as f:
1114
+ json.dump(results_summary, f, indent=2)
1115
+
1116
+ logger.info(f"📊 Beam results saved to: {results_file}")
1117
+
1118
+ return results_summary
1119
+
1120
+
1121
+ def run_beam_distillation(
1122
+ teacher_models: list[str] | None = None,
1123
+ enable_training: bool = False,
1124
+ pca_dims: int | None = None,
1125
+ clear_cache: bool = False,
1126
+ ) -> dict[str, Any]:
1127
+ """Run distillation on Beam and sync results."""
1128
+ logger.info("☁️ Running distillation on Beam with local sync")
1129
+
1130
+ try:
1131
+ # Run distillation on Beam
1132
+ results = _beam_distill_models.remote(teacher_models, enable_training, pca_dims, clear_cache)
1133
+
1134
+ # Check if Beam execution was successful
1135
+ if not results:
1136
+ logger.error("❌ Beam execution failed or returned no results")
1137
+ return {
1138
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
1139
+ "enable_training": enable_training,
1140
+ "successful_models": [],
1141
+ "all_results": {},
1142
+ "total_successful": 0,
1143
+ "total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
1144
+ "error": "Beam execution failed",
1145
+ }
1146
+
1147
+ # Sync models back to local directories
1148
+ if results.get("successful_models"):
1149
+ logger.info("📥 Syncing models from Beam to local directories...")
1150
+
1151
+ for teacher_name in results["successful_models"]:
1152
+ # Sync base model
1153
+ base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}"
1154
+ sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities=True)
1155
+
1156
+ # Sync final model if training was enabled
1157
+ if enable_training:
1158
+ final_dir = Path(LOCAL_FINAL_DIR) / f"code_model2vec_{teacher_name}"
1159
+ sync_model_from_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities=True)
1160
+ else:
1161
+ # Copy base to final
1162
+ copy_base_to_final(teacher_name, enable_training)
1163
+
1164
+ logger.info("✅ All models synced from Beam")
1165
+
1166
+ return results
1167
 
1168
  except Exception as e:
1169
+ logger.exception(" Beam distillation failed with exception")
1170
+ return {
1171
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
1172
+ "enable_training": enable_training,
1173
+ "successful_models": [],
1174
+ "all_results": {},
1175
+ "total_successful": 0,
1176
+ "total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
1177
+ "error": str(e),
1178
+ }
1179
 
1180
 
1181
+ # =============================================================================
1182
+ # CLI INTERFACE
1183
+ # =============================================================================
 
 
 
 
 
 
 
 
 
 
 
1184
 
 
 
 
 
1185
 
1186
+ def main(
1187
+ use_beam: Annotated[bool, typer.Option(help="Use Beam for distillation")] = False,
1188
+ train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False,
1189
+ teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None,
1190
+ pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None,
1191
+ clear_cache: Annotated[
1192
+ bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation")
1193
+ ] = False,
1194
+ ) -> None:
1195
+ """Unified distillation command with optional training."""
1196
+ logger.info("🚀 Starting unified Model2Vec distillation workflow")
1197
+ logger.info(f"🎓 Training mode: {'Advanced (CodeSearchNet fine-tuning)' if train else 'Basic distillation only'}")
1198
+ logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
1199
+
1200
+ # Use default models if none specified
1201
+ models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS
1202
+
1203
+ logger.info(f"📊 Teacher models to process: {len(models_to_distill)}")
1204
+ for i, model in enumerate(models_to_distill, 1):
1205
+ logger.info(f" {i}. {model}")
1206
+
1207
+ # Clear cache for problematic models if requested
1208
+ if clear_cache:
1209
+ logger.info("🧹 Clearing cache for known problematic models...")
1210
+ problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"]
1211
+ for model in problematic_models:
1212
+ if model in models_to_distill:
1213
+ clear_model_cache(model)
1214
+
1215
+ # Run distillation workflow
1216
+ if use_beam:
1217
+ results = run_beam_distillation(
1218
+ teacher_models=models_to_distill,
1219
+ enable_training=train,
1220
+ pca_dims=pca_dims,
1221
+ clear_cache=clear_cache,
1222
+ )
1223
+ else:
1224
+ results = run_local_distillation(
1225
+ teacher_models=models_to_distill,
1226
+ enable_training=train,
1227
+ pca_dims=pca_dims,
1228
+ clear_cache=clear_cache,
1229
+ )
1230
+
1231
+ # Handle case where results might be None or invalid
1232
+ if not results or not isinstance(results, dict):
1233
+ logger.error("❌ Distillation workflow failed - no valid results returned")
1234
+ results = {
1235
+ "total_successful": 0,
1236
+ "total_attempted": len(models_to_distill),
1237
+ "error": "Workflow failed",
1238
+ }
1239
+
1240
+ # Final summary
1241
+ successful_count = results.get("total_successful", 0)
1242
+ total_attempted = results.get("total_attempted", 0)
1243
+
1244
+ logger.info("\n🎉 UNIFIED DISTILLATION WORKFLOW COMPLETED!")
1245
+ logger.info(f"📊 Successfully processed: {successful_count}/{total_attempted} models")
1246
+ logger.info(f"📁 Base models saved to: {LOCAL_BASE_DIR}")
1247
+ logger.info(f"📁 Final models saved to: {LOCAL_FINAL_DIR}")
1248
+
1249
+ if train:
1250
+ logger.info("🎓 Advanced training was enabled - models include CodeSearchNet specialization")
1251
+ else:
1252
+ logger.info("📖 Basic distillation only - use --train flag to enable advanced training")
1253
 
 
 
1254
 
1255
+ def check_model_compatibility(teacher_model: str) -> tuple[bool, str | None]:
1256
+ """
1257
+ Check if a model has known compatibility issues with Model2Vec.
1258
+
1259
+ Returns:
1260
+ Tuple of (is_compatible, warning_message)
1261
+ """
1262
+ known_incompatible = {
1263
+ "BAAI/bge-code-v1": "Qwen2Tokenizer lacks backend_tokenizer attribute",
1264
+ "jinaai/jina-embeddings-v3": "Missing custom transformers module dependencies",
1265
+ "Salesforce/SFR-Embedding-Code-2B_R": "Device placement issues with meta tensors",
1266
  }
1267
 
1268
+ if teacher_model in known_incompatible:
1269
+ return False, known_incompatible[teacher_model]
 
 
 
 
 
1270
 
1271
+ # Check for model families that might have issues
1272
+ if "qwen2" in teacher_model.lower() and "bge" in teacher_model.lower():
1273
+ return False, "BGE models with Qwen2 tokenizers may have compatibility issues"
1274
+
1275
+ if "jina" in teacher_model.lower() and "embeddings-v3" in teacher_model.lower():
1276
+ return False, "Jina embeddings v3 models may have missing dependencies"
1277
+
1278
+ if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower():
1279
+ return False, "Salesforce SFR embedding models may have device placement issues"
 
 
 
 
 
 
 
 
 
 
 
1280
 
1281
+ return True, None
 
 
 
 
1282
 
1283
+
1284
+ def clear_model_cache(model_name: str) -> bool:
1285
+ """Clear HuggingFace cache for a specific model."""
1286
  try:
1287
+ import shutil
1288
+ from pathlib import Path
1289
+
1290
+ # Get HuggingFace cache directory
1291
+ cache_dir = Path.home() / ".cache" / "huggingface"
1292
+
1293
+ # Find model-specific cache directories
1294
+ model_slug = model_name.replace("/", "--")
1295
+
1296
+ # Clear transformers cache
1297
+ transformers_cache = cache_dir / "transformers" / model_slug
1298
+ if transformers_cache.exists():
1299
+ shutil.rmtree(transformers_cache)
1300
+ logger.info(f"🗑️ Cleared transformers cache for {model_name}")
1301
+
1302
+ # Clear hub cache
1303
+ hub_cache = cache_dir / "hub" / f"models--{model_slug}"
1304
+ if hub_cache.exists():
1305
+ shutil.rmtree(hub_cache)
1306
+ logger.info(f"🗑️ Cleared hub cache for {model_name}")
1307
+
1308
+ # Clear modules cache
1309
+ modules_cache = cache_dir / "modules" / "transformers_modules" / model_name.split("/")[0]
1310
+ if modules_cache.exists():
1311
+ shutil.rmtree(modules_cache)
1312
+ logger.info(f"🗑️ Cleared modules cache for {model_name}")
1313
+
1314
+ return True
1315
+
1316
  except Exception as e:
1317
+ logger.warning(f"Failed to clear cache for {model_name}: {e}")
1318
+ return False
1319
+
1320
+
1321
+ def try_model_workarounds(teacher_model: str) -> str | None:
1322
+ """
1323
+ Try specific workarounds for problematic models.
1324
+
1325
+ Returns:
1326
+ The type of workaround needed ("salesforce", "baai", etc.) or None if no workaround available
1327
+ """
1328
+ if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower():
1329
+ logger.info("🔧 Salesforce SFR model detected - will use specialized distillation")
1330
+ return "salesforce"
1331
+
1332
+ if "baai" in teacher_model.lower() and ("bge-code" in teacher_model.lower() or "bge-m3" in teacher_model.lower()):
1333
+ logger.info("🔧 BAAI BGE model detected - will use specialized distillation")
1334
+ return "baai"
1335
 
1336
+ return None
1337
 
1338
 
1339
+ def salesforce_model_distillation(
1340
+ teacher_model: str,
1341
+ output_dir: str,
1342
+ pca_dims: int | None = None,
 
 
1343
  ) -> Any:
1344
+ """Special distillation function for Salesforce SFR models that handles device placement issues."""
1345
+ if pca_dims is None:
1346
+ pca_dims = int(distillation_config.optimal_pca_dims)
1347
+
1348
  output_path = Path(output_dir)
1349
  output_path.mkdir(parents=True, exist_ok=True)
1350
 
1351
+ logger.info(f"🔄 Salesforce-specific distillation: {teacher_model} → {output_dir}")
1352
+ logger.info(f"📊 PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1353
 
1354
  start_time = time.time()
1355
 
1356
+ try:
1357
+ import torch
1358
+ from sentence_transformers import SentenceTransformer
1359
+ from transformers import AutoModel, AutoTokenizer
1360
 
1361
+ # Enhanced custom model loading for Salesforce models
1362
+ logger.info("🔧 Loading model with enhanced device settings...")
 
 
 
 
 
 
 
 
1363
 
1364
+ # Method 1: Try with to_empty() for meta tensor handling
1365
+ try:
1366
+ logger.info("🔄 Attempting with to_empty() method...")
 
 
 
 
1367
 
1368
+ # Load tokenizer first
1369
+ tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True)
 
 
 
 
 
 
 
 
1370
 
1371
+ # Load model with meta device initially
1372
+ model = AutoModel.from_pretrained(
1373
+ teacher_model,
 
 
 
1374
  trust_remote_code=True,
1375
+ torch_dtype=torch.float16,
1376
+ device_map="meta", # Load on meta device first
1377
  )
 
1378
 
1379
+ # Move from meta to actual device using to_empty()
1380
+ if torch.cuda.is_available():
1381
+ device = torch.device("cuda")
1382
+ # Create empty tensors on target device and copy weights
1383
+ model = model.to_empty(device=device)
1384
+ else:
1385
+ device = torch.device("cpu")
1386
+ model = model.to_empty(device=device)
1387
 
1388
+ # Ensure model is in the right dtype
1389
+ model = model.to(torch.float16 if torch.cuda.is_available() else torch.float32)
1390
 
1391
+ logger.info("✅ Successfully loaded with to_empty() method")
 
 
1392
 
1393
+ except Exception as e:
1394
+ logger.warning(f"to_empty() method failed: {e}")
1395
 
1396
+ # Method 2: Try SentenceTransformer with specific settings
1397
+ logger.info("🔄 Falling back to SentenceTransformer method...")
1398
+ sentence_model = SentenceTransformer(
1399
+ teacher_model,
 
 
1400
  trust_remote_code=True,
1401
+ device="cpu", # Force CPU loading first
1402
  )
 
1403
 
1404
+ # Move to GPU if available
1405
+ if torch.cuda.is_available():
1406
+ sentence_model = sentence_model.to("cuda")
 
1407
 
1408
+ # Extract components
1409
+ model = sentence_model[0].auto_model
1410
+ tokenizer = sentence_model.tokenizer
1411
 
1412
+ logger.info("✅ Successfully loaded with SentenceTransformer method")
 
1413
 
1414
+ # Now use Model2Vec's distill_from_model function directly
1415
+ from model2vec.distill.distillation import distill_from_model
 
1416
 
1417
+ distilled_model = distill_from_model(
1418
+ model=model,
1419
+ tokenizer=tokenizer,
1420
+ pca_dims=int(pca_dims),
1421
+ apply_zipf=bool(distillation_config.apply_zipf),
1422
+ sif_coefficient=float(distillation_config.sif_coefficient),
1423
+ )
1424
 
1425
+ logger.info("✅ Core distillation completed successfully")
 
 
1426
 
1427
+ # Save the model
1428
+ distilled_model.save_pretrained(str(output_path))
1429
+ logger.info(f"💾 Model saved to {output_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1430
 
1431
+ # Log model info
1432
+ logger.info(f"Model type: {type(distilled_model)}")
1433
+ if hasattr(distilled_model, "embedding"):
1434
+ logger.info(f"Embedding shape: {distilled_model.embedding.shape}")
1435
+ logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}")
1436
 
1437
+ total_time = time.time() - start_time
1438
+ logger.info(f"🎉 Salesforce distillation completed in {total_time:.2f} seconds")
 
 
1439
 
1440
+ # Clean up
1441
+ if "sentence_model" in locals():
1442
+ del sentence_model
1443
+ del model
1444
+ if torch.cuda.is_available():
1445
+ torch.cuda.empty_cache()
 
 
 
1446
 
1447
+ return distilled_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1448
 
1449
+ except Exception:
1450
+ logger.exception(f"❌ Salesforce-specific distillation failed for {teacher_model}")
1451
+ return None
 
 
 
 
 
 
1452
 
 
 
 
 
 
 
 
 
 
 
1453
 
1454
+ def baai_bge_model_distillation(
1455
+ teacher_model: str,
1456
+ output_dir: str,
1457
+ pca_dims: int | None = None,
1458
+ ) -> Any:
1459
+ """Special distillation function for BAAI BGE models that handles Qwen2Tokenizer compatibility issues."""
1460
+ if pca_dims is None:
1461
+ pca_dims = int(distillation_config.optimal_pca_dims)
1462
 
1463
+ output_path = Path(output_dir)
1464
+ output_path.mkdir(parents=True, exist_ok=True)
 
 
1465
 
1466
+ logger.info(f"🔄 BAAI BGE-specific distillation: {teacher_model} → {output_dir}")
1467
+ logger.info(f"📊 PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}")
 
 
 
 
 
 
 
1468
 
1469
+ start_time = time.time()
 
 
1470
 
1471
+ try:
1472
+ import torch
1473
+ from sentence_transformers import SentenceTransformer
1474
+ from transformers import AutoModel, AutoTokenizer
1475
 
1476
+ logger.info("🔧 Loading BAAI model with tokenizer workaround...")
 
1477
 
1478
+ # Try multiple approaches for BAAI models
1479
+ success = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1480
 
1481
+ # Method 1: Try SentenceTransformer first (often handles tokenizer issues better)
1482
+ try:
1483
+ logger.info("🔄 Attempting with SentenceTransformer wrapper...")
1484
+ sentence_model = SentenceTransformer(teacher_model, trust_remote_code=True)
 
 
 
 
 
1485
 
1486
+ # Extract components
1487
+ model = sentence_model[0].auto_model
1488
+ tokenizer = sentence_model.tokenizer
1489
 
1490
+ # Test if tokenizer works by encoding a simple text
1491
+ test_encoding = tokenizer.encode("test", return_tensors="pt")
1492
+ logger.info("✅ SentenceTransformer method successful")
1493
+ success = True
1494
 
1495
+ except Exception as e:
1496
+ logger.warning(f"SentenceTransformer method failed: {e}")
1497
 
1498
+ # Method 2: Try direct loading with tokenizer replacement
1499
+ try:
1500
+ logger.info("🔄 Attempting with tokenizer replacement...")
1501
+ from transformers import BertTokenizerFast
1502
 
1503
+ # Load model directly
1504
+ model = AutoModel.from_pretrained(teacher_model, trust_remote_code=True)
 
1505
 
1506
+ # Try to use a compatible tokenizer instead
1507
+ try:
1508
+ # First try the original tokenizer
1509
+ tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True)
1510
+ except Exception:
1511
+ # Fall back to BERT tokenizer for BGE models
1512
+ logger.info("🔄 Falling back to BERT tokenizer...")
1513
+ tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
1514
+
1515
+ logger.info("✅ Tokenizer replacement method successful")
1516
+ success = True
1517
+
1518
+ except Exception as e2:
1519
+ logger.warning(f"Tokenizer replacement method failed: {e2}")
1520
+
1521
+ if not success:
1522
+ logger.error("❌ All BAAI model loading methods failed")
1523
+ return None
1524
+
1525
+ # Now use Model2Vec's distill_from_model function directly
1526
+ from model2vec.distill.distillation import distill_from_model
1527
+
1528
+ distilled_model = distill_from_model(
1529
+ model=model,
1530
+ tokenizer=tokenizer,
1531
+ pca_dims=int(pca_dims),
1532
+ apply_zipf=bool(distillation_config.apply_zipf),
1533
+ sif_coefficient=float(distillation_config.sif_coefficient),
1534
+ )
1535
 
1536
+ logger.info("✅ Core distillation completed successfully")
1537
 
1538
+ # Save the model
1539
+ distilled_model.save_pretrained(str(output_path))
1540
+ logger.info(f"💾 Model saved to {output_path}")
1541
 
1542
+ # Log model info
1543
+ logger.info(f"Model type: {type(distilled_model)}")
1544
+ if hasattr(distilled_model, "embedding"):
1545
+ logger.info(f"Embedding shape: {distilled_model.embedding.shape}")
1546
+ logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}")
1547
 
1548
+ total_time = time.time() - start_time
1549
+ logger.info(f"🎉 BAAI BGE distillation completed in {total_time:.2f} seconds")
1550
 
1551
+ # Clean up
1552
+ if "sentence_model" in locals():
1553
+ del sentence_model
1554
+ del model
1555
+ if torch.cuda.is_available():
1556
+ torch.cuda.empty_cache()
1557
 
1558
+ return distilled_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1559
 
1560
+ except Exception:
1561
+ logger.exception(f"❌ BAAI BGE-specific distillation failed for {teacher_model}")
1562
+ return None
 
 
 
 
 
 
 
 
 
 
1563
 
1564
 
1565
  if __name__ == "__main__":
1566
+ typer.run(main)
src/distiller/distill_simplified.py DELETED
@@ -1,413 +0,0 @@
1
- """
2
- Simplified Code-Specialized Model2Vec Distillation Script.
3
-
4
- This script implements a focused, simplified approach for creating code-specialized embeddings
5
- using only the core Model2Vec distillation without additional fine-tuning that may degrade quality.
6
-
7
- Can run locally or on Beam with the --use-beam flag.
8
- """
9
-
10
- import argparse
11
- import json
12
- import logging
13
- import sys
14
- import time
15
- from pathlib import Path
16
- from typing import Any
17
-
18
- from beam import GpuType, Image, Volume, function
19
- from model2vec.distill import distill
20
-
21
- # =============================================================================
22
- # SIMPLIFIED CONFIGURATION
23
- # =============================================================================
24
-
25
- # Use a code-specialized teacher model instead of general instruction model
26
- # Ordered by success likelihood and performance:
27
- CODE_TEACHER_MODELS = [
28
- "sentence-transformers/all-MiniLM-L6-v2",
29
- "sentence-transformers/all-mpnet-base-v2",
30
- "microsoft/codebert-base",
31
- "microsoft/graphcodebert-base",
32
- "sentence-transformers/paraphrase-MiniLM-L6-v2",
33
- "Alibaba-NLP/gte-Qwen2-7B-instruct",
34
- ]
35
-
36
- OUTPUT_BASE_DIR = "code_model2vec"
37
-
38
- # Optimal Model2Vec parameters based on successful models
39
- OPTIMAL_PCA_DIMS = 256 # Match other successful Model2Vec models
40
- SIF_COEFFICIENT = 1e-3 # Slightly higher than default for code specialization
41
- APPLY_ZIPF = True # Enable Zipf weighting for better word importance
42
-
43
- # =============================================================================
44
- # BEAM CONFIGURATION
45
- # =============================================================================
46
-
47
- GPU_NAME = GpuType.A100_40
48
- VOLUME_NAME = "code_model2vec"
49
- VOLUME_PATH = "./code_model2vec"
50
- IMAGE = Image(python_version="python3.12").add_python_packages(
51
- [
52
- "torch>=2.7.0", # Install torch first
53
- "transformers>=4.40.0", # Latest transformers with flash attention support
54
- "lightning>=2.5.1.post0",
55
- "model2vec[train]>=0.5.0",
56
- "numpy>=1.26.4",
57
- "scikit-learn>=1.6.1",
58
- "sentence-transformers>=4.1.0",
59
- "datasets>=3.2.0", # For evaluation
60
- "pandas>=2.0.0",
61
- "tqdm>=4.65.0",
62
- ]
63
- )
64
-
65
- # =============================================================================
66
-
67
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
68
- logger = logging.getLogger(__name__)
69
-
70
- # Add beam utilities for proper model persistence
71
- try:
72
- from .beam_utils import (
73
- create_beam_utilities,
74
- )
75
-
76
- BEAM_UTILS_AVAILABLE = True
77
- except ImportError:
78
- print("Beam utilities not available - models will only be saved locally")
79
- BEAM_UTILS_AVAILABLE = False
80
-
81
-
82
- def apply_local_patches() -> bool:
83
- """Apply patches locally without requiring Beam utilities."""
84
- try:
85
- # Try using patch_utils if available
86
- try:
87
- from .patch_utils import apply_all_patches
88
-
89
- patches_applied = apply_all_patches()
90
- logger.info(f"Successfully applied {patches_applied} patches via patch_utils")
91
- return True
92
- except ImportError:
93
- logger.warning("patch_utils not available, trying direct patching")
94
-
95
- return False
96
-
97
- except Exception as e:
98
- logger.warning(f"Failed to apply patches: {e}")
99
- return False
100
-
101
-
102
- def simplified_code_distillation(
103
- teacher_model: str,
104
- output_dir: str,
105
- pca_dims: int = OPTIMAL_PCA_DIMS,
106
- ) -> Any:
107
- """
108
- Simplified code-specialized distillation using only core Model2Vec.
109
-
110
- This approach:
111
- 1. Uses a teacher model that already performs well on code tasks
112
- 2. Applies optimal Model2Vec parameters
113
- 3. Avoids additional training that may degrade quality
114
- """
115
- output_path = Path(output_dir)
116
- output_path.mkdir(parents=True, exist_ok=True)
117
-
118
- logger.info(f"Starting simplified distillation from {teacher_model}")
119
- logger.info(f"Target dimensions: {pca_dims}")
120
- logger.info(f"SIF coefficient: {SIF_COEFFICIENT}")
121
- logger.info(f"Zipf weighting: {APPLY_ZIPF}")
122
-
123
- start_time = time.time()
124
-
125
- try:
126
- # Perform distillation with optimal parameters
127
- model = distill(
128
- model_name=teacher_model,
129
- pca_dims=pca_dims,
130
- apply_zipf=APPLY_ZIPF,
131
- sif_coefficient=SIF_COEFFICIENT,
132
- trust_remote_code=True,
133
- )
134
-
135
- logger.info("✅ Core distillation completed successfully")
136
-
137
- # Save the model
138
- model.save_pretrained(str(output_path))
139
- logger.info(f"💾 Model saved to {output_path}")
140
-
141
- # Log model info
142
- logger.info(f"Model type: {type(model)}")
143
- if hasattr(model, "embedding"):
144
- logger.info(f"Embedding shape: {model.embedding.shape}")
145
- logger.info(f"Embedding dtype: {model.embedding.dtype}")
146
-
147
- total_time = time.time() - start_time
148
- logger.info(f"🎉 Simplified distillation completed in {total_time:.2f} seconds")
149
- return model
150
-
151
- except ValueError as e:
152
- if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
153
- logger.warning(f"⚠️ Token-vector mismatch with {teacher_model} - this is a Model2Vec library issue")
154
- logger.warning(f"Error details: {e}")
155
- logger.warning("💡 This model has incompatible tokenization. Skipping...")
156
- return None
157
- raise
158
- except Exception:
159
- logger.exception("❌ Distillation failed")
160
- return None
161
-
162
-
163
- def core_distill_all_teachers(use_beam_utilities: bool = False) -> dict[str, Any]:
164
- """
165
- Core logic for distilling all teacher models.
166
-
167
- Args:
168
- use_beam_utilities: Whether to use Beam utilities for persistence
169
-
170
- Returns:
171
- Dictionary with distillation results
172
- """
173
- # Apply patches
174
- logger.info("Applying all patches...")
175
- patch_success = apply_local_patches()
176
- if patch_success:
177
- logger.info("Successfully applied patches")
178
- else:
179
- logger.warning("Failed to apply patches - Microsoft models may fail")
180
-
181
- # Initialize Beam utilities if requested and available
182
- volume_mgr = None
183
- model_mgr = None
184
- if use_beam_utilities and BEAM_UTILS_AVAILABLE:
185
- try:
186
- volume_mgr, _, model_mgr, _ = create_beam_utilities(VOLUME_NAME, VOLUME_PATH)
187
- logger.info("✅ Beam utilities initialized for model persistence")
188
- except Exception as e:
189
- logger.warning(f"Failed to initialize Beam utilities: {e}")
190
- model_mgr = None
191
-
192
- results = {}
193
- successful_models = []
194
-
195
- logger.info("🚀 Starting comprehensive teacher model distillation")
196
- logger.info(f"📊 Processing {len(CODE_TEACHER_MODELS)} teacher models")
197
-
198
- # Determine output base path
199
- base_output_path = VOLUME_PATH if use_beam_utilities else OUTPUT_BASE_DIR
200
-
201
- for teacher_model in CODE_TEACHER_MODELS:
202
- try:
203
- # Create output directory name based on teacher model
204
- teacher_name = teacher_model.split("/")[-1].replace("-", "_")
205
- output_dir = f"{base_output_path}/final/code_model2vec_{teacher_name}"
206
-
207
- logger.info(f"\n{'=' * 60}")
208
- logger.info(f"🔄 Processing teacher model: {teacher_model}")
209
- logger.info(f"📁 Output directory: {output_dir}")
210
- logger.info(f"{'=' * 60}")
211
-
212
- # Check if model already exists
213
- output_path = Path(output_dir)
214
- if output_path.exists():
215
- # Check for essential model files
216
- has_config = (output_path / "config.json").exists()
217
- has_model_file = any(
218
- [
219
- (output_path / "model.safetensors").exists(),
220
- (output_path / "model.bin").exists(),
221
- (output_path / "pytorch_model.bin").exists(),
222
- ]
223
- )
224
-
225
- if has_config and has_model_file:
226
- logger.info(f"✅ Model {teacher_name} already exists - skipping distillation")
227
-
228
- # Still record it as successful
229
- model_info = {
230
- "teacher_model": teacher_model,
231
- "output_dir": output_dir,
232
- "teacher_name": teacher_name,
233
- "distillation_time": 0.0,
234
- "status": "skipped_existing",
235
- }
236
-
237
- results[teacher_name] = model_info
238
- successful_models.append(teacher_name)
239
- logger.info(f"📁 Using existing model at: {output_dir}")
240
- continue
241
-
242
- # Perform distillation
243
- start_time = time.time()
244
- model = simplified_code_distillation(
245
- teacher_model=teacher_model,
246
- output_dir=output_dir,
247
- )
248
- distill_time = time.time() - start_time
249
-
250
- if model is not None:
251
- logger.info(f"✅ Distillation successful for {teacher_model}")
252
-
253
- # Save to Beam volume for persistence if available
254
- if model_mgr:
255
- try:
256
- # Save model to beam volume with teacher-specific name
257
- beam_model_name = f"{teacher_name}_model"
258
- model_mgr.save_model(beam_model_name, output_dir)
259
- logger.info(f"💾 Saved {teacher_name} to Beam volume as {beam_model_name}")
260
- except Exception as e:
261
- logger.warning(f"Failed to save {teacher_name} to Beam volume: {e}")
262
-
263
- # Store results
264
- model_info = {
265
- "teacher_model": teacher_model,
266
- "output_dir": output_dir,
267
- "teacher_name": teacher_name,
268
- "distillation_time": distill_time,
269
- "status": "success",
270
- }
271
-
272
- results[teacher_name] = model_info
273
- successful_models.append(teacher_name)
274
-
275
- logger.info(f"💾 Model saved to: {output_dir}")
276
-
277
- except Exception as e:
278
- logger.exception(f"❌ Failed with {teacher_model}")
279
- results[teacher_model.split("/")[-1]] = {
280
- "teacher_model": teacher_model,
281
- "status": "failed",
282
- "error": str(e),
283
- }
284
- continue
285
-
286
- # Summary
287
- if successful_models:
288
- logger.info("\n🏆 DISTILLATION COMPLETE!")
289
- logger.info(f"📊 Successful models: {len(successful_models)}")
290
-
291
- for model_name in successful_models:
292
- model_info = results[model_name]
293
- logger.info(f"✅ {model_name}: {model_info['teacher_model']}")
294
-
295
- # Save comprehensive results
296
- results_summary = {
297
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
298
- "successful_models": successful_models,
299
- "all_results": results,
300
- "total_successful": len(successful_models),
301
- "total_attempted": len(CODE_TEACHER_MODELS),
302
- }
303
-
304
- # Save results to file
305
- results_file = Path(f"{base_output_path}/distillation_results.json")
306
- results_file.parent.mkdir(parents=True, exist_ok=True)
307
- with results_file.open("w") as f:
308
- json.dump(results_summary, f, indent=2)
309
-
310
- logger.info(f"📊 Results summary saved to: {results_file}")
311
-
312
- return results_summary
313
-
314
- logger.error("❌ No models succeeded")
315
- msg = "All teacher models failed distillation"
316
- raise RuntimeError(msg)
317
-
318
-
319
- def run_local_distillation() -> dict[str, Any]:
320
- """Run distillation locally without Beam."""
321
- logger.info("🖥️ Running simplified distillation locally")
322
- return core_distill_all_teachers(use_beam_utilities=False)
323
-
324
-
325
- @function(
326
- gpu=GPU_NAME,
327
- volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
328
- image=IMAGE,
329
- secrets=["HF_ACCESS_TOKEN"],
330
- env={
331
- "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,max_split_size_mb:512",
332
- "TOKENIZERS_PARALLELISM": "false",
333
- "CUDA_LAUNCH_BLOCKING": "0", # Allow async CUDA operations
334
- "TORCH_CUDNN_V8_API_ENABLED": "1", # Enable optimized cuDNN
335
- },
336
- timeout=3600 * 12, # 12 hours
337
- )
338
- def beam_distill_all_teachers() -> dict[str, Any]:
339
- """
340
- Beam version: Try all teacher models and create distilled models from each.
341
-
342
- Returns information about all models that were successfully created.
343
- """
344
- logger.info("☁️ Running simplified distillation on Beam")
345
- return core_distill_all_teachers(use_beam_utilities=True)
346
-
347
-
348
- def main() -> None:
349
- """Main function with argument parsing."""
350
- global OUTPUT_BASE_DIR # Declare global at the top # noqa: PLW0603
351
-
352
- parser = argparse.ArgumentParser(
353
- description="Simplified Code-Specialized Model2Vec Distillation",
354
- formatter_class=argparse.RawDescriptionHelpFormatter,
355
- epilog="""
356
- Examples:
357
- python -m src.distiller.distill_simplified # Run locally
358
- python -m src.distiller.distill_simplified --use-beam # Run on Beam
359
- distiller distill-simple # CLI shortcut (runs on Beam)
360
- """,
361
- )
362
-
363
- parser.add_argument(
364
- "--use-beam",
365
- action="store_true",
366
- help="Run on Beam instead of locally",
367
- )
368
-
369
- parser.add_argument(
370
- "--output-dir",
371
- type=str,
372
- default=OUTPUT_BASE_DIR,
373
- help=f"Output directory for models (default: {OUTPUT_BASE_DIR})",
374
- )
375
-
376
- args = parser.parse_args()
377
-
378
- # Update output directory if specified
379
- if args.output_dir != OUTPUT_BASE_DIR:
380
- OUTPUT_BASE_DIR = args.output_dir
381
-
382
- try:
383
- if args.use_beam:
384
- logger.info("🚀 Starting Beam execution...")
385
- results = beam_distill_all_teachers()
386
- else:
387
- logger.info("🖥️ Starting local execution...")
388
- results = run_local_distillation()
389
-
390
- # Print final summary
391
- print("\n🎉 Distillation complete!")
392
- print(f"📊 Successfully created {results['total_successful']} models")
393
-
394
- if args.use_beam:
395
- print(f"📁 Models location: {VOLUME_PATH}/final/")
396
- else:
397
- print(f"📁 Models location: {OUTPUT_BASE_DIR}/final/")
398
-
399
- print("\n✅ Created models:")
400
- for model_name in results["successful_models"]:
401
- model_info = results["all_results"][model_name]
402
- print(f" • {model_name} (from {model_info['teacher_model']})")
403
-
404
- except KeyboardInterrupt:
405
- logger.info("🛑 Distillation interrupted by user")
406
- sys.exit(1)
407
- except Exception:
408
- logger.exception("❌ Distillation failed with error")
409
- sys.exit(1)
410
-
411
-
412
- if __name__ == "__main__":
413
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/distiller/evaluate.py CHANGED
@@ -1,130 +1,405 @@
1
  """
2
- CodeSearchNet Evaluation Script for Code-Specialized Embedding Models.
3
 
4
- This script evaluates embedding models on code search tasks using the CodeSearchNet
5
- dataset and methodology. It implements the same evaluation approach as the original
6
- CodeSearchNet challenge, including NDCG and other information retrieval metrics.
 
 
 
 
 
 
 
 
7
 
8
  Usage:
9
- distiller evaluate # Run evaluation on all default models with Beam
10
  """
11
 
12
  import json
13
  import logging
14
  import time
 
15
  from pathlib import Path
16
  from typing import Any
17
 
18
  import numpy as np
19
  import pandas as pd
20
- from beam import GpuType, Image, Volume, function
 
 
 
21
  from datasets import Dataset, load_dataset
22
  from sentence_transformers import SentenceTransformer
23
  from sklearn.metrics.pairwise import cosine_similarity
24
  from tqdm import tqdm
25
 
26
- from .beam_utils import (
27
- BeamCheckpointManager,
28
- BeamEvaluationManager,
29
- create_beam_utilities,
 
 
 
 
 
 
 
30
  )
31
 
32
- # Configure logging
33
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
34
  logger = logging.getLogger(__name__)
35
 
36
  # =============================================================================
37
- # BEAM CONFIGURATION
38
  # =============================================================================
39
 
40
- GPU_NAME = GpuType.A100_40
41
- VOLUME_NAME = "code_model2vec" # Same volume as distill_simplified.py
42
- VOLUME_PATH = "./code_model2vec" # Same mount path as distill_simplified.py
43
- EVALUATION_RESULTS_DIR = "evaluation_results" # Subdirectory within volume
44
- EVALUATION_CACHE_DIR = "evaluation_cache" # Cache for datasets and models
45
-
46
- IMAGE = Image(python_version="python3.12").add_python_packages(
47
- [
48
- "torch>=2.7.0",
49
- "transformers>=4.40.0",
50
- "datasets>=3.2.0",
51
- "sentence-transformers>=4.1.0",
52
- "model2vec[train]>=0.5.0",
53
- "numpy>=1.26.4",
54
- "scikit-learn>=1.6.1",
55
- "pandas>=2.0.0",
56
- "tqdm>=4.65.0",
57
- ]
58
- )
59
 
60
  # =============================================================================
61
- # CONFIGURATION
62
  # =============================================================================
63
 
64
- CODESEARCHNET_EVAL_DATASET = "code_search_net"
65
- BATCH_SIZE = 32
66
- DEFAULT_OUTPUT_DIR = "code_evaluation_results" # Local fallback directory
67
- EVALUATION_LANGUAGES = ["python", "javascript", "java", "php", "ruby", "go"]
68
-
69
- # Default models to evaluate (can be overridden via command line)
70
- DEFAULT_EVALUATION_MODELS = [
71
- # Established Code Models
72
- "sentence-transformers/all-MiniLM-L6-v2",
73
- "microsoft/codebert-base",
74
- "microsoft/graphcodebert-base",
75
- "huggingface/CodeBERTa-small-v1",
76
- "sentence-transformers/all-mpnet-base-v2",
77
- "sentence-transformers/all-MiniLM-L12-v2",
78
- # Model2Vec & Efficiency Models (Direct Competitors)
79
- "minishlab/potion-base-8M",
80
- "minishlab/potion-retrieval-32M",
81
- # Small Transformer-Based Code Models
82
- "Salesforce/codet5-base",
83
- ]
 
 
 
 
 
 
84
 
85
- # =============================================================================
86
- # CHECKPOINT CONFIGURATION
87
- # =============================================================================
88
 
89
- # Prevent conflicts with distill.py checkpoints by using different prefixes
90
- EVAL_CHECKPOINT_PREFIX = "evaluation_checkpoints"
91
- DATASET_CHECKPOINT_PREFIX = "dataset_cache"
92
- MODEL_CACHE_PREFIX = "model_cache"
93
 
94
- # =============================================================================
95
- # CORE EVALUATION CLASSES
96
- # =============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
  class CodeSearchNetEvaluator:
100
  """Evaluator for CodeSearchNet-style code search tasks."""
101
 
102
- def __init__(
103
- self,
104
- model_path: str,
105
- model_name: str | None = None,
106
- checkpoint_manager: BeamCheckpointManager | None = None,
107
- eval_manager: BeamEvaluationManager | None = None,
108
- ) -> None:
109
- """Initialize the evaluator with a model and optional Beam utilities."""
110
  self.model_path = model_path
111
  self.model_name = model_name or Path(model_path).name
112
  self.model: SentenceTransformer | None = None
113
- self.checkpoint_manager = checkpoint_manager
114
- self.eval_manager = eval_manager
115
  self._load_model()
116
 
117
  def _load_model(self) -> None:
118
- """Load the embedding model with caching support."""
119
  logger.info(f"Loading model from {self.model_path}")
120
-
121
- # Check if we have a cached evaluation result for this model
122
- if self.eval_manager:
123
- cached_result = self.eval_manager.load_evaluation_results(self.model_name)
124
- if cached_result:
125
- logger.info(f"✅ Found cached evaluation results for {self.model_name}")
126
- # Note: We still need to load the model for new evaluations
127
-
128
  try:
129
  self.model = SentenceTransformer(self.model_path, trust_remote_code=True)
130
  logger.info(f"Successfully loaded model: {self.model_name}")
@@ -139,7 +414,6 @@ class CodeSearchNetEvaluator:
139
  raise RuntimeError(msg)
140
 
141
  embeddings = []
142
-
143
  for i in tqdm(range(0, len(texts), BATCH_SIZE), desc=desc):
144
  batch = texts[i : i + BATCH_SIZE]
145
  batch_embeddings = self.model.encode(batch, convert_to_tensor=False, normalize_embeddings=True)
@@ -148,33 +422,25 @@ class CodeSearchNetEvaluator:
148
  return np.vstack(embeddings)
149
 
150
  def evaluate_language(self, language: str, max_queries: int = 1000) -> dict[str, Any]:
151
- """Evaluate on a specific programming language with checkpoint support."""
152
  logger.info(f"Evaluating on {language} language (max {max_queries} queries)")
153
 
154
- # Check for existing evaluation checkpoint
155
- if self.checkpoint_manager:
156
- cached_result = self.checkpoint_manager.load_checkpoint(f"{EVAL_CHECKPOINT_PREFIX}_{language}", 0)
157
- if cached_result and cached_result.get("data", {}).get("model_name") == self.model_name:
158
- logger.info(f"✅ Resuming from cached {language} evaluation")
159
- return cached_result.get("data", {})
160
-
161
  try:
162
  # Load test split for the language
163
  dataset = load_dataset(
164
- CODESEARCHNET_EVAL_DATASET,
165
  language,
166
  split="test",
167
  trust_remote_code=True,
168
  )
169
 
170
- # Ensure we have a Dataset object
171
  if not isinstance(dataset, Dataset):
172
  logger.error(f"Unexpected dataset type for {language}: {type(dataset)}")
173
  return {}
174
 
175
- # Sample queries for evaluation (to make it manageable)
176
  if len(dataset) > max_queries:
177
- rng = np.random.default_rng(42) # Use seeded generator for reproducibility
178
  indices = rng.choice(len(dataset), max_queries, replace=False)
179
  dataset = dataset.select(indices)
180
 
@@ -198,642 +464,680 @@ class CodeSearchNetEvaluator:
198
  logger.info(f"Found {len(queries)} valid query-code pairs for {language}")
199
 
200
  # Encode queries and codes
 
201
  query_embeddings = self.encode_texts(queries, f"Encoding {language} queries")
202
- code_embeddings = self.encode_texts(codes, f"Encoding {language} codes")
 
203
 
204
- # Compute similarities
205
  similarities = cosine_similarity(query_embeddings, code_embeddings)
206
-
207
- # Evaluate retrieval metrics
208
  metrics = self._compute_retrieval_metrics(similarities)
209
 
210
- result = {
 
211
  "language": language,
 
212
  "num_queries": len(queries),
 
213
  "metrics": metrics,
214
- "model_name": self.model_name,
215
  }
216
 
217
- # Save checkpoint
218
- if self.checkpoint_manager:
219
- checkpoint_data = {
220
- "data": result,
221
- "timestamp": time.time(),
222
- "config": {
223
- "language": language,
224
- "max_queries": max_queries,
225
- "model_name": self.model_name,
226
- },
227
- }
228
- self.checkpoint_manager.save_checkpoint(f"{EVAL_CHECKPOINT_PREFIX}_{language}", checkpoint_data, 0)
229
- logger.info(f"💾 Saved {language} evaluation checkpoint")
230
-
231
- return result
232
 
233
  except Exception:
234
- logger.exception(f"Error evaluating {language}")
235
  return {}
236
 
237
  def _compute_retrieval_metrics(self, similarities: np.ndarray) -> dict[str, float]:
238
- """Compute retrieval metrics like NDCG, MRR, etc."""
239
- num_queries = similarities.shape[0]
240
 
241
- # For each query, the correct code is at the same index (diagonal)
242
- ranks = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  reciprocal_ranks = []
244
- ndcg_scores = []
245
-
246
- for i in range(num_queries):
247
- # Get similarity scores for query i
248
- scores = similarities[i]
249
-
250
- # Rank all codes by similarity to query i
251
- ranked_indices = np.argsort(scores)[::-1] # Descending order
252
-
253
- # Find rank of the correct code (index i)
254
- correct_rank = np.where(ranked_indices == i)[0][0] + 1 # 1-indexed
255
- ranks.append(correct_rank)
256
- reciprocal_ranks.append(1.0 / correct_rank)
257
-
258
- # Compute NDCG@10
259
- ndcg_scores.append(self._compute_ndcg(ranked_indices, i, k=10))
260
-
261
- return {
262
- "mrr": float(np.mean(reciprocal_ranks)),
263
- "ndcg@1": float(
264
- np.mean([self._compute_ndcg(np.argsort(similarities[i])[::-1], i, k=1) for i in range(num_queries)])
265
- ),
266
- "ndcg@5": float(
267
- np.mean([self._compute_ndcg(np.argsort(similarities[i])[::-1], i, k=5) for i in range(num_queries)])
268
- ),
269
- "ndcg@10": float(np.mean(ndcg_scores)),
270
- "recall@1": float(np.mean([1.0 if rank == 1 else 0.0 for rank in ranks])),
271
- "recall@5": float(np.mean([1.0 if rank <= 5 else 0.0 for rank in ranks])),
272
- "recall@10": float(np.mean([1.0 if rank <= 10 else 0.0 for rank in ranks])),
273
- "mean_rank": float(np.mean(ranks)),
274
- "median_rank": float(np.median(ranks)),
275
- }
276
 
277
  def _compute_ndcg(self, ranked_indices: np.ndarray, correct_idx: int, k: int) -> float:
278
  """Compute NDCG@k for a single query."""
279
- if k == 0:
280
- return 0.0
281
-
282
- # Find position of correct item in top-k
283
- top_k = ranked_indices[:k]
284
- if correct_idx in top_k:
285
- position = np.where(top_k == correct_idx)[0][0]
286
- return 1.0 / np.log2(position + 2) # +2 because log2(1) is 0
287
  return 0.0
288
 
289
  def evaluate_all_languages(
290
  self, max_queries_per_lang: int = 1000, languages: list[str] | None = None
291
  ) -> dict[str, Any]:
292
- """Evaluate on all supported programming languages with comprehensive result saving."""
293
- if languages is None:
294
- languages = EVALUATION_LANGUAGES
295
-
296
- logger.info(f"Starting evaluation on all languages for model: {self.model_name}")
297
 
298
- # Check for existing comprehensive evaluation results
299
- if self.eval_manager:
300
- cached_comprehensive = self.eval_manager.load_evaluation_results(self.model_name)
301
- if cached_comprehensive:
302
- logger.info(f"✅ Found comprehensive cached evaluation for {self.model_name}")
303
- return cached_comprehensive
304
 
305
  start_time = time.time()
306
-
307
- results: dict[str, Any] = {
308
  "model_name": self.model_name,
309
  "model_path": self.model_path,
310
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
311
  "languages": {},
312
  "overall": {},
 
313
  }
 
314
 
315
- all_metrics = []
 
 
 
 
316
 
317
- for language in languages:
318
- logger.info(f"Evaluating {language}...")
319
  lang_results = self.evaluate_language(language, max_queries_per_lang)
320
-
321
  if lang_results:
322
- results["languages"][language] = lang_results
323
- all_metrics.append(lang_results["metrics"])
324
- else:
325
- logger.warning(f"Skipping {language} due to evaluation error")
326
 
327
- # Compute overall metrics (average across languages)
328
- if all_metrics:
 
 
329
  overall_metrics = {}
330
- for metric_name in all_metrics[0]:
331
- values = [m[metric_name] for m in all_metrics if metric_name in m]
332
- overall_metrics[metric_name] = np.mean(values)
 
 
333
 
334
  results["overall"] = overall_metrics
335
 
336
  total_time = time.time() - start_time
337
  results["evaluation_time_seconds"] = total_time
338
 
339
- # Save comprehensive results to Beam volume
340
- if self.eval_manager:
341
- self.eval_manager.save_evaluation_results(self.model_name, results)
342
- logger.info("💾 Saved comprehensive evaluation results to Beam volume")
343
-
344
  logger.info(f"Evaluation completed in {total_time:.2f} seconds")
345
  return results
346
 
347
 
348
- def load_peer_models(peers_file: str) -> list[tuple[str, str]]:
349
- """Load peer models from CSV file."""
350
- try:
351
- df = pd.read_csv(peers_file)
352
- models = []
353
- for _, row in df.iterrows():
354
- model_name = row.get("model_name", row.get("Model", ""))
355
- model_path = row.get("model_path", row.get("Path", model_name))
356
- if model_name:
357
- models.append((model_name, model_path))
358
- logger.info(f"Loaded {len(models)} peer models from {peers_file}")
359
- return models
360
- except Exception:
361
- logger.exception("Error loading peer models from {peers_file}")
362
- return []
363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
- def save_results(
366
- results: dict[str, Any],
367
- output_dir: str,
368
- model_name: str,
369
- eval_manager: BeamEvaluationManager | None = None,
370
- volume_results_dir: Path | None = None,
371
- ) -> None:
372
- """Save evaluation results to JSON file with Beam volume support."""
373
- # Save to Beam volume if available
374
- if volume_results_dir:
375
- volume_output_path = volume_results_dir / f"codesearchnet_eval_{model_name}.json"
376
  try:
377
- with volume_output_path.open("w") as f:
378
- json.dump(results, f, indent=2, default=str)
379
- logger.info(f"💾 Results saved to Beam volume: {volume_output_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  except Exception as e:
381
- logger.warning(f"⚠️ Failed to save to Beam volume: {e}")
 
382
 
383
- # Also try eval_manager if available (for compatibility)
384
- if eval_manager:
385
- success = eval_manager.save_evaluation_results(model_name, results)
386
- if success:
387
- logger.info(f"💾 Results also saved via eval_manager for {model_name}")
388
- else:
389
- logger.warning(f"⚠️ Failed to save via eval_manager for {model_name}")
390
 
391
- # Always save local backup
392
- output_path = Path(output_dir)
393
- output_path.mkdir(parents=True, exist_ok=True)
394
 
395
- # Clean model name for filename
396
- safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
397
- filename = f"codesearchnet_eval_{safe_name}.json"
398
- filepath = output_path / filename
 
399
 
400
- with Path(filepath).open("w") as f:
401
- json.dump(results, f, indent=2, default=str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
- logger.info(f"📄 Local backup saved to {filepath}")
404
 
 
 
 
405
 
406
- def print_results_summary(results: dict[str, Any]) -> None:
407
- """Print a summary of evaluation results."""
408
- model_name = results["model_name"]
409
- overall = results.get("overall", {})
410
 
411
- print(f"\n{'=' * 60}")
412
- print(f"CodeSearchNet Evaluation Results: {model_name}")
413
- print(f"{'=' * 60}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  if overall:
416
- print("\nOverall Metrics (averaged across languages):")
417
- print(f" MRR: {overall.get('mrr', 0):.4f}")
418
- print(f" NDCG@1: {overall.get('ndcg@1', 0):.4f}")
419
- print(f" NDCG@5: {overall.get('ndcg@5', 0):.4f}")
420
- print(f" NDCG@10: {overall.get('ndcg@10', 0):.4f}")
421
- print(f" Recall@1: {overall.get('recall@1', 0):.4f}")
422
- print(f" Recall@5: {overall.get('recall@5', 0):.4f}")
423
- print(f" Recall@10: {overall.get('recall@10', 0):.4f}")
424
-
425
- print("\nPer-Language Results:")
426
- for lang, lang_results in results.get("languages", {}).items():
427
- metrics = lang_results.get("metrics", {})
428
- print(
429
- f" {lang:12s}: MRR={metrics.get('mrr', 0):.3f}, "
430
- f"NDCG@10={metrics.get('ndcg@10', 0):.3f}, "
431
- f"Recall@5={metrics.get('recall@5', 0):.3f}"
432
- )
433
-
434
-
435
- def create_comparison_report(all_results: list[dict[str, Any]], output_dir: str) -> None:
436
- """Create a comparison report across all evaluated models."""
 
 
 
 
 
 
 
 
 
437
  if not all_results:
438
  return
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  output_path = Path(output_dir)
 
441
 
442
- # Create comparison DataFrame
443
- comparison_data = []
444
- for results in all_results:
445
- overall = results.get("overall", {})
446
- row = {
447
- "Model": results["model_name"],
448
- "MRR": overall.get("mrr", 0),
449
- "NDCG@1": overall.get("ndcg@1", 0),
450
- "NDCG@5": overall.get("ndcg@5", 0),
451
- "NDCG@10": overall.get("ndcg@10", 0),
452
- "Recall@1": overall.get("recall@1", 0),
453
- "Recall@5": overall.get("recall@5", 0),
454
- "Recall@10": overall.get("recall@10", 0),
455
- "Mean Rank": overall.get("mean_rank", 0),
456
- }
457
- comparison_data.append(row)
 
 
 
 
458
 
459
- df = pd.DataFrame(comparison_data)
460
- df = df.sort_values("NDCG@10", ascending=False) # Sort by NDCG@10
461
 
462
- # Save to CSV
463
- csv_path = output_path / "codesearchnet_comparison.csv"
464
- df.to_csv(csv_path, index=False, float_format="%.4f")
465
- logger.info(f"Comparison report saved to {csv_path}")
466
 
467
- # Print comparison table
468
- print(f"\n{'=' * 80}")
469
- print("CodeSearchNet Model Comparison")
470
- print(f"{'=' * 80}")
471
- print(df.to_string(index=False, float_format="%.4f"))
472
 
473
 
474
- def beam_evaluate_models(
 
 
 
 
 
475
  models: list[str],
476
  max_queries: int = 1000,
477
  languages: list[str] | None = None,
478
- output_dir: str = DEFAULT_OUTPUT_DIR,
479
- volume_name: str = VOLUME_NAME,
480
- mount_path: str = VOLUME_PATH,
481
  ) -> list[dict[str, Any]]:
482
- """Main evaluation function for Beam execution with checkpoint support."""
483
- logger.info("🚀 Starting Beam-powered CodeSearchNet evaluation")
484
- logger.info(f"📊 Evaluating {len(models)} models on {len(languages or EVALUATION_LANGUAGES)} languages")
 
 
 
 
 
 
485
 
486
- # Initialize Beam utilities
487
- volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(volume_name, mount_path)
 
488
 
489
- # Create evaluation results directory in volume
490
- results_dir = Path(mount_path) / EVALUATION_RESULTS_DIR
491
- results_dir.mkdir(parents=True, exist_ok=True)
 
 
 
492
 
493
- logger.info(f"📁 Using Beam volume: {volume_name} at {mount_path}")
494
- logger.info(f"💾 Evaluation results directory: {results_dir}")
 
495
 
496
- all_results = []
497
- skipped_models = []
 
 
 
 
 
 
 
 
498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  for model_path in models:
500
  model_name = Path(model_path).name
501
 
502
- # Check for existing evaluation results
503
- existing_result_file = results_dir / f"codesearchnet_eval_{model_name}.json"
504
- if existing_result_file.exists():
505
- logger.info(f"✅ Model {model_name} already evaluated - loading existing results")
506
- try:
507
- with existing_result_file.open("r") as f:
508
- existing_results = json.load(f)
509
- all_results.append(existing_results)
510
- skipped_models.append(model_name)
511
- continue
512
- except Exception as e:
513
- logger.warning(f"⚠️ Failed to load existing results for {model_name}: {e}")
514
- # Continue with evaluation if loading fails
515
-
516
  logger.info(f"\n{'=' * 60}")
517
  logger.info(f"🔍 Evaluating model: {model_name}")
518
- logger.info(f"📂 Path: {model_path}")
519
  logger.info(f"{'=' * 60}")
520
 
521
  try:
522
- # Distinguish between local paths and HuggingFace model names
523
- is_huggingface_model = (
524
- "/" in model_path and not model_path.startswith("/") and not Path(model_path).exists()
525
- )
526
 
527
- if is_huggingface_model:
528
- # This is a HuggingFace model name - pass directly to evaluator
529
- logger.info(f"📥 Loading HuggingFace model: {model_path}")
530
- evaluator = CodeSearchNetEvaluator(
531
- model_path,
532
- model_name,
533
- checkpoint_manager=checkpoint_mgr,
534
- eval_manager=eval_mgr,
535
- )
536
- else:
537
- # This is a local path - check if it exists in Beam volume
538
- actual_model_path = model_path # Default to original path
539
- if not Path(model_path).exists() and not model_path.startswith("/"):
540
- # Try to load from Beam volume
541
- local_model_path = Path(mount_path) / MODEL_CACHE_PREFIX / model_name
542
- logger.info(f"🔍 Trying to load {model_name} from Beam volume: {local_model_path}")
543
- if model_mgr.load_model(model_name, local_model_path.parent):
544
- actual_model_path = str(local_model_path)
545
- logger.info(f"✅ Loaded model from Beam volume: {actual_model_path}")
546
- else:
547
- logger.warning(f"⚠️ Model not found locally or in Beam volume: {model_name}")
548
- continue
549
-
550
- evaluator = CodeSearchNetEvaluator(
551
- actual_model_path,
552
- model_name,
553
- checkpoint_manager=checkpoint_mgr,
554
- eval_manager=eval_mgr,
555
- )
556
-
557
- results = evaluator.evaluate_all_languages(max_queries, languages)
558
-
559
- # Save results with Beam support
560
- save_results(results, output_dir, model_name, eval_mgr, results_dir)
561
-
562
- # Print summary
563
- print_results_summary(results)
564
-
565
- all_results.append(results)
566
 
567
  except Exception:
568
  logger.exception(f"❌ Failed to evaluate {model_name}")
569
  continue
570
 
571
- # Create comparison report in Beam volume
572
- if len(all_results) > 1:
573
- comparison_dir = Path(mount_path) / EVALUATION_RESULTS_DIR / "comparisons"
574
- comparison_dir.mkdir(parents=True, exist_ok=True)
575
- create_comparison_report(all_results, str(comparison_dir))
576
- logger.info(f"📊 Comparison report saved to Beam volume: {comparison_dir}")
577
-
578
- # Log summary of what was done
579
- newly_evaluated = len(all_results) - len(skipped_models)
580
- logger.info("\n✅ Beam evaluation complete!")
581
- logger.info(f"📊 Newly evaluated: {newly_evaluated} models")
582
- logger.info(f"⏭️ Skipped (already done): {len(skipped_models)} models")
583
- logger.info(f"📁 Total results: {len(all_results)} models")
584
- logger.info(f"💾 Results available in Beam volume: {volume_name}")
585
-
586
- if skipped_models:
587
- logger.info(f"⏭️ Skipped models: {', '.join(skipped_models)}")
588
-
589
- return all_results
590
 
591
 
592
  @function(
593
  gpu=GPU_NAME,
594
- volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
595
  image=IMAGE,
596
  secrets=["HF_ACCESS_TOKEN"],
597
- env={
598
- "TOKENIZERS_PARALLELISM": "false",
599
- "CUDA_LAUNCH_BLOCKING": "0",
600
- },
601
- timeout=3600 * 6, # 6 hours for evaluation
602
  )
603
- def main(skip_third_party: bool = False) -> None:
604
- """Main evaluation function - runs all default models on Beam."""
605
- logger.info("🚀 Starting comprehensive CodeSearchNet evaluation on Beam")
606
-
607
- # Use default models or skip them based on flag
608
- if skip_third_party:
609
- logger.info("⏭️ Skipping 3rd party models - evaluating only simplified distillation models")
610
- models = []
611
- else:
612
- logger.info("📊 Including 3rd party peer models for comparison")
613
- models = DEFAULT_EVALUATION_MODELS.copy()
614
-
615
- # Discover simplified distillation models in the current directory
616
- logger.info("🔍 Discovering simplified distillation models...")
617
- discovered_models = discover_simplified_models(".")
618
-
619
- # Add discovered models (they're already sorted alphabetically)
620
- if discovered_models:
621
- logger.info(f"✅ Found {len(discovered_models)} simplified models:")
622
- for model_path in discovered_models:
623
- models.append(model_path)
624
- logger.info(f" 📁 {model_path}")
625
- else:
626
- logger.warning("⚠️ No simplified distillation models found")
627
- if skip_third_party:
628
- logger.error("❌ No models to evaluate! Either create simplified models or include 3rd party models.")
629
- return
630
 
631
- logger.info(f"📊 Evaluating {len(models)} models:")
632
- for i, model in enumerate(models, 1):
633
- logger.info(f" {i}. {model}")
634
 
635
- logger.info("\n💡 Checkpoint Info:")
636
- logger.info(" - Already evaluated models will be skipped")
637
- logger.info(" - Results are saved persistently to Beam volume")
638
 
639
- # Run comprehensive evaluation using Beam utilities
640
- results = beam_evaluate_models(
641
- models=models,
642
- max_queries=1000,
643
- languages=EVALUATION_LANGUAGES,
644
- output_dir=str(Path(VOLUME_PATH) / EVALUATION_RESULTS_DIR),
645
- volume_name=VOLUME_NAME,
646
- mount_path=VOLUME_PATH,
647
- )
648
 
649
- # Print final summary
650
- print("\n🎯 Evaluation Summary:")
651
- print(f"📊 Total models processed: {len(results)}")
652
- print(f"💾 Results saved to Beam volume: {VOLUME_NAME}")
653
- print(f"📁 Directory: {EVALUATION_RESULTS_DIR}")
654
- if skip_third_party:
655
- print("⏭️ 3rd party models were skipped")
656
- print("\n🔍 To view analysis:")
657
- print(" beam run src.distiller.analyze:beam_analysis")
658
- print("\n📈 To run evaluations again:")
659
- print(" distiller evaluate (will skip already completed models)")
660
- print(" distiller evaluate --skip-third-party (evaluate only simplified models)")
661
-
662
-
663
- def discover_simplified_models(base_path: str = ".") -> list[str]:
664
- """
665
- Discover all simplified distillation models in the correct directory.
666
-
667
- Looks for directories matching the pattern: ./code_model2vec/final/code_model2vec_*
668
- """
669
- discovered_models: list[str] = []
670
-
671
- # Look in the correct location where distill_simplified.py saves models
672
- models_dir = Path(base_path) / "code_model2vec" / "final"
673
-
674
- if not models_dir.exists():
675
- logger.warning(f"Models directory not found: {models_dir}")
676
- return discovered_models
677
-
678
- # Look for simplified model directories with the updated pattern
679
- pattern = "code_model2vec_*"
680
- for model_dir in models_dir.glob(pattern):
681
- if model_dir.is_dir() and (model_dir / "config.json").exists():
682
- discovered_models.append(str(model_dir))
683
- logger.info(f"🔍 Discovered simplified model: {model_dir}")
684
-
685
- # Sort alphabetically for consistent ordering
686
- discovered_models.sort()
687
 
688
- return discovered_models
 
689
 
 
 
690
 
691
- @function(
692
- gpu=GPU_NAME,
693
- volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
694
- image=IMAGE,
695
- secrets=["HF_ACCESS_TOKEN"],
696
- env={
697
- "TOKENIZERS_PARALLELISM": "false",
698
- "CUDA_LAUNCH_BLOCKING": "0",
699
- },
700
- timeout=3600 * 6, # 6 hours for evaluation
701
- )
702
- def evaluate_simplified_only() -> None:
703
- """Evaluate only simplified distillation models, skipping 3rd party models."""
704
- main(skip_third_party=True)
705
 
706
 
707
- def run_local_evaluation(
708
- models: list[str] | None = None,
709
  max_queries: int = 1000,
710
  languages: list[str] | None = None,
711
- output_dir: str = DEFAULT_OUTPUT_DIR,
712
  ) -> list[dict[str, Any]]:
713
- """Main evaluation function for local execution without Beam utilities."""
714
- logger.info("🖥️ Running CodeSearchNet evaluation locally")
715
-
716
- if models is None:
717
- models = DEFAULT_EVALUATION_MODELS.copy()
718
-
719
- # Discover simplified distillation models in the current directory
720
- logger.info("🔍 Discovering simplified distillation models...")
721
- discovered_models = discover_simplified_models(".")
722
-
723
- # Add discovered models
724
- if discovered_models:
725
- logger.info(f"✅ Found {len(discovered_models)} simplified models:")
726
- for model_path in discovered_models:
727
- models.append(model_path)
728
- logger.info(f" 📁 {model_path}")
729
- else:
730
- logger.warning("⚠️ No simplified distillation models found")
731
-
732
- if languages is None:
733
- languages = EVALUATION_LANGUAGES
734
-
735
- logger.info(f"📊 Evaluating {len(models)} models on {len(languages)} languages")
736
- logger.info(f"📁 Using local output directory: {output_dir}")
737
-
738
- # Create local output directory
739
- output_path = Path(output_dir)
740
- output_path.mkdir(parents=True, exist_ok=True)
741
-
742
- all_results = []
743
- skipped_models = []
744
 
 
745
  for model_path in models:
746
  model_name = Path(model_path).name
747
 
748
- # Check for existing evaluation results locally
749
- safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
750
- result_file = output_path / f"codesearchnet_eval_{safe_name}.json"
751
-
752
- if result_file.exists():
753
- logger.info(f"✅ Model {model_name} already evaluated - loading existing results")
754
- try:
755
- with result_file.open("r") as f:
756
- existing_results = json.load(f)
757
- all_results.append(existing_results)
758
- skipped_models.append(model_name)
759
- continue
760
- except Exception as e:
761
- logger.warning(f"⚠️ Failed to load existing results for {model_name}: {e}")
762
-
763
- logger.info(f"\n{'=' * 60}")
764
- logger.info(f"🔍 Evaluating model: {model_name}")
765
- logger.info(f"📂 Path: {model_path}")
766
- logger.info(f"{'=' * 60}")
767
 
768
  try:
769
- # Create evaluator without Beam utilities (no checkpointing)
770
- evaluator = CodeSearchNetEvaluator(
771
- model_path,
772
- model_name,
773
- checkpoint_manager=None, # No checkpointing for local evaluation
774
- eval_manager=None,
775
- )
776
-
777
- results = evaluator.evaluate_all_languages(max_queries, languages)
778
 
779
- # Save results locally only
780
- save_results(results, output_dir, model_name, eval_manager=None, volume_results_dir=None)
781
-
782
- # Print summary
783
- print_results_summary(results)
 
 
 
 
784
 
785
- all_results.append(results)
 
 
 
 
 
786
 
787
  except Exception:
788
- logger.exception(f"❌ Failed to evaluate {model_name}")
789
  continue
790
 
791
- # Create comparison report locally
792
- if len(all_results) > 1:
793
- create_comparison_report(all_results, output_dir)
794
- logger.info(f"📊 Comparison report saved locally: {output_dir}")
795
 
796
- # Log summary
797
- newly_evaluated = len(all_results) - len(skipped_models)
798
- logger.info("\n✅ Local evaluation complete!")
799
- logger.info(f"📊 Newly evaluated: {newly_evaluated} models")
800
- logger.info(f"⏭️ Skipped (already done): {len(skipped_models)} models")
801
- logger.info(f"📁 Total results: {len(all_results)} models")
802
- logger.info(f"💾 Results available locally: {output_dir}")
803
 
804
- if skipped_models:
805
- logger.info(f"⏭️ Skipped models: {', '.join(skipped_models)}")
 
806
 
807
- return all_results
808
 
 
 
 
 
 
 
 
 
809
 
810
- def run_local_evaluation_simplified(
811
- max_queries: int = 1000,
812
- languages: list[str] | None = None,
813
- output_dir: str = DEFAULT_OUTPUT_DIR,
814
- ) -> list[dict[str, Any]]:
815
- """Local evaluation function for simplified models only."""
816
- logger.info("🖥️ Running simplified model evaluation locally")
817
 
818
- # Discover simplified distillation models only
819
- logger.info("🔍 Discovering simplified distillation models...")
820
- discovered_models = discover_simplified_models(".")
 
 
 
821
 
822
- if not discovered_models:
823
- logger.error(" No simplified distillation models found! Run 'distiller distill-simple' first.")
824
- return []
825
 
826
- logger.info(f"✅ Found {len(discovered_models)} simplified models:")
827
- for model_path in discovered_models:
828
- logger.info(f" 📁 {model_path}")
 
 
 
 
 
 
 
 
 
 
 
829
 
830
- return run_local_evaluation(
831
- models=discovered_models,
 
 
 
 
 
832
  max_queries=max_queries,
833
- languages=languages,
834
- output_dir=output_dir,
 
835
  )
836
 
 
 
 
 
 
837
 
838
  if __name__ == "__main__":
839
- main()
 
1
  """
2
+ Comprehensive Model Evaluation Script for Code-Specialized Embedding Models.
3
 
4
+ This script evaluates embedding models on both task performance and operational metrics:
5
+
6
+ Task Performance:
7
+ - CodeSearchNet evaluation (NDCG, MRR, Recall metrics)
8
+ - Code search accuracy across programming languages
9
+
10
+ Operational Performance:
11
+ - Inference speed (latency and throughput)
12
+ - Memory efficiency (RAM and GPU usage)
13
+ - Model size and storage requirements
14
+ - CPU vs GPU performance scaling
15
 
16
  Usage:
17
+ distiller evaluate [--use-beam] [--skip-benchmark] # Run evaluation locally or on Beam
18
  """
19
 
20
  import json
21
  import logging
22
  import time
23
+ import traceback
24
  from pathlib import Path
25
  from typing import Any
26
 
27
  import numpy as np
28
  import pandas as pd
29
+ import psutil
30
+ import torch
31
+ import typer
32
+ from beam import Volume, function
33
  from datasets import Dataset, load_dataset
34
  from sentence_transformers import SentenceTransformer
35
  from sklearn.metrics.pairwise import cosine_similarity
36
  from tqdm import tqdm
37
 
38
+ from .beam_utils import download_specific_evaluation_file
39
+ from .config import (
40
+ BEAM_ENV_SETTINGS,
41
+ DEFAULT_EVALUATION_MODELS,
42
+ GPU_NAME,
43
+ IMAGE,
44
+ codesearchnet_config,
45
+ directories,
46
+ get_safe_model_name,
47
+ get_volume_config,
48
+ languages_config,
49
  )
50
 
 
 
51
  logger = logging.getLogger(__name__)
52
 
53
  # =============================================================================
54
+ # EVALUATION CONFIGURATION
55
  # =============================================================================
56
 
57
+ BATCH_SIZE = 32
58
+ LOCAL_EVALUATION_DIR = directories.evaluation_results
59
+ LOCAL_BENCHMARK_DIR = directories.benchmark_results
60
+ LOCAL_MODELS_DIR = directories.final
61
+ VOLUME_CONFIG = get_volume_config()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  # =============================================================================
64
+ # CORE EVALUATION CLASSES
65
  # =============================================================================
66
 
67
+ # Sample texts for benchmarking (various lengths)
68
+ BENCHMARK_TEXTS = {
69
+ "short": [
70
+ "def add(a, b): return a + b",
71
+ "function multiply(x, y) { return x * y; }",
72
+ "class Calculator { public int subtract(int a, int b) { return a - b; } }",
73
+ ]
74
+ * 100, # 300 short texts
75
+ "medium": [
76
+ "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
77
+ "function quickSort(arr) {\n if (arr.length <= 1) return arr;\n const pivot = arr[arr.length - 1];\n const left = [], right = [];\n for (let i = 0; i < arr.length - 1; i++) {\n if (arr[i] < pivot) left.push(arr[i]);\n else right.push(arr[i]);\n }\n return [...quickSort(left), pivot, ...quickSort(right)];\n}",
78
+ ]
79
+ * 50, # 100 medium texts
80
+ "long": [
81
+ """
82
+ def complex_algorithm(data, config):
83
+ '''
84
+ Complex data processing algorithm with multiple steps.
85
+ '''
86
+ results = []
87
+ # Data validation and processing steps...
88
+ return results
89
+ """.strip(),
90
+ ]
91
+ * 20, # 20 long texts
92
+ }
93
 
 
 
 
94
 
95
+ class PerformanceBenchmark:
96
+ """Comprehensive performance benchmarking for embedding models."""
 
 
97
 
98
+ def __init__(self, model_path: str, model_name: str | None = None) -> None:
99
+ """Initialize benchmarker with model."""
100
+ self.model_path = model_path
101
+ self.model_name = model_name or Path(model_path).name
102
+ self.model: SentenceTransformer | None = None
103
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
104
+ self.results: dict[str, Any] = {}
105
+
106
+ def load_model(self) -> None:
107
+ """Load the embedding model."""
108
+ logger.info(f"Loading model from {self.model_path}")
109
+ start_time = time.time()
110
+
111
+ try:
112
+ self.model = SentenceTransformer(self.model_path, device=self.device, trust_remote_code=True)
113
+ load_time = time.time() - start_time
114
+
115
+ logger.info(f"✅ Model loaded in {load_time:.2f}s on {self.device}")
116
+ self.results["model_load_time"] = load_time
117
+
118
+ except Exception:
119
+ logger.exception("❌ Failed to load model")
120
+ self.results["error"] = traceback.format_exc()
121
+ raise
122
+
123
+ def measure_model_size(self) -> dict[str, float]:
124
+ """Measure model size metrics."""
125
+ logger.info("📏 Measuring model size...")
126
+
127
+ size_metrics: dict[str, Any] = {}
128
+
129
+ # Disk size
130
+ try:
131
+ if Path(self.model_path).is_dir():
132
+ # Local directory - calculate size of model files only
133
+ model_extensions = {".safetensors", ".bin", ".json", ".txt", ".tokenizer"}
134
+ total_size = 0
135
+ model_dir = Path(self.model_path)
136
+
137
+ for file_path in model_dir.rglob("*"):
138
+ if file_path.is_file() and (
139
+ file_path.suffix.lower() in model_extensions or "tokenizer" in file_path.name.lower()
140
+ ):
141
+ total_size += file_path.stat().st_size
142
+
143
+ size_metrics["disk_size_mb"] = total_size / (1024 * 1024)
144
+ # HuggingFace model - estimate based on model parameters
145
+ elif self.model is not None:
146
+ param_count = sum(p.numel() for p in self.model.parameters())
147
+ # Rough estimate: 4 bytes per parameter (float32)
148
+ estimated_size = param_count * 4
149
+ size_metrics["disk_size_mb"] = estimated_size / (1024 * 1024)
150
+ else:
151
+ size_metrics["disk_size_mb"] = 0.0
152
+
153
+ except Exception as e:
154
+ logger.warning(f"⚠️ Could not calculate disk size: {e}")
155
+ size_metrics["disk_size_mb"] = 0.0
156
+
157
+ # Memory size (if model is loaded)
158
+ if self.model is not None:
159
+ try:
160
+ # Parameter count
161
+ param_count = sum(p.numel() for p in self.model.parameters())
162
+ size_metrics["parameter_count"] = param_count
163
+ size_metrics["parameters_millions"] = param_count / 1e6
164
+
165
+ # Memory usage estimate
166
+ param_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
167
+ buffer_size = sum(b.numel() * b.element_size() for b in self.model.buffers())
168
+ size_metrics["memory_size_mb"] = (param_size + buffer_size) / (1024 * 1024)
169
+ size_metrics["ram_usage_mb"] = size_metrics["memory_size_mb"]
170
+
171
+ # GPU memory if using CUDA
172
+ if self.device == "cuda" and torch.cuda.is_available():
173
+ size_metrics["gpu_memory_mb"] = torch.cuda.memory_allocated() / (1024 * 1024)
174
+ size_metrics["gpu_name"] = torch.cuda.get_device_name(0)
175
+
176
+ # Embedding dimension if available
177
+ if hasattr(self.model, "get_sentence_embedding_dimension"):
178
+ size_metrics["embedding_dim"] = self.model.get_sentence_embedding_dimension()
179
+
180
+ except Exception as e:
181
+ logger.warning(f"⚠️ Could not calculate memory size: {e}")
182
+
183
+ # Update results
184
+ self.results["size_metrics"] = size_metrics
185
+ return size_metrics
186
+
187
+ def benchmark_inference_speed(self, batch_sizes: list[int] | None = None) -> dict[str, Any]:
188
+ """Benchmark inference speed with different batch sizes."""
189
+ if batch_sizes is None:
190
+ batch_sizes = [1, 8, 16, 32]
191
+
192
+ logger.info(f"⚡ Benchmarking inference speed with batch sizes: {batch_sizes}")
193
+
194
+ if self.model is None:
195
+ self.load_model()
196
+
197
+ speed_results: dict[str, Any] = {"medium": {}}
198
+
199
+ # Use medium-length texts for speed testing
200
+ test_texts = BENCHMARK_TEXTS["medium"]
201
+
202
+ for batch_size in batch_sizes:
203
+ logger.info(f" 📊 Testing batch size: {batch_size}")
204
+
205
+ # Prepare batch
206
+ batch = (
207
+ test_texts[:batch_size]
208
+ if batch_size <= len(test_texts)
209
+ else test_texts * ((batch_size // len(test_texts)) + 1)
210
+ )
211
+ batch = batch[:batch_size]
212
+
213
+ # Warmup
214
+ if self.model is not None:
215
+ _ = self.model.encode(batch[: min(4, len(batch))], convert_to_tensor=False)
216
+
217
+ # Benchmark multiple runs
218
+ latencies = []
219
+ num_runs = max(3, 20 // batch_size) # More runs for smaller batches
220
+
221
+ for _ in range(num_runs):
222
+ start_time = time.time()
223
+ if self.model is not None:
224
+ _ = self.model.encode(batch, convert_to_tensor=False, normalize_embeddings=True)
225
+ end_time = time.time()
226
+ latencies.append(end_time - start_time)
227
+
228
+ # Calculate metrics
229
+ avg_latency = sum(latencies) / len(latencies)
230
+ throughput = batch_size / avg_latency
231
+ time_per_text_ms = (avg_latency / batch_size) * 1000
232
+
233
+ batch_key = f"batch_{batch_size}"
234
+ speed_results["medium"][batch_key] = {
235
+ "time_per_text_ms": time_per_text_ms,
236
+ "texts_per_second": throughput,
237
+ "tokens_per_second": throughput * 50, # Estimate 50 tokens per text
238
+ }
239
+
240
+ logger.info(f" ⚡ Latency: {avg_latency:.3f}s, Throughput: {throughput:.1f} texts/sec")
241
+
242
+ # Update results
243
+ self.results["speed_benchmarks"] = speed_results
244
+ return speed_results
245
+
246
+ def benchmark_memory_scaling(self, batch_sizes: list[int] | None = None) -> dict[str, Any]:
247
+ """Benchmark memory usage scaling with batch size."""
248
+ if batch_sizes is None:
249
+ batch_sizes = [1, 8, 16, 32]
250
+
251
+ logger.info(f"🧠 Benchmarking memory scaling with batch sizes: {batch_sizes}")
252
+
253
+ if self.model is None:
254
+ self.load_model()
255
+
256
+ memory_results: dict[str, Any] = {}
257
+ test_texts = BENCHMARK_TEXTS["medium"]
258
+
259
+ for batch_size in batch_sizes:
260
+ logger.info(f" 📊 Testing memory with batch size: {batch_size}")
261
+
262
+ # Prepare batch
263
+ batch = (
264
+ test_texts[:batch_size]
265
+ if batch_size <= len(test_texts)
266
+ else test_texts * ((batch_size // len(test_texts)) + 1)
267
+ )
268
+ batch = batch[:batch_size]
269
+
270
+ # Clear GPU cache if using CUDA
271
+ if torch.cuda.is_available():
272
+ torch.cuda.empty_cache()
273
+ torch.cuda.reset_peak_memory_stats()
274
+
275
+ try:
276
+ # Run inference
277
+ if self.model is not None:
278
+ _ = self.model.encode(batch, convert_to_tensor=False)
279
+
280
+ # Measure peak memory
281
+ if torch.cuda.is_available():
282
+ peak_memory = torch.cuda.max_memory_allocated() / (1024 * 1024)
283
+ memory_per_text = peak_memory / batch_size
284
+ else:
285
+ # Use psutil for CPU memory (less accurate)
286
+ peak_memory = psutil.virtual_memory().used / (1024 * 1024)
287
+ memory_per_text = 0 # Can't accurately measure per-text on CPU
288
+
289
+ batch_key = f"batch_{batch_size}"
290
+ memory_results[batch_key] = {
291
+ "memory_used_mb": peak_memory,
292
+ "memory_per_text_mb": memory_per_text,
293
+ "oom": False,
294
+ }
295
+
296
+ logger.info(f" 🧠 Peak memory: {peak_memory:.1f}MB, Per text: {memory_per_text:.2f}MB")
297
+
298
+ except Exception as e:
299
+ logger.warning(f"⚠️ Memory benchmark failed for batch {batch_size}: {e}")
300
+ batch_key = f"batch_{batch_size}"
301
+ memory_results[batch_key] = {
302
+ "oom": True,
303
+ "error": str(e),
304
+ }
305
+
306
+ self.results["memory_benchmarks"] = memory_results
307
+ return memory_results
308
+
309
+ def benchmark_cpu_vs_gpu(self) -> dict[str, Any]:
310
+ """Compare CPU vs GPU performance."""
311
+ logger.info("⚖️ Benchmarking CPU vs GPU performance")
312
+
313
+ if not torch.cuda.is_available():
314
+ logger.warning("⚠️ CUDA not available - skipping GPU benchmark")
315
+ return {}
316
+
317
+ comparison_results: dict[str, Any] = {}
318
+ test_texts = BENCHMARK_TEXTS["medium"][:16] # Use 16 texts for comparison
319
+
320
+ for device in ["cpu", "cuda"]:
321
+ logger.info(f" 📊 Testing on {device.upper()}")
322
+
323
+ try:
324
+ model = SentenceTransformer(self.model_path, device=device, trust_remote_code=True)
325
+
326
+ # Warmup
327
+ _ = model.encode(test_texts[:4], convert_to_tensor=False)
328
+
329
+ # Benchmark
330
+ start_time = time.time()
331
+ _ = model.encode(test_texts, convert_to_tensor=False, normalize_embeddings=True)
332
+ end_time = time.time()
333
+
334
+ latency = end_time - start_time
335
+ throughput = len(test_texts) / latency
336
+
337
+ comparison_results[device] = {
338
+ "texts_per_second": throughput,
339
+ }
340
+
341
+ logger.info(f" ⚡ {device.upper()}: {latency:.3f}s, {throughput:.1f} texts/sec")
342
+
343
+ # Clean up
344
+ del model
345
+ if device == "cuda":
346
+ torch.cuda.empty_cache()
347
+
348
+ except Exception as e:
349
+ logger.warning(f"⚠️ Failed to benchmark {device}: {e}")
350
+ comparison_results[device] = {"error": str(e)}
351
+
352
+ # Calculate speedup
353
+ if "cpu" in comparison_results and "cuda" in comparison_results:
354
+ cpu_throughput = comparison_results["cpu"].get("texts_per_second", 0)
355
+ gpu_throughput = comparison_results["cuda"].get("texts_per_second", 0)
356
+ if cpu_throughput > 0:
357
+ speedup = gpu_throughput / cpu_throughput
358
+ comparison_results["gpu_speedup"] = speedup
359
+ logger.info(f" 🚀 GPU Speedup: {speedup:.1f}x")
360
+
361
+ self.results["cpu_vs_gpu"] = comparison_results
362
+ return comparison_results
363
+
364
+ def run_comprehensive_benchmark(self) -> dict[str, Any]:
365
+ """Run all benchmarks and return comprehensive results."""
366
+ logger.info(f"🏁 Starting comprehensive benchmark for {self.model_name}")
367
+
368
+ # Model information
369
+ self.results["model_name"] = self.model_name
370
+ self.results["model_path"] = self.model_path
371
+ self.results["timestamp"] = time.strftime("%Y-%m-%d %H:%M:%S")
372
+
373
+ # Run all benchmarks
374
+ try:
375
+ self.load_model()
376
+ self.measure_model_size()
377
+ self.benchmark_inference_speed([1, 8, 16, 32])
378
+ self.benchmark_memory_scaling([1, 8, 16, 32])
379
+ self.benchmark_cpu_vs_gpu()
380
+
381
+ logger.info(f"✅ Comprehensive benchmark completed for {self.model_name}")
382
+
383
+ except Exception:
384
+ logger.exception(f"❌ Benchmark failed for {self.model_name}")
385
+ self.results["error"] = traceback.format_exc()
386
+
387
+ return self.results
388
 
389
 
390
  class CodeSearchNetEvaluator:
391
  """Evaluator for CodeSearchNet-style code search tasks."""
392
 
393
+ def __init__(self, model_path: str, model_name: str | None = None) -> None:
394
+ """Initialize the evaluator with a model."""
 
 
 
 
 
 
395
  self.model_path = model_path
396
  self.model_name = model_name or Path(model_path).name
397
  self.model: SentenceTransformer | None = None
 
 
398
  self._load_model()
399
 
400
  def _load_model(self) -> None:
401
+ """Load the embedding model."""
402
  logger.info(f"Loading model from {self.model_path}")
 
 
 
 
 
 
 
 
403
  try:
404
  self.model = SentenceTransformer(self.model_path, trust_remote_code=True)
405
  logger.info(f"Successfully loaded model: {self.model_name}")
 
414
  raise RuntimeError(msg)
415
 
416
  embeddings = []
 
417
  for i in tqdm(range(0, len(texts), BATCH_SIZE), desc=desc):
418
  batch = texts[i : i + BATCH_SIZE]
419
  batch_embeddings = self.model.encode(batch, convert_to_tensor=False, normalize_embeddings=True)
 
422
  return np.vstack(embeddings)
423
 
424
  def evaluate_language(self, language: str, max_queries: int = 1000) -> dict[str, Any]:
425
+ """Evaluate on a specific programming language."""
426
  logger.info(f"Evaluating on {language} language (max {max_queries} queries)")
427
 
 
 
 
 
 
 
 
428
  try:
429
  # Load test split for the language
430
  dataset = load_dataset(
431
+ codesearchnet_config.dataset_name,
432
  language,
433
  split="test",
434
  trust_remote_code=True,
435
  )
436
 
 
437
  if not isinstance(dataset, Dataset):
438
  logger.error(f"Unexpected dataset type for {language}: {type(dataset)}")
439
  return {}
440
 
441
+ # Sample queries for evaluation
442
  if len(dataset) > max_queries:
443
+ rng = np.random.default_rng(42)
444
  indices = rng.choice(len(dataset), max_queries, replace=False)
445
  dataset = dataset.select(indices)
446
 
 
464
  logger.info(f"Found {len(queries)} valid query-code pairs for {language}")
465
 
466
  # Encode queries and codes
467
+ start_time = time.time()
468
  query_embeddings = self.encode_texts(queries, f"Encoding {language} queries")
469
+ code_embeddings = self.encode_texts(codes, f"Encoding {language} code")
470
+ encoding_time = time.time() - start_time
471
 
472
+ # Compute similarities and metrics
473
  similarities = cosine_similarity(query_embeddings, code_embeddings)
 
 
474
  metrics = self._compute_retrieval_metrics(similarities)
475
 
476
+ # Prepare results
477
+ results = {
478
  "language": language,
479
+ "model_name": self.model_name,
480
  "num_queries": len(queries),
481
+ "encoding_time_seconds": encoding_time,
482
  "metrics": metrics,
 
483
  }
484
 
485
+ logger.info(f"✅ {language} evaluation completed in {encoding_time:.2f}s")
486
+ return results
 
 
 
 
 
 
 
 
 
 
 
 
 
487
 
488
  except Exception:
489
+ logger.exception(f" Failed to evaluate {language}")
490
  return {}
491
 
492
  def _compute_retrieval_metrics(self, similarities: np.ndarray) -> dict[str, float]:
493
+ """Compute retrieval metrics from similarity matrix."""
494
+ n_queries = similarities.shape[0]
495
 
496
+ # For each query, the correct code is at the same index
497
+ correct_indices = np.arange(n_queries)
498
+
499
+ # Rank all codes for each query
500
+ ranked_indices = np.argsort(similarities, axis=1)[:, ::-1]
501
+
502
+ metrics = {}
503
+
504
+ # Compute metrics for different k values
505
+ for k in [1, 5, 10]:
506
+ if k <= similarities.shape[1]:
507
+ # Recall@k
508
+ recall_k = np.mean([correct_indices[i] in ranked_indices[i, :k] for i in range(n_queries)])
509
+ metrics[f"recall@{k}"] = recall_k
510
+
511
+ # NDCG@k
512
+ ndcg_k = np.mean(
513
+ [self._compute_ndcg(ranked_indices[i], correct_indices[i], k) for i in range(n_queries)]
514
+ )
515
+ metrics[f"ndcg@{k}"] = ndcg_k
516
+
517
+ # Mean Reciprocal Rank
518
  reciprocal_ranks = []
519
+ for i in range(n_queries):
520
+ rank = np.where(ranked_indices[i] == correct_indices[i])[0]
521
+ if len(rank) > 0:
522
+ reciprocal_ranks.append(1.0 / (rank[0] + 1))
523
+ else:
524
+ reciprocal_ranks.append(0.0)
525
+
526
+ metrics["mrr"] = np.mean(reciprocal_ranks)
527
+
528
+ # Add mean rank and median rank
529
+ mean_ranks = []
530
+ for i in range(n_queries):
531
+ rank = np.where(ranked_indices[i] == correct_indices[i])[0]
532
+ if len(rank) > 0:
533
+ mean_ranks.append(rank[0] + 1) # 1-indexed
534
+ else:
535
+ mean_ranks.append(similarities.shape[1]) # Worst possible rank
536
+
537
+ metrics["mean_rank"] = np.mean(mean_ranks)
538
+ metrics["median_rank"] = np.median(mean_ranks)
539
+
540
+ # Ensure all values are float
541
+ return {k: float(v) for k, v in metrics.items()}
 
 
 
 
 
 
 
 
 
542
 
543
  def _compute_ndcg(self, ranked_indices: np.ndarray, correct_idx: int, k: int) -> float:
544
  """Compute NDCG@k for a single query."""
545
+ if correct_idx in ranked_indices[:k]:
546
+ rank = np.where(ranked_indices[:k] == correct_idx)[0][0]
547
+ return 1.0 / np.log2(rank + 2)
 
 
 
 
 
548
  return 0.0
549
 
550
  def evaluate_all_languages(
551
  self, max_queries_per_lang: int = 1000, languages: list[str] | None = None
552
  ) -> dict[str, Any]:
553
+ """Evaluate on all specified languages."""
554
+ eval_languages = languages or languages_config.all
 
 
 
555
 
556
+ logger.info(f"🚀 Starting evaluation on {len(eval_languages)} languages")
557
+ logger.info(f"📊 Model: {self.model_name}")
558
+ logger.info(f"🔢 Max queries per language: {max_queries_per_lang}")
 
 
 
559
 
560
  start_time = time.time()
561
+ results = {
 
562
  "model_name": self.model_name,
563
  "model_path": self.model_path,
 
564
  "languages": {},
565
  "overall": {},
566
+ "evaluation_time_seconds": 0,
567
  }
568
+ languages_dict: dict[str, Any] = {}
569
 
570
+ # Evaluate each language
571
+ for language in eval_languages:
572
+ logger.info(f"\n{'=' * 50}")
573
+ logger.info(f"🔍 Evaluating {language}")
574
+ logger.info(f"{'=' * 50}")
575
 
 
 
576
  lang_results = self.evaluate_language(language, max_queries_per_lang)
 
577
  if lang_results:
578
+ languages_dict[language] = lang_results
 
 
 
579
 
580
+ results["languages"] = languages_dict
581
+
582
+ # Compute overall metrics
583
+ if languages_dict:
584
  overall_metrics = {}
585
+ metric_names = list(next(iter(languages_dict.values()))["metrics"].keys())
586
+
587
+ for metric in metric_names:
588
+ values = [languages_dict[lang]["metrics"][metric] for lang in languages_dict]
589
+ overall_metrics[metric] = np.mean(values)
590
 
591
  results["overall"] = overall_metrics
592
 
593
  total_time = time.time() - start_time
594
  results["evaluation_time_seconds"] = total_time
595
 
 
 
 
 
 
596
  logger.info(f"Evaluation completed in {total_time:.2f} seconds")
597
  return results
598
 
599
 
600
+ class ComprehensiveModelEvaluator:
601
+ """Combined evaluator for both task performance and operational benchmarks."""
 
 
 
 
 
 
 
 
 
 
 
 
 
602
 
603
+ def __init__(self, model_path: str, model_name: str | None = None) -> None:
604
+ """Initialize the comprehensive evaluator with a model."""
605
+ self.model_path = model_path
606
+ self.model_name = model_name or Path(model_path).name
607
+
608
+ # Initialize sub-evaluators
609
+ self.codesearch_evaluator = CodeSearchNetEvaluator(model_path, model_name)
610
+ self.performance_benchmarker = PerformanceBenchmark(model_path, model_name)
611
+
612
+ self.results: dict[str, Any] = {}
613
+
614
+ def run_comprehensive_evaluation(
615
+ self,
616
+ max_queries_per_lang: int = 1000,
617
+ languages: list[str] | None = None,
618
+ skip_benchmark: bool = False,
619
+ ) -> dict[str, Any]:
620
+ """Run both CodeSearchNet evaluation and performance benchmarking."""
621
+ logger.info(f"🚀 Starting comprehensive evaluation for {self.model_name}")
622
+ start_time = time.time()
623
+
624
+ # Initialize results structure
625
+ self.results = {
626
+ "model_name": self.model_name,
627
+ "model_path": self.model_path,
628
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
629
+ "evaluation_time_seconds": 0,
630
+ }
631
 
 
 
 
 
 
 
 
 
 
 
 
632
  try:
633
+ # 1. Run CodeSearchNet evaluation
634
+ logger.info("🔍 Running CodeSearchNet task evaluation...")
635
+ codesearch_results = self.codesearch_evaluator.evaluate_all_languages(max_queries_per_lang, languages)
636
+
637
+ # Extract CodeSearchNet metrics
638
+ self.results.update(
639
+ {
640
+ "codesearch_languages": codesearch_results.get("languages", {}),
641
+ "codesearch_overall": codesearch_results.get("overall", {}),
642
+ }
643
+ )
644
+
645
+ # 2. Run performance benchmarking (unless skipped)
646
+ if not skip_benchmark:
647
+ logger.info("⚡ Running operational performance benchmarking...")
648
+ benchmark_results = self.performance_benchmarker.run_comprehensive_benchmark()
649
+
650
+ # Extract benchmark metrics
651
+ self.results.update(
652
+ {
653
+ "size_metrics": benchmark_results.get("size_metrics", {}),
654
+ "speed_benchmarks": benchmark_results.get("speed_benchmarks", {}),
655
+ "memory_benchmarks": benchmark_results.get("memory_benchmarks", {}),
656
+ "cpu_vs_gpu": benchmark_results.get("cpu_vs_gpu", {}),
657
+ }
658
+ )
659
+ else:
660
+ logger.info("⏭️ Skipping performance benchmarking")
661
+ self.results["benchmark_skipped"] = True
662
+
663
  except Exception as e:
664
+ logger.exception(f" Comprehensive evaluation failed for {self.model_name}")
665
+ self.results["error"] = str(e)
666
 
667
+ # Calculate total time
668
+ total_time = time.time() - start_time
669
+ self.results["evaluation_time_seconds"] = total_time
 
 
 
 
670
 
671
+ logger.info(f"✅ Comprehensive evaluation completed in {total_time:.2f} seconds")
672
+ return self.results
 
673
 
674
+ def print_summary(self) -> None:
675
+ """Print a comprehensive summary of all results."""
676
+ logger.info(f"\n{'=' * 60}")
677
+ logger.info(f"📊 COMPREHENSIVE EVALUATION RESULTS: {self.model_name}")
678
+ logger.info(f"{'=' * 60}")
679
 
680
+ # CodeSearchNet results
681
+ overall = self.results.get("codesearch_overall", {})
682
+ if overall:
683
+ logger.info("🔍 CodeSearchNet Performance:")
684
+ for metric, value in overall.items():
685
+ logger.info(f" 🎯 {metric.upper()}: {value:.4f}")
686
+
687
+ # Benchmark results
688
+ if not self.results.get("benchmark_skipped", False):
689
+ size_metrics = self.results.get("size_metrics", {})
690
+ if size_metrics:
691
+ logger.info(f"\n📏 Model Size: {size_metrics.get('disk_size_mb', 0):.1f}MB")
692
+ if "parameters_millions" in size_metrics:
693
+ logger.info(f"🔢 Parameters: {size_metrics['parameters_millions']:.1f}M")
694
+
695
+ speed_benchmarks = self.results.get("speed_benchmarks", {})
696
+ if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
697
+ batch_32 = speed_benchmarks["medium"]["batch_32"]
698
+ logger.info(f"⚡ Throughput (batch 32): {batch_32.get('texts_per_second', 0):.1f} texts/sec")
699
+
700
+ cpu_vs_gpu = self.results.get("cpu_vs_gpu", {})
701
+ if "gpu_speedup" in cpu_vs_gpu:
702
+ speedup = cpu_vs_gpu["gpu_speedup"]
703
+ logger.info(f"🚀 GPU speedup: {speedup:.1f}x")
704
+
705
+ # Language breakdown
706
+ languages = self.results.get("codesearch_languages", {})
707
+ if languages:
708
+ logger.info("\n📋 Language Breakdown:")
709
+ for lang, lang_results in languages.items():
710
+ metrics = lang_results.get("metrics", {})
711
+ ndcg10 = metrics.get("ndcg@10", 0)
712
+ mrr = metrics.get("mrr", 0)
713
+ logger.info(f" {lang}: NDCG@10={ndcg10:.4f}, MRR={mrr:.4f}")
714
 
 
715
 
716
+ # =============================================================================
717
+ # UTILITY FUNCTIONS
718
+ # =============================================================================
719
 
 
 
 
 
720
 
721
+ def check_existing_results(model_name: str, local_dir: str = LOCAL_EVALUATION_DIR) -> dict[str, Any] | None:
722
+ """Check if comprehensive evaluation results already exist for a model."""
723
+ local_path = Path(local_dir)
724
+ safe_model_name = get_safe_model_name(model_name)
725
+
726
+ # Check for new comprehensive format first
727
+ comprehensive_file = local_path / f"comprehensive_eval_{safe_model_name}.json"
728
+ if comprehensive_file.exists():
729
+ try:
730
+ with comprehensive_file.open("r") as f:
731
+ results = json.load(f)
732
+ logger.info(f"✅ Found existing comprehensive results for {model_name}")
733
+ return results
734
+ except Exception as e:
735
+ logger.warning(f"⚠️ Could not load existing comprehensive results for {model_name}: {e}")
736
+
737
+ # Fallback to legacy codesearchnet format for backward compatibility
738
+ legacy_file = local_path / f"codesearchnet_eval_{safe_model_name}.json"
739
+ if legacy_file.exists():
740
+ try:
741
+ with legacy_file.open("r") as f:
742
+ results = json.load(f)
743
+ logger.info(f"✅ Found existing legacy results for {model_name}")
744
+ return results
745
+ except Exception as e:
746
+ logger.warning(f"⚠️ Could not load existing legacy results for {model_name}: {e}")
747
+
748
+ return None
749
+
750
+
751
+ def save_evaluation_results(results: dict[str, Any], local_dir: str = LOCAL_EVALUATION_DIR) -> bool:
752
+ """Save comprehensive evaluation results to local directory as a single JSON file."""
753
+ try:
754
+ local_path = Path(local_dir)
755
+ local_path.mkdir(parents=True, exist_ok=True)
756
+
757
+ model_name = results.get("model_name", "unknown")
758
+ safe_model_name = get_safe_model_name(model_name)
759
 
760
+ # Save single comprehensive results file (CodeSearchNet + Benchmark combined)
761
+ result_file = local_path / f"comprehensive_eval_{safe_model_name}.json"
762
+ with result_file.open("w") as f:
763
+ json.dump(results, f, indent=2, default=str)
764
+
765
+ logger.info(f"💾 Saved comprehensive evaluation results for {model_name}")
766
+ return True
767
+
768
+ except Exception:
769
+ logger.exception("❌ Error saving evaluation results")
770
+ return False
771
+
772
+
773
+ def discover_local_models(models_dir: str = LOCAL_MODELS_DIR) -> list[str]:
774
+ """Discover models in the local models directory."""
775
+ models_path = Path(models_dir)
776
+ discovered_models = []
777
+
778
+ if models_path.exists():
779
+ for model_dir in models_path.iterdir():
780
+ if model_dir.is_dir() and (
781
+ any(model_dir.glob("*.json")) or any(model_dir.glob("*.bin")) or any(model_dir.glob("*.safetensors"))
782
+ ):
783
+ discovered_models.append(str(model_dir))
784
+ logger.info(f"📁 Found local model: {model_dir.name}")
785
+
786
+ return discovered_models
787
+
788
+
789
+ def print_results_summary(results: dict[str, Any]) -> None:
790
+ """Print a formatted summary of comprehensive evaluation results."""
791
+ logger.info(f"\n{'=' * 60}")
792
+ logger.info(f"📊 COMPREHENSIVE EVALUATION: {results.get('model_name', 'Unknown')}")
793
+ logger.info(f"{'=' * 60}")
794
+
795
+ # CodeSearchNet results
796
+ overall = results.get("codesearch_overall", {})
797
  if overall:
798
+ logger.info("🔍 CodeSearchNet Performance:")
799
+ for metric, value in overall.items():
800
+ logger.info(f" 🎯 {metric.upper()}: {value:.4f}")
801
+
802
+ # Benchmark results
803
+ if not results.get("benchmark_skipped", False):
804
+ size_metrics = results.get("size_metrics", {})
805
+ if size_metrics:
806
+ logger.info(f"\n📏 Model Size: {size_metrics.get('disk_size_mb', 0):.1f}MB")
807
+ if "parameters_millions" in size_metrics:
808
+ logger.info(f"🔢 Parameters: {size_metrics['parameters_millions']:.1f}M")
809
+
810
+ speed_benchmarks = results.get("speed_benchmarks", {})
811
+ if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
812
+ batch_32 = speed_benchmarks["medium"]["batch_32"]
813
+ logger.info(f"⚡ Throughput (batch 32): {batch_32.get('texts_per_second', 0):.1f} texts/sec")
814
+
815
+ # Language breakdown
816
+ languages = results.get("codesearch_languages", {})
817
+ if languages:
818
+ logger.info("\n📋 Language Breakdown:")
819
+ for lang, lang_results in languages.items():
820
+ metrics = lang_results.get("metrics", {})
821
+ ndcg10 = metrics.get("ndcg@10", 0)
822
+ mrr = metrics.get("mrr", 0)
823
+ logger.info(f" {lang}: NDCG@10={ndcg10:.4f}, MRR={mrr:.4f}")
824
+
825
+
826
+ def create_comparison_report(all_results: list[dict[str, Any]], output_dir: str = LOCAL_EVALUATION_DIR) -> None:
827
+ """Create a comprehensive comparison report with both CodeSearchNet and benchmark data."""
828
  if not all_results:
829
  return
830
 
831
+ logger.info("📊 Creating comprehensive comparison report...")
832
+
833
+ # Create evaluation comparison dataframe
834
+ evaluation_data = []
835
+ benchmark_data = []
836
+
837
+ for result in all_results:
838
+ model_name = result.get("model_name", "Unknown")
839
+
840
+ # CodeSearchNet data
841
+ overall = result.get("codesearch_overall", {})
842
+ eval_row = {"model_name": model_name}
843
+ eval_row.update(overall)
844
+ evaluation_data.append(eval_row)
845
+
846
+ # Benchmark data (if available)
847
+ if not result.get("benchmark_skipped", False):
848
+ benchmark_row = {"model_name": model_name}
849
+ size_metrics = result.get("size_metrics", {})
850
+ speed_benchmarks = result.get("speed_benchmarks", {})
851
+
852
+ benchmark_row.update(size_metrics)
853
+ if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
854
+ batch_32 = speed_benchmarks["medium"]["batch_32"]
855
+ benchmark_row["best_throughput"] = batch_32.get("texts_per_second", 0)
856
+ benchmark_data.append(benchmark_row)
857
+
858
+ # Save comparison results
859
  output_path = Path(output_dir)
860
+ output_path.mkdir(parents=True, exist_ok=True)
861
 
862
+ # Combined evaluation comparison CSV (includes both CodeSearchNet and key benchmark metrics)
863
+ if evaluation_data and benchmark_data:
864
+ # Merge evaluation and benchmark data
865
+ combined_data = []
866
+ benchmark_dict = {row["model_name"]: row for row in benchmark_data}
867
+
868
+ for eval_row in evaluation_data:
869
+ model_name = eval_row["model_name"]
870
+ combined_row = eval_row.copy()
871
+
872
+ # Add benchmark metrics if available
873
+ if model_name in benchmark_dict:
874
+ benchmark_row = benchmark_dict[model_name]
875
+ combined_row.update(
876
+ {
877
+ "disk_size_mb": benchmark_row.get("disk_size_mb", 0),
878
+ "parameters_millions": benchmark_row.get("parameters_millions", 0),
879
+ "best_throughput": benchmark_row.get("best_throughput", 0),
880
+ }
881
+ )
882
 
883
+ combined_data.append(combined_row)
 
884
 
885
+ combined_df = pd.DataFrame(combined_data)
886
+ combined_csv = output_path / "comprehensive_comparison.csv"
887
+ combined_df.to_csv(combined_csv, index=False)
888
+ logger.info(f"📄 Comprehensive comparison CSV saved: {combined_csv}")
889
 
890
+ # Detailed JSON export
891
+ json_path = output_path / "comprehensive_evaluation.json"
892
+ with json_path.open("w") as f:
893
+ json.dump(all_results, f, indent=2, default=str)
894
+ logger.info(f"📄 Comprehensive results JSON saved: {json_path}")
895
 
896
 
897
+ # =============================================================================
898
+ # MAIN EVALUATION FUNCTIONS
899
+ # =============================================================================
900
+
901
+
902
+ def run_evaluation(
903
  models: list[str],
904
  max_queries: int = 1000,
905
  languages: list[str] | None = None,
906
+ use_beam: bool = False,
907
+ skip_benchmark: bool = False,
 
908
  ) -> list[dict[str, Any]]:
909
+ """Main evaluation function that handles both local and Beam execution."""
910
+ logger.info(f"🚀 Starting comprehensive evaluation ({'Beam' if use_beam else 'Local'})")
911
+ logger.info(f"📊 Evaluating {len(models)} models on {len(languages or languages_config.all)} languages")
912
+ logger.info(f"⚡ Benchmarking: {'Disabled' if skip_benchmark else 'Enabled'}")
913
+
914
+ # Check for existing results and skip already evaluated models
915
+ models_to_evaluate = []
916
+ skipped_models = []
917
+ all_results = []
918
 
919
+ for model_path in models:
920
+ model_name = Path(model_path).name
921
+ existing_results = check_existing_results(model_name)
922
 
923
+ if existing_results:
924
+ logger.info(f"✅ Model {model_name} already evaluated, skipping")
925
+ all_results.append(existing_results)
926
+ skipped_models.append(model_name)
927
+ else:
928
+ models_to_evaluate.append(model_path)
929
 
930
+ if not models_to_evaluate:
931
+ logger.info("🎉 All models already evaluated!")
932
+ return all_results
933
 
934
+ logger.info(f"📊 Need to evaluate {len(models_to_evaluate)} models")
935
+
936
+ if use_beam:
937
+ # Run on Beam
938
+ new_results = _run_beam_evaluation(models_to_evaluate, max_queries, languages, skip_benchmark)
939
+ else:
940
+ # Run locally
941
+ new_results = _run_local_evaluation(models_to_evaluate, max_queries, languages, skip_benchmark)
942
+
943
+ all_results.extend(new_results)
944
 
945
+ # Create comparison report
946
+ if len(all_results) > 1:
947
+ create_comparison_report(all_results)
948
+
949
+ # Print summary
950
+ newly_evaluated = len(new_results)
951
+ logger.info(f"\n{'=' * 60}")
952
+ logger.info("📊 EVALUATION SUMMARY")
953
+ logger.info(f"{'=' * 60}")
954
+ logger.info(f"📊 Total models: {len(models)}")
955
+ logger.info(f"✅ Newly evaluated: {newly_evaluated}")
956
+ logger.info(f"⏭️ Skipped (already done): {len(skipped_models)}")
957
+ logger.info(f"🎯 Total results: {len(all_results)}")
958
+ logger.info(f"⚡ Benchmarking: {'Disabled' if skip_benchmark else 'Enabled'}")
959
+
960
+ return all_results
961
+
962
+
963
+ def _run_local_evaluation(
964
+ models: list[str],
965
+ max_queries: int = 1000,
966
+ languages: list[str] | None = None,
967
+ skip_benchmark: bool = False,
968
+ ) -> list[dict[str, Any]]:
969
+ """Run comprehensive evaluation locally."""
970
+ logger.info("🖥️ Running local comprehensive evaluation")
971
+
972
+ results = []
973
  for model_path in models:
974
  model_name = Path(model_path).name
975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976
  logger.info(f"\n{'=' * 60}")
977
  logger.info(f"🔍 Evaluating model: {model_name}")
 
978
  logger.info(f"{'=' * 60}")
979
 
980
  try:
981
+ evaluator = ComprehensiveModelEvaluator(model_path, model_name)
982
+ result = evaluator.run_comprehensive_evaluation(max_queries, languages, skip_benchmark)
 
 
983
 
984
+ # Save results locally
985
+ save_evaluation_results(result)
986
+ print_results_summary(result)
987
+ results.append(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
988
 
989
  except Exception:
990
  logger.exception(f"❌ Failed to evaluate {model_name}")
991
  continue
992
 
993
+ return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
994
 
995
 
996
  @function(
997
  gpu=GPU_NAME,
998
+ volumes=[Volume(name=VOLUME_CONFIG.name, mount_path=VOLUME_CONFIG.mount_path)],
999
  image=IMAGE,
1000
  secrets=["HF_ACCESS_TOKEN"],
1001
+ env=BEAM_ENV_SETTINGS,
1002
+ timeout=3600 * 8, # 8 hours for comprehensive evaluation
 
 
 
1003
  )
1004
+ def _beam_evaluate_single_model(
1005
+ model_path: str,
1006
+ max_queries: int = 1000,
1007
+ languages: list[str] | None = None,
1008
+ skip_benchmark: bool = False,
1009
+ ) -> dict[str, Any]:
1010
+ """Beam function to comprehensively evaluate a single model."""
1011
+ model_name = Path(model_path).name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1012
 
1013
+ logger.info(f"🚀 Beam comprehensive evaluation starting for {model_name}")
 
 
1014
 
1015
+ try:
1016
+ evaluator = ComprehensiveModelEvaluator(model_path, model_name)
1017
+ results = evaluator.run_comprehensive_evaluation(max_queries, languages, skip_benchmark)
1018
 
1019
+ # Save to Beam volume as single comprehensive file
1020
+ volume_results_dir = Path(VOLUME_CONFIG.mount_path) / "evaluation_results"
1021
+ volume_results_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
1022
 
1023
+ safe_model_name = get_safe_model_name(model_name)
1024
+ result_file = volume_results_dir / f"comprehensive_eval_{safe_model_name}.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1025
 
1026
+ with result_file.open("w") as f:
1027
+ json.dump(results, f, indent=2, default=str)
1028
 
1029
+ logger.info(f"💾 Saved Beam comprehensive evaluation results for {model_name}")
1030
+ return results
1031
 
1032
+ except Exception:
1033
+ logger.exception(f"❌ Beam comprehensive evaluation failed for {model_name}")
1034
+ return {}
 
 
 
 
 
 
 
 
 
 
 
1035
 
1036
 
1037
+ def _run_beam_evaluation(
1038
+ models: list[str],
1039
  max_queries: int = 1000,
1040
  languages: list[str] | None = None,
1041
+ skip_benchmark: bool = False,
1042
  ) -> list[dict[str, Any]]:
1043
+ """Run comprehensive evaluation on Beam and download results."""
1044
+ logger.info("☁️ Running Beam comprehensive evaluation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1045
 
1046
+ results = []
1047
  for model_path in models:
1048
  model_name = Path(model_path).name
1049
 
1050
+ logger.info(f"🚀 Starting Beam comprehensive evaluation for {model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1051
 
1052
  try:
1053
+ # Run evaluation on Beam
1054
+ result = _beam_evaluate_single_model.remote(model_path, max_queries, languages, skip_benchmark)
 
 
 
 
 
 
 
1055
 
1056
+ if result:
1057
+ # Download the comprehensive result file from Beam
1058
+ success = download_specific_evaluation_file(
1059
+ VOLUME_CONFIG.name,
1060
+ model_name,
1061
+ "evaluation_results",
1062
+ LOCAL_EVALUATION_DIR,
1063
+ file_prefix="comprehensive_eval",
1064
+ )
1065
 
1066
+ if success:
1067
+ logger.info(f"📥 Downloaded comprehensive results for {model_name}")
1068
+ print_results_summary(result)
1069
+ results.append(result)
1070
+ else:
1071
+ logger.warning(f"⚠️ Could not download results for {model_name}")
1072
 
1073
  except Exception:
1074
+ logger.exception(f"❌ Beam comprehensive evaluation failed for {model_name}")
1075
  continue
1076
 
1077
+ return results
 
 
 
1078
 
 
 
 
 
 
 
 
1079
 
1080
+ # =============================================================================
1081
+ # CLI INTERFACE
1082
+ # =============================================================================
1083
 
 
1084
 
1085
+ def main(
1086
+ use_beam: bool = typer.Option(default=False, help="Use Beam for evaluation"),
1087
+ skip_third_party: bool = typer.Option(default=False, help="Skip third-party models"),
1088
+ skip_benchmark: bool = typer.Option(default=False, help="Skip performance benchmarking"),
1089
+ max_queries: int = typer.Option(default=1000, help="Maximum queries per language"),
1090
+ ) -> None:
1091
+ """Main comprehensive evaluation function."""
1092
+ logger.info("🚀 Starting comprehensive model evaluation (CodeSearchNet + Performance)")
1093
 
1094
+ # Build model list
1095
+ models = []
 
 
 
 
 
1096
 
1097
+ # Add third-party models if not skipped
1098
+ if not skip_third_party:
1099
+ logger.info("📊 Including third-party peer models for comparison")
1100
+ models.extend(DEFAULT_EVALUATION_MODELS)
1101
+ else:
1102
+ logger.info("⏭️ Skipping third-party models")
1103
 
1104
+ # Discover local models from code_model2vec/final
1105
+ logger.info("🔍 Discovering local distillation models...")
1106
+ local_models = discover_local_models()
1107
 
1108
+ if local_models:
1109
+ logger.info(f"✅ Found {len(local_models)} local models:")
1110
+ for model_path in local_models:
1111
+ models.append(model_path)
1112
+ logger.info(f" 📁 {Path(model_path).name}")
1113
+ else:
1114
+ logger.warning("⚠️ No local distillation models found")
1115
+ if skip_third_party:
1116
+ logger.error("❌ No models to evaluate!")
1117
+ return
1118
+
1119
+ if not models:
1120
+ logger.error("❌ No models to evaluate!")
1121
+ return
1122
 
1123
+ logger.info(f"📊 Will evaluate {len(models)} models:")
1124
+ for i, model in enumerate(models, 1):
1125
+ logger.info(f" {i}. {Path(model).name}")
1126
+
1127
+ # Run evaluation
1128
+ results = run_evaluation(
1129
+ models=models,
1130
  max_queries=max_queries,
1131
+ languages=languages_config.all,
1132
+ use_beam=use_beam,
1133
+ skip_benchmark=skip_benchmark,
1134
  )
1135
 
1136
+ logger.info("🎉 Comprehensive evaluation workflow completed!")
1137
+ logger.info(f"📊 Successfully evaluated {len(results)} models")
1138
+ logger.info(f"💾 Results saved to: {LOCAL_EVALUATION_DIR}")
1139
+ logger.info("📄 Format: Single comprehensive JSON per model (CodeSearchNet + Benchmarks)")
1140
+
1141
 
1142
  if __name__ == "__main__":
1143
+ typer.run(main)
src/distiller/patch_utils.py CHANGED
@@ -67,6 +67,15 @@ def apply_patch_file(patch_file: Path, target_dir: Path) -> bool:
67
  try:
68
  logger.info(f"Applying patch: {patch_file.name}")
69
 
 
 
 
 
 
 
 
 
 
70
  # Use patch command with the following options:
71
  # -p1: strip 1 leading directory from paths
72
  # -d: change to directory before applying
@@ -121,6 +130,9 @@ def apply_all_patches() -> int:
121
  target_dir = get_site_packages_path()
122
  logger.info(f"Applying patches to: {target_dir}")
123
 
 
 
 
124
  success_count = 0
125
 
126
  # Sort patch files for consistent ordering
@@ -132,6 +144,113 @@ def apply_all_patches() -> int:
132
  return success_count
133
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def main() -> None:
136
  """Main function for standalone execution."""
137
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
 
67
  try:
68
  logger.info(f"Applying patch: {patch_file.name}")
69
 
70
+ # Check if patch is already applied
71
+ if is_patch_already_applied(patch_file, target_dir):
72
+ logger.info(f"Patch {patch_file.name} already applied")
73
+ return True
74
+
75
+ # Clean any duplicate validation code before applying
76
+ if "model2vec.patch" in patch_file.name:
77
+ clean_duplicate_validation_code(target_dir)
78
+
79
  # Use patch command with the following options:
80
  # -p1: strip 1 leading directory from paths
81
  # -d: change to directory before applying
 
130
  target_dir = get_site_packages_path()
131
  logger.info(f"Applying patches to: {target_dir}")
132
 
133
+ # Clean any existing duplicates first
134
+ clean_duplicate_validation_code(target_dir)
135
+
136
  success_count = 0
137
 
138
  # Sort patch files for consistent ordering
 
144
  return success_count
145
 
146
 
147
+ def is_patch_already_applied(patch_file: Path, target_dir: Path) -> bool:
148
+ """
149
+ Check if a patch has already been applied by looking for specific markers.
150
+
151
+ Args:
152
+ patch_file: Path to the .patch file
153
+ target_dir: Target directory (usually site-packages)
154
+
155
+ Returns:
156
+ True if patch appears to be already applied, False otherwise
157
+ """
158
+ try:
159
+ # For model2vec.patch, check if the validation code is already present
160
+ if "model2vec.patch" in patch_file.name:
161
+ inference_file = target_dir / "model2vec" / "distill" / "inference.py"
162
+ if inference_file.exists():
163
+ inference_content = inference_file.read_text()
164
+ # Check for the specific validation code we're adding
165
+ if (
166
+ "Token-vector mismatch:" in inference_content
167
+ and "Truncating to prevent failure" in inference_content
168
+ ):
169
+ # Also make sure it's in the right place (before return statement, not after)
170
+ lines = inference_content.split("\n")
171
+ for i, line in enumerate(lines):
172
+ if "return out_tokens, out_weights" in line:
173
+ # Check if validation code appears before this return
174
+ preceding_lines = lines[max(0, i - 10) : i]
175
+ if any("Token-vector mismatch:" in pline for pline in preceding_lines):
176
+ return True
177
+ break
178
+
179
+ return False
180
+
181
+ except Exception as e:
182
+ logger.warning(f"Error checking if patch {patch_file.name} is applied: {e}")
183
+ return False
184
+
185
+
186
+ def clean_duplicate_validation_code(target_dir: Path) -> bool:
187
+ """
188
+ Clean up duplicate validation code that might have been added by multiple patch applications.
189
+
190
+ Args:
191
+ target_dir: Target directory (usually site-packages)
192
+
193
+ Returns:
194
+ True if cleanup was successful, False otherwise
195
+ """
196
+ try:
197
+ inference_file = target_dir / "model2vec" / "distill" / "inference.py"
198
+ if not inference_file.exists():
199
+ return True
200
+
201
+ content = inference_file.read_text()
202
+ lines = content.split("\n")
203
+
204
+ # Find all instances of the validation code
205
+ validation_indices = []
206
+ for i, line in enumerate(lines):
207
+ if "Token-vector mismatch:" in line:
208
+ validation_indices.append(i)
209
+
210
+ if len(validation_indices) <= 1:
211
+ return True # No duplicates or no validation code
212
+
213
+ # Keep only the validation code that appears before a return statement
214
+ lines_to_keep = []
215
+ skip_until = -1
216
+
217
+ for i, line in enumerate(lines):
218
+ if i <= skip_until:
219
+ continue
220
+
221
+ # If this is validation code
222
+ if "Token-vector mismatch:" in line:
223
+ # Look ahead to see if there's a return statement nearby
224
+ has_return_after = False
225
+ for j in range(i, min(len(lines), i + 20)):
226
+ if "return out_tokens, out_weights" in lines[j]:
227
+ has_return_after = True
228
+ break
229
+
230
+ # Keep this validation block only if it's followed by a return
231
+ if has_return_after:
232
+ lines_to_keep.append(line)
233
+ else:
234
+ # Skip this validation block (it's a duplicate)
235
+ # Find the end of this validation block
236
+ for j in range(i + 1, len(lines)):
237
+ if lines[j].strip() == "" or not lines[j].startswith(" "):
238
+ skip_until = j - 1
239
+ break
240
+ else:
241
+ lines_to_keep.append(line)
242
+
243
+ # Write back the cleaned content
244
+ cleaned_content = "\n".join(lines_to_keep)
245
+ inference_file.write_text(cleaned_content)
246
+ logger.info("Cleaned duplicate validation code from inference.py")
247
+ return True
248
+
249
+ except Exception as e:
250
+ logger.warning(f"Error cleaning duplicate validation code: {e}")
251
+ return False
252
+
253
+
254
  def main() -> None:
255
  """Main function for standalone execution."""
256
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
src/distiller/sync.py DELETED
@@ -1,262 +0,0 @@
1
- """
2
- Sync utility for downloading files from Beam volume to local directory.
3
-
4
- This module provides functionality to download generated files from the Beam volume
5
- back to the local filesystem, including:
6
- - Final distilled model files (model.safetensors, tokenizer.json, etc.)
7
- - Analysis reports and charts (README.md, comparison charts, etc.)
8
- """
9
-
10
- import logging
11
- import shutil
12
- from pathlib import Path
13
-
14
- from .beam_utils import create_beam_utilities
15
-
16
- # Configure logging
17
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
18
- logger = logging.getLogger(__name__)
19
-
20
- # Beam volume configuration (must match distill.py)
21
- VOLUME_NAME = "gte_qwen2_m2v_code"
22
- VOLUME_PATH = "./gte_qwen2_m2v_code"
23
-
24
- # Model files to sync
25
- MODEL_FILES = [
26
- "model.safetensors",
27
- "tokenizer.json",
28
- "modules.json",
29
- "config.json",
30
- "pytorch_model.bin", # Backup format
31
- "vocab.txt", # If present
32
- ]
33
-
34
- # Analysis directories and files
35
- ANALYSIS_DIRS = [
36
- "analysis_results/reports",
37
- "analysis_results/charts",
38
- "evaluation_results",
39
- ]
40
-
41
- ANALYSIS_FILES = [
42
- "analysis_results/reports/analysis_report.md",
43
- "analysis_results/reports/README.md",
44
- "analysis_results/charts/*.png",
45
- "analysis_results/charts/*.html",
46
- "evaluation_results/*.json",
47
- "evaluation_results/comparisons/*.csv",
48
- ]
49
-
50
-
51
- def sync_model_files(output_dir: str) -> bool:
52
- """Download final model files from Beam volume."""
53
- logger.info("🔄 Syncing model files from Beam volume...")
54
-
55
- output_path = Path(output_dir)
56
- output_path.mkdir(parents=True, exist_ok=True)
57
-
58
- # First, let's debug what's actually in the volume
59
- volume_root = Path(VOLUME_PATH)
60
- logger.info(f"🔍 Debugging volume contents at: {volume_root}")
61
-
62
- if volume_root.exists():
63
- logger.info("📁 Volume root directory contents:")
64
- for item in volume_root.iterdir():
65
- if item.is_file():
66
- logger.info(f" 📄 {item.name} ({item.stat().st_size} bytes)")
67
- elif item.is_dir():
68
- logger.info(f" 📁 {item.name}/ (directory)")
69
- # List files in important subdirectories
70
- if item.name in ["models", "checkpoints", "gte_qwen2_m2v_code"]:
71
- try:
72
- logger.info(f" Contents of {item.name}/:")
73
- for subitem in item.iterdir():
74
- if subitem.is_file():
75
- logger.info(f" 📄 {subitem.name} ({subitem.stat().st_size} bytes)")
76
- else:
77
- logger.info(f" 📁 {subitem.name}/")
78
- # Check one level deeper for model files
79
- if subitem.is_dir():
80
- for subsubitem in subitem.iterdir():
81
- if subsubitem.is_file() and subsubitem.name in MODEL_FILES:
82
- logger.info(f" 🎯 FOUND MODEL FILE: {subsubitem}")
83
- except Exception as e:
84
- logger.warning(f" Error exploring {item.name}: {e}")
85
-
86
- # Also check for model files directly in root
87
- logger.info("🔍 Checking for model files directly in volume root:")
88
- for model_file in MODEL_FILES:
89
- root_file = volume_root / model_file
90
- if root_file.exists():
91
- logger.info(f" 🎯 FOUND: {model_file} in root ({root_file.stat().st_size} bytes)")
92
- else:
93
- logger.error(f"❌ Volume root does not exist: {volume_root}")
94
- return False
95
-
96
- # Since training completed successfully, look for model files in all possible locations
97
- model_locations = [
98
- Path(VOLUME_PATH), # Root of volume (where final model was saved)
99
- Path(VOLUME_PATH) / "models" / "refined_model", # Refined model directory
100
- ]
101
-
102
- synced_files = []
103
-
104
- for location in model_locations:
105
- logger.info(f"📂 Checking model location: {location}")
106
-
107
- if not location.exists():
108
- logger.info(f" ⚠️ Location does not exist: {location}")
109
- continue
110
-
111
- # Try to download each model file directly
112
- for model_file in MODEL_FILES:
113
- source_path = location / model_file
114
- dest_path = output_path / model_file
115
-
116
- if source_path.exists():
117
- try:
118
- shutil.copy2(source_path, dest_path)
119
- synced_files.append(model_file)
120
- logger.info(f"✅ Downloaded: {model_file}")
121
- except Exception as e:
122
- logger.warning(f"⚠️ Failed to copy {model_file}: {e}")
123
-
124
- if synced_files:
125
- logger.info(f"🎉 Successfully synced {len(synced_files)} model files:")
126
- for file in synced_files:
127
- logger.info(f" ✓ {file}")
128
- return True
129
- logger.error("❌ No model files found to sync")
130
- return False
131
-
132
-
133
- def sync_analysis_files(output_dir: str) -> bool:
134
- """Download analysis reports and charts from Beam volume."""
135
- logger.info("🔄 Syncing analysis files from Beam volume...")
136
-
137
- output_path = Path(output_dir)
138
- output_path.mkdir(parents=True, exist_ok=True)
139
-
140
- synced_files = []
141
-
142
- # Sync analysis reports (including README.md)
143
- analysis_reports_dir = Path(VOLUME_PATH) / "analysis_results" / "reports"
144
- if analysis_reports_dir.exists():
145
- for report_file in analysis_reports_dir.glob("*.md"):
146
- dest_path = output_path / report_file.name
147
- try:
148
- shutil.copy2(report_file, dest_path)
149
- synced_files.append(str(report_file.name))
150
- logger.info(f"✅ Downloaded report: {report_file.name}")
151
-
152
- # Special handling for README.md - copy to root
153
- if report_file.name in {"analysis_report.md", "README.md"}:
154
- root_readme = Path(output_dir) / "README.md"
155
- shutil.copy2(report_file, root_readme)
156
- logger.info("✅ Updated root README.md")
157
-
158
- except Exception as e:
159
- logger.warning(f"⚠️ Failed to copy {report_file.name}: {e}")
160
-
161
- # Sync charts
162
- charts_dir = Path(VOLUME_PATH) / "analysis_results" / "charts"
163
- local_charts_dir = output_path / "charts"
164
- if charts_dir.exists():
165
- local_charts_dir.mkdir(exist_ok=True)
166
-
167
- for chart_file in charts_dir.glob("*"):
168
- if chart_file.is_file():
169
- dest_path = local_charts_dir / chart_file.name
170
- try:
171
- shutil.copy2(chart_file, dest_path)
172
- synced_files.append(f"charts/{chart_file.name}")
173
- logger.info(f"✅ Downloaded chart: {chart_file.name}")
174
- except Exception as e:
175
- logger.warning(f"⚠️ Failed to copy chart {chart_file.name}: {e}")
176
-
177
- # Sync evaluation results
178
- eval_dir = Path(VOLUME_PATH) / "evaluation_results"
179
- local_eval_dir = output_path / "evaluation_results"
180
- if eval_dir.exists():
181
- local_eval_dir.mkdir(exist_ok=True)
182
-
183
- for eval_file in eval_dir.glob("*.json"):
184
- dest_path = local_eval_dir / eval_file.name
185
- try:
186
- shutil.copy2(eval_file, dest_path)
187
- synced_files.append(f"evaluation_results/{eval_file.name}")
188
- logger.info(f"✅ Downloaded evaluation: {eval_file.name}")
189
- except Exception as e:
190
- logger.warning(f"⚠️ Failed to copy evaluation {eval_file.name}: {e}")
191
-
192
- if synced_files:
193
- logger.info(f"🎉 Successfully synced {len(synced_files)} analysis files:")
194
- for file in synced_files[:10]: # Show first 10
195
- logger.info(f" ✓ {file}")
196
- if len(synced_files) > 10:
197
- logger.info(f" ... and {len(synced_files) - 10} more files")
198
- return True
199
- logger.error("❌ No analysis files found to sync")
200
- return False
201
-
202
-
203
- def sync_files(
204
- model_files: bool = False,
205
- analysis_files: bool = False,
206
- all_files: bool = False,
207
- output_dir: str = ".",
208
- ) -> None:
209
- """Main sync function to download files from Beam volume."""
210
- logger.info("🚀 Starting file sync from Beam volume")
211
- logger.info(f"📁 Local output directory: {output_dir}")
212
-
213
- # Initialize Beam utilities (read-only)
214
- try:
215
- volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(VOLUME_NAME, VOLUME_PATH)
216
- logger.info(f"✅ Connected to Beam volume: {VOLUME_NAME}")
217
- except Exception:
218
- logger.exception("❌ Failed to connect to Beam volume")
219
- logger.info("Make sure you have run the distillation/evaluation on Beam first")
220
- return
221
-
222
- # Check what files to sync
223
- sync_model = model_files or all_files
224
- sync_analysis = analysis_files or all_files
225
-
226
- if not (sync_model or sync_analysis):
227
- logger.error("❌ No file types specified. Use --model-files, --analysis-files, or --all")
228
- return
229
-
230
- success_count = 0
231
-
232
- # Sync model files
233
- if sync_model:
234
- logger.info("\n" + "=" * 60) # noqa: G003
235
- logger.info("MODEL FILES SYNC")
236
- logger.info("=" * 60)
237
- if sync_model_files(output_dir):
238
- success_count += 1
239
-
240
- # Sync analysis files
241
- if sync_analysis:
242
- logger.info("\n" + "=" * 60) # noqa: G003
243
- logger.info("ANALYSIS FILES SYNC")
244
- logger.info("=" * 60)
245
- if sync_analysis_files(output_dir):
246
- success_count += 1
247
-
248
- # Summary
249
- logger.info("\n" + "=" * 60) # noqa: G003
250
- logger.info("SYNC SUMMARY")
251
- logger.info("=" * 60)
252
-
253
- total_requested = sum([sync_model, sync_analysis])
254
-
255
- if success_count == total_requested:
256
- logger.info("🎉 All requested files synced successfully!")
257
- elif success_count > 0:
258
- logger.info(f"⚠️ Partial sync: {success_count}/{total_requested} file types synced")
259
- else:
260
- logger.error("❌ No files were synced")
261
-
262
- logger.info(f"📂 Files saved to: {Path(output_dir).absolute()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/distiller/utils.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities for the distiller package.
3
+
4
+ This module provides shared functionality used across multiple components
5
+ including model discovery, result management, and initialization helpers.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ from types import TracebackType
12
+ from typing import Any
13
+
14
+ from .beam_utils import (
15
+ BeamCheckpointManager,
16
+ BeamEvaluationManager,
17
+ BeamModelManager,
18
+ BeamVolumeManager,
19
+ create_beam_utilities,
20
+ )
21
+ from .config import VolumeConfig, get_safe_model_name, get_volume_config, setup_logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # =============================================================================
26
+ # BEAM UTILITIES MANAGEMENT
27
+ # =============================================================================
28
+
29
+
30
+ class BeamContext:
31
+ """Context manager for Beam utilities with consistent initialization."""
32
+
33
+ def __init__(self, workflow: str, volume_config: VolumeConfig | None = None) -> None:
34
+ """
35
+ Initialize Beam context.
36
+
37
+ Args:
38
+ workflow: Workflow type (distill, evaluate, benchmark, etc.)
39
+ volume_config: Optional custom volume config, otherwise inferred from workflow
40
+ """
41
+ self.workflow = workflow
42
+ self.volume_config = volume_config or get_volume_config()
43
+ self.volume_manager: BeamVolumeManager | None = None
44
+ self.checkpoint_manager: BeamCheckpointManager | None = None
45
+ self.model_manager: BeamModelManager | None = None
46
+ self.evaluation_manager: BeamEvaluationManager | None = None
47
+
48
+ def __enter__(self) -> tuple[BeamVolumeManager, BeamCheckpointManager, BeamModelManager, BeamEvaluationManager]:
49
+ """Enter context and initialize utilities."""
50
+ logger.info(f"🚀 Initializing Beam utilities for {self.workflow}")
51
+ logger.info(f"📁 Volume: {self.volume_config.name} at {self.volume_config.mount_path}")
52
+
53
+ self.volume_manager, self.checkpoint_manager, self.model_manager, self.evaluation_manager = (
54
+ create_beam_utilities(self.volume_config.name, self.volume_config.mount_path)
55
+ )
56
+
57
+ return self.volume_manager, self.checkpoint_manager, self.model_manager, self.evaluation_manager
58
+
59
+ def __exit__(
60
+ self,
61
+ exc_type: type[BaseException] | None,
62
+ exc_val: BaseException | None,
63
+ exc_tb: TracebackType | None,
64
+ ) -> None:
65
+ """Exit context with cleanup if needed."""
66
+ if exc_type:
67
+ logger.error(f"❌ Error in Beam context for {self.workflow}: {exc_val}")
68
+ else:
69
+ logger.info(f"✅ Beam context for {self.workflow} completed successfully")
70
+
71
+
72
+ def get_beam_utilities() -> tuple[BeamVolumeManager, BeamCheckpointManager, BeamModelManager, BeamEvaluationManager]:
73
+ """
74
+ Get Beam utilities for a specific workflow.
75
+
76
+ Returns:
77
+ Tuple of (volume_manager, checkpoint_manager, model_manager, evaluation_manager)
78
+ """
79
+ volume_config = get_volume_config()
80
+ return create_beam_utilities(volume_config.name, volume_config.mount_path)
81
+
82
+
83
+ # =============================================================================
84
+ # MODEL DISCOVERY
85
+ # =============================================================================
86
+
87
+
88
+ def discover_simplified_models(base_path: str | Path = ".") -> list[str]:
89
+ """
90
+ Discover simplified distillation models in the specified directory.
91
+
92
+ Args:
93
+ base_path: Base path to search for models
94
+
95
+ Returns:
96
+ List of model paths sorted alphabetically
97
+ """
98
+ base = Path(base_path)
99
+
100
+ # Look for models in common locations
101
+ search_patterns = [
102
+ "code_model2vec/final/**/",
103
+ "final/**/",
104
+ "code_model2vec_*/",
105
+ "*/config.json",
106
+ "*.safetensors",
107
+ ]
108
+
109
+ discovered_models = []
110
+
111
+ for pattern in search_patterns:
112
+ matches = list(base.glob(pattern))
113
+ for match in matches:
114
+ if match.is_dir():
115
+ # Check if it's a valid model directory
116
+ if (match / "config.json").exists() or (match / "model.safetensors").exists():
117
+ discovered_models.append(str(match))
118
+ elif match.name == "config.json":
119
+ # Add parent directory if config.json found
120
+ discovered_models.append(str(match.parent))
121
+
122
+ # Remove duplicates and sort
123
+ unique_models = sorted(set(discovered_models))
124
+
125
+ logger.info(f"🔍 Discovered {len(unique_models)} models in {base_path}")
126
+ for model in unique_models:
127
+ logger.info(f" 📁 {model}")
128
+
129
+ return unique_models
130
+
131
+
132
+ def validate_model_path(model_path: str | Path, volume_manager: BeamVolumeManager | None = None) -> str | None:
133
+ """
134
+ Validate and resolve model path, checking local filesystem and Beam volumes.
135
+
136
+ Args:
137
+ model_path: Path to model (can be local path or HuggingFace model name)
138
+ volume_manager: Optional volume manager for Beam volume checks
139
+
140
+ Returns:
141
+ Resolved model path or None if not found
142
+ """
143
+ path = Path(model_path)
144
+
145
+ # Check if it's a HuggingFace model name
146
+ if "/" in str(model_path) and not path.exists() and not str(model_path).startswith("/"):
147
+ logger.info(f"📥 Treating as HuggingFace model: {model_path}")
148
+ return str(model_path)
149
+
150
+ # Check local filesystem
151
+ if path.exists():
152
+ logger.info(f"✅ Found local model: {model_path}")
153
+ return str(path)
154
+
155
+ # Check Beam volume if available
156
+ if volume_manager:
157
+ volume_path = Path(volume_manager.mount_path) / path.name
158
+ if volume_path.exists():
159
+ logger.info(f"✅ Found model in Beam volume: {volume_path}")
160
+ return str(volume_path)
161
+
162
+ # Check volume root
163
+ root_path = Path(volume_manager.mount_path)
164
+ if (root_path / "config.json").exists():
165
+ logger.info(f"✅ Found model in Beam volume root: {root_path}")
166
+ return str(root_path)
167
+
168
+ logger.warning(f"⚠️ Model not found: {model_path}")
169
+ return None
170
+
171
+
172
+ # =============================================================================
173
+ # RESULT MANAGEMENT
174
+ # =============================================================================
175
+
176
+
177
+ def save_results_with_backup(
178
+ results: dict[str, Any],
179
+ primary_path: str | Path,
180
+ model_name: str,
181
+ result_type: str = "evaluation",
182
+ volume_manager: BeamVolumeManager | None = None,
183
+ evaluation_manager: BeamEvaluationManager | None = None,
184
+ ) -> bool:
185
+ """
186
+ Save results with multiple backup strategies.
187
+
188
+ Args:
189
+ results: Results dictionary to save
190
+ primary_path: Primary save location
191
+ model_name: Model name for filename generation
192
+ result_type: Type of results (evaluation, benchmark, etc.)
193
+ volume_manager: Optional volume manager for Beam storage
194
+ evaluation_manager: Optional evaluation manager for specialized storage
195
+
196
+ Returns:
197
+ True if saved successfully to at least one location
198
+ """
199
+ success_count = 0
200
+ safe_name = get_safe_model_name(model_name)
201
+
202
+ # Save to primary location
203
+ try:
204
+ primary = Path(primary_path)
205
+ primary.mkdir(parents=True, exist_ok=True)
206
+ filename = f"{result_type}_{safe_name}.json"
207
+ filepath = primary / filename
208
+
209
+ with filepath.open("w") as f:
210
+ json.dump(results, f, indent=2, default=str)
211
+
212
+ logger.info(f"💾 Saved {result_type} results to: {filepath}")
213
+ success_count += 1
214
+ except Exception as e:
215
+ logger.warning(f"⚠️ Failed to save to primary location: {e}")
216
+
217
+ # Save to Beam volume if available
218
+ if volume_manager:
219
+ try:
220
+ volume_path = Path(volume_manager.mount_path) / f"{result_type}_results"
221
+ volume_path.mkdir(parents=True, exist_ok=True)
222
+ filename = f"{result_type}_{safe_name}.json"
223
+ filepath = volume_path / filename
224
+
225
+ with filepath.open("w") as f:
226
+ json.dump(results, f, indent=2, default=str)
227
+
228
+ logger.info(f"💾 Saved {result_type} results to Beam volume: {filepath}")
229
+ success_count += 1
230
+ except Exception as e:
231
+ logger.warning(f"⚠️ Failed to save to Beam volume: {e}")
232
+
233
+ # Save via evaluation manager if available and appropriate
234
+ if evaluation_manager and result_type == "evaluation":
235
+ try:
236
+ success = evaluation_manager.save_evaluation_results(model_name, results)
237
+ if success:
238
+ logger.info(f"💾 Saved via evaluation manager for {model_name}")
239
+ success_count += 1
240
+ except Exception as e:
241
+ logger.warning(f"⚠️ Failed to save via evaluation manager: {e}")
242
+
243
+ return success_count > 0
244
+
245
+
246
+ def load_existing_results(
247
+ model_name: str,
248
+ result_type: str = "evaluation",
249
+ search_paths: list[str | Path] | None = None,
250
+ volume_manager: BeamVolumeManager | None = None,
251
+ evaluation_manager: BeamEvaluationManager | None = None,
252
+ ) -> dict[str, Any] | None:
253
+ """
254
+ Load existing results from multiple possible locations.
255
+
256
+ Args:
257
+ model_name: Model name to search for
258
+ result_type: Type of results to load
259
+ search_paths: Additional paths to search
260
+ volume_manager: Optional volume manager
261
+ evaluation_manager: Optional evaluation manager
262
+
263
+ Returns:
264
+ Results dictionary if found, None otherwise
265
+ """
266
+ safe_name = get_safe_model_name(model_name)
267
+ filename = f"{result_type}_{safe_name}.json"
268
+
269
+ # Search in provided paths
270
+ if search_paths:
271
+ for search_path in search_paths:
272
+ filepath = Path(search_path) / filename
273
+ if filepath.exists():
274
+ try:
275
+ with filepath.open("r") as f:
276
+ results = json.load(f)
277
+ logger.info(f"📂 Loaded existing {result_type} results from: {filepath}")
278
+ return results
279
+ except Exception as e:
280
+ logger.warning(f"⚠️ Failed to load from {filepath}: {e}")
281
+
282
+ # Search in Beam volume
283
+ if volume_manager:
284
+ volume_path = Path(volume_manager.mount_path) / f"{result_type}_results" / filename
285
+ if volume_path.exists():
286
+ try:
287
+ with volume_path.open("r") as f:
288
+ results = json.load(f)
289
+ logger.info(f"📂 Loaded existing {result_type} results from Beam volume: {volume_path}")
290
+ return results
291
+ except Exception as e:
292
+ logger.warning(f"⚠️ Failed to load from Beam volume: {e}")
293
+
294
+ # Try evaluation manager
295
+ if evaluation_manager and result_type == "evaluation":
296
+ try:
297
+ results = evaluation_manager.load_evaluation_results(model_name)
298
+ if results:
299
+ logger.info(f"📂 Loaded existing {result_type} results via evaluation manager")
300
+ return results
301
+ except Exception as e:
302
+ logger.warning(f"⚠️ Failed to load via evaluation manager: {e}")
303
+
304
+ logger.info(f"ℹ️ No existing {result_type} results found for {model_name}")
305
+ return None
306
+
307
+
308
+ # =============================================================================
309
+ # WORKFLOW HELPERS
310
+ # =============================================================================
311
+
312
+
313
+ def print_workflow_summary(
314
+ workflow_name: str,
315
+ total_items: int,
316
+ processed_items: int,
317
+ skipped_items: int,
318
+ execution_time: float | None = None,
319
+ ) -> None:
320
+ """Print a standardized workflow summary."""
321
+ logger.info(f"\n✅ {workflow_name} complete!")
322
+ logger.info(f"📊 Total items: {total_items}")
323
+ logger.info(f"✨ Newly processed: {processed_items}")
324
+ logger.info(f"⏭️ Skipped (already done): {skipped_items}")
325
+
326
+ if execution_time:
327
+ logger.info(f"⏱️ Execution time: {execution_time:.2f} seconds")
328
+
329
+
330
+ def check_existing_results(
331
+ items: list[str],
332
+ result_type: str,
333
+ search_paths: list[str | Path] | None = None,
334
+ volume_manager: BeamVolumeManager | None = None,
335
+ ) -> tuple[list[str], list[str]]:
336
+ """
337
+ Check which items already have results and which need processing.
338
+
339
+ Args:
340
+ items: List of items (model names, etc.) to check
341
+ result_type: Type of results to check for
342
+ search_paths: Paths to search for existing results
343
+ volume_manager: Optional volume manager
344
+
345
+ Returns:
346
+ Tuple of (items_to_process, items_to_skip)
347
+ """
348
+ to_process = []
349
+ to_skip = []
350
+
351
+ for item in items:
352
+ existing = load_existing_results(item, result_type, search_paths, volume_manager)
353
+ if existing:
354
+ to_skip.append(item)
355
+ else:
356
+ to_process.append(item)
357
+
358
+ return to_process, to_skip
359
+
360
+
361
+ # =============================================================================
362
+ # INITIALIZATION
363
+ # =============================================================================
364
+
365
+
366
+ def initialize_distiller_logging(level: int = logging.INFO) -> None:
367
+ """Initialize logging for distiller package."""
368
+ setup_logging(level)
369
+ logger.info("🚀 Distiller package initialized")
370
+
371
+
372
+ # Ensure logging is set up when module is imported
373
+ initialize_distiller_logging()