Sarthak
commited on
Commit
·
454e47c
1
Parent(s):
ea0b2a0
feat: overhaul distiller package with unified CLI, enhanced evaluation, and modular structure
Browse files- src/distiller/__init__.py +64 -4
- src/distiller/__main__.py +41 -167
- src/distiller/analyze.py +361 -138
- src/distiller/beam_utils.py +443 -21
- src/distiller/benchmark.py +0 -1181
- src/distiller/config.py +339 -0
- src/distiller/distill.py +1248 -988
- src/distiller/distill_simplified.py +0 -413
- src/distiller/evaluate.py +897 -593
- src/distiller/patch_utils.py +119 -0
- src/distiller/sync.py +0 -262
- src/distiller/utils.py +373 -0
src/distiller/__init__.py
CHANGED
@@ -1,7 +1,67 @@
|
|
1 |
-
"""
|
|
|
2 |
|
3 |
-
|
|
|
|
|
|
|
4 |
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Distiller package for code-specialized embedding model distillation and evaluation.
|
3 |
|
4 |
+
This package provides a complete pipeline for:
|
5 |
+
1. Distilling code-specialized embedding models using Model2Vec
|
6 |
+
2. Comprehensive evaluation including CodeSearchNet and performance benchmarks
|
7 |
+
3. Analysis and reporting of model performance
|
8 |
|
9 |
+
Main modules:
|
10 |
+
- distill: Model2Vec distillation with optional advanced training
|
11 |
+
- evaluate: Comprehensive evaluation (CodeSearchNet + performance benchmarks)
|
12 |
+
- analyze: Analysis and reporting tools
|
13 |
+
- config: Centralized configuration management
|
14 |
+
- beam_utils: Beam cloud utilities for distributed processing
|
15 |
|
16 |
+
Usage:
|
17 |
+
from distiller import distill, evaluate, analyze
|
18 |
+
"""
|
19 |
+
|
20 |
+
from . import analyze, config, distill, evaluate
|
21 |
+
from .analyze import CodeSearchNetAnalyzer
|
22 |
+
from .config import (
|
23 |
+
BEAM_ENV_SETTINGS,
|
24 |
+
DEFAULT_EVALUATION_MODELS,
|
25 |
+
GPU_NAME,
|
26 |
+
IMAGE,
|
27 |
+
codesearchnet_config,
|
28 |
+
directories,
|
29 |
+
distillation_config,
|
30 |
+
get_volume_config,
|
31 |
+
languages_config,
|
32 |
+
)
|
33 |
+
from .distill import (
|
34 |
+
run_beam_distillation,
|
35 |
+
run_local_distillation,
|
36 |
+
)
|
37 |
+
from .evaluate import (
|
38 |
+
CodeSearchNetEvaluator,
|
39 |
+
ComprehensiveModelEvaluator,
|
40 |
+
run_evaluation,
|
41 |
+
)
|
42 |
+
|
43 |
+
__all__ = [
|
44 |
+
# Configuration
|
45 |
+
"BEAM_ENV_SETTINGS",
|
46 |
+
"DEFAULT_EVALUATION_MODELS",
|
47 |
+
"GPU_NAME",
|
48 |
+
"IMAGE",
|
49 |
+
# Main classes
|
50 |
+
"CodeSearchNetAnalyzer",
|
51 |
+
"CodeSearchNetEvaluator",
|
52 |
+
"ComprehensiveModelEvaluator",
|
53 |
+
# Modules
|
54 |
+
"analyze",
|
55 |
+
"codesearchnet_config",
|
56 |
+
"config",
|
57 |
+
"directories",
|
58 |
+
"distill",
|
59 |
+
"distillation_config",
|
60 |
+
"evaluate",
|
61 |
+
"get_volume_config",
|
62 |
+
"languages_config",
|
63 |
+
# Main functions
|
64 |
+
"run_beam_distillation",
|
65 |
+
"run_evaluation",
|
66 |
+
"run_local_distillation",
|
67 |
+
]
|
src/distiller/__main__.py
CHANGED
@@ -1,183 +1,57 @@
|
|
1 |
"""Main entry point for the distiller package."""
|
2 |
|
3 |
-
import
|
4 |
-
import sys
|
5 |
|
|
|
6 |
|
7 |
-
|
8 |
-
"
|
9 |
-
|
10 |
-
|
|
|
11 |
|
12 |
-
# Distillation command
|
13 |
-
distill_parser = subparsers.add_parser("distill", help="Run code-specialized model distillation")
|
14 |
-
distill_parser.add_argument("--model", default="Alibaba-NLP/gte-Qwen2-7B-instruct", help="Model to distill")
|
15 |
-
distill_parser.add_argument("--output-dir", default="gte_qwen2_m2v_code", help="Output directory")
|
16 |
-
distill_parser.add_argument("--pca-dims", type=int, default=512, help="PCA dimensions")
|
17 |
-
distill_parser.add_argument("--max-samples", type=int, default=50000, help="Max CodeSearchNet samples")
|
18 |
-
distill_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud GPU distillation")
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
)
|
25 |
-
|
26 |
-
|
|
|
|
|
27 |
|
28 |
-
#
|
29 |
-
|
30 |
-
evaluate_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud evaluation")
|
31 |
|
32 |
-
# CodeSearchNet evaluation command (simplified models only)
|
33 |
-
evaluate_simple_parser = subparsers.add_parser(
|
34 |
-
"evaluate-simple", help="Run CodeSearchNet evaluation on simplified models only"
|
35 |
-
)
|
36 |
-
evaluate_simple_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud evaluation")
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
46 |
|
47 |
-
#
|
48 |
-
|
49 |
-
sync_parser.add_argument("--model-files", action="store_true", help="Download final model files")
|
50 |
-
sync_parser.add_argument(
|
51 |
-
"--analysis-files",
|
52 |
-
action="store_true",
|
53 |
-
help="Download analysis reports and charts",
|
54 |
-
)
|
55 |
-
sync_parser.add_argument("--all", action="store_true", help="Download all generated files")
|
56 |
-
sync_parser.add_argument("--output-dir", default=".", help="Local output directory")
|
57 |
|
58 |
-
# Benchmark command
|
59 |
-
benchmark_parser = subparsers.add_parser("benchmark", help="Run performance benchmarking on all default models")
|
60 |
-
benchmark_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud benchmarking")
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
)
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
if args.command == "distill":
|
71 |
-
from .distill_simplified import run_local_distillation, beam_distill_all_teachers
|
72 |
-
|
73 |
-
if args.use_beam:
|
74 |
-
# Run on Beam
|
75 |
-
print("Running comprehensive teacher model distillation on Beam...")
|
76 |
-
results = beam_distill_all_teachers()
|
77 |
-
else:
|
78 |
-
# Run locally
|
79 |
-
print("Running comprehensive teacher model distillation locally...")
|
80 |
-
results = run_local_distillation()
|
81 |
-
|
82 |
-
print(f"✅ Distillation complete! Created {results['total_successful']} models")
|
83 |
-
print("📁 Models location: ./code_model2vec/final/")
|
84 |
-
print("\n✅ Created models:")
|
85 |
-
for model_name in results["successful_models"]:
|
86 |
-
model_info = results["all_results"][model_name]
|
87 |
-
print(f" • {model_name} (from {model_info['teacher_model']})")
|
88 |
-
|
89 |
-
elif args.command == "distill-simple":
|
90 |
-
from .distill_simplified import run_local_distillation
|
91 |
-
|
92 |
-
# Run simplified distillation for all teacher models locally
|
93 |
-
print("Running comprehensive teacher model distillation locally...")
|
94 |
-
results = run_local_distillation()
|
95 |
-
print(f"✅ Distillation complete! Created {results['total_successful']} models")
|
96 |
-
print("📁 Models location: ./code_model2vec/final/")
|
97 |
-
print("\n✅ Created models:")
|
98 |
-
for model_name in results["successful_models"]:
|
99 |
-
model_info = results["all_results"][model_name]
|
100 |
-
print(f" • {model_name} (from {model_info['teacher_model']})")
|
101 |
-
|
102 |
-
elif args.command == "evaluate":
|
103 |
-
from .evaluate import main as evaluate_main, run_local_evaluation
|
104 |
-
|
105 |
-
if args.use_beam:
|
106 |
-
# Run on Beam with all default models
|
107 |
-
print("Running comprehensive evaluation on Beam...")
|
108 |
-
evaluate_main()
|
109 |
-
else:
|
110 |
-
# Run locally with all default models
|
111 |
-
print("Running comprehensive evaluation locally...")
|
112 |
-
run_local_evaluation()
|
113 |
-
|
114 |
-
elif args.command == "evaluate-simple":
|
115 |
-
from .evaluate import evaluate_simplified_only, run_local_evaluation_simplified
|
116 |
-
|
117 |
-
if args.use_beam:
|
118 |
-
# Run on Beam with simplified models only
|
119 |
-
print("Running simplified model evaluation on Beam...")
|
120 |
-
evaluate_simplified_only()
|
121 |
-
else:
|
122 |
-
# Run locally with simplified models only
|
123 |
-
print("Running simplified model evaluation locally...")
|
124 |
-
run_local_evaluation_simplified()
|
125 |
-
|
126 |
-
elif args.command == "analyze":
|
127 |
-
from .analyze import main as analyze_main
|
128 |
-
|
129 |
-
# Run locally - Override sys.argv to pass arguments to the analyze script
|
130 |
-
sys.argv = ["analyze.py"]
|
131 |
-
if args.results_dir != "code_evaluation_results":
|
132 |
-
sys.argv.extend(["--results-dir", args.results_dir])
|
133 |
-
if args.results_file:
|
134 |
-
sys.argv.extend(["--results-file", args.results_file])
|
135 |
-
if args.model_name != "gte_qwen2_m2v_code":
|
136 |
-
sys.argv.extend(["--model-name", args.model_name])
|
137 |
-
if args.output != "README.md":
|
138 |
-
sys.argv.extend(["--output", args.output])
|
139 |
-
if args.export_csv:
|
140 |
-
sys.argv.extend(["--export-csv", args.export_csv])
|
141 |
-
analyze_main()
|
142 |
-
|
143 |
-
elif args.command == "sync":
|
144 |
-
from .sync import sync_files
|
145 |
-
|
146 |
-
# Run locally
|
147 |
-
sync_files(
|
148 |
-
model_files=args.model_files,
|
149 |
-
analysis_files=args.analysis_files,
|
150 |
-
all_files=args.all,
|
151 |
-
output_dir=args.output_dir,
|
152 |
-
)
|
153 |
-
|
154 |
-
elif args.command == "benchmark":
|
155 |
-
from .benchmark import main as benchmark_main, run_local_benchmark
|
156 |
-
|
157 |
-
if args.use_beam:
|
158 |
-
# Run on Beam with all default models
|
159 |
-
print("Running comprehensive benchmarking on Beam...")
|
160 |
-
benchmark_main()
|
161 |
-
else:
|
162 |
-
# Run locally with all default models
|
163 |
-
print("Running comprehensive benchmarking locally...")
|
164 |
-
run_local_benchmark()
|
165 |
-
|
166 |
-
elif args.command == "benchmark-simple":
|
167 |
-
from .benchmark import benchmark_simplified_only, run_local_benchmark_simplified
|
168 |
-
|
169 |
-
if args.use_beam:
|
170 |
-
# Run on Beam with simplified models only
|
171 |
-
print("Running simplified model benchmarking on Beam...")
|
172 |
-
benchmark_simplified_only()
|
173 |
-
else:
|
174 |
-
# Run locally with simplified models only
|
175 |
-
print("Running simplified model benchmarking locally...")
|
176 |
-
run_local_benchmark_simplified()
|
177 |
-
|
178 |
-
else:
|
179 |
-
parser.print_help()
|
180 |
|
181 |
|
182 |
if __name__ == "__main__":
|
183 |
-
|
|
|
1 |
"""Main entry point for the distiller package."""
|
2 |
|
3 |
+
from typing import Annotated
|
|
|
4 |
|
5 |
+
import typer
|
6 |
|
7 |
+
app = typer.Typer(
|
8 |
+
help="Model2Vec Code-Specialized Distillation Pipeline",
|
9 |
+
no_args_is_help=True,
|
10 |
+
context_settings={"help_option_names": ["-h", "--help"]},
|
11 |
+
)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
@app.command()
|
15 |
+
def distill(
|
16 |
+
use_beam: Annotated[bool, typer.Option(help="Use Beam for distillation")] = False,
|
17 |
+
train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False,
|
18 |
+
teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None,
|
19 |
+
pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None,
|
20 |
+
) -> None:
|
21 |
+
"""Run unified Model2Vec distillation with optional training."""
|
22 |
+
from .distill import main as distill_main
|
23 |
|
24 |
+
# Call the distill main function with arguments
|
25 |
+
distill_main(use_beam, train, teacher_models, pca_dims)
|
|
|
26 |
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
@app.command()
|
29 |
+
def evaluate(
|
30 |
+
use_beam: Annotated[bool, typer.Option(help="Use Beam for evaluation")] = False,
|
31 |
+
skip_third_party: Annotated[bool, typer.Option(help="Skip third-party models")] = False,
|
32 |
+
skip_benchmark: Annotated[bool, typer.Option(help="Skip performance benchmarking")] = False,
|
33 |
+
max_queries: Annotated[int, typer.Option(help="Maximum queries per language")] = 1000,
|
34 |
+
) -> None:
|
35 |
+
"""Run CodeSearchNet evaluation on models."""
|
36 |
+
from .evaluate import main as evaluate_main
|
37 |
|
38 |
+
# Call the evaluate main function with arguments
|
39 |
+
evaluate_main(use_beam, skip_third_party, skip_benchmark, max_queries)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
|
|
|
|
|
|
41 |
|
42 |
+
@app.command()
|
43 |
+
def analyze(
|
44 |
+
results_dir: Annotated[str | None, typer.Option(help="Results directory")] = None,
|
45 |
+
model_name: Annotated[str, typer.Option(help="Model name for analysis")] = "gte_qwen2_m2v_code (Ours)",
|
46 |
+
output: Annotated[str, typer.Option(help="Output report file")] = "REPORT.md",
|
47 |
+
export_csv: Annotated[str | None, typer.Option(help="Export results to CSV")] = None,
|
48 |
+
) -> None:
|
49 |
+
"""Generate comprehensive analysis reports."""
|
50 |
+
from .analyze import main as analyze_main
|
51 |
|
52 |
+
# Call the analyze main function with arguments
|
53 |
+
analyze_main(results_dir or "code_model2vec/evaluation_results", model_name, output, export_csv)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
|
56 |
if __name__ == "__main__":
|
57 |
+
app()
|
src/distiller/analyze.py
CHANGED
@@ -23,7 +23,6 @@ Usage:
|
|
23 |
distiller analyze --results-dir evaluation_results
|
24 |
"""
|
25 |
|
26 |
-
import argparse
|
27 |
import json
|
28 |
import logging
|
29 |
import time
|
@@ -35,6 +34,8 @@ import numpy as np
|
|
35 |
import pandas as pd
|
36 |
import seaborn as sns
|
37 |
|
|
|
|
|
38 |
# Optional Plotly import with fallback
|
39 |
PLOTLY_AVAILABLE = True
|
40 |
try:
|
@@ -65,48 +66,140 @@ OUTPUT_DIR = Path("analysis_results")
|
|
65 |
IMAGES_DIR = Path("analysis_charts")
|
66 |
REPORT_FILE = Path("REPORT.md") # Changed from README.md
|
67 |
|
68 |
-
# Local directories for results -
|
69 |
-
DEFAULT_EVALUATION_DIR =
|
70 |
-
DEFAULT_BENCHMARK_DIR =
|
71 |
|
72 |
# CodeSearchNet Languages
|
73 |
CODE_LANGUAGES = ["python", "javascript", "java", "php", "ruby", "go"]
|
74 |
|
75 |
# Model name mapping from the default models in evaluate.py and benchmark.py
|
76 |
MODEL_NAME_MAPPING = {
|
77 |
-
# File names to display names
|
78 |
-
"
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
"
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
"
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
}
|
89 |
|
90 |
-
# Reverse mapping for lookups
|
91 |
-
DISPLAY_NAME_TO_FILE = {v: k for k, v in MODEL_NAME_MAPPING.items()}
|
92 |
|
93 |
# Peer models for comparison (code-specialized models)
|
94 |
PEER_MODELS = {
|
95 |
-
"sentence-transformers/all-MiniLM-L6-v2": {
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
}
|
101 |
|
102 |
# Model specifications for efficiency analysis
|
103 |
MODEL_SPECS = {
|
104 |
-
"sentence-transformers/all-MiniLM-L6-v2": {
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
}
|
111 |
|
112 |
# Distilled model specifications
|
@@ -134,13 +227,12 @@ def setup_directories(base_path: Path | None = None) -> tuple[Path, Path, Path]:
|
|
134 |
images_dir = base_path / "analysis_results" / "charts"
|
135 |
reports_dir = base_path / "analysis_results" / "reports"
|
136 |
else:
|
137 |
-
output_dir =
|
138 |
-
images_dir = IMAGES_DIR
|
139 |
-
reports_dir =
|
140 |
|
141 |
-
|
142 |
images_dir.mkdir(parents=True, exist_ok=True)
|
143 |
-
reports_dir.mkdir(parents=True, exist_ok=True)
|
144 |
|
145 |
return output_dir, images_dir, reports_dir
|
146 |
|
@@ -152,17 +244,94 @@ def extract_model_name_from_filename(filename: str) -> str:
|
|
152 |
|
153 |
# Check if it's in our mapping
|
154 |
if name in MODEL_NAME_MAPPING:
|
155 |
-
return MODEL_NAME_MAPPING[name]
|
156 |
|
157 |
# Try to find partial matches
|
158 |
-
for file_key,
|
159 |
if file_key in name or name in file_key:
|
160 |
-
return
|
161 |
|
162 |
# If no mapping found, return the cleaned name
|
163 |
return name
|
164 |
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
class CodeSearchNetAnalyzer:
|
167 |
"""Analyzer for CodeSearchNet evaluation results and performance benchmarks."""
|
168 |
|
@@ -182,33 +351,43 @@ class CodeSearchNetAnalyzer:
|
|
182 |
self.benchmark_df: pd.DataFrame | None = None
|
183 |
|
184 |
def load_benchmark_results(self) -> None:
|
185 |
-
"""Load benchmark results from
|
186 |
-
logger.info("📊 Loading benchmark results...")
|
187 |
|
188 |
-
if not self.
|
189 |
-
logger.warning(f"
|
190 |
return
|
191 |
|
192 |
-
logger.info(f"🔍 Searching for
|
193 |
-
benchmark_files = list(self.benchmark_dir.glob("benchmark_*.json"))
|
194 |
-
logger.info(f"📁 Found {len(benchmark_files)} benchmark files")
|
195 |
|
196 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
try:
|
198 |
-
logger.info(f"📖 Loading: {
|
199 |
-
with
|
200 |
data = json.load(f)
|
|
|
201 |
if data is not None:
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
data
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
210 |
except (json.JSONDecodeError, KeyError) as e:
|
211 |
-
logger.warning(f"❌ Failed to load {
|
212 |
|
213 |
logger.info(f"📊 Total benchmark results loaded: {len(self.benchmark_results)}")
|
214 |
if self.benchmark_results:
|
@@ -217,6 +396,34 @@ class CodeSearchNetAnalyzer:
|
|
217 |
|
218 |
self._create_benchmark_dataframe()
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
def _create_benchmark_dataframe(self) -> None:
|
221 |
"""Create benchmark comparison DataFrame from results."""
|
222 |
if not self.benchmark_results:
|
@@ -263,10 +470,10 @@ class CodeSearchNetAnalyzer:
|
|
263 |
)
|
264 |
|
265 |
# CPU vs GPU comparison
|
266 |
-
for device in
|
267 |
-
if
|
268 |
device_key = f"{device.upper()}_TextsPerSec"
|
269 |
-
row[device_key] =
|
270 |
|
271 |
benchmark_data.append(row)
|
272 |
|
@@ -281,23 +488,31 @@ class CodeSearchNetAnalyzer:
|
|
281 |
return
|
282 |
|
283 |
logger.info(f"🔍 Searching for evaluation files in: {self.results_dir}")
|
284 |
-
json_files = list(self.results_dir.glob("codesearchnet_eval_*.json"))
|
285 |
-
logger.info(f"📁 Found {len(json_files)} evaluation files")
|
286 |
|
287 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
try:
|
289 |
logger.info(f"📖 Loading: {json_file.name}")
|
290 |
with json_file.open() as f:
|
291 |
data = json.load(f)
|
292 |
if data is not None:
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
data
|
298 |
-
|
299 |
-
self.results.append(
|
300 |
-
logger.info(f"✅ Successfully loaded: {
|
|
|
301 |
except (json.JSONDecodeError, KeyError) as e:
|
302 |
logger.warning(f"❌ Failed to load {json_file}: {e}")
|
303 |
|
@@ -311,6 +526,32 @@ class CodeSearchNetAnalyzer:
|
|
311 |
# Also load benchmark results
|
312 |
self.load_benchmark_results()
|
313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
def _create_comparison_dataframe(self) -> None:
|
315 |
"""Create comparison DataFrame from results."""
|
316 |
if not self.results:
|
@@ -453,7 +694,7 @@ class CodeSearchNetAnalyzer:
|
|
453 |
if cpu_vs_gpu:
|
454 |
print("🖥️ CPU vs GPU:")
|
455 |
for device, metrics in cpu_vs_gpu.items():
|
456 |
-
if "error" not in metrics:
|
457 |
print(f" {device.upper()}: {metrics.get('texts_per_second', 0):.1f} texts/sec")
|
458 |
|
459 |
# Memory efficiency
|
@@ -978,7 +1219,7 @@ class CodeSearchNetAnalyzer:
|
|
978 |
# Safe conversion to float for pandas values
|
979 |
score_value = pd.to_numeric(current_model_score, errors="coerce")
|
980 |
scores.append(float(score_value) if not pd.isna(score_value) else 0.0)
|
981 |
-
params.append(float(MODEL_SPECS[model_key].get("parameters", 100)))
|
982 |
is_user_model.append(False)
|
983 |
|
984 |
if not models:
|
@@ -1098,7 +1339,7 @@ class CodeSearchNetAnalyzer:
|
|
1098 |
|
1099 |
# Create visualizations
|
1100 |
logger.info("Generating visualizations...")
|
1101 |
-
setup_directories()
|
1102 |
|
1103 |
self.create_performance_radar_chart(main_model_name, language_scores)
|
1104 |
comparison_chart = self.plot_model_comparison()
|
@@ -1163,21 +1404,14 @@ This report presents a comprehensive analysis of Model2Vec distillation experime
|
|
1163 |
overall_metrics = result.get("overall", {})
|
1164 |
|
1165 |
# Extract teacher model name from model name
|
1166 |
-
|
1167 |
-
if "all_MiniLM_L6_v2" in model_display:
|
1168 |
-
teacher = "all-MiniLM-L6-v2"
|
1169 |
-
elif "codebert_base" in model_display:
|
1170 |
-
teacher = "codebert-base"
|
1171 |
-
elif "graphcodebert_base" in model_display:
|
1172 |
-
teacher = "graphcodebert-base"
|
1173 |
-
elif "gte_Qwen2_7B_instruct" in model_display:
|
1174 |
-
teacher = "gte-Qwen2-7B-instruct"
|
1175 |
-
elif "all_mpnet_base_v2" in model_display:
|
1176 |
-
teacher = "all-mpnet-base-v2"
|
1177 |
|
1178 |
status = "🥇 Best" if rank == 1 else "🥈 2nd" if rank == 2 else "🥉 3rd" if rank == 3 else f"#{rank}"
|
1179 |
|
1180 |
-
|
|
|
|
|
|
|
1181 |
|
1182 |
report += """
|
1183 |
|
@@ -1215,19 +1449,12 @@ This report presents a comprehensive analysis of Model2Vec distillation experime
|
|
1215 |
report += "### Individual Model Performance by Language\n\n"
|
1216 |
for chart_model_name, chart_path in individual_radar_charts.items():
|
1217 |
# Extract teacher name for cleaner display
|
1218 |
-
|
1219 |
-
|
1220 |
-
|
1221 |
-
|
1222 |
-
|
1223 |
-
|
1224 |
-
teacher = "graphcodebert-base"
|
1225 |
-
elif "gte_Qwen2_7B_instruct" in chart_model_name:
|
1226 |
-
teacher = "gte-Qwen2-7B-instruct"
|
1227 |
-
elif "all_mpnet_base_v2" in chart_model_name:
|
1228 |
-
teacher = "all-mpnet-base-v2"
|
1229 |
-
|
1230 |
-
report += f"#### {chart_model_name} (Teacher: {teacher})\n\n"
|
1231 |
report += f"\n\n"
|
1232 |
|
1233 |
report += f"""
|
@@ -1324,7 +1551,7 @@ This report presents a comprehensive analysis of Model2Vec distillation experime
|
|
1324 |
|
1325 |
if language_scores:
|
1326 |
report += "| Language | Best Model Performance | Average Performance | Language Difficulty |\n"
|
1327 |
-
report += "
|
1328 |
|
1329 |
for lang in sorted(language_scores.keys()):
|
1330 |
# Find best performance for this language across all models
|
@@ -1358,16 +1585,8 @@ Based on the evaluation results across all simplified distillation models:
|
|
1358 |
model_name = result["model_name"]
|
1359 |
score = result.get("overall", {}).get("ndcg@10", 0)
|
1360 |
|
1361 |
-
|
1362 |
-
|
1363 |
-
elif "codebert_base" in model_name:
|
1364 |
-
teacher_performance["codebert-base"] = score
|
1365 |
-
elif "graphcodebert_base" in model_name:
|
1366 |
-
teacher_performance["graphcodebert-base"] = score
|
1367 |
-
elif "gte_Qwen2_7B_instruct" in model_name:
|
1368 |
-
teacher_performance["gte-Qwen2-7B-instruct"] = score
|
1369 |
-
elif "all_mpnet_base_v2" in model_name:
|
1370 |
-
teacher_performance["all-mpnet-base-v2"] = score
|
1371 |
|
1372 |
if teacher_performance:
|
1373 |
best_teacher = max(teacher_performance.items(), key=lambda x: x[1])
|
@@ -1397,11 +1616,20 @@ Based on the evaluation results across all simplified distillation models:
|
|
1397 |
- **Evaluation**: Retrieval of correct code for each documentation query
|
1398 |
|
1399 |
### Teacher Models Tested
|
1400 |
-
- sentence-transformers/all-MiniLM-L6-v2 (proven baseline)
|
1401 |
-
-
|
1402 |
-
-
|
1403 |
-
-
|
1404 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1405 |
|
1406 |
### Distillation Method
|
1407 |
- **Technique**: Model2Vec static embedding generation
|
@@ -1424,31 +1652,27 @@ Based on the evaluation results across all simplified distillation models:
|
|
1424 |
logger.info(f"Results exported to {output_file}")
|
1425 |
|
1426 |
|
1427 |
-
def main(
|
|
|
|
|
|
|
|
|
|
|
1428 |
"""Main analysis function."""
|
1429 |
-
|
1430 |
-
parser.add_argument("--results-dir", default=DEFAULT_EVALUATION_DIR, help="Evaluation results directory")
|
1431 |
-
parser.add_argument("--benchmark-dir", default=DEFAULT_BENCHMARK_DIR, help="Benchmark results directory")
|
1432 |
-
parser.add_argument("--model-name", default="gte_qwen2_m2v_code (Ours)", help="Model name for report")
|
1433 |
-
parser.add_argument("--output", default="REPORT.md", help="Output report file")
|
1434 |
-
parser.add_argument("--export-csv", help="Export comparison results to CSV")
|
1435 |
-
|
1436 |
-
args = parser.parse_args()
|
1437 |
-
|
1438 |
-
logger.info("Starting CodeSearchNet Analysis with Benchmark Integration")
|
1439 |
logger.info("=" * 60)
|
1440 |
|
1441 |
# Setup output directories
|
1442 |
output_dir, images_dir, reports_dir = setup_directories()
|
1443 |
|
1444 |
-
# Initialize analyzer with
|
1445 |
analyzer = CodeSearchNetAnalyzer(
|
1446 |
-
results_dir=
|
1447 |
-
benchmark_dir=
|
1448 |
images_dir=images_dir,
|
1449 |
)
|
1450 |
|
1451 |
-
# Load results (this will also load benchmark
|
1452 |
analyzer.load_results()
|
1453 |
|
1454 |
if not analyzer.results:
|
@@ -1463,33 +1687,32 @@ def main() -> None:
|
|
1463 |
if analyzer.benchmark_results:
|
1464 |
analyzer.analyze_benchmark_performance()
|
1465 |
else:
|
1466 |
-
logger.warning("No benchmark results found.
|
1467 |
|
1468 |
# Generate comprehensive report with benchmark integration
|
1469 |
-
logger.info("Generating comprehensive report with benchmark data...")
|
1470 |
-
report = analyzer.generate_comprehensive_report(
|
1471 |
|
1472 |
# Save report
|
1473 |
-
report_path = Path(
|
1474 |
with report_path.open("w") as f:
|
1475 |
f.write(report)
|
1476 |
|
1477 |
# Export CSV if requested
|
1478 |
-
if
|
1479 |
-
analyzer.export_results(
|
1480 |
|
1481 |
# Export benchmark CSV if available
|
1482 |
if analyzer.benchmark_df is not None and not analyzer.benchmark_df.empty:
|
1483 |
-
benchmark_csv = report_path.parent / f"{
|
1484 |
analyzer.benchmark_df.to_csv(benchmark_csv, index=False)
|
1485 |
logger.info(f"📊 Benchmark comparison saved to: {benchmark_csv}")
|
1486 |
|
1487 |
-
logger.info("✅ CodeSearchNet analysis with benchmarks complete!")
|
1488 |
logger.info(f"📊 Report saved to: {report_path}")
|
1489 |
logger.info(f"🖼️ Charts saved to: {images_dir}")
|
|
|
1490 |
|
1491 |
|
1492 |
if __name__ == "__main__":
|
1493 |
-
import argparse
|
1494 |
-
|
1495 |
main()
|
|
|
23 |
distiller analyze --results-dir evaluation_results
|
24 |
"""
|
25 |
|
|
|
26 |
import json
|
27 |
import logging
|
28 |
import time
|
|
|
34 |
import pandas as pd
|
35 |
import seaborn as sns
|
36 |
|
37 |
+
from .config import directories
|
38 |
+
|
39 |
# Optional Plotly import with fallback
|
40 |
PLOTLY_AVAILABLE = True
|
41 |
try:
|
|
|
66 |
IMAGES_DIR = Path("analysis_charts")
|
67 |
REPORT_FILE = Path("REPORT.md") # Changed from README.md
|
68 |
|
69 |
+
# Local directories for results - using standardized directories from config
|
70 |
+
DEFAULT_EVALUATION_DIR = directories.evaluation_results
|
71 |
+
DEFAULT_BENCHMARK_DIR = directories.benchmark_results
|
72 |
|
73 |
# CodeSearchNet Languages
|
74 |
CODE_LANGUAGES = ["python", "javascript", "java", "php", "ruby", "go"]
|
75 |
|
76 |
# Model name mapping from the default models in evaluate.py and benchmark.py
|
77 |
MODEL_NAME_MAPPING = {
|
78 |
+
# File names to display names and HuggingFace links
|
79 |
+
"all-MiniLM-L6-v2": {
|
80 |
+
"name": "sentence-transformers/all-MiniLM-L6-v2",
|
81 |
+
"link": "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
|
82 |
+
},
|
83 |
+
"all-mpnet-base-v2": {
|
84 |
+
"name": "sentence-transformers/all-mpnet-base-v2",
|
85 |
+
"link": "https://huggingface.co/sentence-transformers/all-mpnet-base-v2",
|
86 |
+
},
|
87 |
+
"paraphrase-MiniLM-L6-v2": {
|
88 |
+
"name": "sentence-transformers/paraphrase-MiniLM-L6-v2",
|
89 |
+
"link": "https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2",
|
90 |
+
},
|
91 |
+
"codebert-base": {"name": "microsoft/codebert-base", "link": "https://huggingface.co/microsoft/codebert-base"},
|
92 |
+
"graphcodebert-base": {
|
93 |
+
"name": "microsoft/graphcodebert-base",
|
94 |
+
"link": "https://huggingface.co/microsoft/graphcodebert-base",
|
95 |
+
},
|
96 |
+
"CodeBERTa-small-v1": {
|
97 |
+
"name": "huggingface/CodeBERTa-small-v1",
|
98 |
+
"link": "https://huggingface.co/huggingface/CodeBERTa-small-v1",
|
99 |
+
},
|
100 |
+
"all-MiniLM-L12-v2": {
|
101 |
+
"name": "sentence-transformers/all-MiniLM-L12-v2",
|
102 |
+
"link": "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2",
|
103 |
+
},
|
104 |
+
"potion-base-8M": {"name": "minishlab/potion-base-8M", "link": "https://huggingface.co/minishlab/potion-base-8M"},
|
105 |
+
"potion-retrieval-32M": {
|
106 |
+
"name": "minishlab/potion-retrieval-32M",
|
107 |
+
"link": "https://huggingface.co/minishlab/potion-retrieval-32M",
|
108 |
+
},
|
109 |
+
"codet5-base": {"name": "Salesforce/codet5-base", "link": "https://huggingface.co/Salesforce/codet5-base"},
|
110 |
+
"gte-Qwen2-1.5B-instruct": {
|
111 |
+
"name": "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
112 |
+
"link": "https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
113 |
+
},
|
114 |
+
"bge-m3": {"name": "BAAI/bge-m3", "link": "https://huggingface.co/BAAI/bge-m3"},
|
115 |
+
"jina-embeddings-v3": {
|
116 |
+
"name": "jinaai/jina-embeddings-v3",
|
117 |
+
"link": "https://huggingface.co/jinaai/jina-embeddings-v3",
|
118 |
+
},
|
119 |
+
"nomic-embed-text-v2-moe": {
|
120 |
+
"name": "nomic-ai/nomic-embed-text-v2-moe",
|
121 |
+
"link": "https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe",
|
122 |
+
},
|
123 |
+
"Qodo-Embed-1-1.5B": {"name": "Qodo/Qodo-Embed-1-1.5B", "link": "https://huggingface.co/Qodo/Qodo-Embed-1-1.5B"},
|
124 |
+
"Reason-ModernColBERT": {
|
125 |
+
"name": "lightonai/Reason-ModernColBERT",
|
126 |
+
"link": "https://huggingface.co/lightonai/Reason-ModernColBERT",
|
127 |
+
},
|
128 |
+
"Linq-Embed-Mistral": {
|
129 |
+
"name": "Linq-AI-Research/Linq-Embed-Mistral",
|
130 |
+
"link": "https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral",
|
131 |
+
},
|
132 |
+
"bge-code-v1": {"name": "BAAI/bge-code-v1", "link": "https://huggingface.co/BAAI/bge-code-v1"},
|
133 |
+
"SFR-Embedding-Code-2B_R": {
|
134 |
+
"name": "Salesforce/SFR-Embedding-Code-2B_R",
|
135 |
+
"link": "https://huggingface.co/Salesforce/SFR-Embedding-Code-2B_R",
|
136 |
+
},
|
137 |
}
|
138 |
|
139 |
+
# Reverse mapping for lookups - using just the names
|
140 |
+
DISPLAY_NAME_TO_FILE = {v["name"]: k for k, v in MODEL_NAME_MAPPING.items()}
|
141 |
|
142 |
# Peer models for comparison (code-specialized models)
|
143 |
PEER_MODELS = {
|
144 |
+
"sentence-transformers/all-MiniLM-L6-v2": {
|
145 |
+
"overall_ndcg": 0.25,
|
146 |
+
"type": "General",
|
147 |
+
"link": "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
|
148 |
+
},
|
149 |
+
"microsoft/codebert-base": {
|
150 |
+
"overall_ndcg": 0.32,
|
151 |
+
"type": "Code-Specific",
|
152 |
+
"link": "https://huggingface.co/microsoft/codebert-base",
|
153 |
+
},
|
154 |
+
"microsoft/graphcodebert-base": {
|
155 |
+
"overall_ndcg": 0.35,
|
156 |
+
"type": "Code-Specific",
|
157 |
+
"link": "https://huggingface.co/microsoft/graphcodebert-base",
|
158 |
+
},
|
159 |
+
"huggingface/CodeBERTa-small-v1": {
|
160 |
+
"overall_ndcg": 0.28,
|
161 |
+
"type": "Code-Specific",
|
162 |
+
"link": "https://huggingface.co/huggingface/CodeBERTa-small-v1",
|
163 |
+
},
|
164 |
+
"sentence-transformers/all-mpnet-base-v2": {
|
165 |
+
"overall_ndcg": 0.27,
|
166 |
+
"type": "General",
|
167 |
+
"link": "https://huggingface.co/sentence-transformers/all-mpnet-base-v2",
|
168 |
+
},
|
169 |
}
|
170 |
|
171 |
# Model specifications for efficiency analysis
|
172 |
MODEL_SPECS = {
|
173 |
+
"sentence-transformers/all-MiniLM-L6-v2": {
|
174 |
+
"parameters": 22.7,
|
175 |
+
"size_mb": 90,
|
176 |
+
"link": "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
|
177 |
+
},
|
178 |
+
"microsoft/codebert-base": {
|
179 |
+
"parameters": 125.0,
|
180 |
+
"size_mb": 500,
|
181 |
+
"link": "https://huggingface.co/microsoft/codebert-base",
|
182 |
+
},
|
183 |
+
"microsoft/graphcodebert-base": {
|
184 |
+
"parameters": 125.0,
|
185 |
+
"size_mb": 500,
|
186 |
+
"link": "https://huggingface.co/microsoft/graphcodebert-base",
|
187 |
+
},
|
188 |
+
"huggingface/CodeBERTa-small-v1": {
|
189 |
+
"parameters": 84.0,
|
190 |
+
"size_mb": 340,
|
191 |
+
"link": "https://huggingface.co/huggingface/CodeBERTa-small-v1",
|
192 |
+
},
|
193 |
+
"sentence-transformers/all-mpnet-base-v2": {
|
194 |
+
"parameters": 109.0,
|
195 |
+
"size_mb": 440,
|
196 |
+
"link": "https://huggingface.co/sentence-transformers/all-mpnet-base-v2",
|
197 |
+
},
|
198 |
+
"Alibaba-NLP/gte-Qwen2-1.5B-instruct": {
|
199 |
+
"parameters": 1500.0,
|
200 |
+
"size_mb": 3000,
|
201 |
+
"link": "https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
202 |
+
},
|
203 |
}
|
204 |
|
205 |
# Distilled model specifications
|
|
|
227 |
images_dir = base_path / "analysis_results" / "charts"
|
228 |
reports_dir = base_path / "analysis_results" / "reports"
|
229 |
else:
|
230 |
+
output_dir = Path() # Use current directory
|
231 |
+
images_dir = IMAGES_DIR # Use analysis_charts
|
232 |
+
reports_dir = Path() # Use current directory for reports
|
233 |
|
234 |
+
# Only create directories that we actually use
|
235 |
images_dir.mkdir(parents=True, exist_ok=True)
|
|
|
236 |
|
237 |
return output_dir, images_dir, reports_dir
|
238 |
|
|
|
244 |
|
245 |
# Check if it's in our mapping
|
246 |
if name in MODEL_NAME_MAPPING:
|
247 |
+
return MODEL_NAME_MAPPING[name]["name"]
|
248 |
|
249 |
# Try to find partial matches
|
250 |
+
for file_key, model_info in MODEL_NAME_MAPPING.items():
|
251 |
if file_key in name or name in file_key:
|
252 |
+
return model_info["name"]
|
253 |
|
254 |
# If no mapping found, return the cleaned name
|
255 |
return name
|
256 |
|
257 |
|
258 |
+
def get_model_link(model_name: str) -> str:
|
259 |
+
"""Get HuggingFace link for a model."""
|
260 |
+
# First try direct lookup by file key
|
261 |
+
for model_info in MODEL_NAME_MAPPING.values():
|
262 |
+
if model_info["name"] == model_name:
|
263 |
+
return model_info["link"]
|
264 |
+
|
265 |
+
# Try partial matches
|
266 |
+
for model_info in MODEL_NAME_MAPPING.values():
|
267 |
+
if model_name.lower() in model_info["name"].lower() or model_info["name"].lower() in model_name.lower():
|
268 |
+
return model_info["link"]
|
269 |
+
|
270 |
+
# If no mapping found, construct link from model name
|
271 |
+
if "/" in model_name:
|
272 |
+
return f"https://huggingface.co/{model_name}"
|
273 |
+
return ""
|
274 |
+
|
275 |
+
|
276 |
+
def format_model_with_link(model_name: str) -> str:
|
277 |
+
"""Format model name with markdown link."""
|
278 |
+
link = get_model_link(model_name)
|
279 |
+
if link:
|
280 |
+
return f"[{model_name}]({link})"
|
281 |
+
return model_name
|
282 |
+
|
283 |
+
|
284 |
+
def get_teacher_model_info(model_display_name: str) -> tuple[str, str]:
|
285 |
+
"""Extract teacher model name and link from distilled model display name."""
|
286 |
+
# Mapping from model display patterns to teacher models
|
287 |
+
teacher_mapping = {
|
288 |
+
"all_MiniLM_L6_v2": (
|
289 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
290 |
+
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
|
291 |
+
),
|
292 |
+
"all_mpnet_base_v2": (
|
293 |
+
"sentence-transformers/all-mpnet-base-v2",
|
294 |
+
"https://huggingface.co/sentence-transformers/all-mpnet-base-v2",
|
295 |
+
),
|
296 |
+
"paraphrase_MiniLM_L6_v2": (
|
297 |
+
"sentence-transformers/paraphrase-MiniLM-L6-v2",
|
298 |
+
"https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2",
|
299 |
+
),
|
300 |
+
"codebert_base": ("microsoft/codebert-base", "https://huggingface.co/microsoft/codebert-base"),
|
301 |
+
"graphcodebert_base": ("microsoft/graphcodebert-base", "https://huggingface.co/microsoft/graphcodebert-base"),
|
302 |
+
"gte_Qwen2_1.5B_instruct": (
|
303 |
+
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
304 |
+
"https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
305 |
+
),
|
306 |
+
"bge_m3": ("BAAI/bge-m3", "https://huggingface.co/BAAI/bge-m3"),
|
307 |
+
"jina_embeddings_v3": ("jinaai/jina-embeddings-v3", "https://huggingface.co/jinaai/jina-embeddings-v3"),
|
308 |
+
"nomic_embed_text_v2_moe": (
|
309 |
+
"nomic-ai/nomic-embed-text-v2-moe",
|
310 |
+
"https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe",
|
311 |
+
),
|
312 |
+
"Qodo_Embed_1_1.5B": ("Qodo/Qodo-Embed-1-1.5B", "https://huggingface.co/Qodo/Qodo-Embed-1-1.5B"),
|
313 |
+
"Reason_ModernColBERT": (
|
314 |
+
"lightonai/Reason-ModernColBERT",
|
315 |
+
"https://huggingface.co/lightonai/Reason-ModernColBERT",
|
316 |
+
),
|
317 |
+
"Linq_Embed_Mistral": (
|
318 |
+
"Linq-AI-Research/Linq-Embed-Mistral",
|
319 |
+
"https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral",
|
320 |
+
),
|
321 |
+
"bge_code_v1": ("BAAI/bge-code-v1", "https://huggingface.co/BAAI/bge-code-v1"),
|
322 |
+
"SFR_Embedding_Code_2B_R": (
|
323 |
+
"Salesforce/SFR-Embedding-Code-2B_R",
|
324 |
+
"https://huggingface.co/Salesforce/SFR-Embedding-Code-2B_R",
|
325 |
+
),
|
326 |
+
}
|
327 |
+
|
328 |
+
for pattern, (teacher_name, teacher_link) in teacher_mapping.items():
|
329 |
+
if pattern in model_display_name:
|
330 |
+
return teacher_name, teacher_link
|
331 |
+
|
332 |
+
return "Unknown", ""
|
333 |
+
|
334 |
+
|
335 |
class CodeSearchNetAnalyzer:
|
336 |
"""Analyzer for CodeSearchNet evaluation results and performance benchmarks."""
|
337 |
|
|
|
351 |
self.benchmark_df: pd.DataFrame | None = None
|
352 |
|
353 |
def load_benchmark_results(self) -> None:
|
354 |
+
"""Load benchmark results from comprehensive evaluation files."""
|
355 |
+
logger.info("📊 Loading benchmark results from comprehensive evaluations...")
|
356 |
|
357 |
+
if not self.results_dir.exists():
|
358 |
+
logger.warning(f"Evaluation directory not found: {self.results_dir}")
|
359 |
return
|
360 |
|
361 |
+
logger.info(f"🔍 Searching for comprehensive evaluation files in: {self.results_dir}")
|
|
|
|
|
362 |
|
363 |
+
# Look for both new comprehensive format and legacy formats
|
364 |
+
comprehensive_files = list(self.results_dir.glob("comprehensive_eval_*.json"))
|
365 |
+
legacy_files = list(self.results_dir.glob("codesearchnet_eval_*.json"))
|
366 |
+
|
367 |
+
all_files = comprehensive_files + legacy_files
|
368 |
+
logger.info(
|
369 |
+
f"📁 Found {len(all_files)} evaluation files ({len(comprehensive_files)} comprehensive, {len(legacy_files)} legacy)"
|
370 |
+
)
|
371 |
+
|
372 |
+
for eval_file_path in all_files:
|
373 |
try:
|
374 |
+
logger.info(f"📖 Loading: {eval_file_path.name}")
|
375 |
+
with eval_file_path.open() as f:
|
376 |
data = json.load(f)
|
377 |
+
|
378 |
if data is not None:
|
379 |
+
if not isinstance(data, dict):
|
380 |
+
logger.warning(f"⚠️ Skipping {eval_file_path.name} (not a dict)")
|
381 |
+
continue
|
382 |
+
|
383 |
+
# Extract benchmark data if available
|
384 |
+
benchmark_data = self._extract_benchmark_data(data, eval_file_path)
|
385 |
+
if benchmark_data:
|
386 |
+
self.benchmark_results.append(benchmark_data)
|
387 |
+
logger.info(f"✅ Successfully loaded benchmark data: {benchmark_data['model_name']}")
|
388 |
+
|
389 |
except (json.JSONDecodeError, KeyError) as e:
|
390 |
+
logger.warning(f"❌ Failed to load {eval_file_path}: {e}")
|
391 |
|
392 |
logger.info(f"📊 Total benchmark results loaded: {len(self.benchmark_results)}")
|
393 |
if self.benchmark_results:
|
|
|
396 |
|
397 |
self._create_benchmark_dataframe()
|
398 |
|
399 |
+
def _extract_benchmark_data(self, data: dict, file_path: Path) -> dict[str, Any] | None:
|
400 |
+
"""Extract benchmark data from comprehensive evaluation results."""
|
401 |
+
# Check if this evaluation contains benchmark data
|
402 |
+
if data.get("benchmark_skipped", False):
|
403 |
+
return None
|
404 |
+
|
405 |
+
# Check for benchmark fields
|
406 |
+
if not any(key in data for key in ["size_metrics", "speed_benchmarks", "memory_benchmarks", "cpu_vs_gpu"]):
|
407 |
+
return None
|
408 |
+
|
409 |
+
# Extract model name
|
410 |
+
original_name = data.get("model_name") or "Unknown"
|
411 |
+
mapped_name = extract_model_name_from_filename(
|
412 |
+
file_path.stem.replace("comprehensive_eval_", "").replace("codesearchnet_eval_", "")
|
413 |
+
)
|
414 |
+
|
415 |
+
# Create benchmark result structure
|
416 |
+
result: dict[str, Any] = {
|
417 |
+
"model_name": mapped_name,
|
418 |
+
"original_model_name": original_name,
|
419 |
+
"size_metrics": data.get("size_metrics", {}),
|
420 |
+
"speed_benchmarks": data.get("speed_benchmarks", {}),
|
421 |
+
"memory_benchmarks": data.get("memory_benchmarks", {}),
|
422 |
+
"cpu_vs_gpu": data.get("cpu_vs_gpu", {}),
|
423 |
+
}
|
424 |
+
|
425 |
+
return result
|
426 |
+
|
427 |
def _create_benchmark_dataframe(self) -> None:
|
428 |
"""Create benchmark comparison DataFrame from results."""
|
429 |
if not self.benchmark_results:
|
|
|
470 |
)
|
471 |
|
472 |
# CPU vs GPU comparison
|
473 |
+
for device, metrics in cpu_vs_gpu.items():
|
474 |
+
if isinstance(metrics, dict) and "error" not in metrics:
|
475 |
device_key = f"{device.upper()}_TextsPerSec"
|
476 |
+
row[device_key] = metrics.get("texts_per_second", 0)
|
477 |
|
478 |
benchmark_data.append(row)
|
479 |
|
|
|
488 |
return
|
489 |
|
490 |
logger.info(f"🔍 Searching for evaluation files in: {self.results_dir}")
|
|
|
|
|
491 |
|
492 |
+
# Look for both new comprehensive format and legacy formats
|
493 |
+
comprehensive_files = list(self.results_dir.glob("comprehensive_eval_*.json"))
|
494 |
+
legacy_files = list(self.results_dir.glob("codesearchnet_eval_*.json"))
|
495 |
+
|
496 |
+
all_files = comprehensive_files + legacy_files
|
497 |
+
logger.info(
|
498 |
+
f"📁 Found {len(all_files)} evaluation files ({len(comprehensive_files)} comprehensive, {len(legacy_files)} legacy)"
|
499 |
+
)
|
500 |
+
|
501 |
+
for json_file in all_files:
|
502 |
try:
|
503 |
logger.info(f"📖 Loading: {json_file.name}")
|
504 |
with json_file.open() as f:
|
505 |
data = json.load(f)
|
506 |
if data is not None:
|
507 |
+
if not isinstance(data, dict):
|
508 |
+
logger.warning(f"⚠️ Skipping {json_file.name} (not a dict)")
|
509 |
+
continue
|
510 |
+
|
511 |
+
# Normalize data format for analysis
|
512 |
+
normalized_data = self._normalize_evaluation_data(data, json_file)
|
513 |
+
self.results.append(normalized_data)
|
514 |
+
logger.info(f"✅ Successfully loaded: {normalized_data['model_name']}")
|
515 |
+
|
516 |
except (json.JSONDecodeError, KeyError) as e:
|
517 |
logger.warning(f"❌ Failed to load {json_file}: {e}")
|
518 |
|
|
|
526 |
# Also load benchmark results
|
527 |
self.load_benchmark_results()
|
528 |
|
529 |
+
def _normalize_evaluation_data(self, data: dict, file_path: Path) -> dict[str, Any]:
|
530 |
+
"""Normalize evaluation data to consistent format for analysis."""
|
531 |
+
# Extract model name
|
532 |
+
original_name = data.get("model_name", "Unknown")
|
533 |
+
file_stem = file_path.stem.replace("comprehensive_eval_", "").replace("codesearchnet_eval_", "")
|
534 |
+
mapped_name = extract_model_name_from_filename(file_stem)
|
535 |
+
|
536 |
+
# Handle comprehensive format (new)
|
537 |
+
if "codesearch_overall" in data and "codesearch_languages" in data:
|
538 |
+
result = {
|
539 |
+
"model_name": mapped_name,
|
540 |
+
"original_model_name": original_name,
|
541 |
+
"overall": data.get("codesearch_overall", {}),
|
542 |
+
"languages": data.get("codesearch_languages", {}),
|
543 |
+
}
|
544 |
+
# Handle legacy format (old codesearchnet_eval files)
|
545 |
+
else:
|
546 |
+
result = {
|
547 |
+
"model_name": mapped_name,
|
548 |
+
"original_model_name": original_name,
|
549 |
+
"overall": data.get("overall", {}),
|
550 |
+
"languages": data.get("languages", {}),
|
551 |
+
}
|
552 |
+
|
553 |
+
return result
|
554 |
+
|
555 |
def _create_comparison_dataframe(self) -> None:
|
556 |
"""Create comparison DataFrame from results."""
|
557 |
if not self.results:
|
|
|
694 |
if cpu_vs_gpu:
|
695 |
print("🖥️ CPU vs GPU:")
|
696 |
for device, metrics in cpu_vs_gpu.items():
|
697 |
+
if isinstance(metrics, dict) and "error" not in metrics:
|
698 |
print(f" {device.upper()}: {metrics.get('texts_per_second', 0):.1f} texts/sec")
|
699 |
|
700 |
# Memory efficiency
|
|
|
1219 |
# Safe conversion to float for pandas values
|
1220 |
score_value = pd.to_numeric(current_model_score, errors="coerce")
|
1221 |
scores.append(float(score_value) if not pd.isna(score_value) else 0.0)
|
1222 |
+
params.append(float(MODEL_SPECS[model_key].get("parameters", 100.0)))
|
1223 |
is_user_model.append(False)
|
1224 |
|
1225 |
if not models:
|
|
|
1339 |
|
1340 |
# Create visualizations
|
1341 |
logger.info("Generating visualizations...")
|
1342 |
+
output_dir, images_dir, reports_dir = setup_directories()
|
1343 |
|
1344 |
self.create_performance_radar_chart(main_model_name, language_scores)
|
1345 |
comparison_chart = self.plot_model_comparison()
|
|
|
1404 |
overall_metrics = result.get("overall", {})
|
1405 |
|
1406 |
# Extract teacher model name from model name
|
1407 |
+
teacher_name, teacher_link = get_teacher_model_info(model_display)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1408 |
|
1409 |
status = "🥇 Best" if rank == 1 else "🥈 2nd" if rank == 2 else "🥉 3rd" if rank == 3 else f"#{rank}"
|
1410 |
|
1411 |
+
# Use linked teacher name if available
|
1412 |
+
teacher_display = f"[{teacher_name}]({teacher_link})" if teacher_link else teacher_name
|
1413 |
+
|
1414 |
+
report += f"| {model_display} | {teacher_display} | {overall_metrics.get('ndcg@10', 0):.4f} | {overall_metrics.get('mrr', 0):.4f} | {overall_metrics.get('recall@5', 0):.4f} | {status} |\n"
|
1415 |
|
1416 |
report += """
|
1417 |
|
|
|
1449 |
report += "### Individual Model Performance by Language\n\n"
|
1450 |
for chart_model_name, chart_path in individual_radar_charts.items():
|
1451 |
# Extract teacher name for cleaner display
|
1452 |
+
teacher_name, teacher_link = get_teacher_model_info(chart_model_name)
|
1453 |
+
|
1454 |
+
# Use linked teacher name if available
|
1455 |
+
teacher_display = f"[{teacher_name}]({teacher_link})" if teacher_link else teacher_name
|
1456 |
+
|
1457 |
+
report += f"#### {chart_model_name} (Teacher: {teacher_display})\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1458 |
report += f"\n\n"
|
1459 |
|
1460 |
report += f"""
|
|
|
1551 |
|
1552 |
if language_scores:
|
1553 |
report += "| Language | Best Model Performance | Average Performance | Language Difficulty |\n"
|
1554 |
+
report += "|----------|------------------------|--------------------|--------------------|\n"
|
1555 |
|
1556 |
for lang in sorted(language_scores.keys()):
|
1557 |
# Find best performance for this language across all models
|
|
|
1585 |
model_name = result["model_name"]
|
1586 |
score = result.get("overall", {}).get("ndcg@10", 0)
|
1587 |
|
1588 |
+
teacher_name, teacher_link = get_teacher_model_info(model_name)
|
1589 |
+
teacher_performance[teacher_name] = score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1590 |
|
1591 |
if teacher_performance:
|
1592 |
best_teacher = max(teacher_performance.items(), key=lambda x: x[1])
|
|
|
1616 |
- **Evaluation**: Retrieval of correct code for each documentation query
|
1617 |
|
1618 |
### Teacher Models Tested
|
1619 |
+
- [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) (proven baseline)
|
1620 |
+
- [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) (general purpose)
|
1621 |
+
- [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) (paraphrase model)
|
1622 |
+
- [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base) (code-specialized)
|
1623 |
+
- [microsoft/graphcodebert-base](https://huggingface.co/microsoft/graphcodebert-base) (graph-aware code model)
|
1624 |
+
- [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) (instruction model)
|
1625 |
+
- [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) (multilingual model)
|
1626 |
+
- [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) (modern embedding model)
|
1627 |
+
- [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) (mixture of experts)
|
1628 |
+
- [Qodo/Qodo-Embed-1-1.5B](https://huggingface.co/Qodo/Qodo-Embed-1-1.5B) (code-specialized)
|
1629 |
+
- [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT) (ColBERT architecture)
|
1630 |
+
- [Linq-AI-Research/Linq-Embed-Mistral](https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral) (Mistral-based)
|
1631 |
+
- [BAAI/bge-code-v1](https://huggingface.co/BAAI/bge-code-v1) (code-specialized BGE)
|
1632 |
+
- [Salesforce/SFR-Embedding-Code-2B_R](https://huggingface.co/Salesforce/SFR-Embedding-Code-2B_R) (large code model)
|
1633 |
|
1634 |
### Distillation Method
|
1635 |
- **Technique**: Model2Vec static embedding generation
|
|
|
1652 |
logger.info(f"Results exported to {output_file}")
|
1653 |
|
1654 |
|
1655 |
+
def main(
|
1656 |
+
results_dir: str = DEFAULT_EVALUATION_DIR,
|
1657 |
+
model_name: str = "code_model2vec_distilled_models",
|
1658 |
+
output: str = "REPORT.md",
|
1659 |
+
export_csv: str | None = None,
|
1660 |
+
) -> None:
|
1661 |
"""Main analysis function."""
|
1662 |
+
logger.info("Starting CodeSearchNet Analysis with Integrated Benchmarks")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1663 |
logger.info("=" * 60)
|
1664 |
|
1665 |
# Setup output directories
|
1666 |
output_dir, images_dir, reports_dir = setup_directories()
|
1667 |
|
1668 |
+
# Initialize analyzer with results directory (benchmarks are integrated)
|
1669 |
analyzer = CodeSearchNetAnalyzer(
|
1670 |
+
results_dir=results_dir,
|
1671 |
+
benchmark_dir=None, # No longer needed - benchmarks are in comprehensive files
|
1672 |
images_dir=images_dir,
|
1673 |
)
|
1674 |
|
1675 |
+
# Load results (this will also load benchmark data from comprehensive files)
|
1676 |
analyzer.load_results()
|
1677 |
|
1678 |
if not analyzer.results:
|
|
|
1687 |
if analyzer.benchmark_results:
|
1688 |
analyzer.analyze_benchmark_performance()
|
1689 |
else:
|
1690 |
+
logger.warning("No benchmark results found. Models may have been evaluated with --skip-benchmark flag.")
|
1691 |
|
1692 |
# Generate comprehensive report with benchmark integration
|
1693 |
+
logger.info("Generating comprehensive report with integrated benchmark data...")
|
1694 |
+
report = analyzer.generate_comprehensive_report(model_name)
|
1695 |
|
1696 |
# Save report
|
1697 |
+
report_path = Path(output)
|
1698 |
with report_path.open("w") as f:
|
1699 |
f.write(report)
|
1700 |
|
1701 |
# Export CSV if requested
|
1702 |
+
if export_csv:
|
1703 |
+
analyzer.export_results(export_csv)
|
1704 |
|
1705 |
# Export benchmark CSV if available
|
1706 |
if analyzer.benchmark_df is not None and not analyzer.benchmark_df.empty:
|
1707 |
+
benchmark_csv = report_path.parent / f"{model_name}_benchmark_comparison.csv"
|
1708 |
analyzer.benchmark_df.to_csv(benchmark_csv, index=False)
|
1709 |
logger.info(f"📊 Benchmark comparison saved to: {benchmark_csv}")
|
1710 |
|
1711 |
+
logger.info("✅ CodeSearchNet analysis with integrated benchmarks complete!")
|
1712 |
logger.info(f"📊 Report saved to: {report_path}")
|
1713 |
logger.info(f"🖼️ Charts saved to: {images_dir}")
|
1714 |
+
logger.info(f"💾 Source: Comprehensive evaluation files in {results_dir}")
|
1715 |
|
1716 |
|
1717 |
if __name__ == "__main__":
|
|
|
|
|
1718 |
main()
|
src/distiller/beam_utils.py
CHANGED
@@ -16,6 +16,7 @@ Features:
|
|
16 |
import json
|
17 |
import logging
|
18 |
import shutil
|
|
|
19 |
import time
|
20 |
from pathlib import Path
|
21 |
from typing import Any
|
@@ -204,7 +205,7 @@ class BeamVolumeManager:
|
|
204 |
|
205 |
|
206 |
class BeamCheckpointManager:
|
207 |
-
"""Manager for checkpoint operations on Beam volumes."""
|
208 |
|
209 |
def __init__(self, volume_manager: BeamVolumeManager, checkpoint_prefix: str = "checkpoints") -> None:
|
210 |
"""
|
@@ -216,14 +217,21 @@ class BeamCheckpointManager:
|
|
216 |
"""
|
217 |
self.volume = volume_manager
|
218 |
self.checkpoint_prefix = checkpoint_prefix
|
219 |
-
self.
|
220 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
def save_checkpoint(self, stage: str, data: dict[str, Any], step: int = 0) -> bool:
|
223 |
-
"""Save checkpoint to volume."""
|
224 |
try:
|
|
|
225 |
checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
|
226 |
-
checkpoint_path =
|
227 |
|
228 |
with checkpoint_path.open("w") as f:
|
229 |
json.dump(data, f, indent=2, default=str)
|
@@ -236,10 +244,11 @@ class BeamCheckpointManager:
|
|
236 |
return False
|
237 |
|
238 |
def load_checkpoint(self, stage: str, step: int = 0) -> dict[str, Any] | None:
|
239 |
-
"""Load checkpoint from volume."""
|
240 |
try:
|
|
|
241 |
checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
|
242 |
-
checkpoint_path =
|
243 |
|
244 |
if checkpoint_path.exists():
|
245 |
with checkpoint_path.open("r") as f:
|
@@ -257,11 +266,13 @@ class BeamCheckpointManager:
|
|
257 |
def get_latest_checkpoint(self, stage: str) -> tuple[int, dict[str, Any]] | None:
|
258 |
"""Get the latest checkpoint for a stage."""
|
259 |
try:
|
|
|
|
|
260 |
# Find checkpoint files for this stage
|
261 |
pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
|
262 |
stage_checkpoints: list[tuple[int, Path]] = []
|
263 |
|
264 |
-
for checkpoint_file in
|
265 |
try:
|
266 |
# Extract step number from filename
|
267 |
step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
|
@@ -290,11 +301,13 @@ class BeamCheckpointManager:
|
|
290 |
def cleanup_old_checkpoints(self, stage: str, keep_latest: int = 3) -> list[str]:
|
291 |
"""Clean up old checkpoints, keeping only the latest N."""
|
292 |
try:
|
|
|
|
|
293 |
# Find checkpoint files for this stage
|
294 |
pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
|
295 |
stage_checkpoints: list[tuple[int, Path]] = []
|
296 |
|
297 |
-
for checkpoint_file in
|
298 |
try:
|
299 |
step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
|
300 |
step = int(step_str)
|
@@ -329,29 +342,55 @@ class BeamCheckpointManager:
|
|
329 |
"""List all checkpoints, optionally filtered by stage."""
|
330 |
try:
|
331 |
checkpoints: list[dict[str, Any]] = []
|
332 |
-
pattern = f"{self.checkpoint_prefix}_*.json"
|
333 |
|
334 |
-
|
335 |
-
#
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
|
|
|
|
|
|
343 |
|
344 |
-
if stage is None or checkpoint_stage == stage:
|
345 |
stat = checkpoint_file.stat()
|
346 |
checkpoints.append(
|
347 |
{
|
348 |
-
"stage":
|
349 |
"step": step,
|
350 |
"filename": checkpoint_file.name,
|
351 |
"size": f"{stat.st_size / 1024:.1f}KB",
|
352 |
"modified": time.ctime(stat.st_mtime),
|
353 |
}
|
354 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
|
356 |
return sorted(checkpoints, key=lambda x: (x["stage"], x["step"]))
|
357 |
|
@@ -747,6 +786,389 @@ def example_distillation_workflow() -> None:
|
|
747 |
logger.info(f"Workspace info: {info}")
|
748 |
|
749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
750 |
if __name__ == "__main__":
|
751 |
# Example usage
|
752 |
logging.basicConfig(level=logging.INFO)
|
|
|
16 |
import json
|
17 |
import logging
|
18 |
import shutil
|
19 |
+
import subprocess
|
20 |
import time
|
21 |
from pathlib import Path
|
22 |
from typing import Any
|
|
|
205 |
|
206 |
|
207 |
class BeamCheckpointManager:
|
208 |
+
"""Manager for checkpoint operations on Beam volumes with stage-based organization."""
|
209 |
|
210 |
def __init__(self, volume_manager: BeamVolumeManager, checkpoint_prefix: str = "checkpoints") -> None:
|
211 |
"""
|
|
|
217 |
"""
|
218 |
self.volume = volume_manager
|
219 |
self.checkpoint_prefix = checkpoint_prefix
|
220 |
+
self.checkpoint_base_dir = self.volume.mount_path / checkpoint_prefix
|
221 |
+
self.checkpoint_base_dir.mkdir(parents=True, exist_ok=True)
|
222 |
+
|
223 |
+
def _get_stage_dir(self, stage: str) -> Path:
|
224 |
+
"""Get stage-specific checkpoint directory."""
|
225 |
+
stage_dir = self.checkpoint_base_dir / stage
|
226 |
+
stage_dir.mkdir(parents=True, exist_ok=True)
|
227 |
+
return stage_dir
|
228 |
|
229 |
def save_checkpoint(self, stage: str, data: dict[str, Any], step: int = 0) -> bool:
|
230 |
+
"""Save checkpoint to volume in stage-specific directory."""
|
231 |
try:
|
232 |
+
stage_dir = self._get_stage_dir(stage)
|
233 |
checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
|
234 |
+
checkpoint_path = stage_dir / checkpoint_filename
|
235 |
|
236 |
with checkpoint_path.open("w") as f:
|
237 |
json.dump(data, f, indent=2, default=str)
|
|
|
244 |
return False
|
245 |
|
246 |
def load_checkpoint(self, stage: str, step: int = 0) -> dict[str, Any] | None:
|
247 |
+
"""Load checkpoint from volume stage-specific directory."""
|
248 |
try:
|
249 |
+
stage_dir = self._get_stage_dir(stage)
|
250 |
checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
|
251 |
+
checkpoint_path = stage_dir / checkpoint_filename
|
252 |
|
253 |
if checkpoint_path.exists():
|
254 |
with checkpoint_path.open("r") as f:
|
|
|
266 |
def get_latest_checkpoint(self, stage: str) -> tuple[int, dict[str, Any]] | None:
|
267 |
"""Get the latest checkpoint for a stage."""
|
268 |
try:
|
269 |
+
stage_dir = self._get_stage_dir(stage)
|
270 |
+
|
271 |
# Find checkpoint files for this stage
|
272 |
pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
|
273 |
stage_checkpoints: list[tuple[int, Path]] = []
|
274 |
|
275 |
+
for checkpoint_file in stage_dir.glob(pattern):
|
276 |
try:
|
277 |
# Extract step number from filename
|
278 |
step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
|
|
|
301 |
def cleanup_old_checkpoints(self, stage: str, keep_latest: int = 3) -> list[str]:
|
302 |
"""Clean up old checkpoints, keeping only the latest N."""
|
303 |
try:
|
304 |
+
stage_dir = self._get_stage_dir(stage)
|
305 |
+
|
306 |
# Find checkpoint files for this stage
|
307 |
pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
|
308 |
stage_checkpoints: list[tuple[int, Path]] = []
|
309 |
|
310 |
+
for checkpoint_file in stage_dir.glob(pattern):
|
311 |
try:
|
312 |
step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
|
313 |
step = int(step_str)
|
|
|
342 |
"""List all checkpoints, optionally filtered by stage."""
|
343 |
try:
|
344 |
checkpoints: list[dict[str, Any]] = []
|
|
|
345 |
|
346 |
+
if stage:
|
347 |
+
# List checkpoints for specific stage
|
348 |
+
stage_dir = self._get_stage_dir(stage)
|
349 |
+
pattern = f"{self.checkpoint_prefix}_{stage}_*.json"
|
350 |
+
|
351 |
+
for checkpoint_file in stage_dir.glob(pattern):
|
352 |
+
name_parts = checkpoint_file.stem.split("_")
|
353 |
+
if len(name_parts) >= 4:
|
354 |
+
try:
|
355 |
+
step = int(name_parts[3])
|
356 |
+
except ValueError:
|
357 |
+
step = 0
|
358 |
|
|
|
359 |
stat = checkpoint_file.stat()
|
360 |
checkpoints.append(
|
361 |
{
|
362 |
+
"stage": stage,
|
363 |
"step": step,
|
364 |
"filename": checkpoint_file.name,
|
365 |
"size": f"{stat.st_size / 1024:.1f}KB",
|
366 |
"modified": time.ctime(stat.st_mtime),
|
367 |
}
|
368 |
)
|
369 |
+
else:
|
370 |
+
# List checkpoints for all stages
|
371 |
+
for stage_dir in self.checkpoint_base_dir.iterdir():
|
372 |
+
if stage_dir.is_dir():
|
373 |
+
stage_name = stage_dir.name
|
374 |
+
pattern = f"{self.checkpoint_prefix}_{stage_name}_*.json"
|
375 |
+
|
376 |
+
for checkpoint_file in stage_dir.glob(pattern):
|
377 |
+
name_parts = checkpoint_file.stem.split("_")
|
378 |
+
if len(name_parts) >= 4:
|
379 |
+
try:
|
380 |
+
step = int(name_parts[3])
|
381 |
+
except ValueError:
|
382 |
+
step = 0
|
383 |
+
|
384 |
+
stat = checkpoint_file.stat()
|
385 |
+
checkpoints.append(
|
386 |
+
{
|
387 |
+
"stage": stage_name,
|
388 |
+
"step": step,
|
389 |
+
"filename": checkpoint_file.name,
|
390 |
+
"size": f"{stat.st_size / 1024:.1f}KB",
|
391 |
+
"modified": time.ctime(stat.st_mtime),
|
392 |
+
}
|
393 |
+
)
|
394 |
|
395 |
return sorted(checkpoints, key=lambda x: (x["stage"], x["step"]))
|
396 |
|
|
|
786 |
logger.info(f"Workspace info: {info}")
|
787 |
|
788 |
|
789 |
+
def download_evaluation_results_from_beam(
|
790 |
+
volume_name: str,
|
791 |
+
remote_results_dir: str = "evaluation_results",
|
792 |
+
local_results_dir: str = "code_model2vec/evaluation_results",
|
793 |
+
) -> bool:
|
794 |
+
"""
|
795 |
+
Download evaluation result files from Beam volume to local directory using beam cp.
|
796 |
+
|
797 |
+
Args:
|
798 |
+
volume_name: Name of the Beam volume
|
799 |
+
remote_results_dir: Directory path in the Beam volume containing results
|
800 |
+
local_results_dir: Local directory to download results to
|
801 |
+
|
802 |
+
Returns:
|
803 |
+
True if download successful, False otherwise
|
804 |
+
"""
|
805 |
+
try:
|
806 |
+
local_path = Path(local_results_dir)
|
807 |
+
local_path.mkdir(parents=True, exist_ok=True)
|
808 |
+
|
809 |
+
# Use beam cp to download individual JSON files
|
810 |
+
remote_path = f"{volume_name}:{remote_results_dir}"
|
811 |
+
|
812 |
+
# First, list files in the remote directory
|
813 |
+
list_cmd = ["beam", "cp", "-r", "--list-only", remote_path]
|
814 |
+
try:
|
815 |
+
result = subprocess.run(list_cmd, capture_output=True, text=True, check=True) # noqa: S603
|
816 |
+
remote_files = [line.strip() for line in result.stdout.split("\n") if line.strip().endswith(".json")]
|
817 |
+
except subprocess.CalledProcessError:
|
818 |
+
logger.warning(f"Could not list files in {remote_path}")
|
819 |
+
remote_files = []
|
820 |
+
|
821 |
+
# Download each JSON file individually
|
822 |
+
downloaded_files = []
|
823 |
+
for file_name in remote_files:
|
824 |
+
if file_name.endswith(".json"):
|
825 |
+
remote_file_path = f"{volume_name}:{remote_results_dir}/{file_name}"
|
826 |
+
local_file_path = local_path / file_name
|
827 |
+
|
828 |
+
try:
|
829 |
+
download_cmd = ["beam", "cp", remote_file_path, str(local_file_path)]
|
830 |
+
subprocess.run(download_cmd, check=True, capture_output=True) # noqa: S603
|
831 |
+
downloaded_files.append(file_name)
|
832 |
+
logger.info(f"📥 Downloaded: {file_name}")
|
833 |
+
|
834 |
+
# Delete the file from Beam volume after successful download
|
835 |
+
delete_cmd = ["beam", "rm", remote_file_path]
|
836 |
+
try:
|
837 |
+
subprocess.run(delete_cmd, check=True, capture_output=True) # noqa: S603
|
838 |
+
logger.info(f"🗑️ Deleted from volume: {file_name}")
|
839 |
+
except subprocess.CalledProcessError as e:
|
840 |
+
logger.warning(f"⚠️ Could not delete {file_name} from volume: {e}")
|
841 |
+
|
842 |
+
except subprocess.CalledProcessError as e:
|
843 |
+
logger.warning(f"⚠️ Failed to download {file_name}: {e}")
|
844 |
+
|
845 |
+
if downloaded_files:
|
846 |
+
logger.info(f"✅ Downloaded {len(downloaded_files)} evaluation result files")
|
847 |
+
return True
|
848 |
+
logger.info("ℹ️ No new evaluation files to download")
|
849 |
+
return True
|
850 |
+
|
851 |
+
except Exception:
|
852 |
+
logger.exception("❌ Error downloading evaluation results from Beam")
|
853 |
+
return False
|
854 |
+
|
855 |
+
|
856 |
+
def download_specific_evaluation_file(
|
857 |
+
volume_name: str,
|
858 |
+
model_name: str,
|
859 |
+
remote_results_dir: str = "evaluation_results",
|
860 |
+
local_results_dir: str = "code_model2vec/evaluation_results",
|
861 |
+
file_prefix: str = "codesearchnet_eval",
|
862 |
+
) -> bool:
|
863 |
+
"""
|
864 |
+
Download a specific evaluation or benchmark result file from Beam volume.
|
865 |
+
|
866 |
+
Args:
|
867 |
+
volume_name: Name of the Beam volume
|
868 |
+
model_name: Name of the model whose results to download
|
869 |
+
remote_results_dir: Directory path in the Beam volume containing results
|
870 |
+
local_results_dir: Local directory to download results to
|
871 |
+
file_prefix: Prefix for the file (e.g., 'codesearchnet_eval', 'benchmark')
|
872 |
+
|
873 |
+
Returns:
|
874 |
+
True if download successful, False otherwise
|
875 |
+
"""
|
876 |
+
try:
|
877 |
+
local_path = Path(local_results_dir)
|
878 |
+
local_path.mkdir(parents=True, exist_ok=True)
|
879 |
+
|
880 |
+
# Generate filename following the pattern
|
881 |
+
safe_model_name = model_name.replace("/", "_")
|
882 |
+
filename = f"{file_prefix}_{safe_model_name}.json"
|
883 |
+
|
884 |
+
remote_file_path = f"{volume_name}:{remote_results_dir}/{filename}"
|
885 |
+
local_file_path = local_path / filename
|
886 |
+
|
887 |
+
# Download the specific file
|
888 |
+
download_cmd = ["beam", "cp", remote_file_path, str(local_file_path)]
|
889 |
+
subprocess.run(download_cmd, check=True, capture_output=True) # noqa: S603
|
890 |
+
|
891 |
+
logger.info(f"📥 Downloaded {file_prefix} results for {model_name}")
|
892 |
+
|
893 |
+
# Delete the file from Beam volume after successful download
|
894 |
+
delete_cmd = ["beam", "rm", remote_file_path]
|
895 |
+
try:
|
896 |
+
subprocess.run(delete_cmd, check=True, capture_output=True) # noqa: S603
|
897 |
+
logger.info(f"🗑️ Deleted {file_prefix} results for {model_name} from volume")
|
898 |
+
except subprocess.CalledProcessError as e:
|
899 |
+
logger.warning(f"⚠️ Could not delete {filename} from volume: {e}")
|
900 |
+
|
901 |
+
return True
|
902 |
+
|
903 |
+
except subprocess.CalledProcessError:
|
904 |
+
logger.warning(f"⚠️ No {file_prefix} results found for {model_name} on Beam")
|
905 |
+
return False
|
906 |
+
except Exception:
|
907 |
+
logger.exception(f"❌ Error downloading {file_prefix} results for {model_name}")
|
908 |
+
return False
|
909 |
+
|
910 |
+
|
911 |
+
def download_model_from_beam(
|
912 |
+
volume_name: str,
|
913 |
+
model_name: str,
|
914 |
+
local_dir: str,
|
915 |
+
) -> bool:
|
916 |
+
"""
|
917 |
+
Download a model from Beam volume to local directory.
|
918 |
+
|
919 |
+
Args:
|
920 |
+
volume_name: Name of the Beam volume
|
921 |
+
model_name: Name of the model to download
|
922 |
+
local_dir: Local directory to download model to
|
923 |
+
|
924 |
+
Returns:
|
925 |
+
True if download successful, False otherwise
|
926 |
+
"""
|
927 |
+
try:
|
928 |
+
local_path = Path(local_dir)
|
929 |
+
local_path.mkdir(parents=True, exist_ok=True)
|
930 |
+
|
931 |
+
# Use beam cp to download the model directory
|
932 |
+
remote_path = f"{volume_name}:models/{model_name}"
|
933 |
+
local_model_path = local_path / model_name
|
934 |
+
|
935 |
+
download_cmd = ["beam", "cp", "-r", remote_path, str(local_model_path)]
|
936 |
+
subprocess.run(download_cmd, check=True, capture_output=True) # noqa: S603
|
937 |
+
|
938 |
+
logger.info(f"📥 Downloaded model {model_name} from Beam to {local_dir}")
|
939 |
+
return True
|
940 |
+
|
941 |
+
except subprocess.CalledProcessError as e:
|
942 |
+
logger.warning(f"⚠️ Failed to download model {model_name} from Beam: {e}")
|
943 |
+
return False
|
944 |
+
except Exception:
|
945 |
+
logger.exception(f"❌ Error downloading model {model_name} from Beam")
|
946 |
+
return False
|
947 |
+
|
948 |
+
|
949 |
+
def upload_model_to_beam(
|
950 |
+
volume_name: str,
|
951 |
+
model_name: str,
|
952 |
+
local_dir: str,
|
953 |
+
) -> bool:
|
954 |
+
"""
|
955 |
+
Upload a model from local directory to Beam volume.
|
956 |
+
|
957 |
+
Args:
|
958 |
+
volume_name: Name of the Beam volume
|
959 |
+
model_name: Name for the model on Beam
|
960 |
+
local_dir: Local directory containing the model
|
961 |
+
|
962 |
+
Returns:
|
963 |
+
True if upload successful, False otherwise
|
964 |
+
"""
|
965 |
+
try:
|
966 |
+
local_path = Path(local_dir)
|
967 |
+
if not local_path.exists():
|
968 |
+
logger.error(f"❌ Local model directory does not exist: {local_dir}")
|
969 |
+
return False
|
970 |
+
|
971 |
+
# Use beam cp to upload the model directory
|
972 |
+
remote_path = f"{volume_name}:models/{model_name}"
|
973 |
+
|
974 |
+
upload_cmd = ["beam", "cp", "-r", str(local_path), remote_path]
|
975 |
+
subprocess.run(upload_cmd, check=True, capture_output=True) # noqa: S603
|
976 |
+
|
977 |
+
logger.info(f"📤 Uploaded model {model_name} to Beam from {local_dir}")
|
978 |
+
return True
|
979 |
+
|
980 |
+
except subprocess.CalledProcessError as e:
|
981 |
+
logger.warning(f"⚠️ Failed to upload model {model_name} to Beam: {e}")
|
982 |
+
return False
|
983 |
+
except Exception:
|
984 |
+
logger.exception(f"❌ Error uploading model {model_name} to Beam")
|
985 |
+
return False
|
986 |
+
|
987 |
+
|
988 |
+
def download_checkpoints_from_beam(
|
989 |
+
volume_name: str,
|
990 |
+
stage: str | None = None,
|
991 |
+
remote_checkpoints_dir: str = "checkpoints",
|
992 |
+
local_checkpoints_dir: str = "code_model2vec/checkpoints",
|
993 |
+
) -> bool:
|
994 |
+
"""
|
995 |
+
Download checkpoint files from Beam volume to local directory.
|
996 |
+
|
997 |
+
Args:
|
998 |
+
volume_name: Name of the Beam volume
|
999 |
+
stage: Specific stage to download (e.g., 'distillation', 'training'), or None for all
|
1000 |
+
remote_checkpoints_dir: Directory path in the Beam volume containing checkpoints
|
1001 |
+
local_checkpoints_dir: Local directory to download checkpoints to
|
1002 |
+
|
1003 |
+
Returns:
|
1004 |
+
True if download successful, False otherwise
|
1005 |
+
"""
|
1006 |
+
try:
|
1007 |
+
local_path = Path(local_checkpoints_dir)
|
1008 |
+
local_path.mkdir(parents=True, exist_ok=True)
|
1009 |
+
|
1010 |
+
# Build the pattern for files to download
|
1011 |
+
if stage:
|
1012 |
+
local_stage_dir = local_path / stage
|
1013 |
+
local_stage_dir.mkdir(parents=True, exist_ok=True)
|
1014 |
+
else:
|
1015 |
+
pass
|
1016 |
+
|
1017 |
+
# Use beam cp to download checkpoint files
|
1018 |
+
remote_path = f"{volume_name}:{remote_checkpoints_dir}"
|
1019 |
+
|
1020 |
+
# First, try to list files
|
1021 |
+
list_cmd = ["beam", "cp", "-r", "--list-only", remote_path]
|
1022 |
+
try:
|
1023 |
+
result = subprocess.run(list_cmd, capture_output=True, text=True, check=True) # noqa: S603
|
1024 |
+
remote_files = [
|
1025 |
+
line.strip()
|
1026 |
+
for line in result.stdout.split("\n")
|
1027 |
+
if line.strip().endswith(".json") and "checkpoints_" in line.strip()
|
1028 |
+
]
|
1029 |
+
except subprocess.CalledProcessError:
|
1030 |
+
logger.warning(f"Could not list checkpoint files in {remote_path}")
|
1031 |
+
remote_files = []
|
1032 |
+
|
1033 |
+
# Filter by stage if specified
|
1034 |
+
if stage:
|
1035 |
+
remote_files = [f for f in remote_files if f"checkpoints_{stage}_" in f]
|
1036 |
+
|
1037 |
+
# Download each checkpoint file
|
1038 |
+
downloaded_files = []
|
1039 |
+
for file_name in remote_files:
|
1040 |
+
remote_file_path = f"{volume_name}:{remote_checkpoints_dir}/{file_name}"
|
1041 |
+
|
1042 |
+
# Determine local subdirectory based on checkpoint stage
|
1043 |
+
file_stage = file_name.split("_")[1] if "_" in file_name else "unknown"
|
1044 |
+
local_stage_dir = local_path / file_stage
|
1045 |
+
local_stage_dir.mkdir(parents=True, exist_ok=True)
|
1046 |
+
local_file_path = local_stage_dir / file_name
|
1047 |
+
|
1048 |
+
try:
|
1049 |
+
download_cmd = ["beam", "cp", remote_file_path, str(local_file_path)]
|
1050 |
+
subprocess.run(download_cmd, check=True, capture_output=True) # noqa: S603
|
1051 |
+
downloaded_files.append(file_name)
|
1052 |
+
logger.info(f"📥 Downloaded checkpoint: {file_name}")
|
1053 |
+
|
1054 |
+
except subprocess.CalledProcessError as e:
|
1055 |
+
logger.warning(f"⚠️ Failed to download checkpoint {file_name}: {e}")
|
1056 |
+
|
1057 |
+
if downloaded_files:
|
1058 |
+
logger.info(f"✅ Downloaded {len(downloaded_files)} checkpoint files")
|
1059 |
+
return True
|
1060 |
+
logger.info("ℹ️ No new checkpoint files to download")
|
1061 |
+
return True
|
1062 |
+
|
1063 |
+
except Exception:
|
1064 |
+
logger.exception("❌ Error downloading checkpoints from Beam")
|
1065 |
+
return False
|
1066 |
+
|
1067 |
+
|
1068 |
+
def upload_checkpoints_to_beam(
|
1069 |
+
volume_name: str,
|
1070 |
+
stage: str | None = None,
|
1071 |
+
local_checkpoints_dir: str = "code_model2vec/checkpoints",
|
1072 |
+
remote_checkpoints_dir: str = "checkpoints",
|
1073 |
+
) -> bool:
|
1074 |
+
"""
|
1075 |
+
Upload checkpoint files from local directory to Beam volume.
|
1076 |
+
|
1077 |
+
Args:
|
1078 |
+
volume_name: Name of the Beam volume
|
1079 |
+
stage: Specific stage to upload (e.g., 'distillation', 'training'), or None for all
|
1080 |
+
local_checkpoints_dir: Local directory containing checkpoints
|
1081 |
+
remote_checkpoints_dir: Directory path in the Beam volume to store checkpoints
|
1082 |
+
|
1083 |
+
Returns:
|
1084 |
+
True if upload successful, False otherwise
|
1085 |
+
"""
|
1086 |
+
try:
|
1087 |
+
local_path = Path(local_checkpoints_dir)
|
1088 |
+
if not local_path.exists():
|
1089 |
+
logger.warning(f"⚠️ Local checkpoints directory does not exist: {local_checkpoints_dir}")
|
1090 |
+
return True # Not an error - no checkpoints to upload
|
1091 |
+
|
1092 |
+
# Find checkpoint files to upload
|
1093 |
+
if stage:
|
1094 |
+
# Look in the stage subdirectory
|
1095 |
+
stage_dir = local_path / stage
|
1096 |
+
checkpoint_files = list(stage_dir.glob(f"checkpoints_{stage}_*.json")) if stage_dir.exists() else []
|
1097 |
+
else:
|
1098 |
+
# Look for all checkpoint files in all subdirectories
|
1099 |
+
checkpoint_files = []
|
1100 |
+
for subdir in local_path.iterdir():
|
1101 |
+
if subdir.is_dir():
|
1102 |
+
checkpoint_files.extend(subdir.glob("checkpoints_*.json"))
|
1103 |
+
|
1104 |
+
if not checkpoint_files:
|
1105 |
+
logger.info(f"ℹ️ No checkpoint files found to upload for stage: {stage or 'all'}")
|
1106 |
+
return True
|
1107 |
+
|
1108 |
+
# Upload each checkpoint file
|
1109 |
+
uploaded_files = []
|
1110 |
+
for checkpoint_file in checkpoint_files:
|
1111 |
+
remote_file_path = f"{volume_name}:{remote_checkpoints_dir}/{checkpoint_file.name}"
|
1112 |
+
|
1113 |
+
try:
|
1114 |
+
upload_cmd = ["beam", "cp", str(checkpoint_file), remote_file_path]
|
1115 |
+
subprocess.run(upload_cmd, check=True, capture_output=True) # noqa: S603
|
1116 |
+
uploaded_files.append(checkpoint_file.name)
|
1117 |
+
logger.info(f"📤 Uploaded checkpoint: {checkpoint_file.name}")
|
1118 |
+
|
1119 |
+
except subprocess.CalledProcessError as e:
|
1120 |
+
logger.warning(f"⚠️ Failed to upload checkpoint {checkpoint_file.name}: {e}")
|
1121 |
+
|
1122 |
+
if uploaded_files:
|
1123 |
+
logger.info(f"✅ Uploaded {len(uploaded_files)} checkpoint files")
|
1124 |
+
return True
|
1125 |
+
return False
|
1126 |
+
|
1127 |
+
except Exception:
|
1128 |
+
logger.exception("❌ Error uploading checkpoints to Beam")
|
1129 |
+
return False
|
1130 |
+
|
1131 |
+
|
1132 |
+
def sync_checkpoints_from_beam(
|
1133 |
+
volume_name: str,
|
1134 |
+
stage: str,
|
1135 |
+
local_checkpoints_dir: str = "code_model2vec/checkpoints",
|
1136 |
+
) -> bool:
|
1137 |
+
"""
|
1138 |
+
Sync specific stage checkpoints from Beam to local directory.
|
1139 |
+
|
1140 |
+
Args:
|
1141 |
+
volume_name: Name of the Beam volume
|
1142 |
+
stage: Stage to sync (e.g., 'distillation', 'training')
|
1143 |
+
local_checkpoints_dir: Local directory for checkpoints
|
1144 |
+
|
1145 |
+
Returns:
|
1146 |
+
True if sync successful, False otherwise
|
1147 |
+
"""
|
1148 |
+
logger.info(f"🔄 Syncing {stage} checkpoints from Beam...")
|
1149 |
+
return download_checkpoints_from_beam(volume_name, stage, "checkpoints", local_checkpoints_dir)
|
1150 |
+
|
1151 |
+
|
1152 |
+
def sync_checkpoints_to_beam(
|
1153 |
+
volume_name: str,
|
1154 |
+
stage: str,
|
1155 |
+
local_checkpoints_dir: str = "code_model2vec/checkpoints",
|
1156 |
+
) -> bool:
|
1157 |
+
"""
|
1158 |
+
Sync specific stage checkpoints from local directory to Beam.
|
1159 |
+
|
1160 |
+
Args:
|
1161 |
+
volume_name: Name of the Beam volume
|
1162 |
+
stage: Stage to sync (e.g., 'distillation', 'training')
|
1163 |
+
local_checkpoints_dir: Local directory containing checkpoints
|
1164 |
+
|
1165 |
+
Returns:
|
1166 |
+
True if sync successful, False otherwise
|
1167 |
+
"""
|
1168 |
+
logger.info(f"🔄 Syncing {stage} checkpoints to Beam...")
|
1169 |
+
return upload_checkpoints_to_beam(volume_name, stage, local_checkpoints_dir, "checkpoints")
|
1170 |
+
|
1171 |
+
|
1172 |
if __name__ == "__main__":
|
1173 |
# Example usage
|
1174 |
logging.basicConfig(level=logging.INFO)
|
src/distiller/benchmark.py
DELETED
@@ -1,1181 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Operational Performance Benchmarking for Embedding Models.
|
3 |
-
|
4 |
-
This module benchmarks embedding models on operational metrics like:
|
5 |
-
- Inference speed (latency and throughput)
|
6 |
-
- Memory efficiency (RAM and GPU usage)
|
7 |
-
- Model size and storage requirements
|
8 |
-
- Scalability with batch size
|
9 |
-
- CPU vs GPU performance
|
10 |
-
"""
|
11 |
-
|
12 |
-
import gc
|
13 |
-
import json
|
14 |
-
import logging
|
15 |
-
import os
|
16 |
-
import time
|
17 |
-
from pathlib import Path
|
18 |
-
from typing import Any
|
19 |
-
|
20 |
-
import pandas as pd
|
21 |
-
import psutil
|
22 |
-
import torch
|
23 |
-
from beam import GpuType, Image, Volume, function
|
24 |
-
from sentence_transformers import SentenceTransformer
|
25 |
-
|
26 |
-
from .beam_utils import (
|
27 |
-
BeamCheckpointManager,
|
28 |
-
BeamEvaluationManager,
|
29 |
-
create_beam_utilities,
|
30 |
-
)
|
31 |
-
|
32 |
-
logger = logging.getLogger(__name__)
|
33 |
-
|
34 |
-
# =============================================================================
|
35 |
-
# BEAM CONFIGURATION
|
36 |
-
# =============================================================================
|
37 |
-
|
38 |
-
GPU_NAME = GpuType.A100_40
|
39 |
-
VOLUME_NAME = "gte_qwen2_m2v_code" # Same volume as distill.py and evaluate.py
|
40 |
-
VOLUME_PATH = "./gte_qwen2_m2v_code" # Same mount path as distill.py and evaluate.py
|
41 |
-
BENCHMARK_RESULTS_DIR = "benchmark_results" # Subdirectory within volume
|
42 |
-
BENCHMARK_CACHE_DIR = "benchmark_cache" # Cache for models
|
43 |
-
|
44 |
-
IMAGE = Image(python_version="python3.12").add_python_packages(
|
45 |
-
[
|
46 |
-
"torch>=2.7.0",
|
47 |
-
"transformers>=4.40.0",
|
48 |
-
"datasets>=3.2.0",
|
49 |
-
"sentence-transformers>=4.1.0",
|
50 |
-
"model2vec[train]>=0.5.0",
|
51 |
-
"numpy>=1.26.4",
|
52 |
-
"scikit-learn>=1.6.1",
|
53 |
-
"pandas>=2.0.0",
|
54 |
-
"tqdm>=4.65.0",
|
55 |
-
"psutil>=5.9.0",
|
56 |
-
]
|
57 |
-
)
|
58 |
-
|
59 |
-
# =============================================================================
|
60 |
-
# CONFIGURATION
|
61 |
-
# =============================================================================
|
62 |
-
|
63 |
-
DEFAULT_OUTPUT_DIR = "benchmark_results" # Local fallback directory
|
64 |
-
|
65 |
-
# Default models to benchmark (can be overridden via command line)
|
66 |
-
DEFAULT_BENCHMARK_MODELS = [
|
67 |
-
# Your distilled model (local files in Beam volume root)
|
68 |
-
"gte_qwen2_m2v_code", # This will be resolved to VOLUME_PATH in Beam
|
69 |
-
# Established Code Models
|
70 |
-
"sentence-transformers/all-MiniLM-L6-v2",
|
71 |
-
"microsoft/codebert-base",
|
72 |
-
"microsoft/graphcodebert-base",
|
73 |
-
"huggingface/CodeBERTa-small-v1",
|
74 |
-
"sentence-transformers/all-mpnet-base-v2",
|
75 |
-
"sentence-transformers/all-MiniLM-L12-v2",
|
76 |
-
# Model2Vec & Efficiency Models (Direct Competitors)
|
77 |
-
"minishlab/potion-base-8M",
|
78 |
-
"minishlab/potion-retrieval-32M",
|
79 |
-
# Small Transformer-Based Code Models
|
80 |
-
"Salesforce/codet5-base",
|
81 |
-
]
|
82 |
-
|
83 |
-
# =============================================================================
|
84 |
-
# CHECKPOINT CONFIGURATION
|
85 |
-
# =============================================================================
|
86 |
-
|
87 |
-
# Prevent conflicts with other modules by using unique prefixes
|
88 |
-
BENCHMARK_CHECKPOINT_PREFIX = "benchmark_checkpoints"
|
89 |
-
MODEL_CACHE_PREFIX = "model_cache"
|
90 |
-
|
91 |
-
# Sample texts for benchmarking (various lengths)
|
92 |
-
BENCHMARK_TEXTS = {
|
93 |
-
"short": [
|
94 |
-
"def add(a, b): return a + b",
|
95 |
-
"function multiply(x, y) { return x * y; }",
|
96 |
-
"class Calculator { public int subtract(int a, int b) { return a - b; } }",
|
97 |
-
]
|
98 |
-
* 100, # 300 short texts
|
99 |
-
"medium": [
|
100 |
-
"def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
|
101 |
-
"function quickSort(arr) {\n if (arr.length <= 1) return arr;\n const pivot = arr[arr.length - 1];\n const left = [], right = [];\n for (let i = 0; i < arr.length - 1; i++) {\n if (arr[i] < pivot) left.push(arr[i]);\n else right.push(arr[i]);\n }\n return [...quickSort(left), pivot, ...quickSort(right)];\n}",
|
102 |
-
]
|
103 |
-
* 50, # 100 medium texts
|
104 |
-
"long": [
|
105 |
-
"""
|
106 |
-
def complex_algorithm(data, config):
|
107 |
-
'''
|
108 |
-
Complex data processing algorithm with multiple steps.
|
109 |
-
|
110 |
-
Args:
|
111 |
-
data: Input data structure
|
112 |
-
config: Configuration parameters
|
113 |
-
|
114 |
-
Returns:
|
115 |
-
Processed results
|
116 |
-
'''
|
117 |
-
results = []
|
118 |
-
|
119 |
-
# Step 1: Data validation
|
120 |
-
if not isinstance(data, (list, tuple)):
|
121 |
-
raise ValueError("Data must be list or tuple")
|
122 |
-
|
123 |
-
# Step 2: Preprocessing
|
124 |
-
processed_data = []
|
125 |
-
for item in data:
|
126 |
-
if config.get('normalize', False):
|
127 |
-
item = normalize_item(item)
|
128 |
-
if config.get('filter', False):
|
129 |
-
if not filter_item(item, config['filter_criteria']):
|
130 |
-
continue
|
131 |
-
processed_data.append(item)
|
132 |
-
|
133 |
-
# Step 3: Main processing
|
134 |
-
for item in processed_data:
|
135 |
-
result = process_item(item, config)
|
136 |
-
if result is not None:
|
137 |
-
results.append(result)
|
138 |
-
|
139 |
-
# Step 4: Post-processing
|
140 |
-
if config.get('sort', False):
|
141 |
-
results.sort(key=lambda x: x.get('score', 0), reverse=True)
|
142 |
-
|
143 |
-
return results
|
144 |
-
""".strip(),
|
145 |
-
]
|
146 |
-
* 20, # 20 long texts
|
147 |
-
}
|
148 |
-
|
149 |
-
|
150 |
-
class PerformanceBenchmark:
|
151 |
-
"""Comprehensive performance benchmarking for embedding models."""
|
152 |
-
|
153 |
-
def __init__(
|
154 |
-
self,
|
155 |
-
model_path: str,
|
156 |
-
model_name: str | None = None,
|
157 |
-
checkpoint_manager: BeamCheckpointManager | None = None,
|
158 |
-
eval_manager: BeamEvaluationManager | None = None,
|
159 |
-
) -> None:
|
160 |
-
"""Initialize benchmarker with model and optional Beam utilities."""
|
161 |
-
self.model_path = model_path
|
162 |
-
self.model_name = model_name or Path(model_path).name
|
163 |
-
self.model: SentenceTransformer | None = None
|
164 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
165 |
-
self.results: dict[str, Any] = {}
|
166 |
-
self.checkpoint_manager = checkpoint_manager
|
167 |
-
self.eval_manager = eval_manager
|
168 |
-
|
169 |
-
def load_model(self) -> None:
|
170 |
-
"""Load the embedding model."""
|
171 |
-
logger.info(f"Loading model from {self.model_path}")
|
172 |
-
start_time = time.time()
|
173 |
-
|
174 |
-
try:
|
175 |
-
self.model = SentenceTransformer(self.model_path, device=self.device, trust_remote_code=True)
|
176 |
-
load_time = time.time() - start_time
|
177 |
-
|
178 |
-
logger.info(f"✅ Model loaded in {load_time:.2f}s on {self.device}")
|
179 |
-
self.results["model_load_time"] = load_time
|
180 |
-
|
181 |
-
except Exception:
|
182 |
-
logger.exception("❌ Failed to load model")
|
183 |
-
raise
|
184 |
-
|
185 |
-
def measure_model_size(self) -> dict[str, float]:
|
186 |
-
"""Measure model size metrics."""
|
187 |
-
logger.info("📏 Measuring model size...")
|
188 |
-
|
189 |
-
size_metrics = {}
|
190 |
-
|
191 |
-
# Disk size - handle both local paths and HuggingFace models
|
192 |
-
try:
|
193 |
-
if Path(self.model_path).is_dir():
|
194 |
-
# Local directory - calculate size of model files only
|
195 |
-
model_extensions = {".safetensors", ".bin", ".json", ".txt", ".tokenizer"}
|
196 |
-
total_size = 0
|
197 |
-
model_dir = Path(self.model_path)
|
198 |
-
|
199 |
-
for file_path in model_dir.rglob("*"):
|
200 |
-
if file_path.is_file() and (
|
201 |
-
file_path.suffix.lower() in model_extensions
|
202 |
-
or file_path.name.lower() in {"config.json", "tokenizer.json", "modules.json", "README.md"}
|
203 |
-
):
|
204 |
-
total_size += file_path.stat().st_size
|
205 |
-
|
206 |
-
size_metrics["disk_size_mb"] = total_size / (1024 * 1024)
|
207 |
-
elif Path(self.model_path).is_file():
|
208 |
-
# Single file
|
209 |
-
total_size = Path(self.model_path).stat().st_size
|
210 |
-
size_metrics["disk_size_mb"] = total_size / (1024 * 1024)
|
211 |
-
else:
|
212 |
-
# HuggingFace model - estimate from cache if available
|
213 |
-
from transformers import AutoConfig
|
214 |
-
|
215 |
-
try:
|
216 |
-
config = AutoConfig.from_pretrained(self.model_path)
|
217 |
-
# Estimate size based on parameters (rough approximation)
|
218 |
-
if hasattr(config, "hidden_size") and hasattr(config, "num_hidden_layers"):
|
219 |
-
# Rough estimation for transformer models
|
220 |
-
estimated_params = config.hidden_size * config.num_hidden_layers * 1000 # Very rough
|
221 |
-
size_metrics["disk_size_mb"] = estimated_params * 4 / (1024 * 1024) # 4 bytes per float32
|
222 |
-
else:
|
223 |
-
size_metrics["disk_size_mb"] = 0 # Unknown
|
224 |
-
except Exception:
|
225 |
-
logger.warning(f"Could not determine disk size for HuggingFace model: {self.model_path}")
|
226 |
-
size_metrics["disk_size_mb"] = 0 # Unknown
|
227 |
-
except Exception as e:
|
228 |
-
logger.warning(f"Could not determine disk size: {e}")
|
229 |
-
size_metrics["disk_size_mb"] = 0
|
230 |
-
|
231 |
-
# Model parameters (if accessible)
|
232 |
-
try:
|
233 |
-
if self.model is not None and hasattr(self.model, "modules"):
|
234 |
-
total_params = sum(p.numel() for p in self.model.parameters())
|
235 |
-
size_metrics["parameters_millions"] = total_params / 1_000_000
|
236 |
-
|
237 |
-
# Try to get embedding dimension from model config
|
238 |
-
try:
|
239 |
-
# Use the public modules() method instead of private _modules
|
240 |
-
modules = list(self.model.modules())
|
241 |
-
if len(modules) > 1: # modules[0] is usually the entire model, modules[1] is first submodule
|
242 |
-
first_module = modules[1]
|
243 |
-
if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "config"):
|
244 |
-
config = first_module.auto_model.config
|
245 |
-
if hasattr(config, "hidden_size"):
|
246 |
-
size_metrics["embedding_dim"] = config.hidden_size
|
247 |
-
elif hasattr(config, "model_dim"):
|
248 |
-
size_metrics["embedding_dim"] = config.model_dim
|
249 |
-
except Exception as e:
|
250 |
-
logger.debug(
|
251 |
-
f"Could not extract embedding dimension from model config: {e}"
|
252 |
-
) # Silently continue if this method fails
|
253 |
-
|
254 |
-
# For Model2Vec static models
|
255 |
-
elif self.model is not None and hasattr(self.model, "embedding"):
|
256 |
-
# Handle both tensor and numpy array embeddings
|
257 |
-
embedding = self.model.embedding
|
258 |
-
if hasattr(embedding, "shape"):
|
259 |
-
vocab_size, embedding_dim = embedding.shape # type: ignore[misc]
|
260 |
-
total_params = vocab_size * embedding_dim
|
261 |
-
size_metrics["parameters_millions"] = total_params / 1_000_000
|
262 |
-
size_metrics["vocab_size"] = vocab_size
|
263 |
-
size_metrics["embedding_dim"] = embedding_dim
|
264 |
-
else:
|
265 |
-
logger.warning("Could not determine embedding shape for Model2Vec model")
|
266 |
-
|
267 |
-
# Alternative method: get embedding dimension from a test encoding
|
268 |
-
if "embedding_dim" not in size_metrics and self.model is not None:
|
269 |
-
try:
|
270 |
-
test_embedding = self.model.encode(["test"], convert_to_tensor=False)
|
271 |
-
if hasattr(test_embedding, "shape") and len(test_embedding.shape) >= 2:
|
272 |
-
size_metrics["embedding_dim"] = test_embedding.shape[1]
|
273 |
-
elif (
|
274 |
-
isinstance(test_embedding, (list, tuple))
|
275 |
-
and len(test_embedding) > 0
|
276 |
-
and hasattr(test_embedding[0], "__len__")
|
277 |
-
):
|
278 |
-
size_metrics["embedding_dim"] = len(test_embedding[0])
|
279 |
-
except Exception as e:
|
280 |
-
logger.warning(f"Could not determine embedding dimension: {e}")
|
281 |
-
|
282 |
-
except Exception as e:
|
283 |
-
logger.warning(f"Could not determine parameter count: {e}")
|
284 |
-
|
285 |
-
# Memory footprint
|
286 |
-
if self.device == "cuda" and torch.cuda.is_available():
|
287 |
-
torch.cuda.empty_cache()
|
288 |
-
size_metrics["gpu_memory_mb"] = torch.cuda.memory_allocated() / (1024 * 1024)
|
289 |
-
|
290 |
-
# RAM usage (approximate)
|
291 |
-
process = psutil.Process(os.getpid())
|
292 |
-
size_metrics["ram_usage_mb"] = process.memory_info().rss / (1024 * 1024)
|
293 |
-
|
294 |
-
self.results["size_metrics"] = size_metrics
|
295 |
-
return size_metrics
|
296 |
-
|
297 |
-
def benchmark_inference_speed(self, batch_sizes: list[int] | None = None) -> dict[str, Any]:
|
298 |
-
"""Benchmark inference speed across different batch sizes."""
|
299 |
-
if batch_sizes is None:
|
300 |
-
batch_sizes = [1, 8, 16, 32, 64, 128]
|
301 |
-
logger.info("⚡ Benchmarking inference speed...")
|
302 |
-
|
303 |
-
if self.model is None:
|
304 |
-
self.load_model()
|
305 |
-
|
306 |
-
if self.model is None:
|
307 |
-
msg = "Failed to load model"
|
308 |
-
raise RuntimeError(msg)
|
309 |
-
|
310 |
-
speed_results = {}
|
311 |
-
text_lengths = ["short", "medium", "long"]
|
312 |
-
|
313 |
-
for text_length in text_lengths:
|
314 |
-
logger.info(f" 📝 Testing {text_length} texts...")
|
315 |
-
texts = BENCHMARK_TEXTS[text_length]
|
316 |
-
|
317 |
-
length_results = {}
|
318 |
-
|
319 |
-
for batch_size in batch_sizes:
|
320 |
-
if batch_size > len(texts):
|
321 |
-
continue
|
322 |
-
|
323 |
-
logger.info(f" 🔄 Batch size: {batch_size}")
|
324 |
-
|
325 |
-
# Prepare batch
|
326 |
-
batch_texts = texts[:batch_size]
|
327 |
-
|
328 |
-
# Warmup
|
329 |
-
if self.device == "cuda":
|
330 |
-
torch.cuda.synchronize()
|
331 |
-
_ = self.model.encode(batch_texts[: min(2, batch_size)], convert_to_tensor=False)
|
332 |
-
|
333 |
-
# Clear cache
|
334 |
-
if self.device == "cuda":
|
335 |
-
torch.cuda.empty_cache()
|
336 |
-
torch.cuda.synchronize()
|
337 |
-
|
338 |
-
# Measure inference time
|
339 |
-
start_time = time.perf_counter()
|
340 |
-
|
341 |
-
embeddings = self.model.encode(batch_texts, convert_to_tensor=False, show_progress_bar=False)
|
342 |
-
|
343 |
-
if self.device == "cuda":
|
344 |
-
torch.cuda.synchronize()
|
345 |
-
|
346 |
-
end_time = time.perf_counter()
|
347 |
-
|
348 |
-
# Calculate metrics
|
349 |
-
total_time = end_time - start_time
|
350 |
-
time_per_text = total_time / batch_size
|
351 |
-
texts_per_second = batch_size / total_time
|
352 |
-
|
353 |
-
# Estimate tokens (rough approximation)
|
354 |
-
avg_tokens = sum(len(text.split()) for text in batch_texts) / batch_size
|
355 |
-
total_tokens = avg_tokens * batch_size
|
356 |
-
tokens_per_second = total_tokens / total_time
|
357 |
-
|
358 |
-
length_results[f"batch_{batch_size}"] = {
|
359 |
-
"total_time_ms": total_time * 1000,
|
360 |
-
"time_per_text_ms": time_per_text * 1000,
|
361 |
-
"texts_per_second": texts_per_second,
|
362 |
-
"tokens_per_second": tokens_per_second,
|
363 |
-
"avg_tokens_per_text": avg_tokens,
|
364 |
-
"embedding_shape": embeddings.shape
|
365 |
-
if hasattr(embeddings, "shape")
|
366 |
-
else f"({len(embeddings)}, {len(embeddings[0]) if embeddings else 0})",
|
367 |
-
}
|
368 |
-
|
369 |
-
speed_results[text_length] = length_results
|
370 |
-
|
371 |
-
self.results["speed_benchmarks"] = speed_results
|
372 |
-
return speed_results
|
373 |
-
|
374 |
-
def benchmark_memory_scaling(self, batch_sizes: list[int] | None = None) -> dict[str, Any]:
|
375 |
-
"""Benchmark memory usage across batch sizes."""
|
376 |
-
if batch_sizes is None:
|
377 |
-
batch_sizes = [1, 8, 16, 32, 64, 128, 256]
|
378 |
-
logger.info("💾 Benchmarking memory scaling...")
|
379 |
-
|
380 |
-
if self.model is None:
|
381 |
-
self.load_model()
|
382 |
-
|
383 |
-
if self.model is None:
|
384 |
-
msg = "Failed to load model"
|
385 |
-
raise RuntimeError(msg)
|
386 |
-
|
387 |
-
memory_results: dict[str, Any] = {}
|
388 |
-
texts = BENCHMARK_TEXTS["medium"]
|
389 |
-
|
390 |
-
baseline_memory = 0
|
391 |
-
if self.device == "cuda":
|
392 |
-
torch.cuda.empty_cache()
|
393 |
-
baseline_memory = torch.cuda.memory_allocated()
|
394 |
-
|
395 |
-
for batch_size in batch_sizes:
|
396 |
-
if batch_size > len(texts):
|
397 |
-
continue
|
398 |
-
|
399 |
-
logger.info(f" 📊 Testing batch size: {batch_size}")
|
400 |
-
|
401 |
-
# Clear cache
|
402 |
-
if self.device == "cuda":
|
403 |
-
torch.cuda.empty_cache()
|
404 |
-
gc.collect()
|
405 |
-
|
406 |
-
batch_texts = texts[:batch_size]
|
407 |
-
|
408 |
-
# Measure memory before
|
409 |
-
if self.device == "cuda":
|
410 |
-
torch.cuda.memory_allocated()
|
411 |
-
|
412 |
-
# Run inference
|
413 |
-
try:
|
414 |
-
embeddings = self.model.encode(
|
415 |
-
batch_texts,
|
416 |
-
convert_to_tensor=self.device == "cuda",
|
417 |
-
show_progress_bar=False,
|
418 |
-
)
|
419 |
-
|
420 |
-
# Measure memory after
|
421 |
-
memory_after = 0
|
422 |
-
if self.device == "cuda":
|
423 |
-
memory_after = torch.cuda.max_memory_allocated()
|
424 |
-
torch.cuda.reset_peak_memory_stats()
|
425 |
-
|
426 |
-
memory_used_mb = (memory_after - baseline_memory) / (1024 * 1024)
|
427 |
-
memory_per_text_mb = memory_used_mb / batch_size if batch_size > 0 else 0
|
428 |
-
|
429 |
-
memory_results[f"batch_{batch_size}"] = {
|
430 |
-
"memory_used_mb": memory_used_mb,
|
431 |
-
"memory_per_text_mb": memory_per_text_mb,
|
432 |
-
"baseline_memory_mb": baseline_memory / (1024 * 1024),
|
433 |
-
"peak_memory_mb": memory_after / (1024 * 1024),
|
434 |
-
}
|
435 |
-
|
436 |
-
# Clean up
|
437 |
-
del embeddings
|
438 |
-
|
439 |
-
except torch.cuda.OutOfMemoryError:
|
440 |
-
logger.warning(f"❌ OOM at batch size {batch_size}")
|
441 |
-
memory_results[f"batch_{batch_size}"] = {"oom": True}
|
442 |
-
break
|
443 |
-
except Exception as e:
|
444 |
-
logger.warning(f"❌ Error at batch size {batch_size}: {e}")
|
445 |
-
memory_results[f"batch_{batch_size}"] = {"error": str(e)}
|
446 |
-
|
447 |
-
self.results["memory_benchmarks"] = memory_results
|
448 |
-
return memory_results
|
449 |
-
|
450 |
-
def benchmark_cpu_vs_gpu(self) -> dict[str, Any]:
|
451 |
-
"""Compare CPU vs GPU performance."""
|
452 |
-
logger.info("🖥️ Benchmarking CPU vs GPU performance...")
|
453 |
-
|
454 |
-
comparison_results = {}
|
455 |
-
test_texts = BENCHMARK_TEXTS["medium"][:32] # Fixed batch size
|
456 |
-
|
457 |
-
devices = ["cpu"]
|
458 |
-
if torch.cuda.is_available():
|
459 |
-
devices.append("cuda")
|
460 |
-
|
461 |
-
for device in devices:
|
462 |
-
logger.info(f" 🔄 Testing on {device}")
|
463 |
-
|
464 |
-
# Load model on device
|
465 |
-
try:
|
466 |
-
model = SentenceTransformer(self.model_path, device=device)
|
467 |
-
|
468 |
-
# Warmup
|
469 |
-
_ = model.encode(test_texts[:2], convert_to_tensor=False)
|
470 |
-
|
471 |
-
# Benchmark
|
472 |
-
start_time = time.perf_counter()
|
473 |
-
embeddings = model.encode(test_texts, convert_to_tensor=False, show_progress_bar=False)
|
474 |
-
end_time = time.perf_counter()
|
475 |
-
|
476 |
-
total_time = end_time - start_time
|
477 |
-
|
478 |
-
comparison_results[device] = {
|
479 |
-
"total_time_ms": total_time * 1000,
|
480 |
-
"texts_per_second": len(test_texts) / total_time,
|
481 |
-
"time_per_text_ms": (total_time / len(test_texts)) * 1000,
|
482 |
-
"embedding_shape": embeddings.shape
|
483 |
-
if hasattr(embeddings, "shape")
|
484 |
-
else f"({len(embeddings)}, {len(embeddings[0]) if embeddings else 0})",
|
485 |
-
}
|
486 |
-
|
487 |
-
del model
|
488 |
-
if device == "cuda":
|
489 |
-
torch.cuda.empty_cache()
|
490 |
-
|
491 |
-
except Exception as e:
|
492 |
-
logger.warning(f"❌ Failed on {device}: {e}")
|
493 |
-
comparison_results[device] = {"error": str(e)}
|
494 |
-
|
495 |
-
self.results["cpu_vs_gpu"] = comparison_results
|
496 |
-
return comparison_results
|
497 |
-
|
498 |
-
def run_comprehensive_benchmark(self) -> dict[str, Any]:
|
499 |
-
"""Run all benchmarks and return comprehensive results."""
|
500 |
-
logger.info(f"🚀 Starting comprehensive benchmark for {self.model_name}")
|
501 |
-
|
502 |
-
# Load model
|
503 |
-
self.load_model()
|
504 |
-
|
505 |
-
# Run all benchmarks
|
506 |
-
self.measure_model_size()
|
507 |
-
self.benchmark_inference_speed()
|
508 |
-
self.benchmark_memory_scaling()
|
509 |
-
self.benchmark_cpu_vs_gpu()
|
510 |
-
|
511 |
-
# Add metadata
|
512 |
-
self.results["model_name"] = self.model_name
|
513 |
-
self.results["model_path"] = self.model_path
|
514 |
-
self.results["device"] = self.device
|
515 |
-
self.results["torch_version"] = torch.__version__
|
516 |
-
self.results["cuda_available"] = torch.cuda.is_available()
|
517 |
-
|
518 |
-
if torch.cuda.is_available():
|
519 |
-
self.results["gpu_name"] = torch.cuda.get_device_name(0)
|
520 |
-
self.results["gpu_memory_gb"] = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
521 |
-
|
522 |
-
# System info
|
523 |
-
self.results["cpu_count"] = psutil.cpu_count()
|
524 |
-
self.results["ram_gb"] = psutil.virtual_memory().total / (1024**3)
|
525 |
-
|
526 |
-
logger.info("✅ Comprehensive benchmark completed!")
|
527 |
-
return self.results
|
528 |
-
|
529 |
-
def save_results(self, output_file: str) -> None:
|
530 |
-
"""Save benchmark results to JSON file."""
|
531 |
-
output_path = Path(output_file)
|
532 |
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
533 |
-
|
534 |
-
with output_path.open("w") as f:
|
535 |
-
json.dump(self.results, f, indent=2, default=str)
|
536 |
-
|
537 |
-
logger.info(f"📄 Results saved to {output_path}")
|
538 |
-
|
539 |
-
def print_summary(self) -> None:
|
540 |
-
"""Print a summary of benchmark results."""
|
541 |
-
if not self.results:
|
542 |
-
logger.warning("No results to summarize")
|
543 |
-
return
|
544 |
-
|
545 |
-
print(f"\n{'=' * 60}")
|
546 |
-
print(f"Performance Benchmark Summary: {self.model_name}")
|
547 |
-
print(f"{'=' * 60}")
|
548 |
-
|
549 |
-
# Model size
|
550 |
-
if "size_metrics" in self.results:
|
551 |
-
size = self.results["size_metrics"]
|
552 |
-
print("\n📏 Model Size:")
|
553 |
-
print(f" Disk Size: {size.get('disk_size_mb', 0):.1f} MB")
|
554 |
-
if "parameters_millions" in size:
|
555 |
-
print(f" Parameters: {size['parameters_millions']:.1f}M")
|
556 |
-
if "embedding_dim" in size:
|
557 |
-
print(f" Embedding Dim: {size['embedding_dim']}")
|
558 |
-
|
559 |
-
# Speed summary
|
560 |
-
if "speed_benchmarks" in self.results:
|
561 |
-
speed = self.results["speed_benchmarks"]
|
562 |
-
print("\n⚡ Speed (medium texts, batch 32):")
|
563 |
-
if "medium" in speed and "batch_32" in speed["medium"]:
|
564 |
-
batch_32 = speed["medium"]["batch_32"]
|
565 |
-
print(f" Throughput: {batch_32['texts_per_second']:.1f} texts/sec")
|
566 |
-
print(f" Latency: {batch_32['time_per_text_ms']:.1f} ms/text")
|
567 |
-
print(f" Token Speed: {batch_32['tokens_per_second']:.0f} tokens/sec")
|
568 |
-
|
569 |
-
# CPU vs GPU
|
570 |
-
if "cpu_vs_gpu" in self.results:
|
571 |
-
comparison = self.results["cpu_vs_gpu"]
|
572 |
-
print("\n🖥️ CPU vs GPU:")
|
573 |
-
for device, metrics in comparison.items():
|
574 |
-
if "error" not in metrics:
|
575 |
-
print(f" {device.upper()}: {metrics['texts_per_second']:.1f} texts/sec")
|
576 |
-
|
577 |
-
print()
|
578 |
-
|
579 |
-
|
580 |
-
def run_benchmark(
|
581 |
-
model_path: str | list[str],
|
582 |
-
model_name: str | None = None,
|
583 |
-
output: str = "benchmark_results.json",
|
584 |
-
quick: bool = False,
|
585 |
-
compare_models: list[str] | None = None,
|
586 |
-
) -> None:
|
587 |
-
"""Run benchmark for one or multiple models with comparison."""
|
588 |
-
# Handle both single model and multiple models
|
589 |
-
models_to_benchmark = [model_path] if isinstance(model_path, str) else model_path
|
590 |
-
|
591 |
-
if compare_models:
|
592 |
-
models_to_benchmark.extend(compare_models)
|
593 |
-
|
594 |
-
all_results = []
|
595 |
-
|
596 |
-
for i, model in enumerate(models_to_benchmark):
|
597 |
-
current_model_name = model_name if i == 0 else Path(model).name
|
598 |
-
|
599 |
-
print(f"\n{'=' * 60}")
|
600 |
-
print(f"Benchmarking Model {i + 1}/{len(models_to_benchmark)}: {current_model_name}")
|
601 |
-
print(f"{'=' * 60}")
|
602 |
-
|
603 |
-
try:
|
604 |
-
benchmarker = PerformanceBenchmark(model, current_model_name)
|
605 |
-
|
606 |
-
if quick:
|
607 |
-
# Quick benchmark
|
608 |
-
benchmarker.load_model()
|
609 |
-
benchmarker.measure_model_size()
|
610 |
-
benchmarker.benchmark_inference_speed([1, 16, 32])
|
611 |
-
else:
|
612 |
-
# Comprehensive benchmark
|
613 |
-
benchmarker.run_comprehensive_benchmark()
|
614 |
-
|
615 |
-
all_results.append(benchmarker.results)
|
616 |
-
benchmarker.print_summary()
|
617 |
-
|
618 |
-
except Exception:
|
619 |
-
logger.exception(f"❌ Failed to benchmark {current_model_name}")
|
620 |
-
continue
|
621 |
-
|
622 |
-
# Save individual results
|
623 |
-
output_dir = Path(output).parent if Path(output).suffix else Path(output)
|
624 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
625 |
-
|
626 |
-
for results in all_results:
|
627 |
-
model_name_safe = "".join(c for c in results["model_name"] if c.isalnum() or c in ("-", "_", "."))
|
628 |
-
output_path = output_dir / f"benchmark_{model_name_safe}.json"
|
629 |
-
|
630 |
-
with output_path.open("w") as f:
|
631 |
-
json.dump(results, f, indent=2, default=str)
|
632 |
-
|
633 |
-
logger.info(f"📄 Results saved to {output_path}")
|
634 |
-
|
635 |
-
# Create comparison if multiple models
|
636 |
-
if len(all_results) > 1:
|
637 |
-
create_benchmark_comparison(all_results, str(output_dir / "benchmark_comparison.json"))
|
638 |
-
|
639 |
-
print(f"\n✅ Benchmark complete! Results saved to {output_dir}")
|
640 |
-
|
641 |
-
|
642 |
-
def create_benchmark_comparison(all_results: list[dict[str, Any]], output_path: str) -> None:
|
643 |
-
"""Create a comparison report for multiple benchmark results."""
|
644 |
-
print(f"\n{'=' * 80}")
|
645 |
-
print("Performance Benchmark Comparison")
|
646 |
-
print(f"{'=' * 80}")
|
647 |
-
|
648 |
-
comparison_data = []
|
649 |
-
|
650 |
-
for results in all_results:
|
651 |
-
model_name = results.get("model_name", "Unknown")
|
652 |
-
size_metrics = results.get("size_metrics", {})
|
653 |
-
speed_benchmarks = results.get("speed_benchmarks", {})
|
654 |
-
cpu_vs_gpu = results.get("cpu_vs_gpu", {})
|
655 |
-
|
656 |
-
# Extract key metrics
|
657 |
-
row = {
|
658 |
-
"Model": model_name,
|
659 |
-
"Disk Size (MB)": size_metrics.get("disk_size_mb", 0),
|
660 |
-
"Parameters (M)": size_metrics.get("parameters_millions", 0),
|
661 |
-
"Embedding Dim": size_metrics.get("embedding_dim", 0),
|
662 |
-
}
|
663 |
-
|
664 |
-
# Speed metrics (medium texts, batch 32)
|
665 |
-
if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
|
666 |
-
batch_32 = speed_benchmarks["medium"]["batch_32"]
|
667 |
-
row.update(
|
668 |
-
{
|
669 |
-
"Throughput (texts/sec)": batch_32.get("texts_per_second", 0),
|
670 |
-
"Latency (ms/text)": batch_32.get("time_per_text_ms", 0),
|
671 |
-
"Token Speed (tokens/sec)": batch_32.get("tokens_per_second", 0),
|
672 |
-
}
|
673 |
-
)
|
674 |
-
|
675 |
-
# CPU vs GPU comparison
|
676 |
-
for device in ["cpu", "cuda"]:
|
677 |
-
if device in cpu_vs_gpu and "error" not in cpu_vs_gpu[device]:
|
678 |
-
row[f"{device.upper()} Speed (texts/sec)"] = cpu_vs_gpu[device].get("texts_per_second", 0)
|
679 |
-
|
680 |
-
comparison_data.append(row)
|
681 |
-
|
682 |
-
# Create DataFrame and save
|
683 |
-
df = pd.DataFrame(comparison_data)
|
684 |
-
|
685 |
-
# Sort by throughput (descending)
|
686 |
-
if "Throughput (texts/sec)" in df.columns:
|
687 |
-
df = df.sort_values("Throughput (texts/sec)", ascending=False)
|
688 |
-
|
689 |
-
# Print comparison table
|
690 |
-
print(df.to_string(index=False, float_format="%.2f"))
|
691 |
-
|
692 |
-
# Save comparison results
|
693 |
-
comparison_summary = {
|
694 |
-
"comparison_table": df.to_dict(orient="records"),
|
695 |
-
"summary": {
|
696 |
-
"fastest_model": df.iloc[0]["Model"] if len(df) > 0 else None,
|
697 |
-
"smallest_model": df.loc[df["Disk Size (MB)"].idxmin()]["Model"] if len(df) > 0 else None,
|
698 |
-
"most_efficient": df.loc[df["Throughput (texts/sec)"].idxmax()]["Model"]
|
699 |
-
if "Throughput (texts/sec)" in df.columns and len(df) > 0
|
700 |
-
else None,
|
701 |
-
},
|
702 |
-
"timestamp": time.time(),
|
703 |
-
}
|
704 |
-
|
705 |
-
with Path(output_path).open("w") as f:
|
706 |
-
json.dump(comparison_summary, f, indent=2, default=str)
|
707 |
-
|
708 |
-
print(f"\n📊 Comparison saved to {output_path}")
|
709 |
-
|
710 |
-
|
711 |
-
def save_benchmark_results(
|
712 |
-
results: dict[str, Any],
|
713 |
-
output_dir: str,
|
714 |
-
model_name: str,
|
715 |
-
volume_results_dir: Path | None = None,
|
716 |
-
) -> None:
|
717 |
-
"""Save benchmark results to JSON file with Beam volume support."""
|
718 |
-
# Save to Beam volume if available
|
719 |
-
if volume_results_dir:
|
720 |
-
volume_output_path = volume_results_dir / f"benchmark_{model_name}.json"
|
721 |
-
try:
|
722 |
-
with volume_output_path.open("w") as f:
|
723 |
-
json.dump(results, f, indent=2, default=str)
|
724 |
-
logger.info(f"💾 Results saved to Beam volume: {volume_output_path}")
|
725 |
-
except Exception as e:
|
726 |
-
logger.warning(f"⚠️ Failed to save to Beam volume: {e}")
|
727 |
-
|
728 |
-
# Always save local backup
|
729 |
-
output_path = Path(output_dir)
|
730 |
-
output_path.mkdir(parents=True, exist_ok=True)
|
731 |
-
|
732 |
-
# Clean model name for filename
|
733 |
-
safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
|
734 |
-
filename = f"benchmark_{safe_name}.json"
|
735 |
-
filepath = output_path / filename
|
736 |
-
|
737 |
-
with filepath.open("w") as f:
|
738 |
-
json.dump(results, f, indent=2, default=str)
|
739 |
-
|
740 |
-
logger.info(f"📄 Local backup saved to {filepath}")
|
741 |
-
|
742 |
-
|
743 |
-
def beam_benchmark_models(
|
744 |
-
models: list[str],
|
745 |
-
quick: bool = False,
|
746 |
-
output_dir: str = DEFAULT_OUTPUT_DIR,
|
747 |
-
volume_name: str = VOLUME_NAME,
|
748 |
-
mount_path: str = VOLUME_PATH,
|
749 |
-
) -> list[dict[str, Any]]:
|
750 |
-
"""Main benchmarking function for Beam execution with checkpoint support."""
|
751 |
-
logger.info("🚀 Starting Beam-powered performance benchmarking")
|
752 |
-
logger.info(f"📊 Benchmarking {len(models)} models")
|
753 |
-
|
754 |
-
# Initialize Beam utilities
|
755 |
-
volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(volume_name, mount_path)
|
756 |
-
|
757 |
-
# Create benchmark results directory in volume
|
758 |
-
results_dir = Path(mount_path) / BENCHMARK_RESULTS_DIR
|
759 |
-
results_dir.mkdir(parents=True, exist_ok=True)
|
760 |
-
|
761 |
-
logger.info(f"📁 Using Beam volume: {volume_name} at {mount_path}")
|
762 |
-
logger.info(f"💾 Benchmark results directory: {results_dir}")
|
763 |
-
|
764 |
-
all_results = []
|
765 |
-
skipped_models = []
|
766 |
-
|
767 |
-
for model_path in models:
|
768 |
-
model_name = Path(model_path).name if model_path != str(Path(mount_path)) else "gte_qwen2_m2v_code"
|
769 |
-
|
770 |
-
# Check if this model has already been benchmarked (except for trained model)
|
771 |
-
is_trained_model = model_path == str(Path(mount_path)) or model_name == "gte_qwen2_m2v_code"
|
772 |
-
|
773 |
-
if not is_trained_model:
|
774 |
-
# Check for existing benchmark results
|
775 |
-
existing_result_file = results_dir / f"benchmark_{model_name}.json"
|
776 |
-
if existing_result_file.exists():
|
777 |
-
logger.info(f"✅ Model {model_name} already benchmarked - loading existing results")
|
778 |
-
try:
|
779 |
-
with existing_result_file.open("r") as f:
|
780 |
-
existing_results = json.load(f)
|
781 |
-
all_results.append(existing_results)
|
782 |
-
skipped_models.append(model_name)
|
783 |
-
continue
|
784 |
-
except Exception as e:
|
785 |
-
logger.warning(f"⚠️ Failed to load existing results for {model_name}: {e}")
|
786 |
-
# Continue with benchmarking if loading fails
|
787 |
-
|
788 |
-
logger.info(f"\n{'=' * 60}")
|
789 |
-
logger.info(f"🔍 Benchmarking model: {model_name}")
|
790 |
-
logger.info(f"📂 Path: {model_path}")
|
791 |
-
if is_trained_model:
|
792 |
-
logger.info("🎯 Trained model - always re-benchmark")
|
793 |
-
logger.info(f"{'=' * 60}")
|
794 |
-
|
795 |
-
try:
|
796 |
-
# Distinguish between local paths and HuggingFace model names
|
797 |
-
is_huggingface_model = (
|
798 |
-
"/" in model_path and not model_path.startswith("/") and not Path(model_path).exists()
|
799 |
-
)
|
800 |
-
|
801 |
-
if is_huggingface_model:
|
802 |
-
# This is a HuggingFace model name - pass directly to benchmarker
|
803 |
-
logger.info(f"📥 Loading HuggingFace model: {model_path}")
|
804 |
-
benchmarker = PerformanceBenchmark(
|
805 |
-
model_path,
|
806 |
-
model_name,
|
807 |
-
checkpoint_manager=checkpoint_mgr,
|
808 |
-
eval_manager=eval_mgr,
|
809 |
-
)
|
810 |
-
else:
|
811 |
-
# This is a local path - check if it exists in Beam volume
|
812 |
-
actual_model_path = model_path # Default to original path
|
813 |
-
if not Path(model_path).exists() and not model_path.startswith("/"):
|
814 |
-
# Try to load from Beam volume
|
815 |
-
local_model_path = Path(mount_path) / model_name
|
816 |
-
logger.info(f"🔍 Trying to load {model_name} from Beam volume: {local_model_path}")
|
817 |
-
if local_model_path.exists():
|
818 |
-
actual_model_path = str(local_model_path)
|
819 |
-
logger.info(f"✅ Found model in Beam volume: {actual_model_path}")
|
820 |
-
else:
|
821 |
-
# Try in root of volume (for your trained model)
|
822 |
-
root_model_path = Path(mount_path)
|
823 |
-
if (root_model_path / "config.json").exists():
|
824 |
-
actual_model_path = str(root_model_path)
|
825 |
-
logger.info(f"✅ Found model in Beam volume root: {actual_model_path}")
|
826 |
-
else:
|
827 |
-
logger.warning(f"⚠️ Model not found locally or in Beam volume: {model_name}")
|
828 |
-
continue
|
829 |
-
|
830 |
-
benchmarker = PerformanceBenchmark(
|
831 |
-
actual_model_path,
|
832 |
-
model_name,
|
833 |
-
checkpoint_manager=checkpoint_mgr,
|
834 |
-
eval_manager=eval_mgr,
|
835 |
-
)
|
836 |
-
|
837 |
-
# Run benchmarking
|
838 |
-
if quick:
|
839 |
-
# Quick benchmark
|
840 |
-
benchmarker.load_model()
|
841 |
-
benchmarker.measure_model_size()
|
842 |
-
benchmarker.benchmark_inference_speed([1, 16, 32])
|
843 |
-
else:
|
844 |
-
# Comprehensive benchmark
|
845 |
-
benchmarker.run_comprehensive_benchmark()
|
846 |
-
|
847 |
-
# Save results with Beam support
|
848 |
-
save_benchmark_results(benchmarker.results, output_dir, model_name, results_dir)
|
849 |
-
|
850 |
-
# Print summary
|
851 |
-
benchmarker.print_summary()
|
852 |
-
|
853 |
-
all_results.append(benchmarker.results)
|
854 |
-
|
855 |
-
except Exception:
|
856 |
-
logger.exception(f"❌ Failed to benchmark {model_name}")
|
857 |
-
continue
|
858 |
-
|
859 |
-
# Create comparison report in Beam volume
|
860 |
-
if len(all_results) > 1:
|
861 |
-
comparison_dir = results_dir / "comparisons"
|
862 |
-
comparison_dir.mkdir(parents=True, exist_ok=True)
|
863 |
-
create_benchmark_comparison(all_results, str(comparison_dir / "benchmark_comparison.json"))
|
864 |
-
logger.info(f"📊 Comparison report saved to Beam volume: {comparison_dir}")
|
865 |
-
|
866 |
-
# Log summary of what was done
|
867 |
-
newly_benchmarked = len(all_results) - len(skipped_models)
|
868 |
-
logger.info("\n✅ Beam benchmarking complete!")
|
869 |
-
logger.info(f"📊 Newly benchmarked: {newly_benchmarked} models")
|
870 |
-
logger.info(f"⏭️ Skipped (already done): {len(skipped_models)} models")
|
871 |
-
logger.info(f"📁 Total results: {len(all_results)} models")
|
872 |
-
logger.info(f"💾 Results available in Beam volume: {volume_name}")
|
873 |
-
|
874 |
-
if skipped_models:
|
875 |
-
logger.info(f"⏭️ Skipped models: {', '.join(skipped_models)}")
|
876 |
-
|
877 |
-
return all_results
|
878 |
-
|
879 |
-
|
880 |
-
@function(
|
881 |
-
gpu=GPU_NAME,
|
882 |
-
volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
|
883 |
-
image=IMAGE,
|
884 |
-
secrets=["HF_ACCESS_TOKEN"],
|
885 |
-
env={
|
886 |
-
"TOKENIZERS_PARALLELISM": "false",
|
887 |
-
"CUDA_LAUNCH_BLOCKING": "0",
|
888 |
-
},
|
889 |
-
timeout=3600 * 4, # 4 hours for benchmarking all models
|
890 |
-
)
|
891 |
-
def main() -> None:
|
892 |
-
"""Main benchmarking function - runs all default models on Beam."""
|
893 |
-
logger.info("🚀 Starting comprehensive performance benchmarking on Beam")
|
894 |
-
|
895 |
-
# Use default models but replace the local model path with Beam volume path
|
896 |
-
models = DEFAULT_BENCHMARK_MODELS.copy()
|
897 |
-
|
898 |
-
# Replace "gte_qwen2_m2v_code" with actual Beam volume path
|
899 |
-
for i, model in enumerate(models):
|
900 |
-
if model == "gte_qwen2_m2v_code":
|
901 |
-
models[i] = str(Path(VOLUME_PATH)) # Use the Beam volume root
|
902 |
-
logger.info(f"🎯 Using trained model from Beam volume: {models[i]}")
|
903 |
-
|
904 |
-
# Discover simplified distillation models
|
905 |
-
logger.info("🔍 Discovering simplified distillation models...")
|
906 |
-
discovered_models = discover_simplified_models(".")
|
907 |
-
|
908 |
-
# Add discovered models
|
909 |
-
if discovered_models:
|
910 |
-
logger.info(f"✅ Found {len(discovered_models)} simplified models:")
|
911 |
-
for model_path in discovered_models:
|
912 |
-
models.append(model_path)
|
913 |
-
logger.info(f" 📁 {model_path}")
|
914 |
-
else:
|
915 |
-
logger.warning("⚠️ No simplified distillation models found")
|
916 |
-
|
917 |
-
logger.info(f"📊 Benchmarking {len(models)} models:")
|
918 |
-
for i, model in enumerate(models, 1):
|
919 |
-
logger.info(f" {i}. {model}")
|
920 |
-
|
921 |
-
logger.info("\n💡 Checkpoint Info:")
|
922 |
-
logger.info(" - Already benchmarked models will be skipped")
|
923 |
-
logger.info(" - Your trained model will always be re-benchmarked")
|
924 |
-
logger.info(" - Results are saved persistently to Beam volume")
|
925 |
-
|
926 |
-
# Run comprehensive benchmark using Beam utilities
|
927 |
-
results = beam_benchmark_models(
|
928 |
-
models=models,
|
929 |
-
quick=True, # Use quick benchmark for efficiency
|
930 |
-
output_dir=str(Path(VOLUME_PATH) / BENCHMARK_RESULTS_DIR),
|
931 |
-
volume_name=VOLUME_NAME,
|
932 |
-
mount_path=VOLUME_PATH,
|
933 |
-
)
|
934 |
-
|
935 |
-
# Print final summary
|
936 |
-
print("\n🎯 Benchmarking Summary:")
|
937 |
-
print(f"📊 Total models processed: {len(results)}")
|
938 |
-
print(f"💾 Results saved to Beam volume: {VOLUME_NAME}")
|
939 |
-
print(f"📁 Directory: {BENCHMARK_RESULTS_DIR}")
|
940 |
-
print("\n🔍 To view analysis:")
|
941 |
-
print(" beam run src.distiller.analyze:beam_analysis")
|
942 |
-
print("\n📈 To run benchmarks again:")
|
943 |
-
print(" distiller benchmark (will skip already completed models)")
|
944 |
-
|
945 |
-
|
946 |
-
def discover_simplified_models(base_path: str = ".") -> list[str]:
|
947 |
-
"""
|
948 |
-
Discover all simplified distillation models in the correct directory.
|
949 |
-
|
950 |
-
Looks for directories matching the pattern: ./code_model2vec/final/code_model2vec_*
|
951 |
-
"""
|
952 |
-
discovered_models: list[str] = []
|
953 |
-
|
954 |
-
# Look in the correct location where distill_simplified.py saves models
|
955 |
-
models_dir = Path(base_path) / "code_model2vec" / "final"
|
956 |
-
|
957 |
-
if not models_dir.exists():
|
958 |
-
logger.warning(f"Models directory not found: {models_dir}")
|
959 |
-
return discovered_models
|
960 |
-
|
961 |
-
# Look for simplified model directories with the updated pattern
|
962 |
-
pattern = "code_model2vec_*"
|
963 |
-
for model_dir in models_dir.glob(pattern):
|
964 |
-
if model_dir.is_dir() and (model_dir / "config.json").exists():
|
965 |
-
discovered_models.append(str(model_dir))
|
966 |
-
logger.info(f"🔍 Discovered simplified model: {model_dir}")
|
967 |
-
|
968 |
-
# Sort alphabetically for consistent ordering
|
969 |
-
discovered_models.sort()
|
970 |
-
|
971 |
-
return discovered_models
|
972 |
-
|
973 |
-
|
974 |
-
@function(
|
975 |
-
gpu=GPU_NAME,
|
976 |
-
volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
|
977 |
-
image=IMAGE,
|
978 |
-
secrets=["HF_ACCESS_TOKEN"],
|
979 |
-
env={
|
980 |
-
"TOKENIZERS_PARALLELISM": "false",
|
981 |
-
"CUDA_LAUNCH_BLOCKING": "0",
|
982 |
-
},
|
983 |
-
timeout=3600 * 3, # 3 hours for simplified models only
|
984 |
-
)
|
985 |
-
def benchmark_simplified_only() -> None:
|
986 |
-
"""Benchmark only simplified distillation models, skipping 3rd party models."""
|
987 |
-
logger.info("🚀 Starting simplified distillation models benchmarking on Beam")
|
988 |
-
logger.info("⏭️ Skipping 3rd party models - benchmarking only simplified distillation models")
|
989 |
-
|
990 |
-
# Discover simplified distillation models
|
991 |
-
logger.info("🔍 Discovering simplified distillation models...")
|
992 |
-
discovered_models = discover_simplified_models(".")
|
993 |
-
|
994 |
-
if not discovered_models:
|
995 |
-
logger.error("❌ No simplified distillation models found! Run distill-simple first.")
|
996 |
-
return
|
997 |
-
|
998 |
-
logger.info(f"✅ Found {len(discovered_models)} simplified models:")
|
999 |
-
for model_path in discovered_models:
|
1000 |
-
logger.info(f" 📁 {model_path}")
|
1001 |
-
|
1002 |
-
logger.info("\n💡 Checkpoint Info:")
|
1003 |
-
logger.info(" - Already benchmarked models will be skipped")
|
1004 |
-
logger.info(" - Results are saved persistently to Beam volume")
|
1005 |
-
|
1006 |
-
# Run comprehensive benchmark using Beam utilities
|
1007 |
-
results = beam_benchmark_models(
|
1008 |
-
models=discovered_models,
|
1009 |
-
quick=True, # Use quick benchmark for efficiency
|
1010 |
-
output_dir=str(Path(VOLUME_PATH) / BENCHMARK_RESULTS_DIR),
|
1011 |
-
volume_name=VOLUME_NAME,
|
1012 |
-
mount_path=VOLUME_PATH,
|
1013 |
-
)
|
1014 |
-
|
1015 |
-
# Print final summary
|
1016 |
-
print("\n🎯 Simplified Benchmarking Summary:")
|
1017 |
-
print(f"📊 Total simplified models processed: {len(results)}")
|
1018 |
-
print(f"💾 Results saved to Beam volume: {VOLUME_NAME}")
|
1019 |
-
print(f"📁 Directory: {BENCHMARK_RESULTS_DIR}")
|
1020 |
-
print("⏭️ 3rd party models were skipped")
|
1021 |
-
print("\n🔍 To view analysis:")
|
1022 |
-
print(" distiller analyze")
|
1023 |
-
print("\n📈 To run full benchmarks (including 3rd party):")
|
1024 |
-
print(" distiller benchmark")
|
1025 |
-
|
1026 |
-
|
1027 |
-
def run_local_benchmark(
|
1028 |
-
models: list[str] | None = None,
|
1029 |
-
quick: bool = False,
|
1030 |
-
output_dir: str = DEFAULT_OUTPUT_DIR,
|
1031 |
-
) -> list[dict[str, Any]]:
|
1032 |
-
"""Main benchmarking function for local execution without Beam utilities."""
|
1033 |
-
logger.info("🖥️ Running performance benchmarking locally")
|
1034 |
-
|
1035 |
-
if models is None:
|
1036 |
-
models = DEFAULT_BENCHMARK_MODELS.copy()
|
1037 |
-
|
1038 |
-
# Replace "gte_qwen2_m2v_code" with a reasonable local path
|
1039 |
-
for i, model in enumerate(models):
|
1040 |
-
if model == "gte_qwen2_m2v_code":
|
1041 |
-
# Look for local trained model
|
1042 |
-
local_model_paths = [
|
1043 |
-
"./gte_qwen2_m2v_code",
|
1044 |
-
"./models/gte_qwen2_m2v_code",
|
1045 |
-
"./output/gte_qwen2_m2v_code",
|
1046 |
-
]
|
1047 |
-
found = False
|
1048 |
-
for local_path in local_model_paths:
|
1049 |
-
if Path(local_path).exists():
|
1050 |
-
models[i] = local_path
|
1051 |
-
logger.info(f"🎯 Found local trained model: {local_path}")
|
1052 |
-
found = True
|
1053 |
-
break
|
1054 |
-
if not found:
|
1055 |
-
logger.warning("⚠️ Local trained model not found, skipping")
|
1056 |
-
models.pop(i)
|
1057 |
-
|
1058 |
-
# Discover simplified distillation models
|
1059 |
-
logger.info("🔍 Discovering simplified distillation models...")
|
1060 |
-
discovered_models = discover_simplified_models(".")
|
1061 |
-
|
1062 |
-
# Add discovered models
|
1063 |
-
if discovered_models:
|
1064 |
-
logger.info(f"✅ Found {len(discovered_models)} simplified models:")
|
1065 |
-
for model_path in discovered_models:
|
1066 |
-
models.append(model_path)
|
1067 |
-
logger.info(f" 📁 {model_path}")
|
1068 |
-
else:
|
1069 |
-
logger.warning("⚠️ No simplified distillation models found")
|
1070 |
-
|
1071 |
-
logger.info(f"📊 Benchmarking {len(models)} models")
|
1072 |
-
logger.info(f"📁 Using local output directory: {output_dir}")
|
1073 |
-
|
1074 |
-
# Create local output directory
|
1075 |
-
output_path = Path(output_dir)
|
1076 |
-
output_path.mkdir(parents=True, exist_ok=True)
|
1077 |
-
|
1078 |
-
all_results = []
|
1079 |
-
skipped_models = []
|
1080 |
-
|
1081 |
-
for model_path in models:
|
1082 |
-
model_name = Path(model_path).name
|
1083 |
-
|
1084 |
-
# Check for existing benchmark results locally
|
1085 |
-
safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
|
1086 |
-
result_file = output_path / f"benchmark_{safe_name}.json"
|
1087 |
-
|
1088 |
-
if result_file.exists():
|
1089 |
-
logger.info(f"✅ Model {model_name} already benchmarked - loading existing results")
|
1090 |
-
try:
|
1091 |
-
with result_file.open("r") as f:
|
1092 |
-
existing_results = json.load(f)
|
1093 |
-
all_results.append(existing_results)
|
1094 |
-
skipped_models.append(model_name)
|
1095 |
-
continue
|
1096 |
-
except Exception as e:
|
1097 |
-
logger.warning(f"⚠️ Failed to load existing results for {model_name}: {e}")
|
1098 |
-
|
1099 |
-
logger.info(f"\n{'=' * 60}")
|
1100 |
-
logger.info(f"🔍 Benchmarking model: {model_name}")
|
1101 |
-
logger.info(f"📂 Path: {model_path}")
|
1102 |
-
logger.info(f"{'=' * 60}")
|
1103 |
-
|
1104 |
-
try:
|
1105 |
-
# Create benchmarker without Beam utilities
|
1106 |
-
benchmarker = PerformanceBenchmark(
|
1107 |
-
model_path,
|
1108 |
-
model_name,
|
1109 |
-
checkpoint_manager=None, # No checkpointing for local benchmarking
|
1110 |
-
eval_manager=None,
|
1111 |
-
)
|
1112 |
-
|
1113 |
-
# Run benchmarking
|
1114 |
-
if quick:
|
1115 |
-
# Quick benchmark
|
1116 |
-
benchmarker.load_model()
|
1117 |
-
benchmarker.measure_model_size()
|
1118 |
-
benchmarker.benchmark_inference_speed([1, 16, 32])
|
1119 |
-
else:
|
1120 |
-
# Comprehensive benchmark
|
1121 |
-
benchmarker.run_comprehensive_benchmark()
|
1122 |
-
|
1123 |
-
# Save results locally only
|
1124 |
-
save_benchmark_results(benchmarker.results, output_dir, model_name, volume_results_dir=None)
|
1125 |
-
|
1126 |
-
# Print summary
|
1127 |
-
benchmarker.print_summary()
|
1128 |
-
|
1129 |
-
all_results.append(benchmarker.results)
|
1130 |
-
|
1131 |
-
except Exception:
|
1132 |
-
logger.exception(f"❌ Failed to benchmark {model_name}")
|
1133 |
-
continue
|
1134 |
-
|
1135 |
-
# Create comparison report locally
|
1136 |
-
if len(all_results) > 1:
|
1137 |
-
create_benchmark_comparison(all_results, str(output_path / "benchmark_comparison.json"))
|
1138 |
-
logger.info(f"📊 Comparison report saved locally: {output_dir}")
|
1139 |
-
|
1140 |
-
# Log summary
|
1141 |
-
newly_benchmarked = len(all_results) - len(skipped_models)
|
1142 |
-
logger.info("\n✅ Local benchmarking complete!")
|
1143 |
-
logger.info(f"📊 Newly benchmarked: {newly_benchmarked} models")
|
1144 |
-
logger.info(f"⏭️ Skipped (already done): {len(skipped_models)} models")
|
1145 |
-
logger.info(f"📁 Total results: {len(all_results)} models")
|
1146 |
-
logger.info(f"💾 Results available locally: {output_dir}")
|
1147 |
-
|
1148 |
-
if skipped_models:
|
1149 |
-
logger.info(f"⏭️ Skipped models: {', '.join(skipped_models)}")
|
1150 |
-
|
1151 |
-
return all_results
|
1152 |
-
|
1153 |
-
|
1154 |
-
def run_local_benchmark_simplified(
|
1155 |
-
quick: bool = False,
|
1156 |
-
output_dir: str = DEFAULT_OUTPUT_DIR,
|
1157 |
-
) -> list[dict[str, Any]]:
|
1158 |
-
"""Local benchmarking function for simplified models only."""
|
1159 |
-
logger.info("🖥️ Running simplified model benchmarking locally")
|
1160 |
-
|
1161 |
-
# Discover simplified distillation models only
|
1162 |
-
logger.info("🔍 Discovering simplified distillation models...")
|
1163 |
-
discovered_models = discover_simplified_models(".")
|
1164 |
-
|
1165 |
-
if not discovered_models:
|
1166 |
-
logger.error("❌ No simplified distillation models found! Run 'distiller distill-simple' first.")
|
1167 |
-
return []
|
1168 |
-
|
1169 |
-
logger.info(f"✅ Found {len(discovered_models)} simplified models:")
|
1170 |
-
for model_path in discovered_models:
|
1171 |
-
logger.info(f" 📁 {model_path}")
|
1172 |
-
|
1173 |
-
return run_local_benchmark(
|
1174 |
-
models=discovered_models,
|
1175 |
-
quick=quick,
|
1176 |
-
output_dir=output_dir,
|
1177 |
-
)
|
1178 |
-
|
1179 |
-
|
1180 |
-
if __name__ == "__main__":
|
1181 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/distiller/config.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Shared configuration for the distiller package.
|
3 |
+
|
4 |
+
This module centralizes all configuration constants, default values, and common
|
5 |
+
settings used across distillation, evaluation, and benchmarking modules.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Any
|
11 |
+
|
12 |
+
from beam import GpuType, Image
|
13 |
+
from pydantic import BaseModel
|
14 |
+
|
15 |
+
# =============================================================================
|
16 |
+
# LOGGING CONFIGURATION
|
17 |
+
# =============================================================================
|
18 |
+
|
19 |
+
|
20 |
+
def setup_logging(level: int = logging.INFO) -> None:
|
21 |
+
"""Set up consistent logging across the package."""
|
22 |
+
log_dir = Path("logs")
|
23 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
24 |
+
log_path = log_dir / "distiller.log"
|
25 |
+
logging.basicConfig(
|
26 |
+
level=level,
|
27 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
28 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(log_path, mode="a")],
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
# =============================================================================
|
33 |
+
# BEAM CLOUD CONFIGURATION
|
34 |
+
# =============================================================================
|
35 |
+
|
36 |
+
# Beam execution settings
|
37 |
+
GPU_NAME = GpuType.A100_40
|
38 |
+
|
39 |
+
|
40 |
+
# Volume configurations for different workflows
|
41 |
+
class VolumeConfig(BaseModel):
|
42 |
+
"""Volume configuration container."""
|
43 |
+
|
44 |
+
name: str
|
45 |
+
mount_path: str
|
46 |
+
description: str = ""
|
47 |
+
|
48 |
+
|
49 |
+
# Define volume configurations - code_model2vec is the primary volume for all workflows
|
50 |
+
VOLUMES: dict[str, VolumeConfig] = {
|
51 |
+
"primary": VolumeConfig(
|
52 |
+
name="code_model2vec",
|
53 |
+
mount_path="./code_model2vec",
|
54 |
+
description="Primary volume for all distillation models, evaluations, benchmarks, and checkpoints",
|
55 |
+
),
|
56 |
+
# Legacy volume name mapping for backwards compatibility
|
57 |
+
"simplified": VolumeConfig(
|
58 |
+
name="code_model2vec",
|
59 |
+
mount_path="./code_model2vec",
|
60 |
+
description="Primary volume for all distillation models, evaluations, benchmarks, and checkpoints",
|
61 |
+
),
|
62 |
+
}
|
63 |
+
|
64 |
+
# Default volume name for all workflows
|
65 |
+
DEFAULT_VOLUME = "primary"
|
66 |
+
|
67 |
+
# Beam environment settings
|
68 |
+
BEAM_ENV_SETTINGS: dict[str, str] = {
|
69 |
+
"TOKENIZERS_PARALLELISM": "false",
|
70 |
+
"CUDA_LAUNCH_BLOCKING": "0",
|
71 |
+
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,max_split_size_mb:512",
|
72 |
+
"TORCH_CUDNN_V8_API_ENABLED": "1",
|
73 |
+
}
|
74 |
+
|
75 |
+
# Common Python packages for Beam images
|
76 |
+
COMMON_PACKAGES: list[str] = [
|
77 |
+
"torch>=2.7.0",
|
78 |
+
"transformers>=4.40.0",
|
79 |
+
"datasets>=3.2.0",
|
80 |
+
"sentence-transformers>=4.1.0",
|
81 |
+
"model2vec[train]>=0.5.0",
|
82 |
+
"numpy>=1.26.4",
|
83 |
+
"scikit-learn>=1.6.1",
|
84 |
+
"pandas>=2.0.0",
|
85 |
+
"tqdm>=4.65.0",
|
86 |
+
"plotly>=5.0.0",
|
87 |
+
"matplotlib>=3.7.0",
|
88 |
+
"seaborn>=0.12.0",
|
89 |
+
]
|
90 |
+
|
91 |
+
# Create common Beam image
|
92 |
+
IMAGE = Image(python_version="python3.12").add_python_packages(COMMON_PACKAGES)
|
93 |
+
|
94 |
+
# =============================================================================
|
95 |
+
# MODEL CONFIGURATION
|
96 |
+
# =============================================================================
|
97 |
+
|
98 |
+
# Teacher model configurations
|
99 |
+
TEACHER_MODELS: list[str] = [
|
100 |
+
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
101 |
+
"BAAI/bge-m3",
|
102 |
+
"jinaai/jina-embeddings-v3",
|
103 |
+
"lightonai/Reason-ModernColBERT",
|
104 |
+
"Linq-AI-Research/Linq-Embed-Mistral",
|
105 |
+
"microsoft/codebert-base",
|
106 |
+
"microsoft/graphcodebert-base",
|
107 |
+
"nomic-ai/nomic-embed-text-v2-moe",
|
108 |
+
"Qodo/Qodo-Embed-1-1.5B",
|
109 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
110 |
+
"sentence-transformers/all-mpnet-base-v2",
|
111 |
+
"sentence-transformers/paraphrase-MiniLM-L6-v2",
|
112 |
+
"nomic-ai/nomic-embed-code",
|
113 |
+
"nomic-ai/CodeRankEmbed",
|
114 |
+
]
|
115 |
+
|
116 |
+
# Default evaluation models for comparison
|
117 |
+
DEFAULT_EVALUATION_MODELS: list[str] = [
|
118 |
+
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
119 |
+
"BAAI/bge-m3",
|
120 |
+
"huggingface/CodeBERTa-small-v1",
|
121 |
+
"jinaai/jina-embeddings-v3",
|
122 |
+
"lightonai/Reason-ModernColBERT",
|
123 |
+
"Linq-AI-Research/Linq-Embed-Mistral",
|
124 |
+
"microsoft/codebert-base",
|
125 |
+
"microsoft/graphcodebert-base",
|
126 |
+
"minishlab/potion-base-8M",
|
127 |
+
"minishlab/potion-retrieval-32M",
|
128 |
+
"nomic-ai/nomic-embed-text-v2-moe",
|
129 |
+
"Qodo/Qodo-Embed-1-1.5B",
|
130 |
+
"Salesforce/codet5-base",
|
131 |
+
"sentence-transformers/all-MiniLM-L12-v2",
|
132 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
133 |
+
"sentence-transformers/all-mpnet-base-v2",
|
134 |
+
"sentence-transformers/paraphrase-MiniLM-L6-v2",
|
135 |
+
"nvidia/NV-Embed-v2",
|
136 |
+
"nomic-ai/nomic-embed-code",
|
137 |
+
"nomic-ai/CodeRankEmbed",
|
138 |
+
]
|
139 |
+
|
140 |
+
|
141 |
+
# Model2Vec distillation parameters
|
142 |
+
class DistillationConfig(BaseModel):
|
143 |
+
"""Configuration for Model2Vec distillation parameters."""
|
144 |
+
|
145 |
+
# Teacher models for distillation
|
146 |
+
code_teacher_models: list[str] = TEACHER_MODELS
|
147 |
+
|
148 |
+
# Basic distillation parameters
|
149 |
+
optimal_pca_dims: int = 256
|
150 |
+
sif_coefficient: float = 1e-3
|
151 |
+
apply_zipf: bool = True
|
152 |
+
|
153 |
+
# Training parameters (used when --train flag is enabled)
|
154 |
+
training_epochs: int = 2
|
155 |
+
learning_rate: float = 1e-4
|
156 |
+
batch_size: int = 32
|
157 |
+
max_training_samples: int = 50000
|
158 |
+
teacher_model_config: dict[str, Any] = {}
|
159 |
+
|
160 |
+
|
161 |
+
distillation_config = DistillationConfig()
|
162 |
+
|
163 |
+
|
164 |
+
# =============================================================================
|
165 |
+
# DATASET CONFIGURATION
|
166 |
+
# =============================================================================
|
167 |
+
|
168 |
+
|
169 |
+
# Add a LanguagesConfig Pydantic model
|
170 |
+
class LanguagesConfig(BaseModel):
|
171 |
+
"""Configuration for languages used in evaluation."""
|
172 |
+
|
173 |
+
all: list[str] = [
|
174 |
+
"python",
|
175 |
+
"java",
|
176 |
+
"javascript",
|
177 |
+
"php",
|
178 |
+
"ruby",
|
179 |
+
"go",
|
180 |
+
]
|
181 |
+
|
182 |
+
|
183 |
+
languages_config = LanguagesConfig()
|
184 |
+
|
185 |
+
|
186 |
+
# Update CodeSearchNetConfig to use languages_config.all as the default for evaluation_languages
|
187 |
+
class CodeSearchNetConfig(BaseModel):
|
188 |
+
"""Configuration for CodeSearchNet evaluation settings."""
|
189 |
+
|
190 |
+
dataset_name: str = "code_search_net"
|
191 |
+
evaluation_languages: list[str] = languages_config.all
|
192 |
+
max_queries_per_language: int = 1000
|
193 |
+
similarity_threshold: float = 0.7
|
194 |
+
evaluation_metrics: list[str] = ["ndcg@1", "ndcg@5", "ndcg@10", "mrr", "recall@1", "recall@5", "recall@10"]
|
195 |
+
|
196 |
+
|
197 |
+
codesearchnet_config = CodeSearchNetConfig()
|
198 |
+
|
199 |
+
# Training dataset configurations
|
200 |
+
TRAINING_DATASETS: dict[str, str] = {
|
201 |
+
"codesearchnet": "sentence-transformers/codesearchnet",
|
202 |
+
"code_search_net": "code_search_net",
|
203 |
+
}
|
204 |
+
|
205 |
+
# =============================================================================
|
206 |
+
# OUTPUT DIRECTORY CONFIGURATION
|
207 |
+
# =============================================================================
|
208 |
+
|
209 |
+
|
210 |
+
# Standardized directory structure within code_model2vec
|
211 |
+
class StandardDirectories(BaseModel):
|
212 |
+
"""Standardized directory structure for code_model2vec workspace."""
|
213 |
+
|
214 |
+
# Root directory
|
215 |
+
root: str = "code_model2vec"
|
216 |
+
|
217 |
+
# Model directories
|
218 |
+
base: str = "code_model2vec/base" # Basic distilled models
|
219 |
+
final: str = "code_model2vec/final" # Final trained models
|
220 |
+
models: str = "code_model2vec/models" # Legacy/alternative models location
|
221 |
+
|
222 |
+
# Results directories
|
223 |
+
evaluation_results: str = "code_model2vec/evaluation_results"
|
224 |
+
benchmark_results: str = "code_model2vec/benchmark_results"
|
225 |
+
analysis_results: str = "code_model2vec/analysis_results"
|
226 |
+
|
227 |
+
# Working directories
|
228 |
+
checkpoints: str = "code_model2vec/checkpoints"
|
229 |
+
cache: str = "code_model2vec/cache"
|
230 |
+
temp: str = "code_model2vec/temp"
|
231 |
+
|
232 |
+
|
233 |
+
# Create global instance
|
234 |
+
directories = StandardDirectories()
|
235 |
+
|
236 |
+
|
237 |
+
# Legacy OutputDirs for backwards compatibility
|
238 |
+
class OutputDirs(BaseModel):
|
239 |
+
"""Base output directory structure for storing models, checkpoints, and results."""
|
240 |
+
|
241 |
+
base: str = "base"
|
242 |
+
models: str = "final"
|
243 |
+
checkpoints: str = "checkpoints"
|
244 |
+
evaluation_results: str = "evaluation_results"
|
245 |
+
benchmark_results: str = "benchmark_results"
|
246 |
+
analysis_results: str = "analysis_results"
|
247 |
+
cache: str = "cache"
|
248 |
+
|
249 |
+
|
250 |
+
output_dirs = OutputDirs()
|
251 |
+
|
252 |
+
|
253 |
+
# File naming patterns
|
254 |
+
class FilenamePatterns(BaseModel):
|
255 |
+
"""File naming patterns for evaluation, benchmark, checkpoint, and model files."""
|
256 |
+
|
257 |
+
evaluation: str = "codesearchnet_eval_{model_name}.json"
|
258 |
+
bencmark: str = "benchmark_{model_name}.json"
|
259 |
+
checkpoint: str = "checkpoints_{stage}_step_{step}.json"
|
260 |
+
model: str = "{teacher_model}_{dims}d"
|
261 |
+
|
262 |
+
|
263 |
+
filename_patterns = FilenamePatterns()
|
264 |
+
|
265 |
+
# =============================================================================
|
266 |
+
# ANALYSIS AND VISUALIZATION
|
267 |
+
# =============================================================================
|
268 |
+
|
269 |
+
|
270 |
+
# Chart configuration
|
271 |
+
class ChartConfig(BaseModel):
|
272 |
+
"""Chart configuration for analysis and visualization."""
|
273 |
+
|
274 |
+
figsize: tuple[int, int] = (12, 8)
|
275 |
+
dpi: int = 300
|
276 |
+
style: str = "whitegrid"
|
277 |
+
color_palette: str = "Set2"
|
278 |
+
save_formats: list[str] = ["png", "pdf"]
|
279 |
+
|
280 |
+
|
281 |
+
chart_config = ChartConfig()
|
282 |
+
|
283 |
+
|
284 |
+
# Performance thresholds for analysis
|
285 |
+
class PerformanceThresholds(BaseModel):
|
286 |
+
"""Performance thresholds for analysis results."""
|
287 |
+
|
288 |
+
excellent: float = 0.7
|
289 |
+
good: float = 0.5
|
290 |
+
fair: float = 0.3
|
291 |
+
pour: float = 0.1
|
292 |
+
|
293 |
+
|
294 |
+
performance_thresholds = PerformanceThresholds()
|
295 |
+
|
296 |
+
# =============================================================================
|
297 |
+
# HELPER FUNCTIONS
|
298 |
+
# =============================================================================
|
299 |
+
|
300 |
+
|
301 |
+
def get_volume_config() -> VolumeConfig:
|
302 |
+
"""Get volume configuration for any workflow - always returns the primary code_model2vec volume."""
|
303 |
+
return VOLUMES["primary"]
|
304 |
+
|
305 |
+
|
306 |
+
def get_output_path(base_path: str | Path, output_type: str) -> Path:
|
307 |
+
"""Get standardized output path for different types of outputs."""
|
308 |
+
base = Path(base_path)
|
309 |
+
if hasattr(output_dirs, output_type):
|
310 |
+
return base / getattr(output_dirs, output_type)
|
311 |
+
return base / output_type
|
312 |
+
|
313 |
+
|
314 |
+
def get_standard_directory(dir_type: str) -> str:
|
315 |
+
"""Get standardized directory path for any directory type."""
|
316 |
+
if hasattr(directories, dir_type):
|
317 |
+
return getattr(directories, dir_type)
|
318 |
+
# Default to relative path within code_model2vec
|
319 |
+
return f"code_model2vec/{dir_type}"
|
320 |
+
|
321 |
+
|
322 |
+
def ensure_checkpoint_directory(stage: str) -> str:
|
323 |
+
"""Ensure checkpoint directory exists for a specific stage and return the path."""
|
324 |
+
checkpoint_dir = f"{directories.checkpoints}/{stage}"
|
325 |
+
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
326 |
+
return checkpoint_dir
|
327 |
+
|
328 |
+
|
329 |
+
def format_filename(pattern_key: str, **kwargs: Any) -> str:
|
330 |
+
"""Format filename using predefined patterns."""
|
331 |
+
if hasattr(filename_patterns, pattern_key):
|
332 |
+
return getattr(filename_patterns, pattern_key).format(**kwargs)
|
333 |
+
msg = f"Unknown filename pattern: {pattern_key}"
|
334 |
+
raise ValueError(msg)
|
335 |
+
|
336 |
+
|
337 |
+
def get_safe_model_name(model_name: str) -> str:
|
338 |
+
"""Convert model name to filesystem-safe name."""
|
339 |
+
return "".join(c for c in model_name if c.isalnum() or c in ("-", "_", ".")).replace("/", "_")
|
src/distiller/distill.py
CHANGED
@@ -1,31 +1,35 @@
|
|
1 |
"""
|
2 |
-
Code-Specialized Model2Vec Distillation Script
|
3 |
|
4 |
-
This script
|
5 |
-
using Model2Vec distillation with
|
6 |
|
7 |
Features:
|
8 |
-
-
|
9 |
-
-
|
10 |
-
-
|
11 |
-
-
|
12 |
-
- Smart
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
"""
|
18 |
|
19 |
import json
|
20 |
import logging
|
21 |
-
import os
|
22 |
import time
|
23 |
from pathlib import Path
|
24 |
-
from typing import Any
|
25 |
|
26 |
import numpy as np
|
27 |
import torch
|
28 |
-
|
|
|
29 |
from datasets import load_dataset
|
30 |
from model2vec.distill import distill
|
31 |
from model2vec.train.base import FinetunableStaticModel, TextDataset
|
@@ -35,239 +39,474 @@ from torch import nn, optim
|
|
35 |
|
36 |
from .beam_utils import (
|
37 |
BeamCheckpointManager,
|
38 |
-
BeamModelManager,
|
39 |
create_beam_utilities,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
)
|
41 |
|
42 |
# =============================================================================
|
43 |
-
#
|
44 |
# =============================================================================
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
BATCH_SIZE = 32
|
56 |
-
REGULARIZATION_WEIGHT = 0.01
|
57 |
-
|
58 |
-
# CodeSearchNet dataset configuration
|
59 |
-
CODESEARCHNET_DATASET = "sentence-transformers/codesearchnet"
|
60 |
-
MAX_TRAINING_SAMPLES = 50000 # Limit for manageable training time
|
61 |
-
|
62 |
-
# Checkpoint configuration
|
63 |
-
CHECKPOINT_INTERVAL = 1000 # Save every N samples
|
64 |
-
EMBEDDINGS_BATCH_SIZE = 100 # Save embeddings in smaller batches
|
65 |
-
|
66 |
-
# OPTIMIZED TEACHER MODEL CONFIGURATION FOR 40GB VRAM
|
67 |
-
TEACHER_MODEL_CONFIG: dict[str, Any] = {
|
68 |
-
"batch_size": 12, # Slightly reduced due to float32 memory usage
|
69 |
-
"precision": "float32", # Use float32 for quality preservation
|
70 |
-
"max_seq_length": 8192, # Reduce from 32k default for better performance
|
71 |
-
"device_map": "auto", # Automatic device placement
|
72 |
-
"torch_dtype": torch.float32, # Use float32 for quality preservation
|
73 |
-
"trust_remote_code": True,
|
74 |
-
"use_flash_attention": True, # Try to enable flash attention if available
|
75 |
-
"attn_implementation": "flash_attention_2", # Use flash attention 2 if available
|
76 |
-
}
|
77 |
|
78 |
# =============================================================================
|
79 |
-
#
|
80 |
# =============================================================================
|
81 |
|
82 |
-
GPU_NAME = GpuType.A100_40
|
83 |
-
VOLUME_NAME = "gte_qwen2_m2v_code"
|
84 |
-
VOLUME_PATH = "./gte_qwen2_m2v_code"
|
85 |
-
IMAGE = Image(python_version="python3.12").add_python_packages(
|
86 |
-
[
|
87 |
-
"torch>=2.7.0", # Install torch first
|
88 |
-
"transformers>=4.40.0", # Latest transformers with flash attention support
|
89 |
-
"accelerate>=1.7.0",
|
90 |
-
"datasets>=3.2.0",
|
91 |
-
"model2vec[train]>=0.5.0",
|
92 |
-
"numpy>=1.26.4",
|
93 |
-
"scikit-learn>=1.6.1",
|
94 |
-
"sentence-transformers>=4.1.0",
|
95 |
-
]
|
96 |
-
)
|
97 |
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
100 |
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
"""Generate a hash of current configuration parameters for checkpoint validation."""
|
104 |
import hashlib
|
105 |
|
106 |
config_params = {
|
107 |
-
"
|
108 |
-
"
|
109 |
-
"
|
110 |
-
"
|
111 |
-
"max_samples": MAX_TRAINING_SAMPLES,
|
112 |
-
"codesearchnet_dataset": CODESEARCHNET_DATASET,
|
113 |
}
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
config_str = str(sorted(config_params.items()))
|
116 |
return hashlib.md5(config_str.encode()).hexdigest()[:12] # noqa: S324
|
117 |
|
118 |
|
119 |
-
def
|
120 |
-
"""
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
|
124 |
-
|
|
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
return False
|
135 |
|
136 |
-
# Additional validation checks
|
137 |
-
checkpoint_config = checkpoint_data.get("config", {})
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
142 |
return False
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
return False
|
149 |
|
150 |
-
|
151 |
-
logger.warning(
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
return False
|
155 |
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
|
160 |
-
def
|
|
|
|
|
|
|
|
|
|
|
161 |
"""
|
162 |
-
|
163 |
|
164 |
Args:
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
168 |
|
169 |
Returns:
|
170 |
-
|
171 |
"""
|
172 |
-
|
173 |
-
|
174 |
-
"config": {
|
175 |
-
"model_name": MODEL_NAME,
|
176 |
-
"pca_dims": PCA_DIMS,
|
177 |
-
"precision": TEACHER_MODEL_CONFIG["precision"],
|
178 |
-
"torch_dtype": str(TEACHER_MODEL_CONFIG["torch_dtype"]),
|
179 |
-
"max_samples": MAX_TRAINING_SAMPLES,
|
180 |
-
"codesearchnet_dataset": CODESEARCHNET_DATASET,
|
181 |
-
},
|
182 |
-
"stage": stage,
|
183 |
-
"step": step,
|
184 |
-
"timestamp": time.time(),
|
185 |
-
"data": data,
|
186 |
-
}
|
187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
192 |
) -> list[str]:
|
193 |
-
"""Load and format the
|
194 |
-
|
|
|
|
|
|
|
195 |
logger.info(f"Limiting to {max_samples} samples for training efficiency")
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
-
# Check for existing dataset checkpoint with validation
|
198 |
if checkpoint_manager:
|
199 |
checkpoint_data = checkpoint_manager.load_checkpoint("dataset", 0)
|
200 |
if checkpoint_data:
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
else:
|
209 |
-
logger.warning("🔄 Incompatible dataset checkpoint found, starting fresh")
|
210 |
-
# Clean up incompatible checkpoint
|
211 |
-
checkpoint_manager.cleanup_old_checkpoints("dataset", keep_latest=0)
|
212 |
-
texts = []
|
213 |
-
start_from = 0
|
214 |
-
else:
|
215 |
-
texts = []
|
216 |
-
start_from = 0
|
217 |
-
else:
|
218 |
-
texts = []
|
219 |
-
start_from = 0
|
220 |
|
221 |
try:
|
222 |
-
#
|
223 |
-
|
|
|
|
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
try:
|
229 |
-
next(dataset_iter)
|
230 |
-
except StopIteration:
|
231 |
-
break
|
232 |
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
235 |
break
|
236 |
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0)
|
252 |
-
logger.info(f"💾 Saved dataset checkpoint: {len(texts)} texts collected")
|
253 |
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
# Final checkpoint save
|
258 |
if checkpoint_manager:
|
259 |
-
checkpoint_data =
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0)
|
261 |
|
262 |
-
logger.info(f"Successfully loaded {len(
|
263 |
-
return
|
264 |
|
265 |
except Exception:
|
266 |
logger.exception("Error loading CodeSearchNet dataset")
|
267 |
return texts # Return what we have so far
|
268 |
|
269 |
|
270 |
-
def
|
271 |
teacher_model: SentenceTransformer,
|
272 |
texts: list[str],
|
273 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
@@ -275,13 +514,11 @@ def generate_teacher_embeddings_with_checkpoints(
|
|
275 |
"""Generate teacher embeddings for code training with checkpoint support."""
|
276 |
logger.info(f"Generating teacher embeddings for {len(texts)} texts...")
|
277 |
|
278 |
-
# Check for existing embeddings checkpoint
|
279 |
-
final_embeddings = None
|
280 |
-
|
281 |
if checkpoint_manager:
|
282 |
-
|
283 |
-
embeddings_path =
|
284 |
-
config_path =
|
285 |
|
286 |
if embeddings_path.exists() and config_path.exists():
|
287 |
try:
|
@@ -289,118 +526,78 @@ def generate_teacher_embeddings_with_checkpoints(
|
|
289 |
with config_path.open("r") as f:
|
290 |
config_data = json.load(f)
|
291 |
|
292 |
-
|
293 |
-
|
294 |
-
"config_hash": config_data.get("config_hash"),
|
295 |
-
"config": config_data.get("config", {}),
|
296 |
-
}
|
297 |
-
|
298 |
-
if validate_checkpoint_compatibility(checkpoint_data):
|
299 |
# Load the embeddings tensor
|
300 |
final_embeddings = torch.load(embeddings_path, map_location="cpu")
|
301 |
num_expected = config_data.get("num_texts", len(texts))
|
302 |
|
303 |
if final_embeddings.shape[0] >= num_expected:
|
304 |
-
logger.info(
|
305 |
-
|
306 |
-
|
307 |
-
return final_embeddings[: len(texts)] # Return only the needed amount
|
308 |
-
logger.info(
|
309 |
-
f"⚠️ Cached embeddings incomplete ({final_embeddings.shape[0]}/{num_expected}), regenerating"
|
310 |
-
)
|
311 |
-
final_embeddings = None
|
312 |
-
else:
|
313 |
-
logger.warning("🔄 Incompatible embeddings cache found, regenerating")
|
314 |
-
final_embeddings = None
|
315 |
except Exception as e:
|
316 |
logger.warning(f"Failed to load embeddings cache: {e}, regenerating...")
|
317 |
-
final_embeddings = None
|
318 |
-
|
319 |
-
# If we have complete embeddings, return them
|
320 |
-
if final_embeddings is not None:
|
321 |
-
return final_embeddings
|
322 |
|
323 |
# Generate embeddings from scratch
|
324 |
logger.info("Generating fresh teacher embeddings...")
|
325 |
|
326 |
-
|
327 |
-
batch_size_raw = TEACHER_MODEL_CONFIG["batch_size"]
|
328 |
-
current_batch_size: int = batch_size_raw if isinstance(batch_size_raw, int) else 16
|
329 |
-
logger.info(f"Using optimized batch size: {current_batch_size} for 40GB VRAM (7B model)")
|
330 |
-
|
331 |
embeddings_list = []
|
332 |
|
333 |
-
for i in range(0, len(texts),
|
334 |
-
batch_texts = texts[i : i +
|
335 |
|
336 |
try:
|
337 |
-
# Use optimized encoding with convert_to_tensor=True for efficiency
|
338 |
batch_embeddings = teacher_model.encode(
|
339 |
batch_texts,
|
340 |
convert_to_tensor=True,
|
341 |
-
batch_size=
|
342 |
-
show_progress_bar=False,
|
343 |
-
normalize_embeddings=True,
|
344 |
)
|
345 |
embeddings_list.append(batch_embeddings)
|
346 |
|
347 |
-
if i % (
|
348 |
logger.info(f"Generated embeddings for {i + len(batch_texts)}/{len(texts)} texts")
|
349 |
|
350 |
except torch.cuda.OutOfMemoryError:
|
351 |
-
logger.warning(
|
352 |
-
|
353 |
-
)
|
354 |
-
|
355 |
-
# Clear cache and reduce batch size
|
356 |
-
if torch.cuda.is_available():
|
357 |
-
torch.cuda.empty_cache()
|
358 |
-
|
359 |
-
current_batch_size = max(1, current_batch_size // 2)
|
360 |
|
361 |
# Retry with smaller batch size
|
362 |
-
batch_texts = texts[i : i + current_batch_size]
|
363 |
batch_embeddings = teacher_model.encode(
|
364 |
batch_texts,
|
365 |
convert_to_tensor=True,
|
366 |
-
batch_size=
|
367 |
show_progress_bar=False,
|
368 |
normalize_embeddings=True,
|
369 |
)
|
370 |
embeddings_list.append(batch_embeddings)
|
371 |
|
372 |
-
|
373 |
-
|
374 |
-
# Combine all embeddings and force fp32 precision
|
375 |
teacher_embeddings = torch.cat(embeddings_list, dim=0)
|
376 |
|
377 |
-
# Ensure
|
378 |
if teacher_embeddings.dtype != torch.float32:
|
379 |
-
logger.info(f"Converting teacher embeddings from {teacher_embeddings.dtype} to fp32")
|
380 |
teacher_embeddings = teacher_embeddings.to(torch.float32)
|
381 |
|
382 |
logger.info(f"Generated {teacher_embeddings.shape[0]} teacher embeddings in {teacher_embeddings.dtype}")
|
383 |
|
384 |
-
# Save embeddings cache
|
385 |
if checkpoint_manager:
|
386 |
try:
|
387 |
-
|
388 |
-
|
|
|
389 |
|
390 |
# Save embeddings tensor
|
391 |
torch.save(teacher_embeddings, embeddings_path)
|
392 |
|
393 |
# Save configuration
|
394 |
config_data = {
|
395 |
-
"config_hash": get_current_config_hash(),
|
396 |
-
"config": {
|
397 |
-
"model_name": MODEL_NAME,
|
398 |
-
"pca_dims": PCA_DIMS,
|
399 |
-
"precision": TEACHER_MODEL_CONFIG["precision"],
|
400 |
-
"torch_dtype": str(TEACHER_MODEL_CONFIG["torch_dtype"]),
|
401 |
-
"max_samples": MAX_TRAINING_SAMPLES,
|
402 |
-
"codesearchnet_dataset": CODESEARCHNET_DATASET,
|
403 |
-
},
|
404 |
"num_texts": len(texts),
|
405 |
"embedding_shape": list(teacher_embeddings.shape),
|
406 |
"timestamp": time.time(),
|
@@ -417,890 +614,953 @@ def generate_teacher_embeddings_with_checkpoints(
|
|
417 |
return teacher_embeddings
|
418 |
|
419 |
|
420 |
-
def
|
421 |
student_model: Any,
|
422 |
-
|
423 |
-
teacher_embeddings: torch.Tensor,
|
424 |
-
epochs: int = 2,
|
425 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
426 |
-
model_manager: BeamModelManager | None = None,
|
427 |
) -> Any:
|
428 |
-
"""
|
429 |
-
logger.info(
|
430 |
|
431 |
-
#
|
432 |
-
|
433 |
-
logger.error("student_model is None - cannot proceed with code training")
|
434 |
-
msg = "student_model cannot be None"
|
435 |
-
raise ValueError(msg)
|
436 |
|
437 |
-
if not
|
438 |
-
logger.
|
439 |
-
|
440 |
-
raise ValueError(msg)
|
441 |
|
442 |
-
|
443 |
-
|
444 |
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
logger.info("🎯 Enforcing fp32 precision throughout for maximum quality")
|
449 |
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
|
|
458 |
|
459 |
-
|
460 |
-
|
461 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
|
463 |
-
|
|
|
|
|
464 |
|
465 |
-
|
466 |
-
|
467 |
-
from sklearn.decomposition import PCA
|
468 |
|
469 |
-
|
470 |
-
|
471 |
|
472 |
-
|
473 |
-
|
474 |
-
teacher_embeddings_projected = pca.fit_transform(teacher_embeddings_np)
|
475 |
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
dtype=target_dtype,
|
480 |
-
)
|
481 |
-
logger.info(f"PCA projection completed: {teacher_embeddings.shape} with dtype {target_dtype}")
|
482 |
-
logger.info(
|
483 |
-
f"PCA preserved variance ratio: {pca.explained_variance_ratio_[:5].sum():.4f} (first 5 components)"
|
484 |
-
)
|
485 |
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
)
|
491 |
|
492 |
-
|
493 |
-
|
|
|
494 |
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
actual_model_dtype = param.dtype
|
503 |
-
break
|
504 |
|
505 |
-
|
506 |
-
|
507 |
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
|
|
512 |
|
513 |
-
|
514 |
-
if teacher_embeddings.dtype != target_dtype:
|
515 |
-
logger.warning(f"⚠️ Teacher embeddings not in {target_dtype}: {teacher_embeddings.dtype}")
|
516 |
-
if actual_model_dtype != target_dtype:
|
517 |
-
logger.warning(f"⚠️ Model parameters not in {target_dtype}: {actual_model_dtype}")
|
518 |
|
519 |
-
|
|
|
|
|
|
|
|
|
520 |
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
if tokens.shape[1] > 0:
|
526 |
-
tokenized_texts.append(tokens[0].tolist())
|
527 |
|
528 |
-
|
529 |
-
|
530 |
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
|
535 |
-
|
536 |
-
tokenized_texts, targets, test_size=0.2, random_state=42
|
537 |
-
)
|
538 |
|
539 |
-
logger.info(f"Train
|
540 |
-
logger.info(f"Val targets dtype: {val_targets.dtype}")
|
541 |
|
542 |
-
#
|
543 |
-
|
544 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
|
546 |
-
|
547 |
-
|
|
|
548 |
|
549 |
-
|
550 |
|
551 |
-
try:
|
552 |
-
trainable_model = trainable_model.to(device)
|
553 |
-
logger.info(f"Training on {device}")
|
554 |
-
except torch.cuda.OutOfMemoryError:
|
555 |
-
logger.warning("GPU OOM loading training model, using CPU")
|
556 |
-
device = torch.device("cpu")
|
557 |
-
trainable_model = trainable_model.to(device)
|
558 |
-
if torch.cuda.is_available():
|
559 |
-
torch.cuda.empty_cache()
|
560 |
|
561 |
-
|
562 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
sample_texts = training_texts[: min(5, len(training_texts))]
|
570 |
-
sample_tokens = trainable_model.tokenize(sample_texts)
|
571 |
-
sample_tokens = sample_tokens.to(device)
|
572 |
-
|
573 |
-
_, student_embeddings_before = trainable_model(sample_tokens)
|
574 |
-
sample_teacher_embeddings = targets[: len(sample_texts)].to(device)
|
575 |
-
|
576 |
-
# Compute average cosine similarity
|
577 |
-
similarities_before = []
|
578 |
-
for i in range(len(sample_texts)):
|
579 |
-
sim = torch.cosine_similarity(
|
580 |
-
student_embeddings_before[i].unsqueeze(0),
|
581 |
-
sample_teacher_embeddings[i].unsqueeze(0),
|
582 |
-
).item()
|
583 |
-
similarities_before.append(sim)
|
584 |
-
|
585 |
-
avg_similarity_before = np.mean(similarities_before)
|
586 |
-
logger.info(f"📊 Pre-training average teacher-student similarity: {avg_similarity_before:.4f}")
|
587 |
-
|
588 |
-
# Training loop with validation
|
589 |
-
for epoch in range(epochs):
|
590 |
-
# Training phase
|
591 |
-
trainable_model.train()
|
592 |
-
|
593 |
-
# Try with current batch size, reduce if OOM
|
594 |
-
train_successful = False
|
595 |
-
while not train_successful and adaptive_batch_size >= 1:
|
596 |
-
try:
|
597 |
-
train_loader = train_dataset.to_dataloader(shuffle=True, batch_size=adaptive_batch_size)
|
598 |
-
|
599 |
-
epoch_loss = 0.0
|
600 |
-
num_batches = 0
|
601 |
-
|
602 |
-
for batch_idx, (tokens, targets_batch) in enumerate(train_loader):
|
603 |
-
batch_tokens = tokens.to(device)
|
604 |
-
batch_targets = targets_batch.to(device)
|
605 |
-
|
606 |
-
optimizer.zero_grad()
|
607 |
-
_, student_embeddings = trainable_model(batch_tokens)
|
608 |
-
|
609 |
-
# Debug dtype information on first batch
|
610 |
-
if batch_idx == 0:
|
611 |
-
logger.info(
|
612 |
-
f"Batch {batch_idx}: tokens shape {batch_tokens.shape}, dtype {batch_tokens.dtype}"
|
613 |
-
)
|
614 |
-
logger.info(
|
615 |
-
f"Batch {batch_idx}: targets shape {batch_targets.shape}, dtype {batch_targets.dtype}"
|
616 |
-
)
|
617 |
-
logger.info(
|
618 |
-
f"Batch {batch_idx}: student_embeddings shape {student_embeddings.shape}, dtype {student_embeddings.dtype}"
|
619 |
-
)
|
620 |
-
|
621 |
-
# Force both tensors to fp32 to avoid any precision loss
|
622 |
-
if student_embeddings.dtype != target_dtype:
|
623 |
-
logger.warning(
|
624 |
-
f"Student embeddings not in fp32: {student_embeddings.dtype}, converting to fp32"
|
625 |
-
)
|
626 |
-
student_embeddings = student_embeddings.to(target_dtype)
|
627 |
-
if batch_targets.dtype != target_dtype:
|
628 |
-
logger.info(f"Converting targets from {batch_targets.dtype} to fp32")
|
629 |
-
batch_targets = batch_targets.to(target_dtype)
|
630 |
-
|
631 |
-
try:
|
632 |
-
loss = mse_loss(student_embeddings, batch_targets)
|
633 |
-
loss.backward()
|
634 |
-
optimizer.step()
|
635 |
-
except RuntimeError as e:
|
636 |
-
if "expected scalar type" in str(e):
|
637 |
-
logger.exception("Dtype mismatch error occurred:")
|
638 |
-
logger.exception(
|
639 |
-
f"student_embeddings: {student_embeddings.shape}, {student_embeddings.dtype}"
|
640 |
-
)
|
641 |
-
logger.exception(f"batch_targets: {batch_targets.shape}, {batch_targets.dtype}")
|
642 |
-
logger.exception(
|
643 |
-
f"MSE loss input dtypes: {student_embeddings.dtype} vs {batch_targets.dtype}"
|
644 |
-
)
|
645 |
-
# Force explicit casting to fp32 for maximum precision
|
646 |
-
batch_targets = batch_targets.to(target_dtype)
|
647 |
-
student_embeddings = student_embeddings.to(target_dtype)
|
648 |
-
logger.info("Emergency dtype fix: forced both to fp32")
|
649 |
-
loss = mse_loss(student_embeddings, batch_targets)
|
650 |
-
loss.backward()
|
651 |
-
optimizer.step()
|
652 |
-
else:
|
653 |
-
raise
|
654 |
-
|
655 |
-
epoch_loss += loss.item()
|
656 |
-
num_batches += 1
|
657 |
-
|
658 |
-
# Save training checkpoint periodically
|
659 |
-
if checkpoint_manager and batch_idx % 100 == 0:
|
660 |
-
training_state = {
|
661 |
-
"epoch": epoch,
|
662 |
-
"batch": batch_idx,
|
663 |
-
"model_state": trainable_model.state_dict(),
|
664 |
-
"optimizer_state": optimizer.state_dict(),
|
665 |
-
"loss": epoch_loss / max(1, num_batches),
|
666 |
-
}
|
667 |
-
checkpoint_data = create_checkpoint_data("training", training_state, epoch)
|
668 |
-
checkpoint_manager.save_checkpoint("training", checkpoint_data, epoch)
|
669 |
-
|
670 |
-
train_successful = True
|
671 |
-
|
672 |
-
except torch.cuda.OutOfMemoryError:
|
673 |
-
logger.warning(
|
674 |
-
f"Training OOM with batch size {adaptive_batch_size}, reducing to {adaptive_batch_size // 2}"
|
675 |
-
)
|
676 |
-
adaptive_batch_size = max(1, adaptive_batch_size // 2)
|
677 |
-
if torch.cuda.is_available():
|
678 |
-
torch.cuda.empty_cache()
|
679 |
|
680 |
-
|
681 |
-
|
682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
|
684 |
-
|
685 |
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
|
|
|
|
|
|
691 |
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
696 |
|
697 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
698 |
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
|
|
|
|
704 |
|
705 |
-
|
706 |
-
val_loss += loss.item()
|
707 |
-
val_batches += 1
|
708 |
|
709 |
-
|
|
|
|
|
710 |
|
711 |
-
|
712 |
-
|
713 |
-
)
|
714 |
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
|
|
|
|
|
|
|
|
|
|
723 |
}
|
724 |
-
checkpoint_data = create_checkpoint_data("epoch", epoch_state, epoch + 1)
|
725 |
-
checkpoint_manager.save_checkpoint("epoch", checkpoint_data, epoch + 1)
|
726 |
-
|
727 |
-
# Quality monitoring: compute embedding similarity after training
|
728 |
-
logger.info("🔍 Quality monitoring: Computing post-training teacher-student similarity...")
|
729 |
-
trainable_model.eval()
|
730 |
-
with torch.no_grad():
|
731 |
-
# Use the same sample texts as before
|
732 |
-
sample_texts = training_texts[: min(5, len(training_texts))]
|
733 |
-
sample_tokens = trainable_model.tokenize(sample_texts)
|
734 |
-
sample_tokens = sample_tokens.to(device)
|
735 |
-
|
736 |
-
_, student_embeddings_after = trainable_model(sample_tokens)
|
737 |
-
sample_teacher_embeddings = targets[: len(sample_texts)].to(device)
|
738 |
-
|
739 |
-
# Compute average cosine similarity
|
740 |
-
similarities_after = []
|
741 |
-
for i in range(len(sample_texts)):
|
742 |
-
sim = torch.cosine_similarity(
|
743 |
-
student_embeddings_after[i].unsqueeze(0),
|
744 |
-
sample_teacher_embeddings[i].unsqueeze(0),
|
745 |
-
).item()
|
746 |
-
similarities_after.append(sim)
|
747 |
-
|
748 |
-
avg_similarity_after = np.mean(similarities_after)
|
749 |
-
logger.info(f"📊 Post-training average teacher-student similarity: {avg_similarity_after:.4f}")
|
750 |
-
|
751 |
-
# Quality assessment
|
752 |
-
quality_change = avg_similarity_after - avg_similarity_before
|
753 |
-
logger.info(f"📈 Quality change: {quality_change:+.4f}")
|
754 |
-
|
755 |
-
if abs(quality_change) < 0.01:
|
756 |
-
logger.info("✅ Quality well preserved during training!")
|
757 |
-
elif quality_change > 0:
|
758 |
-
logger.info("✅ Quality improved during training!")
|
759 |
-
else:
|
760 |
-
logger.warning(f"⚠️ Quality degraded by {abs(quality_change):.4f} during training")
|
761 |
|
762 |
-
|
763 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
764 |
|
765 |
-
|
766 |
-
|
767 |
-
# Save to temporary local directory first
|
768 |
-
temp_refined_path = Path("./temp_refined_save")
|
769 |
-
temp_refined_path.mkdir(exist_ok=True)
|
770 |
-
refined_model.save_pretrained(str(temp_refined_path))
|
771 |
|
772 |
-
#
|
773 |
-
|
|
|
774 |
|
775 |
-
#
|
776 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
777 |
|
778 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
779 |
|
780 |
-
|
781 |
|
782 |
-
|
783 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
784 |
|
785 |
except Exception as e:
|
786 |
-
logger.
|
787 |
-
return
|
|
|
|
|
|
|
|
|
|
|
788 |
|
789 |
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
if model is None:
|
794 |
-
logger.error("Cannot apply regularization: model is None")
|
795 |
-
msg = "model cannot be None"
|
796 |
-
raise ValueError(msg)
|
797 |
|
798 |
-
if not hasattr(model, "embedding"):
|
799 |
-
logger.error(f"Cannot apply regularization: model of type {type(model)} does not have 'embedding' attribute")
|
800 |
-
msg = f"model must have 'embedding' attribute, got {type(model)}"
|
801 |
-
raise ValueError(msg)
|
802 |
|
803 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
804 |
|
805 |
-
|
806 |
-
|
807 |
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
|
814 |
-
|
815 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
816 |
|
817 |
-
|
818 |
-
|
|
|
|
|
|
|
819 |
|
820 |
-
|
821 |
-
norms = np.where(norms == 0, 1, norms)
|
822 |
-
norms = np.where(norms > 1e6, 1e6, norms) # Prevent extremely large norms
|
823 |
|
824 |
-
|
825 |
|
826 |
-
# Create new model
|
827 |
-
from model2vec.model import StaticModel
|
828 |
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
836 |
)
|
837 |
|
838 |
-
|
839 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
840 |
|
841 |
except Exception as e:
|
842 |
-
logger.
|
843 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
844 |
|
845 |
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
device: str = "cuda",
|
850 |
-
resume: bool = True,
|
851 |
-
) -> SentenceTransformer:
|
852 |
-
"""Load teacher model with local caching to avoid re-downloading."""
|
853 |
-
cache_dir = Path(output_dir) / "teacher_model_cache"
|
854 |
-
|
855 |
-
# Check if cached model exists
|
856 |
-
if resume and cache_dir.exists():
|
857 |
-
try:
|
858 |
-
logger.info(f"Loading cached teacher model from {cache_dir}")
|
859 |
-
teacher_model = SentenceTransformer(str(cache_dir), device=device)
|
860 |
|
861 |
-
# Set optimized sequence length
|
862 |
-
max_seq_len = TEACHER_MODEL_CONFIG.get("max_seq_length", 8192)
|
863 |
-
if isinstance(max_seq_len, int):
|
864 |
-
teacher_model.max_seq_length = max_seq_len
|
865 |
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
871 |
|
872 |
-
# Download and cache the model
|
873 |
-
logger.info(f"Downloading teacher model {model_name} (this may take a while)")
|
874 |
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
879 |
}
|
880 |
|
881 |
-
|
882 |
-
|
883 |
-
try:
|
884 |
-
model_kwargs["attn_implementation"] = TEACHER_MODEL_CONFIG["attn_implementation"]
|
885 |
-
logger.info("Flash Attention 2 enabled")
|
886 |
-
except Exception as e:
|
887 |
-
logger.warning(f"Flash Attention not available, using default attention: {e}")
|
888 |
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
logger.warning("Flash Attention 2 not available, falling back to default attention")
|
899 |
-
# Remove flash attention from model_kwargs and retry
|
900 |
-
model_kwargs_fallback = {k: v for k, v in model_kwargs.items() if k != "attn_implementation"}
|
901 |
-
teacher_model = SentenceTransformer(
|
902 |
-
model_name,
|
903 |
-
device=device,
|
904 |
-
trust_remote_code=bool(TEACHER_MODEL_CONFIG["trust_remote_code"]),
|
905 |
-
model_kwargs=model_kwargs_fallback,
|
906 |
-
)
|
907 |
-
else:
|
908 |
-
raise
|
909 |
|
910 |
-
|
911 |
-
max_seq_len = TEACHER_MODEL_CONFIG.get("max_seq_length", 8192)
|
912 |
-
if isinstance(max_seq_len, int):
|
913 |
-
teacher_model.max_seq_length = max_seq_len
|
914 |
-
logger.info(f"Set max_seq_length to {max_seq_len} for better performance")
|
915 |
|
916 |
-
|
|
|
|
|
917 |
try:
|
918 |
-
|
919 |
-
|
920 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
921 |
except Exception as e:
|
922 |
-
logger.warning(f"Failed to cache
|
923 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
924 |
|
925 |
-
return
|
926 |
|
927 |
|
928 |
-
def
|
929 |
-
|
930 |
-
output_dir: str
|
931 |
-
pca_dims: int =
|
932 |
-
max_samples: int = MAX_TRAINING_SAMPLES,
|
933 |
-
resume: bool = True,
|
934 |
) -> Any:
|
935 |
-
"""
|
|
|
|
|
|
|
936 |
output_path = Path(output_dir)
|
937 |
output_path.mkdir(parents=True, exist_ok=True)
|
938 |
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
logger.info(f"Starting code-specialized distillation of {model_name}")
|
943 |
-
logger.info(f"Using CodeSearchNet dataset: {CODESEARCHNET_DATASET}")
|
944 |
-
logger.info(f"Resume mode: {resume}")
|
945 |
-
|
946 |
-
# GPU Diagnostics
|
947 |
-
logger.info("=== GPU DIAGNOSTICS ===")
|
948 |
-
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
949 |
-
if torch.cuda.is_available():
|
950 |
-
logger.info(f"CUDA version: {torch.version.cuda}")
|
951 |
-
logger.info(f"GPU count: {torch.cuda.device_count()}")
|
952 |
-
for i in range(torch.cuda.device_count()):
|
953 |
-
gpu_name = torch.cuda.get_device_name(i)
|
954 |
-
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
955 |
-
logger.info(f"GPU {i}: {gpu_name} ({gpu_memory:.1f} GB)")
|
956 |
-
|
957 |
-
# Current GPU memory
|
958 |
-
current_device = torch.cuda.current_device()
|
959 |
-
allocated = torch.cuda.memory_allocated(current_device) / 1024**3
|
960 |
-
total = torch.cuda.get_device_properties(current_device).total_memory / 1024**3
|
961 |
-
logger.info(f"Current GPU {current_device}: {allocated:.2f}GB allocated, {total:.1f}GB total")
|
962 |
-
else:
|
963 |
-
logger.warning("CUDA not available - will use CPU (much slower)")
|
964 |
-
logger.info("======================")
|
965 |
|
966 |
start_time = time.time()
|
967 |
|
968 |
-
|
969 |
-
|
|
|
|
|
970 |
|
971 |
-
|
972 |
-
|
973 |
-
if resume:
|
974 |
-
# Check if model files exist directly in the volume root
|
975 |
-
try:
|
976 |
-
# Try to load from the volume root where the model was successfully saved
|
977 |
-
volume_root_path = Path(VOLUME_PATH)
|
978 |
-
if (volume_root_path / "config.json").exists() and (volume_root_path / "model.safetensors").exists():
|
979 |
-
logger.info("✅ Found existing model files in volume root")
|
980 |
-
from model2vec.model import StaticModel
|
981 |
|
982 |
-
|
983 |
-
|
984 |
-
|
985 |
-
logger.info("No existing model files found in volume root")
|
986 |
-
except Exception as e:
|
987 |
-
logger.warning(f"Failed to load existing model from volume: {e}")
|
988 |
-
m2v_model = None
|
989 |
|
990 |
-
|
991 |
-
|
992 |
-
if torch.cuda.is_available():
|
993 |
-
torch.cuda.empty_cache()
|
994 |
-
current_device = torch.cuda.current_device()
|
995 |
-
allocated = torch.cuda.memory_allocated(current_device) / 1024**3
|
996 |
-
total = torch.cuda.get_device_properties(current_device).total_memory / 1024**3
|
997 |
-
logger.info(f"GPU memory before distillation: {allocated:.2f}GB allocated / {total:.1f}GB total")
|
998 |
-
else:
|
999 |
-
logger.info("Using CPU for distillation")
|
1000 |
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
pca_dims=pca_dims,
|
1005 |
-
apply_zipf=None,
|
1006 |
-
sif_coefficient=1e-4,
|
1007 |
trust_remote_code=True,
|
|
|
|
|
1008 |
)
|
1009 |
-
logger.info("Basic distillation completed with preserved precision")
|
1010 |
|
1011 |
-
#
|
1012 |
-
if
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
|
|
1018 |
|
1019 |
-
#
|
1020 |
-
|
1021 |
|
1022 |
-
|
1023 |
-
logger.warning("GPU OOM during distillation, clearing cache and retrying...")
|
1024 |
-
torch.cuda.empty_cache()
|
1025 |
|
1026 |
-
|
1027 |
-
|
1028 |
|
1029 |
-
|
1030 |
-
|
1031 |
-
|
1032 |
-
|
1033 |
-
apply_zipf=None,
|
1034 |
-
sif_coefficient=1e-4,
|
1035 |
trust_remote_code=True,
|
|
|
1036 |
)
|
1037 |
-
logger.info("Basic distillation completed on CPU")
|
1038 |
|
1039 |
-
#
|
1040 |
-
if
|
1041 |
-
|
1042 |
-
raise ValueError(msg) from None
|
1043 |
|
1044 |
-
|
1045 |
-
|
|
|
1046 |
|
1047 |
-
|
1048 |
-
# model_mgr.save_model("base_distilled_model", str(output_path))
|
1049 |
|
1050 |
-
|
1051 |
-
|
1052 |
-
raise
|
1053 |
|
1054 |
-
|
1055 |
-
|
1056 |
-
|
1057 |
-
|
|
|
|
|
|
|
1058 |
|
1059 |
-
|
1060 |
-
logger.info("Step 2: Loading CodeSearchNet training data...")
|
1061 |
-
code_texts = load_codesearchnet_dataset_with_resume(max_samples, checkpoint_mgr)
|
1062 |
|
1063 |
-
|
1064 |
-
|
1065 |
-
|
1066 |
-
logger.info("Step 3: Code specialization training...")
|
1067 |
-
|
1068 |
-
# Check for existing refined model
|
1069 |
-
if resume:
|
1070 |
-
# Check if refined model exists in beam volume
|
1071 |
-
models = model_mgr.list_models()
|
1072 |
-
refined_model_exists = any(model["name"] == "refined_model" for model in models)
|
1073 |
-
|
1074 |
-
if refined_model_exists:
|
1075 |
-
# Download model to local path for loading
|
1076 |
-
temp_model_path = Path("./temp_refined_model")
|
1077 |
-
if model_mgr.load_model("refined_model", temp_model_path):
|
1078 |
-
try:
|
1079 |
-
from model2vec.model import StaticModel
|
1080 |
-
|
1081 |
-
refined_model = StaticModel.from_pretrained(str(temp_model_path / "refined_model"))
|
1082 |
-
logger.info("✅ Resumed from existing refined model")
|
1083 |
-
m2v_model = refined_model
|
1084 |
-
# Clean up temp directory
|
1085 |
-
import shutil
|
1086 |
-
|
1087 |
-
shutil.rmtree(temp_model_path, ignore_errors=True)
|
1088 |
-
except Exception as e:
|
1089 |
-
logger.warning(f"Failed to load existing refined model: {e}")
|
1090 |
-
refined_model = None
|
1091 |
-
# Clean up temp directory
|
1092 |
-
import shutil
|
1093 |
-
|
1094 |
-
shutil.rmtree(temp_model_path, ignore_errors=True)
|
1095 |
-
else:
|
1096 |
-
refined_model = None
|
1097 |
-
else:
|
1098 |
-
refined_model = None
|
1099 |
-
|
1100 |
-
if refined_model is None:
|
1101 |
-
# Load teacher model with memory management
|
1102 |
-
try:
|
1103 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
1104 |
-
logger.info(f"Loading teacher model on {device} with optimized settings")
|
1105 |
-
logger.info(
|
1106 |
-
f"Using precision: {TEACHER_MODEL_CONFIG['precision']}, batch_size: {TEACHER_MODEL_CONFIG['batch_size']}"
|
1107 |
-
)
|
1108 |
-
logger.info("Attempting to enable Flash Attention 2 for maximum performance")
|
1109 |
|
1110 |
-
|
|
|
|
|
|
|
|
|
1111 |
|
1112 |
-
|
1113 |
-
|
1114 |
-
teacher_model, code_texts, checkpoint_mgr
|
1115 |
-
)
|
1116 |
|
1117 |
-
|
1118 |
-
|
1119 |
-
|
1120 |
-
|
1121 |
-
|
1122 |
-
|
1123 |
-
checkpoint_manager=checkpoint_mgr,
|
1124 |
-
model_manager=model_mgr,
|
1125 |
-
)
|
1126 |
|
1127 |
-
|
1128 |
-
if torch.cuda.is_available():
|
1129 |
-
torch.cuda.empty_cache()
|
1130 |
-
|
1131 |
-
except torch.cuda.OutOfMemoryError:
|
1132 |
-
logger.warning("GPU OOM during code training, falling back to CPU...")
|
1133 |
-
|
1134 |
-
if torch.cuda.is_available():
|
1135 |
-
torch.cuda.empty_cache()
|
1136 |
-
|
1137 |
-
# Force CPU for teacher model with optimized settings (no flash attention on CPU)
|
1138 |
-
try:
|
1139 |
-
teacher_model = load_teacher_model_with_cache(
|
1140 |
-
model_name, output_dir, device="cpu", resume=resume
|
1141 |
-
)
|
1142 |
-
except ImportError as e:
|
1143 |
-
if "flash_attn" in str(e):
|
1144 |
-
logger.warning("Flash Attention 2 not available on CPU, using default attention")
|
1145 |
-
# Fallback without any special attention implementation
|
1146 |
-
teacher_model = load_teacher_model_with_cache(
|
1147 |
-
model_name, output_dir, device="cpu", resume=resume
|
1148 |
-
)
|
1149 |
-
else:
|
1150 |
-
raise
|
1151 |
-
|
1152 |
-
# Generate teacher embeddings on CPU with checkpoints
|
1153 |
-
teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
|
1154 |
-
teacher_model, code_texts, checkpoint_mgr
|
1155 |
-
)
|
1156 |
|
1157 |
-
|
1158 |
-
|
1159 |
-
|
1160 |
-
code_texts,
|
1161 |
-
teacher_embeddings,
|
1162 |
-
epochs=TRAINING_EPOCHS,
|
1163 |
-
checkpoint_manager=checkpoint_mgr,
|
1164 |
-
model_manager=model_mgr,
|
1165 |
-
)
|
1166 |
|
1167 |
-
del teacher_model
|
1168 |
-
else:
|
1169 |
-
# Fresh training without resume
|
1170 |
-
try:
|
1171 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
1172 |
-
logger.info(f"Loading teacher model on {device} with optimized settings")
|
1173 |
-
logger.info(
|
1174 |
-
f"Using precision: {TEACHER_MODEL_CONFIG['precision']}, batch_size: {TEACHER_MODEL_CONFIG['batch_size']}"
|
1175 |
-
)
|
1176 |
-
logger.info("Attempting to enable Flash Attention 2 for maximum performance")
|
1177 |
|
1178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1179 |
|
1180 |
-
|
1181 |
-
|
1182 |
-
teacher_model, code_texts, checkpoint_mgr
|
1183 |
-
)
|
1184 |
|
1185 |
-
|
1186 |
-
|
1187 |
-
m2v_model,
|
1188 |
-
code_texts,
|
1189 |
-
teacher_embeddings,
|
1190 |
-
epochs=TRAINING_EPOCHS,
|
1191 |
-
checkpoint_manager=checkpoint_mgr,
|
1192 |
-
model_manager=model_mgr,
|
1193 |
-
)
|
1194 |
|
1195 |
-
|
1196 |
-
if torch.cuda.is_available():
|
1197 |
-
torch.cuda.empty_cache()
|
1198 |
|
1199 |
-
|
1200 |
-
|
|
|
|
|
1201 |
|
1202 |
-
|
1203 |
-
torch.cuda.empty_cache()
|
1204 |
|
1205 |
-
|
1206 |
-
|
1207 |
-
teacher_model = load_teacher_model_with_cache(model_name, output_dir, device="cpu", resume=resume)
|
1208 |
-
except ImportError as e:
|
1209 |
-
if "flash_attn" in str(e):
|
1210 |
-
logger.warning("Flash Attention 2 not available on CPU, using default attention")
|
1211 |
-
# Fallback without any special attention implementation
|
1212 |
-
teacher_model = load_teacher_model_with_cache(
|
1213 |
-
model_name, output_dir, device="cpu", resume=resume
|
1214 |
-
)
|
1215 |
-
else:
|
1216 |
-
raise
|
1217 |
-
|
1218 |
-
# Generate teacher embeddings on CPU with checkpoints
|
1219 |
-
teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
|
1220 |
-
teacher_model, code_texts, checkpoint_mgr
|
1221 |
-
)
|
1222 |
|
1223 |
-
|
1224 |
-
|
1225 |
-
|
1226 |
-
|
1227 |
-
teacher_embeddings,
|
1228 |
-
epochs=TRAINING_EPOCHS,
|
1229 |
-
checkpoint_manager=checkpoint_mgr,
|
1230 |
-
model_manager=model_mgr,
|
1231 |
-
)
|
1232 |
|
1233 |
-
|
|
|
|
|
1234 |
|
1235 |
-
|
1236 |
-
|
1237 |
-
|
|
|
1238 |
|
1239 |
-
|
1240 |
-
|
1241 |
|
1242 |
-
|
1243 |
-
|
1244 |
-
|
1245 |
-
|
1246 |
|
1247 |
-
|
1248 |
-
|
1249 |
-
raise ValueError(msg)
|
1250 |
|
1251 |
-
|
1252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1253 |
|
1254 |
-
|
1255 |
|
1256 |
-
|
1257 |
-
|
|
|
1258 |
|
1259 |
-
|
1260 |
-
|
|
|
|
|
|
|
1261 |
|
1262 |
-
|
|
|
1263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1264 |
|
1265 |
-
|
1266 |
-
gpu=GPU_NAME,
|
1267 |
-
volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
|
1268 |
-
image=IMAGE,
|
1269 |
-
secrets=["HF_ACCESS_TOKEN"],
|
1270 |
-
env={
|
1271 |
-
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,max_split_size_mb:512",
|
1272 |
-
"TOKENIZERS_PARALLELISM": "false",
|
1273 |
-
"CUDA_LAUNCH_BLOCKING": "0", # Allow async CUDA operations
|
1274 |
-
"TORCH_CUDNN_V8_API_ENABLED": "1", # Enable optimized cuDNN
|
1275 |
-
"OMP_NUM_THREADS": "8", # Limit CPU threads for better GPU utilization
|
1276 |
-
},
|
1277 |
-
timeout=3600 * 12, # 12 hours
|
1278 |
-
)
|
1279 |
-
def beam_code_distillation(
|
1280 |
-
model_name: str = MODEL_NAME,
|
1281 |
-
output_dir: str = OUTPUT_DIR,
|
1282 |
-
pca_dims: int = PCA_DIMS,
|
1283 |
-
max_samples: int = MAX_TRAINING_SAMPLES,
|
1284 |
-
resume: bool = True,
|
1285 |
-
) -> Any:
|
1286 |
-
# Apply all patches from the patches directory
|
1287 |
-
try:
|
1288 |
-
from .patch_utils import apply_all_patches
|
1289 |
|
1290 |
-
|
1291 |
-
|
1292 |
-
|
1293 |
-
except Exception as e:
|
1294 |
-
logger.warning(f"Failed to apply patches: {e}. Continuing without patches.")
|
1295 |
-
|
1296 |
-
return code_specialized_distillation(
|
1297 |
-
model_name=model_name,
|
1298 |
-
output_dir=output_dir,
|
1299 |
-
pca_dims=pca_dims,
|
1300 |
-
max_samples=max_samples,
|
1301 |
-
resume=resume,
|
1302 |
-
)
|
1303 |
|
1304 |
|
1305 |
if __name__ == "__main__":
|
1306 |
-
|
|
|
1 |
"""
|
2 |
+
Unified Code-Specialized Model2Vec Distillation Script.
|
3 |
|
4 |
+
This script provides a unified approach for creating code-specialized embeddings
|
5 |
+
using Model2Vec distillation with optional code-specific training.
|
6 |
|
7 |
Features:
|
8 |
+
- Basic distillation (default): Simple Model2Vec distillation
|
9 |
+
- Advanced training (--train flag): Additional CodeSearchNet fine-tuning
|
10 |
+
- Checkpoint support with Beam sync utilities
|
11 |
+
- Multi-teacher model processing
|
12 |
+
- Smart resume capabilities
|
13 |
+
- Hierarchical storage: base → final
|
14 |
+
|
15 |
+
Directory Structure:
|
16 |
+
- code_model2vec/base: Basic distilled models (first step)
|
17 |
+
- code_model2vec/final: Final models (copied from base or after training)
|
18 |
+
|
19 |
+
Usage:
|
20 |
+
distiller distill [--use-beam] [--train] # Basic distillation or with training
|
21 |
"""
|
22 |
|
23 |
import json
|
24 |
import logging
|
|
|
25 |
import time
|
26 |
from pathlib import Path
|
27 |
+
from typing import Annotated, Any
|
28 |
|
29 |
import numpy as np
|
30 |
import torch
|
31 |
+
import typer
|
32 |
+
from beam import Volume, function
|
33 |
from datasets import load_dataset
|
34 |
from model2vec.distill import distill
|
35 |
from model2vec.train.base import FinetunableStaticModel, TextDataset
|
|
|
39 |
|
40 |
from .beam_utils import (
|
41 |
BeamCheckpointManager,
|
|
|
42 |
create_beam_utilities,
|
43 |
+
download_model_from_beam,
|
44 |
+
sync_checkpoints_from_beam,
|
45 |
+
sync_checkpoints_to_beam,
|
46 |
+
upload_model_to_beam,
|
47 |
+
)
|
48 |
+
from .config import (
|
49 |
+
BEAM_ENV_SETTINGS,
|
50 |
+
GPU_NAME,
|
51 |
+
IMAGE,
|
52 |
+
codesearchnet_config,
|
53 |
+
directories,
|
54 |
+
distillation_config,
|
55 |
+
get_volume_config,
|
56 |
+
languages_config,
|
57 |
)
|
58 |
|
59 |
# =============================================================================
|
60 |
+
# CONFIGURATION
|
61 |
# =============================================================================
|
62 |
|
63 |
+
VOLUME_CONFIG = get_volume_config()
|
64 |
+
LOCAL_BASE_DIR = directories.base
|
65 |
+
LOCAL_FINAL_DIR = directories.final
|
66 |
+
|
67 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
68 |
+
logger = logging.getLogger(__name__)
|
69 |
+
|
70 |
+
# Teacher models for distillation
|
71 |
+
DEFAULT_TEACHER_MODELS = list(distillation_config.code_teacher_models)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
# =============================================================================
|
74 |
+
# UTILITY FUNCTIONS
|
75 |
# =============================================================================
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
+
def apply_local_patches() -> bool:
|
79 |
+
"""Apply patches locally without requiring Beam utilities."""
|
80 |
+
try:
|
81 |
+
try:
|
82 |
+
from .patch_utils import apply_all_patches
|
83 |
|
84 |
+
patches_applied = apply_all_patches()
|
85 |
+
logger.info(f"Successfully applied {patches_applied} patches via patch_utils")
|
86 |
+
return True
|
87 |
+
except ImportError:
|
88 |
+
logger.warning("patch_utils not available, trying direct patching")
|
89 |
|
90 |
+
return False
|
91 |
+
|
92 |
+
except Exception as e:
|
93 |
+
logger.warning(f"Failed to apply patches: {e}")
|
94 |
+
return False
|
95 |
+
|
96 |
+
|
97 |
+
def get_current_config_hash(enable_training: bool) -> str:
|
98 |
"""Generate a hash of current configuration parameters for checkpoint validation."""
|
99 |
import hashlib
|
100 |
|
101 |
config_params = {
|
102 |
+
"pca_dims": distillation_config.optimal_pca_dims,
|
103 |
+
"sif_coefficient": distillation_config.sif_coefficient,
|
104 |
+
"apply_zipf": distillation_config.apply_zipf,
|
105 |
+
"enable_training": enable_training,
|
|
|
|
|
106 |
}
|
107 |
|
108 |
+
if enable_training:
|
109 |
+
config_params.update(
|
110 |
+
{
|
111 |
+
"training_epochs": distillation_config.training_epochs,
|
112 |
+
"learning_rate": distillation_config.learning_rate,
|
113 |
+
"max_samples": distillation_config.max_training_samples,
|
114 |
+
}
|
115 |
+
)
|
116 |
+
|
117 |
config_str = str(sorted(config_params.items()))
|
118 |
return hashlib.md5(config_str.encode()).hexdigest()[:12] # noqa: S324
|
119 |
|
120 |
|
121 |
+
def check_existing_base_model(teacher_name: str) -> str | None:
|
122 |
+
"""Check if base distilled model already exists locally."""
|
123 |
+
base_dir = Path(LOCAL_BASE_DIR)
|
124 |
+
model_dir = base_dir / f"code_model2vec_{teacher_name}"
|
125 |
+
|
126 |
+
if model_dir.exists():
|
127 |
+
# Check for essential model files
|
128 |
+
has_config = (model_dir / "config.json").exists()
|
129 |
+
has_model_file = any(
|
130 |
+
[
|
131 |
+
(model_dir / "model.safetensors").exists(),
|
132 |
+
(model_dir / "model.bin").exists(),
|
133 |
+
(model_dir / "pytorch_model.bin").exists(),
|
134 |
+
]
|
135 |
+
)
|
136 |
|
137 |
+
if has_config and has_model_file:
|
138 |
+
logger.info(f"✅ Found existing base model: {teacher_name}")
|
139 |
+
return str(model_dir)
|
140 |
|
141 |
+
return None
|
142 |
+
|
143 |
+
|
144 |
+
def check_existing_final_model(teacher_name: str, enable_training: bool = False) -> str | None:
|
145 |
+
"""Check if final model already exists locally."""
|
146 |
+
final_dir = Path(LOCAL_FINAL_DIR)
|
147 |
+
|
148 |
+
# Add suffix for trained models
|
149 |
+
model_name = f"code_model2vec_{teacher_name}"
|
150 |
+
if enable_training:
|
151 |
+
model_name += "_fine_tuned"
|
152 |
+
model_dir = final_dir / model_name
|
153 |
+
|
154 |
+
if model_dir.exists():
|
155 |
+
# Check for essential model files
|
156 |
+
has_config = (model_dir / "config.json").exists()
|
157 |
+
has_model_file = any(
|
158 |
+
[
|
159 |
+
(model_dir / "model.safetensors").exists(),
|
160 |
+
(model_dir / "model.bin").exists(),
|
161 |
+
(model_dir / "pytorch_model.bin").exists(),
|
162 |
+
]
|
163 |
+
)
|
164 |
+
|
165 |
+
if has_config and has_model_file:
|
166 |
+
logger.info(f"✅ Found existing final model: {teacher_name}{'_fine_tuned' if enable_training else ''}")
|
167 |
+
return str(model_dir)
|
168 |
|
169 |
+
return None
|
170 |
+
|
171 |
+
|
172 |
+
def copy_base_to_final(teacher_name: str, enable_training: bool = False) -> bool:
|
173 |
+
"""Copy base model to final directory."""
|
174 |
+
import shutil
|
175 |
+
|
176 |
+
base_path = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}"
|
177 |
+
|
178 |
+
# Add suffix for trained models
|
179 |
+
final_model_name = f"code_model2vec_{teacher_name}"
|
180 |
+
if enable_training:
|
181 |
+
final_model_name += "_fine_tuned"
|
182 |
+
final_path = Path(LOCAL_FINAL_DIR) / final_model_name
|
183 |
+
|
184 |
+
try:
|
185 |
+
final_path.parent.mkdir(parents=True, exist_ok=True)
|
186 |
+
if final_path.exists():
|
187 |
+
shutil.rmtree(final_path)
|
188 |
+
shutil.copytree(base_path, final_path)
|
189 |
+
logger.info(f"📁 Copied {teacher_name} from base to final{'_fine_tuned' if enable_training else ''}")
|
190 |
+
return True
|
191 |
+
except Exception:
|
192 |
+
logger.exception(f"❌ Failed to copy {teacher_name} to final{'_fine_tuned' if enable_training else ''}")
|
193 |
return False
|
194 |
|
|
|
|
|
195 |
|
196 |
+
def sync_model_from_beam(
|
197 |
+
teacher_name: str,
|
198 |
+
target_dir: str,
|
199 |
+
use_beam_utilities: bool = False,
|
200 |
+
) -> bool:
|
201 |
+
"""Sync model from Beam volume to local directory."""
|
202 |
+
if not use_beam_utilities:
|
203 |
return False
|
204 |
|
205 |
+
try:
|
206 |
+
target_path = Path(target_dir)
|
207 |
+
target_path.mkdir(parents=True, exist_ok=True)
|
208 |
+
|
209 |
+
beam_model_name = f"{teacher_name}_model"
|
210 |
+
success = download_model_from_beam(VOLUME_CONFIG.name, beam_model_name, str(target_path))
|
211 |
+
|
212 |
+
if success:
|
213 |
+
logger.info(f"📥 Synced {teacher_name} from Beam to {target_dir}")
|
214 |
+
return True
|
215 |
+
logger.warning(f"⚠️ Failed to sync {teacher_name} from Beam")
|
216 |
return False
|
217 |
|
218 |
+
except Exception as e:
|
219 |
+
logger.warning(f"Failed to sync {teacher_name} from Beam: {e}")
|
220 |
+
return False
|
221 |
+
|
222 |
+
|
223 |
+
def sync_model_to_beam(
|
224 |
+
teacher_name: str,
|
225 |
+
source_dir: str,
|
226 |
+
use_beam_utilities: bool = False,
|
227 |
+
) -> bool:
|
228 |
+
"""Sync model from local directory to Beam volume."""
|
229 |
+
if not use_beam_utilities:
|
230 |
+
return False
|
231 |
+
|
232 |
+
try:
|
233 |
+
beam_model_name = f"{teacher_name}_model"
|
234 |
+
success = upload_model_to_beam(VOLUME_CONFIG.name, beam_model_name, source_dir)
|
235 |
+
|
236 |
+
if success:
|
237 |
+
logger.info(f"📤 Synced {teacher_name} to Beam from {source_dir}")
|
238 |
+
return True
|
239 |
+
logger.warning(f"⚠️ Failed to sync {teacher_name} to Beam")
|
240 |
return False
|
241 |
|
242 |
+
except Exception as e:
|
243 |
+
logger.warning(f"Failed to sync {teacher_name} to Beam: {e}")
|
244 |
+
return False
|
245 |
+
|
246 |
+
|
247 |
+
# =============================================================================
|
248 |
+
# DISTILLATION FUNCTIONS
|
249 |
+
# =============================================================================
|
250 |
|
251 |
|
252 |
+
def simple_distillation(
|
253 |
+
teacher_model: str,
|
254 |
+
output_dir: str,
|
255 |
+
pca_dims: int | None = None,
|
256 |
+
retry_with_cache_clear: bool = False,
|
257 |
+
) -> Any:
|
258 |
"""
|
259 |
+
Perform simple Model2Vec distillation without additional training.
|
260 |
|
261 |
Args:
|
262 |
+
teacher_model: Name of teacher model
|
263 |
+
output_dir: Output directory for the distilled model
|
264 |
+
pca_dims: PCA dimensions (uses config default if None)
|
265 |
+
retry_with_cache_clear: Whether this is a retry after clearing cache
|
266 |
|
267 |
Returns:
|
268 |
+
Distilled model or None if failed
|
269 |
"""
|
270 |
+
if pca_dims is None:
|
271 |
+
pca_dims = int(distillation_config.optimal_pca_dims)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
+
output_path = Path(output_dir)
|
274 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
275 |
+
|
276 |
+
retry_suffix = " (retry after cache clear)" if retry_with_cache_clear else ""
|
277 |
+
logger.info(f"🔄 Simple distillation{retry_suffix}: {teacher_model} → {output_dir}")
|
278 |
+
logger.info(f"📊 PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}")
|
279 |
+
|
280 |
+
start_time = time.time()
|
281 |
+
|
282 |
+
try:
|
283 |
+
# Perform distillation with optimal parameters
|
284 |
+
model = distill(
|
285 |
+
model_name=teacher_model,
|
286 |
+
pca_dims=int(pca_dims),
|
287 |
+
apply_zipf=bool(distillation_config.apply_zipf),
|
288 |
+
sif_coefficient=float(distillation_config.sif_coefficient),
|
289 |
+
trust_remote_code=True,
|
290 |
+
)
|
291 |
+
|
292 |
+
logger.info("✅ Core distillation completed successfully")
|
293 |
+
|
294 |
+
# Save the model
|
295 |
+
model.save_pretrained(str(output_path))
|
296 |
+
logger.info(f"💾 Model saved to {output_path}")
|
297 |
+
|
298 |
+
# Log model info
|
299 |
+
logger.info(f"Model type: {type(model)}")
|
300 |
+
if hasattr(model, "embedding"):
|
301 |
+
logger.info(f"Embedding shape: {model.embedding.shape}")
|
302 |
+
logger.info(f"Embedding dtype: {model.embedding.dtype}")
|
303 |
|
304 |
+
total_time = time.time() - start_time
|
305 |
+
logger.info(f"🎉 Simple distillation completed in {total_time:.2f} seconds")
|
306 |
+
return model
|
307 |
+
|
308 |
+
except ValueError as e:
|
309 |
+
if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
|
310 |
+
logger.warning(f"⚠️ Token-vector mismatch with {teacher_model} - this is a Model2Vec library issue")
|
311 |
+
logger.warning(f"Error details: {e}")
|
312 |
+
logger.warning("💡 This model has incompatible tokenization. Skipping...")
|
313 |
+
return None
|
314 |
+
if "weight is on the meta device" in str(e):
|
315 |
+
logger.warning(f"⚠️ Device placement issue with {teacher_model} - model weights on meta device")
|
316 |
+
logger.warning(f"Error details: {e}")
|
317 |
+
logger.warning("💡 This model has device placement issues. Skipping...")
|
318 |
+
return None
|
319 |
+
raise
|
320 |
+
except AttributeError as e:
|
321 |
+
if "backend_tokenizer" in str(e):
|
322 |
+
logger.warning(f"⚠️ Tokenizer compatibility issue with {teacher_model}")
|
323 |
+
logger.warning(f"Error details: {e}")
|
324 |
+
logger.warning("💡 This model's tokenizer is incompatible with Model2Vec. Skipping...")
|
325 |
+
return None
|
326 |
+
raise
|
327 |
+
except FileNotFoundError as e:
|
328 |
+
if "transformers_modules" in str(e) or "xlm_padding.py" in str(e):
|
329 |
+
logger.warning(f"⚠️ Missing custom model files for {teacher_model}")
|
330 |
+
logger.warning(f"Error details: {e}")
|
331 |
+
|
332 |
+
# Try clearing cache and retrying once
|
333 |
+
if not retry_with_cache_clear:
|
334 |
+
logger.info("🔧 Attempting to clear cache and retry...")
|
335 |
+
if clear_model_cache(teacher_model):
|
336 |
+
logger.info("🔄 Retrying distillation after cache clear...")
|
337 |
+
return simple_distillation(teacher_model, output_dir, pca_dims, retry_with_cache_clear=True)
|
338 |
+
|
339 |
+
logger.warning("💡 This model has missing dependencies. Manual intervention may be required.")
|
340 |
+
return None
|
341 |
+
raise
|
342 |
+
except Exception:
|
343 |
+
logger.exception(f"❌ Simple distillation failed for {teacher_model}")
|
344 |
+
return None
|
345 |
+
|
346 |
+
|
347 |
+
def load_codesearchnet_dataset(
|
348 |
+
max_samples: int | None = None,
|
349 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
350 |
) -> list[str]:
|
351 |
+
"""Load and format the CodeSearchNet dataset for training with balanced language distribution."""
|
352 |
+
if max_samples is None:
|
353 |
+
max_samples = int(distillation_config.max_training_samples)
|
354 |
+
|
355 |
+
logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
|
356 |
logger.info(f"Limiting to {max_samples} samples for training efficiency")
|
357 |
+
logger.info(f"Languages: {', '.join(languages_config.all)}")
|
358 |
+
|
359 |
+
# Check for existing dataset checkpoint
|
360 |
+
texts = []
|
361 |
+
start_from = 0
|
362 |
|
|
|
363 |
if checkpoint_manager:
|
364 |
checkpoint_data = checkpoint_manager.load_checkpoint("dataset", 0)
|
365 |
if checkpoint_data:
|
366 |
+
cached_texts = checkpoint_data.get("data", {}).get("texts", [])
|
367 |
+
if len(cached_texts) >= max_samples:
|
368 |
+
logger.info(f"✅ Resumed dataset loading: {len(cached_texts)} texts from checkpoint")
|
369 |
+
return cached_texts[:max_samples]
|
370 |
+
logger.info(f"📋 Partial dataset found: {len(cached_texts)} texts, continuing...")
|
371 |
+
texts = cached_texts
|
372 |
+
start_from = len(texts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
try:
|
375 |
+
# Calculate samples per language for balanced distribution
|
376 |
+
num_languages = len(languages_config.all)
|
377 |
+
samples_per_language = max_samples // num_languages
|
378 |
+
remaining_samples = max_samples % num_languages
|
379 |
|
380 |
+
logger.info(f"📊 Target distribution: {samples_per_language} samples per language")
|
381 |
+
if remaining_samples > 0:
|
382 |
+
logger.info(f"📊 Extra {remaining_samples} samples will be distributed to first languages")
|
|
|
|
|
|
|
|
|
383 |
|
384 |
+
# Load training data from each language separately for balanced distribution
|
385 |
+
language_texts: dict[str, list[str]] = {}
|
386 |
+
total_collected = len(texts)
|
387 |
+
|
388 |
+
for i, language in enumerate(languages_config.all):
|
389 |
+
if total_collected >= max_samples:
|
390 |
break
|
391 |
|
392 |
+
logger.info(f"🔍 Loading {language} training data...")
|
393 |
+
|
394 |
+
# Determine how many samples to collect for this language
|
395 |
+
target_for_lang = samples_per_language
|
396 |
+
if i < remaining_samples: # Distribute extra samples to first languages
|
397 |
+
target_for_lang += 1
|
398 |
+
|
399 |
+
# Skip if we already have enough from this language
|
400 |
+
if language in language_texts and len(language_texts[language]) >= target_for_lang:
|
401 |
+
continue
|
402 |
+
|
403 |
+
try:
|
404 |
+
# Load training split for the specific language (same format as evaluate.py)
|
405 |
+
dataset = load_dataset(
|
406 |
+
codesearchnet_config.dataset_name,
|
407 |
+
language,
|
408 |
+
split="train",
|
409 |
+
trust_remote_code=True,
|
410 |
+
)
|
411 |
+
|
412 |
+
lang_texts: list[str] = []
|
413 |
+
processed_count = 0
|
414 |
|
415 |
+
for processed_count, example in enumerate(dataset, 1):
|
416 |
+
if len(lang_texts) >= target_for_lang:
|
417 |
+
break
|
418 |
|
419 |
+
# Use same field names as evaluate.py
|
420 |
+
doc_string = example.get("func_documentation_string", "").strip()
|
421 |
+
code_string = example.get("func_code_string", "").strip()
|
422 |
|
423 |
+
if doc_string and code_string and len(doc_string.split()) >= 3 and len(code_string) > 50:
|
424 |
+
# Format as documentation-code pair for training (same as evaluate.py)
|
425 |
+
text = f"Documentation: {doc_string}\nCode:\n{code_string}"
|
|
|
|
|
426 |
|
427 |
+
# Ensure reasonable length for embedding models
|
428 |
+
if len(text) <= 2048:
|
429 |
+
lang_texts.append(text)
|
430 |
+
|
431 |
+
if processed_count % 5000 == 0:
|
432 |
+
logger.info(f" {language}: processed {processed_count}, collected {len(lang_texts)}")
|
433 |
+
|
434 |
+
language_texts[language] = lang_texts
|
435 |
+
total_collected += len(lang_texts)
|
436 |
+
logger.info(f"✅ {language}: collected {len(lang_texts)} samples")
|
437 |
+
|
438 |
+
except Exception as e:
|
439 |
+
logger.warning(f"⚠️ Failed to load {language} data: {e}")
|
440 |
+
continue
|
441 |
+
|
442 |
+
# Combine all language texts in a balanced way
|
443 |
+
combined_texts = []
|
444 |
+
|
445 |
+
# Add existing texts first (from checkpoint)
|
446 |
+
if start_from > 0:
|
447 |
+
combined_texts = texts[:start_from]
|
448 |
+
|
449 |
+
# Interleave texts from different languages for better training distribution
|
450 |
+
max_lang_samples = max(len(lang_texts) for lang_texts in language_texts.values()) if language_texts else 0
|
451 |
+
|
452 |
+
for sample_idx in range(max_lang_samples):
|
453 |
+
for language in languages_config.all:
|
454 |
+
if len(combined_texts) >= max_samples:
|
455 |
+
break
|
456 |
+
|
457 |
+
if language in language_texts and sample_idx < len(language_texts[language]):
|
458 |
+
combined_texts.append(language_texts[language][sample_idx])
|
459 |
+
|
460 |
+
if len(combined_texts) >= max_samples:
|
461 |
+
break
|
462 |
+
|
463 |
+
# Truncate to exact max_samples
|
464 |
+
combined_texts = combined_texts[:max_samples]
|
465 |
+
|
466 |
+
# Log final distribution
|
467 |
+
logger.info("📊 Final dataset distribution:")
|
468 |
+
lang_counts: dict[str, int] = {}
|
469 |
+
for text in combined_texts:
|
470 |
+
# Simple heuristic to identify language from code patterns
|
471 |
+
if "def " in text and ":" in text:
|
472 |
+
lang_counts["python"] = lang_counts.get("python", 0) + 1
|
473 |
+
elif "function " in text and "{" in text:
|
474 |
+
lang_counts["javascript"] = lang_counts.get("javascript", 0) + 1
|
475 |
+
elif "public " in text and "class " in text:
|
476 |
+
lang_counts["java"] = lang_counts.get("java", 0) + 1
|
477 |
+
elif "<?php" in text or "$" in text:
|
478 |
+
lang_counts["php"] = lang_counts.get("php", 0) + 1
|
479 |
+
elif "func " in text and "end" in text:
|
480 |
+
lang_counts["ruby"] = lang_counts.get("ruby", 0) + 1
|
481 |
+
elif "func " in text and "}" in text:
|
482 |
+
lang_counts["go"] = lang_counts.get("go", 0) + 1
|
483 |
+
else:
|
484 |
+
lang_counts["other"] = lang_counts.get("other", 0) + 1
|
485 |
+
|
486 |
+
for lang, count in lang_counts.items():
|
487 |
+
percentage = (count / len(combined_texts)) * 100
|
488 |
+
logger.info(f" {lang}: {count} samples ({percentage:.1f}%)")
|
489 |
|
490 |
# Final checkpoint save
|
491 |
if checkpoint_manager:
|
492 |
+
checkpoint_data = {
|
493 |
+
"config_hash": get_current_config_hash(enable_training=True),
|
494 |
+
"stage": "dataset",
|
495 |
+
"step": 0,
|
496 |
+
"timestamp": time.time(),
|
497 |
+
"data": {"texts": combined_texts},
|
498 |
+
}
|
499 |
checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0)
|
500 |
|
501 |
+
logger.info(f"Successfully loaded {len(combined_texts)} balanced code-documentation pairs from CodeSearchNet")
|
502 |
+
return combined_texts
|
503 |
|
504 |
except Exception:
|
505 |
logger.exception("Error loading CodeSearchNet dataset")
|
506 |
return texts # Return what we have so far
|
507 |
|
508 |
|
509 |
+
def generate_teacher_embeddings(
|
510 |
teacher_model: SentenceTransformer,
|
511 |
texts: list[str],
|
512 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
|
|
514 |
"""Generate teacher embeddings for code training with checkpoint support."""
|
515 |
logger.info(f"Generating teacher embeddings for {len(texts)} texts...")
|
516 |
|
517 |
+
# Check for existing embeddings checkpoint
|
|
|
|
|
518 |
if checkpoint_manager:
|
519 |
+
volume_path = Path(VOLUME_CONFIG.mount_path)
|
520 |
+
embeddings_path = volume_path / "embeddings_cache.pt"
|
521 |
+
config_path = volume_path / "embeddings_config.json"
|
522 |
|
523 |
if embeddings_path.exists() and config_path.exists():
|
524 |
try:
|
|
|
526 |
with config_path.open("r") as f:
|
527 |
config_data = json.load(f)
|
528 |
|
529 |
+
current_hash = get_current_config_hash(enable_training=True)
|
530 |
+
if config_data.get("config_hash") == current_hash:
|
|
|
|
|
|
|
|
|
|
|
531 |
# Load the embeddings tensor
|
532 |
final_embeddings = torch.load(embeddings_path, map_location="cpu")
|
533 |
num_expected = config_data.get("num_texts", len(texts))
|
534 |
|
535 |
if final_embeddings.shape[0] >= num_expected:
|
536 |
+
logger.info(f"✅ Loaded embeddings from cache ({final_embeddings.shape[0]} embeddings)")
|
537 |
+
return final_embeddings[: len(texts)]
|
538 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
except Exception as e:
|
540 |
logger.warning(f"Failed to load embeddings cache: {e}, regenerating...")
|
|
|
|
|
|
|
|
|
|
|
541 |
|
542 |
# Generate embeddings from scratch
|
543 |
logger.info("Generating fresh teacher embeddings...")
|
544 |
|
545 |
+
batch_size = int(distillation_config.teacher_model_config.get("batch_size", 16))
|
|
|
|
|
|
|
|
|
546 |
embeddings_list = []
|
547 |
|
548 |
+
for i in range(0, len(texts), batch_size):
|
549 |
+
batch_texts = texts[i : i + batch_size]
|
550 |
|
551 |
try:
|
|
|
552 |
batch_embeddings = teacher_model.encode(
|
553 |
batch_texts,
|
554 |
convert_to_tensor=True,
|
555 |
+
batch_size=batch_size,
|
556 |
+
show_progress_bar=False,
|
557 |
+
normalize_embeddings=True,
|
558 |
)
|
559 |
embeddings_list.append(batch_embeddings)
|
560 |
|
561 |
+
if i % (batch_size * 10) == 0:
|
562 |
logger.info(f"Generated embeddings for {i + len(batch_texts)}/{len(texts)} texts")
|
563 |
|
564 |
except torch.cuda.OutOfMemoryError:
|
565 |
+
logger.warning(f"GPU OOM with batch size {batch_size}, reducing...")
|
566 |
+
torch.cuda.empty_cache()
|
567 |
+
batch_size = max(1, batch_size // 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
568 |
|
569 |
# Retry with smaller batch size
|
|
|
570 |
batch_embeddings = teacher_model.encode(
|
571 |
batch_texts,
|
572 |
convert_to_tensor=True,
|
573 |
+
batch_size=batch_size,
|
574 |
show_progress_bar=False,
|
575 |
normalize_embeddings=True,
|
576 |
)
|
577 |
embeddings_list.append(batch_embeddings)
|
578 |
|
579 |
+
# Combine all embeddings
|
|
|
|
|
580 |
teacher_embeddings = torch.cat(embeddings_list, dim=0)
|
581 |
|
582 |
+
# Ensure fp32 precision
|
583 |
if teacher_embeddings.dtype != torch.float32:
|
|
|
584 |
teacher_embeddings = teacher_embeddings.to(torch.float32)
|
585 |
|
586 |
logger.info(f"Generated {teacher_embeddings.shape[0]} teacher embeddings in {teacher_embeddings.dtype}")
|
587 |
|
588 |
+
# Save embeddings cache for future runs
|
589 |
if checkpoint_manager:
|
590 |
try:
|
591 |
+
volume_path = Path(VOLUME_CONFIG.mount_path)
|
592 |
+
embeddings_path = volume_path / "embeddings_cache.pt"
|
593 |
+
config_path = volume_path / "embeddings_config.json"
|
594 |
|
595 |
# Save embeddings tensor
|
596 |
torch.save(teacher_embeddings, embeddings_path)
|
597 |
|
598 |
# Save configuration
|
599 |
config_data = {
|
600 |
+
"config_hash": get_current_config_hash(enable_training=True),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
601 |
"num_texts": len(texts),
|
602 |
"embedding_shape": list(teacher_embeddings.shape),
|
603 |
"timestamp": time.time(),
|
|
|
614 |
return teacher_embeddings
|
615 |
|
616 |
|
617 |
+
def advanced_training(
|
618 |
student_model: Any,
|
619 |
+
teacher_model: SentenceTransformer,
|
|
|
|
|
620 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
|
|
621 |
) -> Any:
|
622 |
+
"""Perform advanced code specialization training."""
|
623 |
+
logger.info("🎓 Starting advanced code specialization training...")
|
624 |
|
625 |
+
# Load CodeSearchNet training data
|
626 |
+
training_texts = load_codesearchnet_dataset(checkpoint_manager=checkpoint_manager)
|
|
|
|
|
|
|
627 |
|
628 |
+
if not training_texts:
|
629 |
+
logger.warning("No training data available, skipping advanced training")
|
630 |
+
return student_model
|
|
|
631 |
|
632 |
+
# Generate teacher embeddings
|
633 |
+
teacher_embeddings = generate_teacher_embeddings(teacher_model, training_texts, checkpoint_manager)
|
634 |
|
635 |
+
# Create trainable model
|
636 |
+
student_embedding_dim = student_model.embedding.shape[1]
|
637 |
+
teacher_embedding_dim = teacher_embeddings.shape[1]
|
|
|
638 |
|
639 |
+
# Project teacher embeddings if needed
|
640 |
+
if teacher_embedding_dim != student_embedding_dim:
|
641 |
+
from sklearn.decomposition import PCA
|
642 |
|
643 |
+
logger.info("Performing PCA projection for dimension matching...")
|
644 |
+
pca = PCA(n_components=student_embedding_dim)
|
645 |
+
teacher_embeddings_np = teacher_embeddings.cpu().numpy().astype(np.float64)
|
646 |
+
teacher_embeddings_projected = pca.fit_transform(teacher_embeddings_np)
|
647 |
+
teacher_embeddings = torch.tensor(teacher_embeddings_projected.astype(np.float32), dtype=torch.float32)
|
648 |
|
649 |
+
# Create trainable model
|
650 |
+
trainable_model = FinetunableStaticModel.from_static_model(
|
651 |
+
model=student_model,
|
652 |
+
out_dim=student_embedding_dim,
|
653 |
+
)
|
654 |
+
trainable_model = trainable_model.float()
|
655 |
+
|
656 |
+
# Tokenize texts
|
657 |
+
tokenized_texts = []
|
658 |
+
for text in training_texts:
|
659 |
+
tokens = trainable_model.tokenize([text])
|
660 |
+
if tokens.shape[1] > 0:
|
661 |
+
tokenized_texts.append(tokens[0].tolist())
|
662 |
+
|
663 |
+
# Prepare training data
|
664 |
+
targets = teacher_embeddings[: len(tokenized_texts)].to(torch.float32)
|
665 |
+
train_texts, val_texts, train_targets, val_targets = train_test_split(
|
666 |
+
tokenized_texts, targets, test_size=0.2, random_state=42
|
667 |
+
)
|
668 |
|
669 |
+
# Training setup
|
670 |
+
train_dataset = TextDataset(train_texts, train_targets)
|
671 |
+
val_dataset = TextDataset(val_texts, val_targets)
|
672 |
|
673 |
+
optimizer = optim.Adam(trainable_model.parameters(), lr=float(distillation_config.learning_rate))
|
674 |
+
mse_loss = nn.MSELoss()
|
|
|
675 |
|
676 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
677 |
+
trainable_model = trainable_model.to(device)
|
678 |
|
679 |
+
batch_size = int(distillation_config.batch_size)
|
680 |
+
epochs = int(distillation_config.training_epochs)
|
|
|
681 |
|
682 |
+
# Training loop
|
683 |
+
for epoch in range(epochs):
|
684 |
+
trainable_model.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
685 |
|
686 |
+
try:
|
687 |
+
train_loader = train_dataset.to_dataloader(shuffle=True, batch_size=batch_size)
|
688 |
+
epoch_loss = 0.0
|
689 |
+
num_batches = 0
|
|
|
690 |
|
691 |
+
for _batch_idx, (tokens, targets_batch) in enumerate(train_loader):
|
692 |
+
batch_tokens = tokens.to(device)
|
693 |
+
batch_targets = targets_batch.to(device).to(torch.float32)
|
694 |
|
695 |
+
optimizer.zero_grad()
|
696 |
+
_, student_embeddings = trainable_model(batch_tokens)
|
697 |
+
student_embeddings = student_embeddings.to(torch.float32)
|
698 |
|
699 |
+
loss = mse_loss(student_embeddings, batch_targets)
|
700 |
+
loss.backward()
|
701 |
+
optimizer.step()
|
|
|
|
|
702 |
|
703 |
+
epoch_loss += loss.item()
|
704 |
+
num_batches += 1
|
705 |
|
706 |
+
except torch.cuda.OutOfMemoryError:
|
707 |
+
logger.warning(f"Training OOM with batch size {batch_size}, reducing...")
|
708 |
+
batch_size = max(1, batch_size // 2)
|
709 |
+
torch.cuda.empty_cache()
|
710 |
+
continue
|
711 |
|
712 |
+
avg_train_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
|
|
|
|
|
|
|
|
|
713 |
|
714 |
+
# Validation
|
715 |
+
trainable_model.eval()
|
716 |
+
val_loader = val_dataset.to_dataloader(shuffle=False, batch_size=batch_size)
|
717 |
+
val_loss = 0.0
|
718 |
+
val_batches = 0
|
719 |
|
720 |
+
with torch.no_grad():
|
721 |
+
for tokens, targets_batch in val_loader:
|
722 |
+
batch_tokens = tokens.to(device)
|
723 |
+
batch_targets = targets_batch.to(device).to(torch.float32)
|
|
|
|
|
724 |
|
725 |
+
_, student_embeddings = trainable_model(batch_tokens)
|
726 |
+
student_embeddings = student_embeddings.to(torch.float32)
|
727 |
|
728 |
+
loss = mse_loss(student_embeddings, batch_targets)
|
729 |
+
val_loss += loss.item()
|
730 |
+
val_batches += 1
|
731 |
|
732 |
+
avg_val_loss = val_loss / val_batches if val_batches > 0 else 0.0
|
|
|
|
|
733 |
|
734 |
+
logger.info(f"Epoch {epoch + 1}/{epochs} - Train: {avg_train_loss:.6f}, Val: {avg_val_loss:.6f}")
|
|
|
735 |
|
736 |
+
# Save checkpoint
|
737 |
+
if checkpoint_manager:
|
738 |
+
checkpoint_data = {
|
739 |
+
"config_hash": get_current_config_hash(enable_training=True),
|
740 |
+
"stage": "training",
|
741 |
+
"step": epoch + 1,
|
742 |
+
"timestamp": time.time(),
|
743 |
+
"data": {
|
744 |
+
"model_state": trainable_model.state_dict(),
|
745 |
+
"optimizer_state": optimizer.state_dict(),
|
746 |
+
"train_loss": avg_train_loss,
|
747 |
+
"val_loss": avg_val_loss,
|
748 |
+
},
|
749 |
+
}
|
750 |
+
checkpoint_manager.save_checkpoint("training", checkpoint_data, epoch + 1)
|
751 |
|
752 |
+
# Convert back to static model
|
753 |
+
refined_model = trainable_model.to_static_model()
|
754 |
+
logger.info("✅ Advanced training completed")
|
755 |
|
756 |
+
return refined_model
|
757 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
758 |
|
759 |
+
def distill_single_teacher(
|
760 |
+
teacher_model: str,
|
761 |
+
enable_training: bool = False,
|
762 |
+
use_beam_utilities: bool = False,
|
763 |
+
pca_dims: int | None = None,
|
764 |
+
) -> dict[str, Any]:
|
765 |
+
"""
|
766 |
+
Distill a single teacher model with optional training.
|
767 |
|
768 |
+
Args:
|
769 |
+
teacher_model: Name of teacher model
|
770 |
+
enable_training: Whether to enable advanced training
|
771 |
+
use_beam_utilities: Whether to use Beam utilities
|
772 |
+
pca_dims: PCA dimensions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
773 |
|
774 |
+
Returns:
|
775 |
+
Dictionary with distillation results
|
776 |
+
"""
|
777 |
+
teacher_name = teacher_model.split("/")[-1].replace("-", "_")
|
778 |
+
base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}"
|
779 |
+
|
780 |
+
# Add suffix for trained models
|
781 |
+
final_model_name = f"code_model2vec_{teacher_name}"
|
782 |
+
if enable_training:
|
783 |
+
final_model_name += "_fine_tuned"
|
784 |
+
final_dir = Path(LOCAL_FINAL_DIR) / final_model_name
|
785 |
+
|
786 |
+
logger.info(f"\n{'=' * 60}")
|
787 |
+
logger.info(f"🔄 Processing teacher model: {teacher_model}")
|
788 |
+
logger.info(f"📁 Teacher name: {teacher_name}")
|
789 |
+
logger.info(f"🎓 Training enabled: {enable_training}")
|
790 |
+
logger.info(f"{'=' * 60}")
|
791 |
+
|
792 |
+
# Check model compatibility first
|
793 |
+
is_compatible, warning_msg = check_model_compatibility(teacher_model)
|
794 |
+
if not is_compatible:
|
795 |
+
logger.warning(f"⚠️ Known compatibility issue: {warning_msg}")
|
796 |
+
logger.info("🔧 Attempting distillation anyway, but may fail...")
|
797 |
+
|
798 |
+
# Try model-specific workarounds
|
799 |
+
workaround_type = try_model_workarounds(teacher_model)
|
800 |
+
# Don't skip if we have a workaround - we'll use it later
|
801 |
|
802 |
+
start_time = time.time()
|
803 |
|
804 |
+
# Initialize Beam utilities if requested
|
805 |
+
checkpoint_mgr = None
|
806 |
+
model_mgr = None
|
807 |
+
if use_beam_utilities:
|
808 |
+
try:
|
809 |
+
_, checkpoint_mgr, model_mgr, _ = create_beam_utilities(VOLUME_CONFIG.name, VOLUME_CONFIG.mount_path)
|
810 |
+
except Exception as e:
|
811 |
+
logger.warning(f"Failed to initialize Beam utilities: {e}")
|
812 |
|
813 |
+
try:
|
814 |
+
# Step 1: Check for existing final model
|
815 |
+
existing_final = check_existing_final_model(teacher_name, enable_training)
|
816 |
+
if existing_final:
|
817 |
+
logger.info(f"✅ Final model already exists: {teacher_name}{'_fine_tuned' if enable_training else ''}")
|
818 |
+
return {
|
819 |
+
"teacher_model": teacher_model,
|
820 |
+
"teacher_name": teacher_name,
|
821 |
+
"status": "skipped_existing_final",
|
822 |
+
"final_path": existing_final,
|
823 |
+
"distillation_time": 0.0,
|
824 |
+
}
|
825 |
|
826 |
+
# Step 1.5: Sync existing checkpoints from Beam if using Beam utilities
|
827 |
+
if use_beam_utilities and checkpoint_mgr:
|
828 |
+
logger.info(f"🔄 Syncing existing checkpoints for {teacher_name}...")
|
829 |
+
sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints)
|
830 |
+
if enable_training:
|
831 |
+
sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints)
|
832 |
+
|
833 |
+
# Step 2: Check for existing base model or create it
|
834 |
+
existing_base = check_existing_base_model(teacher_name)
|
835 |
+
base_model = None
|
836 |
+
|
837 |
+
if existing_base:
|
838 |
+
logger.info(f"✅ Found existing base model: {teacher_name}")
|
839 |
+
if enable_training:
|
840 |
+
# Load base model for training
|
841 |
+
from model2vec.model import StaticModel
|
842 |
|
843 |
+
base_model = StaticModel.from_pretrained(existing_base)
|
844 |
+
elif use_beam_utilities:
|
845 |
+
synced = sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities)
|
846 |
+
if synced:
|
847 |
+
existing_base = str(base_dir)
|
848 |
+
if enable_training:
|
849 |
+
from model2vec.model import StaticModel
|
850 |
|
851 |
+
base_model = StaticModel.from_pretrained(existing_base)
|
|
|
|
|
852 |
|
853 |
+
if not existing_base:
|
854 |
+
# Perform simple distillation to create base model
|
855 |
+
logger.info(f"🔄 Creating base model for {teacher_name}")
|
856 |
|
857 |
+
# Check if we need specialized distillation
|
858 |
+
workaround_type = try_model_workarounds(teacher_model)
|
|
|
859 |
|
860 |
+
if workaround_type == "salesforce":
|
861 |
+
base_model = salesforce_model_distillation(teacher_model, str(base_dir), pca_dims)
|
862 |
+
elif workaround_type == "baai":
|
863 |
+
base_model = baai_bge_model_distillation(teacher_model, str(base_dir), pca_dims)
|
864 |
+
else:
|
865 |
+
base_model = simple_distillation(teacher_model, str(base_dir), pca_dims)
|
866 |
+
|
867 |
+
if base_model is None:
|
868 |
+
return {
|
869 |
+
"teacher_model": teacher_model,
|
870 |
+
"teacher_name": teacher_name,
|
871 |
+
"status": "failed_base_distillation",
|
872 |
+
"error": "Simple distillation failed",
|
873 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
874 |
|
875 |
+
# Sync base model and checkpoints to Beam
|
876 |
+
if use_beam_utilities:
|
877 |
+
sync_model_to_beam(teacher_name, str(base_dir), use_beam_utilities)
|
878 |
+
if checkpoint_mgr:
|
879 |
+
sync_checkpoints_to_beam(
|
880 |
+
VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints
|
881 |
+
)
|
882 |
+
|
883 |
+
existing_base = str(base_dir)
|
884 |
+
|
885 |
+
# Step 3: Handle final model creation
|
886 |
+
if enable_training and base_model is not None:
|
887 |
+
# Perform advanced training
|
888 |
+
logger.info(f"🎓 Starting advanced training for {teacher_name}")
|
889 |
+
|
890 |
+
# Load teacher model for training
|
891 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
892 |
+
teacher_st_model = SentenceTransformer(teacher_model, device=device, trust_remote_code=True)
|
893 |
|
894 |
+
# Perform advanced training
|
895 |
+
final_model = advanced_training(base_model, teacher_st_model, checkpoint_mgr)
|
|
|
|
|
|
|
|
|
896 |
|
897 |
+
# Save final model
|
898 |
+
final_dir.mkdir(parents=True, exist_ok=True)
|
899 |
+
final_model.save_pretrained(str(final_dir))
|
900 |
|
901 |
+
# Sync final model and training checkpoints to Beam
|
902 |
+
if use_beam_utilities:
|
903 |
+
sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities)
|
904 |
+
if checkpoint_mgr:
|
905 |
+
sync_checkpoints_to_beam(VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints)
|
906 |
+
|
907 |
+
del teacher_st_model
|
908 |
+
if torch.cuda.is_available():
|
909 |
+
torch.cuda.empty_cache()
|
910 |
|
911 |
+
else:
|
912 |
+
# Copy base to final (no training)
|
913 |
+
logger.info(f"📁 Copying base to final for {teacher_name}")
|
914 |
+
if not copy_base_to_final(teacher_name, enable_training):
|
915 |
+
return {
|
916 |
+
"teacher_model": teacher_model,
|
917 |
+
"teacher_name": teacher_name,
|
918 |
+
"status": "failed_copy_to_final",
|
919 |
+
"error": "Failed to copy base to final",
|
920 |
+
}
|
921 |
|
922 |
+
total_time = time.time() - start_time
|
923 |
|
924 |
+
return {
|
925 |
+
"teacher_model": teacher_model,
|
926 |
+
"teacher_name": teacher_name,
|
927 |
+
"status": "success",
|
928 |
+
"enable_training": enable_training,
|
929 |
+
"base_path": existing_base,
|
930 |
+
"final_path": str(final_dir),
|
931 |
+
"distillation_time": total_time,
|
932 |
+
}
|
933 |
|
934 |
except Exception as e:
|
935 |
+
logger.exception(f"❌ Failed to process {teacher_model}")
|
936 |
+
return {
|
937 |
+
"teacher_model": teacher_model,
|
938 |
+
"teacher_name": teacher_name,
|
939 |
+
"status": "failed",
|
940 |
+
"error": str(e),
|
941 |
+
}
|
942 |
|
943 |
|
944 |
+
# =============================================================================
|
945 |
+
# MAIN EXECUTION FUNCTIONS
|
946 |
+
# =============================================================================
|
|
|
|
|
|
|
|
|
947 |
|
|
|
|
|
|
|
|
|
948 |
|
949 |
+
def run_local_distillation(
|
950 |
+
teacher_models: list[str] | None = None,
|
951 |
+
enable_training: bool = False,
|
952 |
+
pca_dims: int | None = None,
|
953 |
+
clear_cache: bool = False,
|
954 |
+
) -> dict[str, Any]:
|
955 |
+
"""Run distillation locally."""
|
956 |
+
logger.info("🖥️ Running distillation locally")
|
957 |
|
958 |
+
if teacher_models is None:
|
959 |
+
teacher_models = DEFAULT_TEACHER_MODELS
|
960 |
|
961 |
+
# Apply patches
|
962 |
+
patch_success = apply_local_patches()
|
963 |
+
if patch_success:
|
964 |
+
logger.info("✅ Successfully applied patches")
|
965 |
+
else:
|
966 |
+
logger.warning("⚠️ Failed to apply patches - some models may fail")
|
967 |
+
|
968 |
+
results = {}
|
969 |
+
successful_models = []
|
970 |
+
|
971 |
+
logger.info("🚀 Starting distillation workflow")
|
972 |
+
logger.info(f"📊 Processing {len(teacher_models)} teacher models")
|
973 |
+
logger.info(f"🎓 Training enabled: {enable_training}")
|
974 |
+
|
975 |
+
# Use default models if none specified
|
976 |
+
models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS
|
977 |
+
|
978 |
+
logger.info(f"📊 Teacher models to process: {len(models_to_distill)}")
|
979 |
+
for i, model in enumerate(models_to_distill, 1):
|
980 |
+
logger.info(f" {i}. {model}")
|
981 |
+
|
982 |
+
# Clear cache for problematic models if requested
|
983 |
+
if clear_cache:
|
984 |
+
logger.info("🧹 Clearing cache for known problematic models...")
|
985 |
+
problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"]
|
986 |
+
for model in problematic_models:
|
987 |
+
if model in models_to_distill:
|
988 |
+
clear_model_cache(model)
|
989 |
+
|
990 |
+
for teacher_model in models_to_distill:
|
991 |
+
result = distill_single_teacher(
|
992 |
+
teacher_model=teacher_model,
|
993 |
+
enable_training=enable_training,
|
994 |
+
use_beam_utilities=False,
|
995 |
+
pca_dims=pca_dims,
|
996 |
+
)
|
997 |
|
998 |
+
teacher_name = result["teacher_name"]
|
999 |
+
results[teacher_name] = result
|
1000 |
+
|
1001 |
+
if result["status"] == "success" or result["status"].startswith("skipped"):
|
1002 |
+
successful_models.append(teacher_name)
|
1003 |
+
|
1004 |
+
# Summary
|
1005 |
+
logger.info("\n🏆 DISTILLATION WORKFLOW COMPLETE!")
|
1006 |
+
logger.info(f"📊 Successful models: {len(successful_models)}")
|
1007 |
+
logger.info(f"🎓 Training mode: {'Enabled' if enable_training else 'Basic distillation only'}")
|
1008 |
+
|
1009 |
+
for model_name in successful_models:
|
1010 |
+
result = results[model_name]
|
1011 |
+
logger.info(f"✅ {model_name}: {result['teacher_model']}")
|
1012 |
+
|
1013 |
+
# Save results summary
|
1014 |
+
results_summary = {
|
1015 |
+
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
1016 |
+
"enable_training": enable_training,
|
1017 |
+
"successful_models": successful_models,
|
1018 |
+
"all_results": results,
|
1019 |
+
"total_successful": len(successful_models),
|
1020 |
+
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
|
1021 |
+
}
|
1022 |
|
1023 |
+
# Save results to file
|
1024 |
+
results_file = Path(LOCAL_BASE_DIR).parent / "distillation_results.json"
|
1025 |
+
results_file.parent.mkdir(parents=True, exist_ok=True)
|
1026 |
+
with results_file.open("w") as f:
|
1027 |
+
json.dump(results_summary, f, indent=2)
|
1028 |
|
1029 |
+
logger.info(f"📊 Results summary saved to: {results_file}")
|
|
|
|
|
1030 |
|
1031 |
+
return results_summary
|
1032 |
|
|
|
|
|
1033 |
|
1034 |
+
@function(
|
1035 |
+
gpu=GPU_NAME,
|
1036 |
+
volumes=[Volume(name=VOLUME_CONFIG.name, mount_path=VOLUME_CONFIG.mount_path)],
|
1037 |
+
image=IMAGE,
|
1038 |
+
secrets=["HF_ACCESS_TOKEN"],
|
1039 |
+
env=BEAM_ENV_SETTINGS,
|
1040 |
+
timeout=3600 * 12, # 12 hours
|
1041 |
+
)
|
1042 |
+
def _beam_distill_models(
|
1043 |
+
teacher_models: list[str] | None = None,
|
1044 |
+
enable_training: bool = False,
|
1045 |
+
pca_dims: int | None = None,
|
1046 |
+
clear_cache: bool = False,
|
1047 |
+
) -> dict[str, Any]:
|
1048 |
+
"""Internal Beam function for distillation."""
|
1049 |
+
logger.info("☁️ Running distillation on Beam")
|
1050 |
+
|
1051 |
+
# Apply patches
|
1052 |
+
patch_success = apply_local_patches()
|
1053 |
+
if patch_success:
|
1054 |
+
logger.info("✅ Successfully applied patches")
|
1055 |
+
else:
|
1056 |
+
logger.warning("⚠️ Failed to apply patches - some models may fail")
|
1057 |
+
|
1058 |
+
if teacher_models is None:
|
1059 |
+
teacher_models = DEFAULT_TEACHER_MODELS
|
1060 |
+
|
1061 |
+
# Clear cache for problematic models if requested
|
1062 |
+
if clear_cache:
|
1063 |
+
logger.info("🧹 Clearing cache for known problematic models...")
|
1064 |
+
problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"]
|
1065 |
+
for model in problematic_models:
|
1066 |
+
if model in teacher_models:
|
1067 |
+
clear_model_cache(model)
|
1068 |
+
|
1069 |
+
results = {}
|
1070 |
+
successful_models = []
|
1071 |
+
|
1072 |
+
logger.info("🚀 Starting Beam distillation workflow")
|
1073 |
+
logger.info(f"📊 Processing {len(teacher_models)} teacher models")
|
1074 |
+
logger.info(f"🎓 Training enabled: {enable_training}")
|
1075 |
+
|
1076 |
+
# Use default models if none specified
|
1077 |
+
models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS
|
1078 |
+
|
1079 |
+
logger.info(f"📊 Teacher models to process: {len(models_to_distill)}")
|
1080 |
+
for i, model in enumerate(models_to_distill, 1):
|
1081 |
+
logger.info(f" {i}. {model}")
|
1082 |
+
|
1083 |
+
for teacher_model in models_to_distill:
|
1084 |
+
result = distill_single_teacher(
|
1085 |
+
teacher_model=teacher_model,
|
1086 |
+
enable_training=enable_training,
|
1087 |
+
use_beam_utilities=True,
|
1088 |
+
pca_dims=pca_dims,
|
1089 |
)
|
1090 |
|
1091 |
+
teacher_name = result["teacher_name"]
|
1092 |
+
results[teacher_name] = result
|
1093 |
+
|
1094 |
+
if result["status"] == "success" or result["status"].startswith("skipped"):
|
1095 |
+
successful_models.append(teacher_name)
|
1096 |
+
|
1097 |
+
# Summary
|
1098 |
+
logger.info("\n🏆 BEAM DISTILLATION WORKFLOW COMPLETE!")
|
1099 |
+
logger.info(f"📊 Successful models: {len(successful_models)}")
|
1100 |
+
|
1101 |
+
# Save results to Beam volume
|
1102 |
+
volume_path = Path(VOLUME_CONFIG.mount_path)
|
1103 |
+
results_summary = {
|
1104 |
+
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
1105 |
+
"enable_training": enable_training,
|
1106 |
+
"successful_models": successful_models,
|
1107 |
+
"all_results": results,
|
1108 |
+
"total_successful": len(successful_models),
|
1109 |
+
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
|
1110 |
+
}
|
1111 |
+
|
1112 |
+
results_file = volume_path / "distillation_results.json"
|
1113 |
+
with results_file.open("w") as f:
|
1114 |
+
json.dump(results_summary, f, indent=2)
|
1115 |
+
|
1116 |
+
logger.info(f"📊 Beam results saved to: {results_file}")
|
1117 |
+
|
1118 |
+
return results_summary
|
1119 |
+
|
1120 |
+
|
1121 |
+
def run_beam_distillation(
|
1122 |
+
teacher_models: list[str] | None = None,
|
1123 |
+
enable_training: bool = False,
|
1124 |
+
pca_dims: int | None = None,
|
1125 |
+
clear_cache: bool = False,
|
1126 |
+
) -> dict[str, Any]:
|
1127 |
+
"""Run distillation on Beam and sync results."""
|
1128 |
+
logger.info("☁️ Running distillation on Beam with local sync")
|
1129 |
+
|
1130 |
+
try:
|
1131 |
+
# Run distillation on Beam
|
1132 |
+
results = _beam_distill_models.remote(teacher_models, enable_training, pca_dims, clear_cache)
|
1133 |
+
|
1134 |
+
# Check if Beam execution was successful
|
1135 |
+
if not results:
|
1136 |
+
logger.error("❌ Beam execution failed or returned no results")
|
1137 |
+
return {
|
1138 |
+
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
1139 |
+
"enable_training": enable_training,
|
1140 |
+
"successful_models": [],
|
1141 |
+
"all_results": {},
|
1142 |
+
"total_successful": 0,
|
1143 |
+
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
|
1144 |
+
"error": "Beam execution failed",
|
1145 |
+
}
|
1146 |
+
|
1147 |
+
# Sync models back to local directories
|
1148 |
+
if results.get("successful_models"):
|
1149 |
+
logger.info("📥 Syncing models from Beam to local directories...")
|
1150 |
+
|
1151 |
+
for teacher_name in results["successful_models"]:
|
1152 |
+
# Sync base model
|
1153 |
+
base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}"
|
1154 |
+
sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities=True)
|
1155 |
+
|
1156 |
+
# Sync final model if training was enabled
|
1157 |
+
if enable_training:
|
1158 |
+
final_dir = Path(LOCAL_FINAL_DIR) / f"code_model2vec_{teacher_name}"
|
1159 |
+
sync_model_from_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities=True)
|
1160 |
+
else:
|
1161 |
+
# Copy base to final
|
1162 |
+
copy_base_to_final(teacher_name, enable_training)
|
1163 |
+
|
1164 |
+
logger.info("✅ All models synced from Beam")
|
1165 |
+
|
1166 |
+
return results
|
1167 |
|
1168 |
except Exception as e:
|
1169 |
+
logger.exception("❌ Beam distillation failed with exception")
|
1170 |
+
return {
|
1171 |
+
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
1172 |
+
"enable_training": enable_training,
|
1173 |
+
"successful_models": [],
|
1174 |
+
"all_results": {},
|
1175 |
+
"total_successful": 0,
|
1176 |
+
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
|
1177 |
+
"error": str(e),
|
1178 |
+
}
|
1179 |
|
1180 |
|
1181 |
+
# =============================================================================
|
1182 |
+
# CLI INTERFACE
|
1183 |
+
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1184 |
|
|
|
|
|
|
|
|
|
1185 |
|
1186 |
+
def main(
|
1187 |
+
use_beam: Annotated[bool, typer.Option(help="Use Beam for distillation")] = False,
|
1188 |
+
train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False,
|
1189 |
+
teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None,
|
1190 |
+
pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None,
|
1191 |
+
clear_cache: Annotated[
|
1192 |
+
bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation")
|
1193 |
+
] = False,
|
1194 |
+
) -> None:
|
1195 |
+
"""Unified distillation command with optional training."""
|
1196 |
+
logger.info("🚀 Starting unified Model2Vec distillation workflow")
|
1197 |
+
logger.info(f"🎓 Training mode: {'Advanced (CodeSearchNet fine-tuning)' if train else 'Basic distillation only'}")
|
1198 |
+
logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
|
1199 |
+
|
1200 |
+
# Use default models if none specified
|
1201 |
+
models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS
|
1202 |
+
|
1203 |
+
logger.info(f"📊 Teacher models to process: {len(models_to_distill)}")
|
1204 |
+
for i, model in enumerate(models_to_distill, 1):
|
1205 |
+
logger.info(f" {i}. {model}")
|
1206 |
+
|
1207 |
+
# Clear cache for problematic models if requested
|
1208 |
+
if clear_cache:
|
1209 |
+
logger.info("🧹 Clearing cache for known problematic models...")
|
1210 |
+
problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"]
|
1211 |
+
for model in problematic_models:
|
1212 |
+
if model in models_to_distill:
|
1213 |
+
clear_model_cache(model)
|
1214 |
+
|
1215 |
+
# Run distillation workflow
|
1216 |
+
if use_beam:
|
1217 |
+
results = run_beam_distillation(
|
1218 |
+
teacher_models=models_to_distill,
|
1219 |
+
enable_training=train,
|
1220 |
+
pca_dims=pca_dims,
|
1221 |
+
clear_cache=clear_cache,
|
1222 |
+
)
|
1223 |
+
else:
|
1224 |
+
results = run_local_distillation(
|
1225 |
+
teacher_models=models_to_distill,
|
1226 |
+
enable_training=train,
|
1227 |
+
pca_dims=pca_dims,
|
1228 |
+
clear_cache=clear_cache,
|
1229 |
+
)
|
1230 |
+
|
1231 |
+
# Handle case where results might be None or invalid
|
1232 |
+
if not results or not isinstance(results, dict):
|
1233 |
+
logger.error("❌ Distillation workflow failed - no valid results returned")
|
1234 |
+
results = {
|
1235 |
+
"total_successful": 0,
|
1236 |
+
"total_attempted": len(models_to_distill),
|
1237 |
+
"error": "Workflow failed",
|
1238 |
+
}
|
1239 |
+
|
1240 |
+
# Final summary
|
1241 |
+
successful_count = results.get("total_successful", 0)
|
1242 |
+
total_attempted = results.get("total_attempted", 0)
|
1243 |
+
|
1244 |
+
logger.info("\n🎉 UNIFIED DISTILLATION WORKFLOW COMPLETED!")
|
1245 |
+
logger.info(f"📊 Successfully processed: {successful_count}/{total_attempted} models")
|
1246 |
+
logger.info(f"📁 Base models saved to: {LOCAL_BASE_DIR}")
|
1247 |
+
logger.info(f"📁 Final models saved to: {LOCAL_FINAL_DIR}")
|
1248 |
+
|
1249 |
+
if train:
|
1250 |
+
logger.info("🎓 Advanced training was enabled - models include CodeSearchNet specialization")
|
1251 |
+
else:
|
1252 |
+
logger.info("📖 Basic distillation only - use --train flag to enable advanced training")
|
1253 |
|
|
|
|
|
1254 |
|
1255 |
+
def check_model_compatibility(teacher_model: str) -> tuple[bool, str | None]:
|
1256 |
+
"""
|
1257 |
+
Check if a model has known compatibility issues with Model2Vec.
|
1258 |
+
|
1259 |
+
Returns:
|
1260 |
+
Tuple of (is_compatible, warning_message)
|
1261 |
+
"""
|
1262 |
+
known_incompatible = {
|
1263 |
+
"BAAI/bge-code-v1": "Qwen2Tokenizer lacks backend_tokenizer attribute",
|
1264 |
+
"jinaai/jina-embeddings-v3": "Missing custom transformers module dependencies",
|
1265 |
+
"Salesforce/SFR-Embedding-Code-2B_R": "Device placement issues with meta tensors",
|
1266 |
}
|
1267 |
|
1268 |
+
if teacher_model in known_incompatible:
|
1269 |
+
return False, known_incompatible[teacher_model]
|
|
|
|
|
|
|
|
|
|
|
1270 |
|
1271 |
+
# Check for model families that might have issues
|
1272 |
+
if "qwen2" in teacher_model.lower() and "bge" in teacher_model.lower():
|
1273 |
+
return False, "BGE models with Qwen2 tokenizers may have compatibility issues"
|
1274 |
+
|
1275 |
+
if "jina" in teacher_model.lower() and "embeddings-v3" in teacher_model.lower():
|
1276 |
+
return False, "Jina embeddings v3 models may have missing dependencies"
|
1277 |
+
|
1278 |
+
if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower():
|
1279 |
+
return False, "Salesforce SFR embedding models may have device placement issues"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1280 |
|
1281 |
+
return True, None
|
|
|
|
|
|
|
|
|
1282 |
|
1283 |
+
|
1284 |
+
def clear_model_cache(model_name: str) -> bool:
|
1285 |
+
"""Clear HuggingFace cache for a specific model."""
|
1286 |
try:
|
1287 |
+
import shutil
|
1288 |
+
from pathlib import Path
|
1289 |
+
|
1290 |
+
# Get HuggingFace cache directory
|
1291 |
+
cache_dir = Path.home() / ".cache" / "huggingface"
|
1292 |
+
|
1293 |
+
# Find model-specific cache directories
|
1294 |
+
model_slug = model_name.replace("/", "--")
|
1295 |
+
|
1296 |
+
# Clear transformers cache
|
1297 |
+
transformers_cache = cache_dir / "transformers" / model_slug
|
1298 |
+
if transformers_cache.exists():
|
1299 |
+
shutil.rmtree(transformers_cache)
|
1300 |
+
logger.info(f"🗑️ Cleared transformers cache for {model_name}")
|
1301 |
+
|
1302 |
+
# Clear hub cache
|
1303 |
+
hub_cache = cache_dir / "hub" / f"models--{model_slug}"
|
1304 |
+
if hub_cache.exists():
|
1305 |
+
shutil.rmtree(hub_cache)
|
1306 |
+
logger.info(f"🗑️ Cleared hub cache for {model_name}")
|
1307 |
+
|
1308 |
+
# Clear modules cache
|
1309 |
+
modules_cache = cache_dir / "modules" / "transformers_modules" / model_name.split("/")[0]
|
1310 |
+
if modules_cache.exists():
|
1311 |
+
shutil.rmtree(modules_cache)
|
1312 |
+
logger.info(f"🗑️ Cleared modules cache for {model_name}")
|
1313 |
+
|
1314 |
+
return True
|
1315 |
+
|
1316 |
except Exception as e:
|
1317 |
+
logger.warning(f"Failed to clear cache for {model_name}: {e}")
|
1318 |
+
return False
|
1319 |
+
|
1320 |
+
|
1321 |
+
def try_model_workarounds(teacher_model: str) -> str | None:
|
1322 |
+
"""
|
1323 |
+
Try specific workarounds for problematic models.
|
1324 |
+
|
1325 |
+
Returns:
|
1326 |
+
The type of workaround needed ("salesforce", "baai", etc.) or None if no workaround available
|
1327 |
+
"""
|
1328 |
+
if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower():
|
1329 |
+
logger.info("🔧 Salesforce SFR model detected - will use specialized distillation")
|
1330 |
+
return "salesforce"
|
1331 |
+
|
1332 |
+
if "baai" in teacher_model.lower() and ("bge-code" in teacher_model.lower() or "bge-m3" in teacher_model.lower()):
|
1333 |
+
logger.info("🔧 BAAI BGE model detected - will use specialized distillation")
|
1334 |
+
return "baai"
|
1335 |
|
1336 |
+
return None
|
1337 |
|
1338 |
|
1339 |
+
def salesforce_model_distillation(
|
1340 |
+
teacher_model: str,
|
1341 |
+
output_dir: str,
|
1342 |
+
pca_dims: int | None = None,
|
|
|
|
|
1343 |
) -> Any:
|
1344 |
+
"""Special distillation function for Salesforce SFR models that handles device placement issues."""
|
1345 |
+
if pca_dims is None:
|
1346 |
+
pca_dims = int(distillation_config.optimal_pca_dims)
|
1347 |
+
|
1348 |
output_path = Path(output_dir)
|
1349 |
output_path.mkdir(parents=True, exist_ok=True)
|
1350 |
|
1351 |
+
logger.info(f"🔄 Salesforce-specific distillation: {teacher_model} → {output_dir}")
|
1352 |
+
logger.info(f"📊 PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1353 |
|
1354 |
start_time = time.time()
|
1355 |
|
1356 |
+
try:
|
1357 |
+
import torch
|
1358 |
+
from sentence_transformers import SentenceTransformer
|
1359 |
+
from transformers import AutoModel, AutoTokenizer
|
1360 |
|
1361 |
+
# Enhanced custom model loading for Salesforce models
|
1362 |
+
logger.info("🔧 Loading model with enhanced device settings...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1363 |
|
1364 |
+
# Method 1: Try with to_empty() for meta tensor handling
|
1365 |
+
try:
|
1366 |
+
logger.info("🔄 Attempting with to_empty() method...")
|
|
|
|
|
|
|
|
|
1367 |
|
1368 |
+
# Load tokenizer first
|
1369 |
+
tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1370 |
|
1371 |
+
# Load model with meta device initially
|
1372 |
+
model = AutoModel.from_pretrained(
|
1373 |
+
teacher_model,
|
|
|
|
|
|
|
1374 |
trust_remote_code=True,
|
1375 |
+
torch_dtype=torch.float16,
|
1376 |
+
device_map="meta", # Load on meta device first
|
1377 |
)
|
|
|
1378 |
|
1379 |
+
# Move from meta to actual device using to_empty()
|
1380 |
+
if torch.cuda.is_available():
|
1381 |
+
device = torch.device("cuda")
|
1382 |
+
# Create empty tensors on target device and copy weights
|
1383 |
+
model = model.to_empty(device=device)
|
1384 |
+
else:
|
1385 |
+
device = torch.device("cpu")
|
1386 |
+
model = model.to_empty(device=device)
|
1387 |
|
1388 |
+
# Ensure model is in the right dtype
|
1389 |
+
model = model.to(torch.float16 if torch.cuda.is_available() else torch.float32)
|
1390 |
|
1391 |
+
logger.info("✅ Successfully loaded with to_empty() method")
|
|
|
|
|
1392 |
|
1393 |
+
except Exception as e:
|
1394 |
+
logger.warning(f"to_empty() method failed: {e}")
|
1395 |
|
1396 |
+
# Method 2: Try SentenceTransformer with specific settings
|
1397 |
+
logger.info("🔄 Falling back to SentenceTransformer method...")
|
1398 |
+
sentence_model = SentenceTransformer(
|
1399 |
+
teacher_model,
|
|
|
|
|
1400 |
trust_remote_code=True,
|
1401 |
+
device="cpu", # Force CPU loading first
|
1402 |
)
|
|
|
1403 |
|
1404 |
+
# Move to GPU if available
|
1405 |
+
if torch.cuda.is_available():
|
1406 |
+
sentence_model = sentence_model.to("cuda")
|
|
|
1407 |
|
1408 |
+
# Extract components
|
1409 |
+
model = sentence_model[0].auto_model
|
1410 |
+
tokenizer = sentence_model.tokenizer
|
1411 |
|
1412 |
+
logger.info("✅ Successfully loaded with SentenceTransformer method")
|
|
|
1413 |
|
1414 |
+
# Now use Model2Vec's distill_from_model function directly
|
1415 |
+
from model2vec.distill.distillation import distill_from_model
|
|
|
1416 |
|
1417 |
+
distilled_model = distill_from_model(
|
1418 |
+
model=model,
|
1419 |
+
tokenizer=tokenizer,
|
1420 |
+
pca_dims=int(pca_dims),
|
1421 |
+
apply_zipf=bool(distillation_config.apply_zipf),
|
1422 |
+
sif_coefficient=float(distillation_config.sif_coefficient),
|
1423 |
+
)
|
1424 |
|
1425 |
+
logger.info("✅ Core distillation completed successfully")
|
|
|
|
|
1426 |
|
1427 |
+
# Save the model
|
1428 |
+
distilled_model.save_pretrained(str(output_path))
|
1429 |
+
logger.info(f"💾 Model saved to {output_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1430 |
|
1431 |
+
# Log model info
|
1432 |
+
logger.info(f"Model type: {type(distilled_model)}")
|
1433 |
+
if hasattr(distilled_model, "embedding"):
|
1434 |
+
logger.info(f"Embedding shape: {distilled_model.embedding.shape}")
|
1435 |
+
logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}")
|
1436 |
|
1437 |
+
total_time = time.time() - start_time
|
1438 |
+
logger.info(f"🎉 Salesforce distillation completed in {total_time:.2f} seconds")
|
|
|
|
|
1439 |
|
1440 |
+
# Clean up
|
1441 |
+
if "sentence_model" in locals():
|
1442 |
+
del sentence_model
|
1443 |
+
del model
|
1444 |
+
if torch.cuda.is_available():
|
1445 |
+
torch.cuda.empty_cache()
|
|
|
|
|
|
|
1446 |
|
1447 |
+
return distilled_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1448 |
|
1449 |
+
except Exception:
|
1450 |
+
logger.exception(f"❌ Salesforce-specific distillation failed for {teacher_model}")
|
1451 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
1452 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1453 |
|
1454 |
+
def baai_bge_model_distillation(
|
1455 |
+
teacher_model: str,
|
1456 |
+
output_dir: str,
|
1457 |
+
pca_dims: int | None = None,
|
1458 |
+
) -> Any:
|
1459 |
+
"""Special distillation function for BAAI BGE models that handles Qwen2Tokenizer compatibility issues."""
|
1460 |
+
if pca_dims is None:
|
1461 |
+
pca_dims = int(distillation_config.optimal_pca_dims)
|
1462 |
|
1463 |
+
output_path = Path(output_dir)
|
1464 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
1465 |
|
1466 |
+
logger.info(f"🔄 BAAI BGE-specific distillation: {teacher_model} → {output_dir}")
|
1467 |
+
logger.info(f"📊 PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1468 |
|
1469 |
+
start_time = time.time()
|
|
|
|
|
1470 |
|
1471 |
+
try:
|
1472 |
+
import torch
|
1473 |
+
from sentence_transformers import SentenceTransformer
|
1474 |
+
from transformers import AutoModel, AutoTokenizer
|
1475 |
|
1476 |
+
logger.info("🔧 Loading BAAI model with tokenizer workaround...")
|
|
|
1477 |
|
1478 |
+
# Try multiple approaches for BAAI models
|
1479 |
+
success = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1480 |
|
1481 |
+
# Method 1: Try SentenceTransformer first (often handles tokenizer issues better)
|
1482 |
+
try:
|
1483 |
+
logger.info("🔄 Attempting with SentenceTransformer wrapper...")
|
1484 |
+
sentence_model = SentenceTransformer(teacher_model, trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
|
1485 |
|
1486 |
+
# Extract components
|
1487 |
+
model = sentence_model[0].auto_model
|
1488 |
+
tokenizer = sentence_model.tokenizer
|
1489 |
|
1490 |
+
# Test if tokenizer works by encoding a simple text
|
1491 |
+
test_encoding = tokenizer.encode("test", return_tensors="pt")
|
1492 |
+
logger.info("✅ SentenceTransformer method successful")
|
1493 |
+
success = True
|
1494 |
|
1495 |
+
except Exception as e:
|
1496 |
+
logger.warning(f"SentenceTransformer method failed: {e}")
|
1497 |
|
1498 |
+
# Method 2: Try direct loading with tokenizer replacement
|
1499 |
+
try:
|
1500 |
+
logger.info("🔄 Attempting with tokenizer replacement...")
|
1501 |
+
from transformers import BertTokenizerFast
|
1502 |
|
1503 |
+
# Load model directly
|
1504 |
+
model = AutoModel.from_pretrained(teacher_model, trust_remote_code=True)
|
|
|
1505 |
|
1506 |
+
# Try to use a compatible tokenizer instead
|
1507 |
+
try:
|
1508 |
+
# First try the original tokenizer
|
1509 |
+
tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True)
|
1510 |
+
except Exception:
|
1511 |
+
# Fall back to BERT tokenizer for BGE models
|
1512 |
+
logger.info("🔄 Falling back to BERT tokenizer...")
|
1513 |
+
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
1514 |
+
|
1515 |
+
logger.info("✅ Tokenizer replacement method successful")
|
1516 |
+
success = True
|
1517 |
+
|
1518 |
+
except Exception as e2:
|
1519 |
+
logger.warning(f"Tokenizer replacement method failed: {e2}")
|
1520 |
+
|
1521 |
+
if not success:
|
1522 |
+
logger.error("❌ All BAAI model loading methods failed")
|
1523 |
+
return None
|
1524 |
+
|
1525 |
+
# Now use Model2Vec's distill_from_model function directly
|
1526 |
+
from model2vec.distill.distillation import distill_from_model
|
1527 |
+
|
1528 |
+
distilled_model = distill_from_model(
|
1529 |
+
model=model,
|
1530 |
+
tokenizer=tokenizer,
|
1531 |
+
pca_dims=int(pca_dims),
|
1532 |
+
apply_zipf=bool(distillation_config.apply_zipf),
|
1533 |
+
sif_coefficient=float(distillation_config.sif_coefficient),
|
1534 |
+
)
|
1535 |
|
1536 |
+
logger.info("✅ Core distillation completed successfully")
|
1537 |
|
1538 |
+
# Save the model
|
1539 |
+
distilled_model.save_pretrained(str(output_path))
|
1540 |
+
logger.info(f"💾 Model saved to {output_path}")
|
1541 |
|
1542 |
+
# Log model info
|
1543 |
+
logger.info(f"Model type: {type(distilled_model)}")
|
1544 |
+
if hasattr(distilled_model, "embedding"):
|
1545 |
+
logger.info(f"Embedding shape: {distilled_model.embedding.shape}")
|
1546 |
+
logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}")
|
1547 |
|
1548 |
+
total_time = time.time() - start_time
|
1549 |
+
logger.info(f"🎉 BAAI BGE distillation completed in {total_time:.2f} seconds")
|
1550 |
|
1551 |
+
# Clean up
|
1552 |
+
if "sentence_model" in locals():
|
1553 |
+
del sentence_model
|
1554 |
+
del model
|
1555 |
+
if torch.cuda.is_available():
|
1556 |
+
torch.cuda.empty_cache()
|
1557 |
|
1558 |
+
return distilled_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1559 |
|
1560 |
+
except Exception:
|
1561 |
+
logger.exception(f"❌ BAAI BGE-specific distillation failed for {teacher_model}")
|
1562 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1563 |
|
1564 |
|
1565 |
if __name__ == "__main__":
|
1566 |
+
typer.run(main)
|
src/distiller/distill_simplified.py
DELETED
@@ -1,413 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Simplified Code-Specialized Model2Vec Distillation Script.
|
3 |
-
|
4 |
-
This script implements a focused, simplified approach for creating code-specialized embeddings
|
5 |
-
using only the core Model2Vec distillation without additional fine-tuning that may degrade quality.
|
6 |
-
|
7 |
-
Can run locally or on Beam with the --use-beam flag.
|
8 |
-
"""
|
9 |
-
|
10 |
-
import argparse
|
11 |
-
import json
|
12 |
-
import logging
|
13 |
-
import sys
|
14 |
-
import time
|
15 |
-
from pathlib import Path
|
16 |
-
from typing import Any
|
17 |
-
|
18 |
-
from beam import GpuType, Image, Volume, function
|
19 |
-
from model2vec.distill import distill
|
20 |
-
|
21 |
-
# =============================================================================
|
22 |
-
# SIMPLIFIED CONFIGURATION
|
23 |
-
# =============================================================================
|
24 |
-
|
25 |
-
# Use a code-specialized teacher model instead of general instruction model
|
26 |
-
# Ordered by success likelihood and performance:
|
27 |
-
CODE_TEACHER_MODELS = [
|
28 |
-
"sentence-transformers/all-MiniLM-L6-v2",
|
29 |
-
"sentence-transformers/all-mpnet-base-v2",
|
30 |
-
"microsoft/codebert-base",
|
31 |
-
"microsoft/graphcodebert-base",
|
32 |
-
"sentence-transformers/paraphrase-MiniLM-L6-v2",
|
33 |
-
"Alibaba-NLP/gte-Qwen2-7B-instruct",
|
34 |
-
]
|
35 |
-
|
36 |
-
OUTPUT_BASE_DIR = "code_model2vec"
|
37 |
-
|
38 |
-
# Optimal Model2Vec parameters based on successful models
|
39 |
-
OPTIMAL_PCA_DIMS = 256 # Match other successful Model2Vec models
|
40 |
-
SIF_COEFFICIENT = 1e-3 # Slightly higher than default for code specialization
|
41 |
-
APPLY_ZIPF = True # Enable Zipf weighting for better word importance
|
42 |
-
|
43 |
-
# =============================================================================
|
44 |
-
# BEAM CONFIGURATION
|
45 |
-
# =============================================================================
|
46 |
-
|
47 |
-
GPU_NAME = GpuType.A100_40
|
48 |
-
VOLUME_NAME = "code_model2vec"
|
49 |
-
VOLUME_PATH = "./code_model2vec"
|
50 |
-
IMAGE = Image(python_version="python3.12").add_python_packages(
|
51 |
-
[
|
52 |
-
"torch>=2.7.0", # Install torch first
|
53 |
-
"transformers>=4.40.0", # Latest transformers with flash attention support
|
54 |
-
"lightning>=2.5.1.post0",
|
55 |
-
"model2vec[train]>=0.5.0",
|
56 |
-
"numpy>=1.26.4",
|
57 |
-
"scikit-learn>=1.6.1",
|
58 |
-
"sentence-transformers>=4.1.0",
|
59 |
-
"datasets>=3.2.0", # For evaluation
|
60 |
-
"pandas>=2.0.0",
|
61 |
-
"tqdm>=4.65.0",
|
62 |
-
]
|
63 |
-
)
|
64 |
-
|
65 |
-
# =============================================================================
|
66 |
-
|
67 |
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
68 |
-
logger = logging.getLogger(__name__)
|
69 |
-
|
70 |
-
# Add beam utilities for proper model persistence
|
71 |
-
try:
|
72 |
-
from .beam_utils import (
|
73 |
-
create_beam_utilities,
|
74 |
-
)
|
75 |
-
|
76 |
-
BEAM_UTILS_AVAILABLE = True
|
77 |
-
except ImportError:
|
78 |
-
print("Beam utilities not available - models will only be saved locally")
|
79 |
-
BEAM_UTILS_AVAILABLE = False
|
80 |
-
|
81 |
-
|
82 |
-
def apply_local_patches() -> bool:
|
83 |
-
"""Apply patches locally without requiring Beam utilities."""
|
84 |
-
try:
|
85 |
-
# Try using patch_utils if available
|
86 |
-
try:
|
87 |
-
from .patch_utils import apply_all_patches
|
88 |
-
|
89 |
-
patches_applied = apply_all_patches()
|
90 |
-
logger.info(f"Successfully applied {patches_applied} patches via patch_utils")
|
91 |
-
return True
|
92 |
-
except ImportError:
|
93 |
-
logger.warning("patch_utils not available, trying direct patching")
|
94 |
-
|
95 |
-
return False
|
96 |
-
|
97 |
-
except Exception as e:
|
98 |
-
logger.warning(f"Failed to apply patches: {e}")
|
99 |
-
return False
|
100 |
-
|
101 |
-
|
102 |
-
def simplified_code_distillation(
|
103 |
-
teacher_model: str,
|
104 |
-
output_dir: str,
|
105 |
-
pca_dims: int = OPTIMAL_PCA_DIMS,
|
106 |
-
) -> Any:
|
107 |
-
"""
|
108 |
-
Simplified code-specialized distillation using only core Model2Vec.
|
109 |
-
|
110 |
-
This approach:
|
111 |
-
1. Uses a teacher model that already performs well on code tasks
|
112 |
-
2. Applies optimal Model2Vec parameters
|
113 |
-
3. Avoids additional training that may degrade quality
|
114 |
-
"""
|
115 |
-
output_path = Path(output_dir)
|
116 |
-
output_path.mkdir(parents=True, exist_ok=True)
|
117 |
-
|
118 |
-
logger.info(f"Starting simplified distillation from {teacher_model}")
|
119 |
-
logger.info(f"Target dimensions: {pca_dims}")
|
120 |
-
logger.info(f"SIF coefficient: {SIF_COEFFICIENT}")
|
121 |
-
logger.info(f"Zipf weighting: {APPLY_ZIPF}")
|
122 |
-
|
123 |
-
start_time = time.time()
|
124 |
-
|
125 |
-
try:
|
126 |
-
# Perform distillation with optimal parameters
|
127 |
-
model = distill(
|
128 |
-
model_name=teacher_model,
|
129 |
-
pca_dims=pca_dims,
|
130 |
-
apply_zipf=APPLY_ZIPF,
|
131 |
-
sif_coefficient=SIF_COEFFICIENT,
|
132 |
-
trust_remote_code=True,
|
133 |
-
)
|
134 |
-
|
135 |
-
logger.info("✅ Core distillation completed successfully")
|
136 |
-
|
137 |
-
# Save the model
|
138 |
-
model.save_pretrained(str(output_path))
|
139 |
-
logger.info(f"💾 Model saved to {output_path}")
|
140 |
-
|
141 |
-
# Log model info
|
142 |
-
logger.info(f"Model type: {type(model)}")
|
143 |
-
if hasattr(model, "embedding"):
|
144 |
-
logger.info(f"Embedding shape: {model.embedding.shape}")
|
145 |
-
logger.info(f"Embedding dtype: {model.embedding.dtype}")
|
146 |
-
|
147 |
-
total_time = time.time() - start_time
|
148 |
-
logger.info(f"🎉 Simplified distillation completed in {total_time:.2f} seconds")
|
149 |
-
return model
|
150 |
-
|
151 |
-
except ValueError as e:
|
152 |
-
if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
|
153 |
-
logger.warning(f"⚠️ Token-vector mismatch with {teacher_model} - this is a Model2Vec library issue")
|
154 |
-
logger.warning(f"Error details: {e}")
|
155 |
-
logger.warning("💡 This model has incompatible tokenization. Skipping...")
|
156 |
-
return None
|
157 |
-
raise
|
158 |
-
except Exception:
|
159 |
-
logger.exception("❌ Distillation failed")
|
160 |
-
return None
|
161 |
-
|
162 |
-
|
163 |
-
def core_distill_all_teachers(use_beam_utilities: bool = False) -> dict[str, Any]:
|
164 |
-
"""
|
165 |
-
Core logic for distilling all teacher models.
|
166 |
-
|
167 |
-
Args:
|
168 |
-
use_beam_utilities: Whether to use Beam utilities for persistence
|
169 |
-
|
170 |
-
Returns:
|
171 |
-
Dictionary with distillation results
|
172 |
-
"""
|
173 |
-
# Apply patches
|
174 |
-
logger.info("Applying all patches...")
|
175 |
-
patch_success = apply_local_patches()
|
176 |
-
if patch_success:
|
177 |
-
logger.info("Successfully applied patches")
|
178 |
-
else:
|
179 |
-
logger.warning("Failed to apply patches - Microsoft models may fail")
|
180 |
-
|
181 |
-
# Initialize Beam utilities if requested and available
|
182 |
-
volume_mgr = None
|
183 |
-
model_mgr = None
|
184 |
-
if use_beam_utilities and BEAM_UTILS_AVAILABLE:
|
185 |
-
try:
|
186 |
-
volume_mgr, _, model_mgr, _ = create_beam_utilities(VOLUME_NAME, VOLUME_PATH)
|
187 |
-
logger.info("✅ Beam utilities initialized for model persistence")
|
188 |
-
except Exception as e:
|
189 |
-
logger.warning(f"Failed to initialize Beam utilities: {e}")
|
190 |
-
model_mgr = None
|
191 |
-
|
192 |
-
results = {}
|
193 |
-
successful_models = []
|
194 |
-
|
195 |
-
logger.info("🚀 Starting comprehensive teacher model distillation")
|
196 |
-
logger.info(f"📊 Processing {len(CODE_TEACHER_MODELS)} teacher models")
|
197 |
-
|
198 |
-
# Determine output base path
|
199 |
-
base_output_path = VOLUME_PATH if use_beam_utilities else OUTPUT_BASE_DIR
|
200 |
-
|
201 |
-
for teacher_model in CODE_TEACHER_MODELS:
|
202 |
-
try:
|
203 |
-
# Create output directory name based on teacher model
|
204 |
-
teacher_name = teacher_model.split("/")[-1].replace("-", "_")
|
205 |
-
output_dir = f"{base_output_path}/final/code_model2vec_{teacher_name}"
|
206 |
-
|
207 |
-
logger.info(f"\n{'=' * 60}")
|
208 |
-
logger.info(f"🔄 Processing teacher model: {teacher_model}")
|
209 |
-
logger.info(f"📁 Output directory: {output_dir}")
|
210 |
-
logger.info(f"{'=' * 60}")
|
211 |
-
|
212 |
-
# Check if model already exists
|
213 |
-
output_path = Path(output_dir)
|
214 |
-
if output_path.exists():
|
215 |
-
# Check for essential model files
|
216 |
-
has_config = (output_path / "config.json").exists()
|
217 |
-
has_model_file = any(
|
218 |
-
[
|
219 |
-
(output_path / "model.safetensors").exists(),
|
220 |
-
(output_path / "model.bin").exists(),
|
221 |
-
(output_path / "pytorch_model.bin").exists(),
|
222 |
-
]
|
223 |
-
)
|
224 |
-
|
225 |
-
if has_config and has_model_file:
|
226 |
-
logger.info(f"✅ Model {teacher_name} already exists - skipping distillation")
|
227 |
-
|
228 |
-
# Still record it as successful
|
229 |
-
model_info = {
|
230 |
-
"teacher_model": teacher_model,
|
231 |
-
"output_dir": output_dir,
|
232 |
-
"teacher_name": teacher_name,
|
233 |
-
"distillation_time": 0.0,
|
234 |
-
"status": "skipped_existing",
|
235 |
-
}
|
236 |
-
|
237 |
-
results[teacher_name] = model_info
|
238 |
-
successful_models.append(teacher_name)
|
239 |
-
logger.info(f"📁 Using existing model at: {output_dir}")
|
240 |
-
continue
|
241 |
-
|
242 |
-
# Perform distillation
|
243 |
-
start_time = time.time()
|
244 |
-
model = simplified_code_distillation(
|
245 |
-
teacher_model=teacher_model,
|
246 |
-
output_dir=output_dir,
|
247 |
-
)
|
248 |
-
distill_time = time.time() - start_time
|
249 |
-
|
250 |
-
if model is not None:
|
251 |
-
logger.info(f"✅ Distillation successful for {teacher_model}")
|
252 |
-
|
253 |
-
# Save to Beam volume for persistence if available
|
254 |
-
if model_mgr:
|
255 |
-
try:
|
256 |
-
# Save model to beam volume with teacher-specific name
|
257 |
-
beam_model_name = f"{teacher_name}_model"
|
258 |
-
model_mgr.save_model(beam_model_name, output_dir)
|
259 |
-
logger.info(f"💾 Saved {teacher_name} to Beam volume as {beam_model_name}")
|
260 |
-
except Exception as e:
|
261 |
-
logger.warning(f"Failed to save {teacher_name} to Beam volume: {e}")
|
262 |
-
|
263 |
-
# Store results
|
264 |
-
model_info = {
|
265 |
-
"teacher_model": teacher_model,
|
266 |
-
"output_dir": output_dir,
|
267 |
-
"teacher_name": teacher_name,
|
268 |
-
"distillation_time": distill_time,
|
269 |
-
"status": "success",
|
270 |
-
}
|
271 |
-
|
272 |
-
results[teacher_name] = model_info
|
273 |
-
successful_models.append(teacher_name)
|
274 |
-
|
275 |
-
logger.info(f"💾 Model saved to: {output_dir}")
|
276 |
-
|
277 |
-
except Exception as e:
|
278 |
-
logger.exception(f"❌ Failed with {teacher_model}")
|
279 |
-
results[teacher_model.split("/")[-1]] = {
|
280 |
-
"teacher_model": teacher_model,
|
281 |
-
"status": "failed",
|
282 |
-
"error": str(e),
|
283 |
-
}
|
284 |
-
continue
|
285 |
-
|
286 |
-
# Summary
|
287 |
-
if successful_models:
|
288 |
-
logger.info("\n🏆 DISTILLATION COMPLETE!")
|
289 |
-
logger.info(f"📊 Successful models: {len(successful_models)}")
|
290 |
-
|
291 |
-
for model_name in successful_models:
|
292 |
-
model_info = results[model_name]
|
293 |
-
logger.info(f"✅ {model_name}: {model_info['teacher_model']}")
|
294 |
-
|
295 |
-
# Save comprehensive results
|
296 |
-
results_summary = {
|
297 |
-
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
298 |
-
"successful_models": successful_models,
|
299 |
-
"all_results": results,
|
300 |
-
"total_successful": len(successful_models),
|
301 |
-
"total_attempted": len(CODE_TEACHER_MODELS),
|
302 |
-
}
|
303 |
-
|
304 |
-
# Save results to file
|
305 |
-
results_file = Path(f"{base_output_path}/distillation_results.json")
|
306 |
-
results_file.parent.mkdir(parents=True, exist_ok=True)
|
307 |
-
with results_file.open("w") as f:
|
308 |
-
json.dump(results_summary, f, indent=2)
|
309 |
-
|
310 |
-
logger.info(f"📊 Results summary saved to: {results_file}")
|
311 |
-
|
312 |
-
return results_summary
|
313 |
-
|
314 |
-
logger.error("❌ No models succeeded")
|
315 |
-
msg = "All teacher models failed distillation"
|
316 |
-
raise RuntimeError(msg)
|
317 |
-
|
318 |
-
|
319 |
-
def run_local_distillation() -> dict[str, Any]:
|
320 |
-
"""Run distillation locally without Beam."""
|
321 |
-
logger.info("🖥️ Running simplified distillation locally")
|
322 |
-
return core_distill_all_teachers(use_beam_utilities=False)
|
323 |
-
|
324 |
-
|
325 |
-
@function(
|
326 |
-
gpu=GPU_NAME,
|
327 |
-
volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
|
328 |
-
image=IMAGE,
|
329 |
-
secrets=["HF_ACCESS_TOKEN"],
|
330 |
-
env={
|
331 |
-
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,max_split_size_mb:512",
|
332 |
-
"TOKENIZERS_PARALLELISM": "false",
|
333 |
-
"CUDA_LAUNCH_BLOCKING": "0", # Allow async CUDA operations
|
334 |
-
"TORCH_CUDNN_V8_API_ENABLED": "1", # Enable optimized cuDNN
|
335 |
-
},
|
336 |
-
timeout=3600 * 12, # 12 hours
|
337 |
-
)
|
338 |
-
def beam_distill_all_teachers() -> dict[str, Any]:
|
339 |
-
"""
|
340 |
-
Beam version: Try all teacher models and create distilled models from each.
|
341 |
-
|
342 |
-
Returns information about all models that were successfully created.
|
343 |
-
"""
|
344 |
-
logger.info("☁️ Running simplified distillation on Beam")
|
345 |
-
return core_distill_all_teachers(use_beam_utilities=True)
|
346 |
-
|
347 |
-
|
348 |
-
def main() -> None:
|
349 |
-
"""Main function with argument parsing."""
|
350 |
-
global OUTPUT_BASE_DIR # Declare global at the top # noqa: PLW0603
|
351 |
-
|
352 |
-
parser = argparse.ArgumentParser(
|
353 |
-
description="Simplified Code-Specialized Model2Vec Distillation",
|
354 |
-
formatter_class=argparse.RawDescriptionHelpFormatter,
|
355 |
-
epilog="""
|
356 |
-
Examples:
|
357 |
-
python -m src.distiller.distill_simplified # Run locally
|
358 |
-
python -m src.distiller.distill_simplified --use-beam # Run on Beam
|
359 |
-
distiller distill-simple # CLI shortcut (runs on Beam)
|
360 |
-
""",
|
361 |
-
)
|
362 |
-
|
363 |
-
parser.add_argument(
|
364 |
-
"--use-beam",
|
365 |
-
action="store_true",
|
366 |
-
help="Run on Beam instead of locally",
|
367 |
-
)
|
368 |
-
|
369 |
-
parser.add_argument(
|
370 |
-
"--output-dir",
|
371 |
-
type=str,
|
372 |
-
default=OUTPUT_BASE_DIR,
|
373 |
-
help=f"Output directory for models (default: {OUTPUT_BASE_DIR})",
|
374 |
-
)
|
375 |
-
|
376 |
-
args = parser.parse_args()
|
377 |
-
|
378 |
-
# Update output directory if specified
|
379 |
-
if args.output_dir != OUTPUT_BASE_DIR:
|
380 |
-
OUTPUT_BASE_DIR = args.output_dir
|
381 |
-
|
382 |
-
try:
|
383 |
-
if args.use_beam:
|
384 |
-
logger.info("🚀 Starting Beam execution...")
|
385 |
-
results = beam_distill_all_teachers()
|
386 |
-
else:
|
387 |
-
logger.info("🖥️ Starting local execution...")
|
388 |
-
results = run_local_distillation()
|
389 |
-
|
390 |
-
# Print final summary
|
391 |
-
print("\n🎉 Distillation complete!")
|
392 |
-
print(f"📊 Successfully created {results['total_successful']} models")
|
393 |
-
|
394 |
-
if args.use_beam:
|
395 |
-
print(f"📁 Models location: {VOLUME_PATH}/final/")
|
396 |
-
else:
|
397 |
-
print(f"📁 Models location: {OUTPUT_BASE_DIR}/final/")
|
398 |
-
|
399 |
-
print("\n✅ Created models:")
|
400 |
-
for model_name in results["successful_models"]:
|
401 |
-
model_info = results["all_results"][model_name]
|
402 |
-
print(f" • {model_name} (from {model_info['teacher_model']})")
|
403 |
-
|
404 |
-
except KeyboardInterrupt:
|
405 |
-
logger.info("🛑 Distillation interrupted by user")
|
406 |
-
sys.exit(1)
|
407 |
-
except Exception:
|
408 |
-
logger.exception("❌ Distillation failed with error")
|
409 |
-
sys.exit(1)
|
410 |
-
|
411 |
-
|
412 |
-
if __name__ == "__main__":
|
413 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/distiller/evaluate.py
CHANGED
@@ -1,130 +1,405 @@
|
|
1 |
"""
|
2 |
-
|
3 |
|
4 |
-
This script evaluates embedding models on
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
Usage:
|
9 |
-
distiller evaluate # Run evaluation
|
10 |
"""
|
11 |
|
12 |
import json
|
13 |
import logging
|
14 |
import time
|
|
|
15 |
from pathlib import Path
|
16 |
from typing import Any
|
17 |
|
18 |
import numpy as np
|
19 |
import pandas as pd
|
20 |
-
|
|
|
|
|
|
|
21 |
from datasets import Dataset, load_dataset
|
22 |
from sentence_transformers import SentenceTransformer
|
23 |
from sklearn.metrics.pairwise import cosine_similarity
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
-
from .beam_utils import
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
)
|
31 |
|
32 |
-
# Configure logging
|
33 |
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
34 |
logger = logging.getLogger(__name__)
|
35 |
|
36 |
# =============================================================================
|
37 |
-
#
|
38 |
# =============================================================================
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
IMAGE = Image(python_version="python3.12").add_python_packages(
|
47 |
-
[
|
48 |
-
"torch>=2.7.0",
|
49 |
-
"transformers>=4.40.0",
|
50 |
-
"datasets>=3.2.0",
|
51 |
-
"sentence-transformers>=4.1.0",
|
52 |
-
"model2vec[train]>=0.5.0",
|
53 |
-
"numpy>=1.26.4",
|
54 |
-
"scikit-learn>=1.6.1",
|
55 |
-
"pandas>=2.0.0",
|
56 |
-
"tqdm>=4.65.0",
|
57 |
-
]
|
58 |
-
)
|
59 |
|
60 |
# =============================================================================
|
61 |
-
#
|
62 |
# =============================================================================
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
#
|
72 |
-
"
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
"
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
# =============================================================================
|
86 |
-
# CHECKPOINT CONFIGURATION
|
87 |
-
# =============================================================================
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
DATASET_CHECKPOINT_PREFIX = "dataset_cache"
|
92 |
-
MODEL_CACHE_PREFIX = "model_cache"
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
|
99 |
class CodeSearchNetEvaluator:
|
100 |
"""Evaluator for CodeSearchNet-style code search tasks."""
|
101 |
|
102 |
-
def __init__(
|
103 |
-
|
104 |
-
model_path: str,
|
105 |
-
model_name: str | None = None,
|
106 |
-
checkpoint_manager: BeamCheckpointManager | None = None,
|
107 |
-
eval_manager: BeamEvaluationManager | None = None,
|
108 |
-
) -> None:
|
109 |
-
"""Initialize the evaluator with a model and optional Beam utilities."""
|
110 |
self.model_path = model_path
|
111 |
self.model_name = model_name or Path(model_path).name
|
112 |
self.model: SentenceTransformer | None = None
|
113 |
-
self.checkpoint_manager = checkpoint_manager
|
114 |
-
self.eval_manager = eval_manager
|
115 |
self._load_model()
|
116 |
|
117 |
def _load_model(self) -> None:
|
118 |
-
"""Load the embedding model
|
119 |
logger.info(f"Loading model from {self.model_path}")
|
120 |
-
|
121 |
-
# Check if we have a cached evaluation result for this model
|
122 |
-
if self.eval_manager:
|
123 |
-
cached_result = self.eval_manager.load_evaluation_results(self.model_name)
|
124 |
-
if cached_result:
|
125 |
-
logger.info(f"✅ Found cached evaluation results for {self.model_name}")
|
126 |
-
# Note: We still need to load the model for new evaluations
|
127 |
-
|
128 |
try:
|
129 |
self.model = SentenceTransformer(self.model_path, trust_remote_code=True)
|
130 |
logger.info(f"Successfully loaded model: {self.model_name}")
|
@@ -139,7 +414,6 @@ class CodeSearchNetEvaluator:
|
|
139 |
raise RuntimeError(msg)
|
140 |
|
141 |
embeddings = []
|
142 |
-
|
143 |
for i in tqdm(range(0, len(texts), BATCH_SIZE), desc=desc):
|
144 |
batch = texts[i : i + BATCH_SIZE]
|
145 |
batch_embeddings = self.model.encode(batch, convert_to_tensor=False, normalize_embeddings=True)
|
@@ -148,33 +422,25 @@ class CodeSearchNetEvaluator:
|
|
148 |
return np.vstack(embeddings)
|
149 |
|
150 |
def evaluate_language(self, language: str, max_queries: int = 1000) -> dict[str, Any]:
|
151 |
-
"""Evaluate on a specific programming language
|
152 |
logger.info(f"Evaluating on {language} language (max {max_queries} queries)")
|
153 |
|
154 |
-
# Check for existing evaluation checkpoint
|
155 |
-
if self.checkpoint_manager:
|
156 |
-
cached_result = self.checkpoint_manager.load_checkpoint(f"{EVAL_CHECKPOINT_PREFIX}_{language}", 0)
|
157 |
-
if cached_result and cached_result.get("data", {}).get("model_name") == self.model_name:
|
158 |
-
logger.info(f"✅ Resuming from cached {language} evaluation")
|
159 |
-
return cached_result.get("data", {})
|
160 |
-
|
161 |
try:
|
162 |
# Load test split for the language
|
163 |
dataset = load_dataset(
|
164 |
-
|
165 |
language,
|
166 |
split="test",
|
167 |
trust_remote_code=True,
|
168 |
)
|
169 |
|
170 |
-
# Ensure we have a Dataset object
|
171 |
if not isinstance(dataset, Dataset):
|
172 |
logger.error(f"Unexpected dataset type for {language}: {type(dataset)}")
|
173 |
return {}
|
174 |
|
175 |
-
# Sample queries for evaluation
|
176 |
if len(dataset) > max_queries:
|
177 |
-
rng = np.random.default_rng(42)
|
178 |
indices = rng.choice(len(dataset), max_queries, replace=False)
|
179 |
dataset = dataset.select(indices)
|
180 |
|
@@ -198,642 +464,680 @@ class CodeSearchNetEvaluator:
|
|
198 |
logger.info(f"Found {len(queries)} valid query-code pairs for {language}")
|
199 |
|
200 |
# Encode queries and codes
|
|
|
201 |
query_embeddings = self.encode_texts(queries, f"Encoding {language} queries")
|
202 |
-
code_embeddings = self.encode_texts(codes, f"Encoding {language}
|
|
|
203 |
|
204 |
-
# Compute similarities
|
205 |
similarities = cosine_similarity(query_embeddings, code_embeddings)
|
206 |
-
|
207 |
-
# Evaluate retrieval metrics
|
208 |
metrics = self._compute_retrieval_metrics(similarities)
|
209 |
|
210 |
-
|
|
|
211 |
"language": language,
|
|
|
212 |
"num_queries": len(queries),
|
|
|
213 |
"metrics": metrics,
|
214 |
-
"model_name": self.model_name,
|
215 |
}
|
216 |
|
217 |
-
|
218 |
-
|
219 |
-
checkpoint_data = {
|
220 |
-
"data": result,
|
221 |
-
"timestamp": time.time(),
|
222 |
-
"config": {
|
223 |
-
"language": language,
|
224 |
-
"max_queries": max_queries,
|
225 |
-
"model_name": self.model_name,
|
226 |
-
},
|
227 |
-
}
|
228 |
-
self.checkpoint_manager.save_checkpoint(f"{EVAL_CHECKPOINT_PREFIX}_{language}", checkpoint_data, 0)
|
229 |
-
logger.info(f"💾 Saved {language} evaluation checkpoint")
|
230 |
-
|
231 |
-
return result
|
232 |
|
233 |
except Exception:
|
234 |
-
logger.exception(f"
|
235 |
return {}
|
236 |
|
237 |
def _compute_retrieval_metrics(self, similarities: np.ndarray) -> dict[str, float]:
|
238 |
-
"""Compute retrieval metrics
|
239 |
-
|
240 |
|
241 |
-
# For each query, the correct code is at the same index
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
reciprocal_ranks = []
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
np.mean([self._compute_ndcg(np.argsort(similarities[i])[::-1], i, k=5) for i in range(num_queries)])
|
268 |
-
),
|
269 |
-
"ndcg@10": float(np.mean(ndcg_scores)),
|
270 |
-
"recall@1": float(np.mean([1.0 if rank == 1 else 0.0 for rank in ranks])),
|
271 |
-
"recall@5": float(np.mean([1.0 if rank <= 5 else 0.0 for rank in ranks])),
|
272 |
-
"recall@10": float(np.mean([1.0 if rank <= 10 else 0.0 for rank in ranks])),
|
273 |
-
"mean_rank": float(np.mean(ranks)),
|
274 |
-
"median_rank": float(np.median(ranks)),
|
275 |
-
}
|
276 |
|
277 |
def _compute_ndcg(self, ranked_indices: np.ndarray, correct_idx: int, k: int) -> float:
|
278 |
"""Compute NDCG@k for a single query."""
|
279 |
-
if
|
280 |
-
|
281 |
-
|
282 |
-
# Find position of correct item in top-k
|
283 |
-
top_k = ranked_indices[:k]
|
284 |
-
if correct_idx in top_k:
|
285 |
-
position = np.where(top_k == correct_idx)[0][0]
|
286 |
-
return 1.0 / np.log2(position + 2) # +2 because log2(1) is 0
|
287 |
return 0.0
|
288 |
|
289 |
def evaluate_all_languages(
|
290 |
self, max_queries_per_lang: int = 1000, languages: list[str] | None = None
|
291 |
) -> dict[str, Any]:
|
292 |
-
"""Evaluate on all
|
293 |
-
|
294 |
-
languages = EVALUATION_LANGUAGES
|
295 |
-
|
296 |
-
logger.info(f"Starting evaluation on all languages for model: {self.model_name}")
|
297 |
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
if cached_comprehensive:
|
302 |
-
logger.info(f"✅ Found comprehensive cached evaluation for {self.model_name}")
|
303 |
-
return cached_comprehensive
|
304 |
|
305 |
start_time = time.time()
|
306 |
-
|
307 |
-
results: dict[str, Any] = {
|
308 |
"model_name": self.model_name,
|
309 |
"model_path": self.model_path,
|
310 |
-
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
311 |
"languages": {},
|
312 |
"overall": {},
|
|
|
313 |
}
|
|
|
314 |
|
315 |
-
|
|
|
|
|
|
|
|
|
316 |
|
317 |
-
for language in languages:
|
318 |
-
logger.info(f"Evaluating {language}...")
|
319 |
lang_results = self.evaluate_language(language, max_queries_per_lang)
|
320 |
-
|
321 |
if lang_results:
|
322 |
-
|
323 |
-
all_metrics.append(lang_results["metrics"])
|
324 |
-
else:
|
325 |
-
logger.warning(f"Skipping {language} due to evaluation error")
|
326 |
|
327 |
-
|
328 |
-
|
|
|
|
|
329 |
overall_metrics = {}
|
330 |
-
|
331 |
-
|
332 |
-
|
|
|
|
|
333 |
|
334 |
results["overall"] = overall_metrics
|
335 |
|
336 |
total_time = time.time() - start_time
|
337 |
results["evaluation_time_seconds"] = total_time
|
338 |
|
339 |
-
# Save comprehensive results to Beam volume
|
340 |
-
if self.eval_manager:
|
341 |
-
self.eval_manager.save_evaluation_results(self.model_name, results)
|
342 |
-
logger.info("💾 Saved comprehensive evaluation results to Beam volume")
|
343 |
-
|
344 |
logger.info(f"Evaluation completed in {total_time:.2f} seconds")
|
345 |
return results
|
346 |
|
347 |
|
348 |
-
|
349 |
-
"""
|
350 |
-
try:
|
351 |
-
df = pd.read_csv(peers_file)
|
352 |
-
models = []
|
353 |
-
for _, row in df.iterrows():
|
354 |
-
model_name = row.get("model_name", row.get("Model", ""))
|
355 |
-
model_path = row.get("model_path", row.get("Path", model_name))
|
356 |
-
if model_name:
|
357 |
-
models.append((model_name, model_path))
|
358 |
-
logger.info(f"Loaded {len(models)} peer models from {peers_file}")
|
359 |
-
return models
|
360 |
-
except Exception:
|
361 |
-
logger.exception("Error loading peer models from {peers_file}")
|
362 |
-
return []
|
363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
|
365 |
-
def save_results(
|
366 |
-
results: dict[str, Any],
|
367 |
-
output_dir: str,
|
368 |
-
model_name: str,
|
369 |
-
eval_manager: BeamEvaluationManager | None = None,
|
370 |
-
volume_results_dir: Path | None = None,
|
371 |
-
) -> None:
|
372 |
-
"""Save evaluation results to JSON file with Beam volume support."""
|
373 |
-
# Save to Beam volume if available
|
374 |
-
if volume_results_dir:
|
375 |
-
volume_output_path = volume_results_dir / f"codesearchnet_eval_{model_name}.json"
|
376 |
try:
|
377 |
-
|
378 |
-
|
379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
except Exception as e:
|
381 |
-
logger.
|
|
|
382 |
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
if success:
|
387 |
-
logger.info(f"💾 Results also saved via eval_manager for {model_name}")
|
388 |
-
else:
|
389 |
-
logger.warning(f"⚠️ Failed to save via eval_manager for {model_name}")
|
390 |
|
391 |
-
|
392 |
-
|
393 |
-
output_path.mkdir(parents=True, exist_ok=True)
|
394 |
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
|
|
399 |
|
400 |
-
|
401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
|
403 |
-
logger.info(f"📄 Local backup saved to {filepath}")
|
404 |
|
|
|
|
|
|
|
405 |
|
406 |
-
def print_results_summary(results: dict[str, Any]) -> None:
|
407 |
-
"""Print a summary of evaluation results."""
|
408 |
-
model_name = results["model_name"]
|
409 |
-
overall = results.get("overall", {})
|
410 |
|
411 |
-
|
412 |
-
|
413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
if overall:
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
f"
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
if not all_results:
|
438 |
return
|
439 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
output_path = Path(output_dir)
|
|
|
441 |
|
442 |
-
#
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
|
|
|
|
|
|
|
|
458 |
|
459 |
-
|
460 |
-
df = df.sort_values("NDCG@10", ascending=False) # Sort by NDCG@10
|
461 |
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
|
467 |
-
#
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
|
473 |
|
474 |
-
|
|
|
|
|
|
|
|
|
|
|
475 |
models: list[str],
|
476 |
max_queries: int = 1000,
|
477 |
languages: list[str] | None = None,
|
478 |
-
|
479 |
-
|
480 |
-
mount_path: str = VOLUME_PATH,
|
481 |
) -> list[dict[str, Any]]:
|
482 |
-
"""Main evaluation function
|
483 |
-
logger.info("🚀 Starting Beam
|
484 |
-
logger.info(f"📊 Evaluating {len(models)} models on {len(languages or
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
|
486 |
-
|
487 |
-
|
|
|
488 |
|
489 |
-
|
490 |
-
|
491 |
-
|
|
|
|
|
|
|
492 |
|
493 |
-
|
494 |
-
|
|
|
495 |
|
496 |
-
|
497 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
for model_path in models:
|
500 |
model_name = Path(model_path).name
|
501 |
|
502 |
-
# Check for existing evaluation results
|
503 |
-
existing_result_file = results_dir / f"codesearchnet_eval_{model_name}.json"
|
504 |
-
if existing_result_file.exists():
|
505 |
-
logger.info(f"✅ Model {model_name} already evaluated - loading existing results")
|
506 |
-
try:
|
507 |
-
with existing_result_file.open("r") as f:
|
508 |
-
existing_results = json.load(f)
|
509 |
-
all_results.append(existing_results)
|
510 |
-
skipped_models.append(model_name)
|
511 |
-
continue
|
512 |
-
except Exception as e:
|
513 |
-
logger.warning(f"⚠️ Failed to load existing results for {model_name}: {e}")
|
514 |
-
# Continue with evaluation if loading fails
|
515 |
-
|
516 |
logger.info(f"\n{'=' * 60}")
|
517 |
logger.info(f"🔍 Evaluating model: {model_name}")
|
518 |
-
logger.info(f"📂 Path: {model_path}")
|
519 |
logger.info(f"{'=' * 60}")
|
520 |
|
521 |
try:
|
522 |
-
|
523 |
-
|
524 |
-
"/" in model_path and not model_path.startswith("/") and not Path(model_path).exists()
|
525 |
-
)
|
526 |
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
model_path,
|
532 |
-
model_name,
|
533 |
-
checkpoint_manager=checkpoint_mgr,
|
534 |
-
eval_manager=eval_mgr,
|
535 |
-
)
|
536 |
-
else:
|
537 |
-
# This is a local path - check if it exists in Beam volume
|
538 |
-
actual_model_path = model_path # Default to original path
|
539 |
-
if not Path(model_path).exists() and not model_path.startswith("/"):
|
540 |
-
# Try to load from Beam volume
|
541 |
-
local_model_path = Path(mount_path) / MODEL_CACHE_PREFIX / model_name
|
542 |
-
logger.info(f"🔍 Trying to load {model_name} from Beam volume: {local_model_path}")
|
543 |
-
if model_mgr.load_model(model_name, local_model_path.parent):
|
544 |
-
actual_model_path = str(local_model_path)
|
545 |
-
logger.info(f"✅ Loaded model from Beam volume: {actual_model_path}")
|
546 |
-
else:
|
547 |
-
logger.warning(f"⚠️ Model not found locally or in Beam volume: {model_name}")
|
548 |
-
continue
|
549 |
-
|
550 |
-
evaluator = CodeSearchNetEvaluator(
|
551 |
-
actual_model_path,
|
552 |
-
model_name,
|
553 |
-
checkpoint_manager=checkpoint_mgr,
|
554 |
-
eval_manager=eval_mgr,
|
555 |
-
)
|
556 |
-
|
557 |
-
results = evaluator.evaluate_all_languages(max_queries, languages)
|
558 |
-
|
559 |
-
# Save results with Beam support
|
560 |
-
save_results(results, output_dir, model_name, eval_mgr, results_dir)
|
561 |
-
|
562 |
-
# Print summary
|
563 |
-
print_results_summary(results)
|
564 |
-
|
565 |
-
all_results.append(results)
|
566 |
|
567 |
except Exception:
|
568 |
logger.exception(f"❌ Failed to evaluate {model_name}")
|
569 |
continue
|
570 |
|
571 |
-
|
572 |
-
if len(all_results) > 1:
|
573 |
-
comparison_dir = Path(mount_path) / EVALUATION_RESULTS_DIR / "comparisons"
|
574 |
-
comparison_dir.mkdir(parents=True, exist_ok=True)
|
575 |
-
create_comparison_report(all_results, str(comparison_dir))
|
576 |
-
logger.info(f"📊 Comparison report saved to Beam volume: {comparison_dir}")
|
577 |
-
|
578 |
-
# Log summary of what was done
|
579 |
-
newly_evaluated = len(all_results) - len(skipped_models)
|
580 |
-
logger.info("\n✅ Beam evaluation complete!")
|
581 |
-
logger.info(f"📊 Newly evaluated: {newly_evaluated} models")
|
582 |
-
logger.info(f"⏭️ Skipped (already done): {len(skipped_models)} models")
|
583 |
-
logger.info(f"📁 Total results: {len(all_results)} models")
|
584 |
-
logger.info(f"💾 Results available in Beam volume: {volume_name}")
|
585 |
-
|
586 |
-
if skipped_models:
|
587 |
-
logger.info(f"⏭️ Skipped models: {', '.join(skipped_models)}")
|
588 |
-
|
589 |
-
return all_results
|
590 |
|
591 |
|
592 |
@function(
|
593 |
gpu=GPU_NAME,
|
594 |
-
volumes=[Volume(name=
|
595 |
image=IMAGE,
|
596 |
secrets=["HF_ACCESS_TOKEN"],
|
597 |
-
env=
|
598 |
-
|
599 |
-
"CUDA_LAUNCH_BLOCKING": "0",
|
600 |
-
},
|
601 |
-
timeout=3600 * 6, # 6 hours for evaluation
|
602 |
)
|
603 |
-
def
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
else:
|
612 |
-
logger.info("📊 Including 3rd party peer models for comparison")
|
613 |
-
models = DEFAULT_EVALUATION_MODELS.copy()
|
614 |
-
|
615 |
-
# Discover simplified distillation models in the current directory
|
616 |
-
logger.info("🔍 Discovering simplified distillation models...")
|
617 |
-
discovered_models = discover_simplified_models(".")
|
618 |
-
|
619 |
-
# Add discovered models (they're already sorted alphabetically)
|
620 |
-
if discovered_models:
|
621 |
-
logger.info(f"✅ Found {len(discovered_models)} simplified models:")
|
622 |
-
for model_path in discovered_models:
|
623 |
-
models.append(model_path)
|
624 |
-
logger.info(f" 📁 {model_path}")
|
625 |
-
else:
|
626 |
-
logger.warning("⚠️ No simplified distillation models found")
|
627 |
-
if skip_third_party:
|
628 |
-
logger.error("❌ No models to evaluate! Either create simplified models or include 3rd party models.")
|
629 |
-
return
|
630 |
|
631 |
-
logger.info(f"
|
632 |
-
for i, model in enumerate(models, 1):
|
633 |
-
logger.info(f" {i}. {model}")
|
634 |
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
max_queries=1000,
|
643 |
-
languages=EVALUATION_LANGUAGES,
|
644 |
-
output_dir=str(Path(VOLUME_PATH) / EVALUATION_RESULTS_DIR),
|
645 |
-
volume_name=VOLUME_NAME,
|
646 |
-
mount_path=VOLUME_PATH,
|
647 |
-
)
|
648 |
|
649 |
-
|
650 |
-
|
651 |
-
print(f"📊 Total models processed: {len(results)}")
|
652 |
-
print(f"💾 Results saved to Beam volume: {VOLUME_NAME}")
|
653 |
-
print(f"📁 Directory: {EVALUATION_RESULTS_DIR}")
|
654 |
-
if skip_third_party:
|
655 |
-
print("⏭️ 3rd party models were skipped")
|
656 |
-
print("\n🔍 To view analysis:")
|
657 |
-
print(" beam run src.distiller.analyze:beam_analysis")
|
658 |
-
print("\n📈 To run evaluations again:")
|
659 |
-
print(" distiller evaluate (will skip already completed models)")
|
660 |
-
print(" distiller evaluate --skip-third-party (evaluate only simplified models)")
|
661 |
-
|
662 |
-
|
663 |
-
def discover_simplified_models(base_path: str = ".") -> list[str]:
|
664 |
-
"""
|
665 |
-
Discover all simplified distillation models in the correct directory.
|
666 |
-
|
667 |
-
Looks for directories matching the pattern: ./code_model2vec/final/code_model2vec_*
|
668 |
-
"""
|
669 |
-
discovered_models: list[str] = []
|
670 |
-
|
671 |
-
# Look in the correct location where distill_simplified.py saves models
|
672 |
-
models_dir = Path(base_path) / "code_model2vec" / "final"
|
673 |
-
|
674 |
-
if not models_dir.exists():
|
675 |
-
logger.warning(f"Models directory not found: {models_dir}")
|
676 |
-
return discovered_models
|
677 |
-
|
678 |
-
# Look for simplified model directories with the updated pattern
|
679 |
-
pattern = "code_model2vec_*"
|
680 |
-
for model_dir in models_dir.glob(pattern):
|
681 |
-
if model_dir.is_dir() and (model_dir / "config.json").exists():
|
682 |
-
discovered_models.append(str(model_dir))
|
683 |
-
logger.info(f"🔍 Discovered simplified model: {model_dir}")
|
684 |
-
|
685 |
-
# Sort alphabetically for consistent ordering
|
686 |
-
discovered_models.sort()
|
687 |
|
688 |
-
|
|
|
689 |
|
|
|
|
|
690 |
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
image=IMAGE,
|
695 |
-
secrets=["HF_ACCESS_TOKEN"],
|
696 |
-
env={
|
697 |
-
"TOKENIZERS_PARALLELISM": "false",
|
698 |
-
"CUDA_LAUNCH_BLOCKING": "0",
|
699 |
-
},
|
700 |
-
timeout=3600 * 6, # 6 hours for evaluation
|
701 |
-
)
|
702 |
-
def evaluate_simplified_only() -> None:
|
703 |
-
"""Evaluate only simplified distillation models, skipping 3rd party models."""
|
704 |
-
main(skip_third_party=True)
|
705 |
|
706 |
|
707 |
-
def
|
708 |
-
models: list[str]
|
709 |
max_queries: int = 1000,
|
710 |
languages: list[str] | None = None,
|
711 |
-
|
712 |
) -> list[dict[str, Any]]:
|
713 |
-
"""
|
714 |
-
logger.info("
|
715 |
-
|
716 |
-
if models is None:
|
717 |
-
models = DEFAULT_EVALUATION_MODELS.copy()
|
718 |
-
|
719 |
-
# Discover simplified distillation models in the current directory
|
720 |
-
logger.info("🔍 Discovering simplified distillation models...")
|
721 |
-
discovered_models = discover_simplified_models(".")
|
722 |
-
|
723 |
-
# Add discovered models
|
724 |
-
if discovered_models:
|
725 |
-
logger.info(f"✅ Found {len(discovered_models)} simplified models:")
|
726 |
-
for model_path in discovered_models:
|
727 |
-
models.append(model_path)
|
728 |
-
logger.info(f" 📁 {model_path}")
|
729 |
-
else:
|
730 |
-
logger.warning("⚠️ No simplified distillation models found")
|
731 |
-
|
732 |
-
if languages is None:
|
733 |
-
languages = EVALUATION_LANGUAGES
|
734 |
-
|
735 |
-
logger.info(f"📊 Evaluating {len(models)} models on {len(languages)} languages")
|
736 |
-
logger.info(f"📁 Using local output directory: {output_dir}")
|
737 |
-
|
738 |
-
# Create local output directory
|
739 |
-
output_path = Path(output_dir)
|
740 |
-
output_path.mkdir(parents=True, exist_ok=True)
|
741 |
-
|
742 |
-
all_results = []
|
743 |
-
skipped_models = []
|
744 |
|
|
|
745 |
for model_path in models:
|
746 |
model_name = Path(model_path).name
|
747 |
|
748 |
-
|
749 |
-
safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
|
750 |
-
result_file = output_path / f"codesearchnet_eval_{safe_name}.json"
|
751 |
-
|
752 |
-
if result_file.exists():
|
753 |
-
logger.info(f"✅ Model {model_name} already evaluated - loading existing results")
|
754 |
-
try:
|
755 |
-
with result_file.open("r") as f:
|
756 |
-
existing_results = json.load(f)
|
757 |
-
all_results.append(existing_results)
|
758 |
-
skipped_models.append(model_name)
|
759 |
-
continue
|
760 |
-
except Exception as e:
|
761 |
-
logger.warning(f"⚠️ Failed to load existing results for {model_name}: {e}")
|
762 |
-
|
763 |
-
logger.info(f"\n{'=' * 60}")
|
764 |
-
logger.info(f"🔍 Evaluating model: {model_name}")
|
765 |
-
logger.info(f"📂 Path: {model_path}")
|
766 |
-
logger.info(f"{'=' * 60}")
|
767 |
|
768 |
try:
|
769 |
-
#
|
770 |
-
|
771 |
-
model_path,
|
772 |
-
model_name,
|
773 |
-
checkpoint_manager=None, # No checkpointing for local evaluation
|
774 |
-
eval_manager=None,
|
775 |
-
)
|
776 |
-
|
777 |
-
results = evaluator.evaluate_all_languages(max_queries, languages)
|
778 |
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
|
|
|
|
|
|
|
|
784 |
|
785 |
-
|
|
|
|
|
|
|
|
|
|
|
786 |
|
787 |
except Exception:
|
788 |
-
logger.exception(f"❌
|
789 |
continue
|
790 |
|
791 |
-
|
792 |
-
if len(all_results) > 1:
|
793 |
-
create_comparison_report(all_results, output_dir)
|
794 |
-
logger.info(f"📊 Comparison report saved locally: {output_dir}")
|
795 |
|
796 |
-
# Log summary
|
797 |
-
newly_evaluated = len(all_results) - len(skipped_models)
|
798 |
-
logger.info("\n✅ Local evaluation complete!")
|
799 |
-
logger.info(f"📊 Newly evaluated: {newly_evaluated} models")
|
800 |
-
logger.info(f"⏭️ Skipped (already done): {len(skipped_models)} models")
|
801 |
-
logger.info(f"📁 Total results: {len(all_results)} models")
|
802 |
-
logger.info(f"💾 Results available locally: {output_dir}")
|
803 |
|
804 |
-
|
805 |
-
|
|
|
806 |
|
807 |
-
return all_results
|
808 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
809 |
|
810 |
-
|
811 |
-
|
812 |
-
languages: list[str] | None = None,
|
813 |
-
output_dir: str = DEFAULT_OUTPUT_DIR,
|
814 |
-
) -> list[dict[str, Any]]:
|
815 |
-
"""Local evaluation function for simplified models only."""
|
816 |
-
logger.info("🖥️ Running simplified model evaluation locally")
|
817 |
|
818 |
-
#
|
819 |
-
|
820 |
-
|
|
|
|
|
|
|
821 |
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
|
826 |
-
|
827 |
-
|
828 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
829 |
|
830 |
-
|
831 |
-
|
|
|
|
|
|
|
|
|
|
|
832 |
max_queries=max_queries,
|
833 |
-
languages=
|
834 |
-
|
|
|
835 |
)
|
836 |
|
|
|
|
|
|
|
|
|
|
|
837 |
|
838 |
if __name__ == "__main__":
|
839 |
-
main
|
|
|
1 |
"""
|
2 |
+
Comprehensive Model Evaluation Script for Code-Specialized Embedding Models.
|
3 |
|
4 |
+
This script evaluates embedding models on both task performance and operational metrics:
|
5 |
+
|
6 |
+
Task Performance:
|
7 |
+
- CodeSearchNet evaluation (NDCG, MRR, Recall metrics)
|
8 |
+
- Code search accuracy across programming languages
|
9 |
+
|
10 |
+
Operational Performance:
|
11 |
+
- Inference speed (latency and throughput)
|
12 |
+
- Memory efficiency (RAM and GPU usage)
|
13 |
+
- Model size and storage requirements
|
14 |
+
- CPU vs GPU performance scaling
|
15 |
|
16 |
Usage:
|
17 |
+
distiller evaluate [--use-beam] [--skip-benchmark] # Run evaluation locally or on Beam
|
18 |
"""
|
19 |
|
20 |
import json
|
21 |
import logging
|
22 |
import time
|
23 |
+
import traceback
|
24 |
from pathlib import Path
|
25 |
from typing import Any
|
26 |
|
27 |
import numpy as np
|
28 |
import pandas as pd
|
29 |
+
import psutil
|
30 |
+
import torch
|
31 |
+
import typer
|
32 |
+
from beam import Volume, function
|
33 |
from datasets import Dataset, load_dataset
|
34 |
from sentence_transformers import SentenceTransformer
|
35 |
from sklearn.metrics.pairwise import cosine_similarity
|
36 |
from tqdm import tqdm
|
37 |
|
38 |
+
from .beam_utils import download_specific_evaluation_file
|
39 |
+
from .config import (
|
40 |
+
BEAM_ENV_SETTINGS,
|
41 |
+
DEFAULT_EVALUATION_MODELS,
|
42 |
+
GPU_NAME,
|
43 |
+
IMAGE,
|
44 |
+
codesearchnet_config,
|
45 |
+
directories,
|
46 |
+
get_safe_model_name,
|
47 |
+
get_volume_config,
|
48 |
+
languages_config,
|
49 |
)
|
50 |
|
|
|
|
|
51 |
logger = logging.getLogger(__name__)
|
52 |
|
53 |
# =============================================================================
|
54 |
+
# EVALUATION CONFIGURATION
|
55 |
# =============================================================================
|
56 |
|
57 |
+
BATCH_SIZE = 32
|
58 |
+
LOCAL_EVALUATION_DIR = directories.evaluation_results
|
59 |
+
LOCAL_BENCHMARK_DIR = directories.benchmark_results
|
60 |
+
LOCAL_MODELS_DIR = directories.final
|
61 |
+
VOLUME_CONFIG = get_volume_config()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
# =============================================================================
|
64 |
+
# CORE EVALUATION CLASSES
|
65 |
# =============================================================================
|
66 |
|
67 |
+
# Sample texts for benchmarking (various lengths)
|
68 |
+
BENCHMARK_TEXTS = {
|
69 |
+
"short": [
|
70 |
+
"def add(a, b): return a + b",
|
71 |
+
"function multiply(x, y) { return x * y; }",
|
72 |
+
"class Calculator { public int subtract(int a, int b) { return a - b; } }",
|
73 |
+
]
|
74 |
+
* 100, # 300 short texts
|
75 |
+
"medium": [
|
76 |
+
"def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
|
77 |
+
"function quickSort(arr) {\n if (arr.length <= 1) return arr;\n const pivot = arr[arr.length - 1];\n const left = [], right = [];\n for (let i = 0; i < arr.length - 1; i++) {\n if (arr[i] < pivot) left.push(arr[i]);\n else right.push(arr[i]);\n }\n return [...quickSort(left), pivot, ...quickSort(right)];\n}",
|
78 |
+
]
|
79 |
+
* 50, # 100 medium texts
|
80 |
+
"long": [
|
81 |
+
"""
|
82 |
+
def complex_algorithm(data, config):
|
83 |
+
'''
|
84 |
+
Complex data processing algorithm with multiple steps.
|
85 |
+
'''
|
86 |
+
results = []
|
87 |
+
# Data validation and processing steps...
|
88 |
+
return results
|
89 |
+
""".strip(),
|
90 |
+
]
|
91 |
+
* 20, # 20 long texts
|
92 |
+
}
|
93 |
|
|
|
|
|
|
|
94 |
|
95 |
+
class PerformanceBenchmark:
|
96 |
+
"""Comprehensive performance benchmarking for embedding models."""
|
|
|
|
|
97 |
|
98 |
+
def __init__(self, model_path: str, model_name: str | None = None) -> None:
|
99 |
+
"""Initialize benchmarker with model."""
|
100 |
+
self.model_path = model_path
|
101 |
+
self.model_name = model_name or Path(model_path).name
|
102 |
+
self.model: SentenceTransformer | None = None
|
103 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
104 |
+
self.results: dict[str, Any] = {}
|
105 |
+
|
106 |
+
def load_model(self) -> None:
|
107 |
+
"""Load the embedding model."""
|
108 |
+
logger.info(f"Loading model from {self.model_path}")
|
109 |
+
start_time = time.time()
|
110 |
+
|
111 |
+
try:
|
112 |
+
self.model = SentenceTransformer(self.model_path, device=self.device, trust_remote_code=True)
|
113 |
+
load_time = time.time() - start_time
|
114 |
+
|
115 |
+
logger.info(f"✅ Model loaded in {load_time:.2f}s on {self.device}")
|
116 |
+
self.results["model_load_time"] = load_time
|
117 |
+
|
118 |
+
except Exception:
|
119 |
+
logger.exception("❌ Failed to load model")
|
120 |
+
self.results["error"] = traceback.format_exc()
|
121 |
+
raise
|
122 |
+
|
123 |
+
def measure_model_size(self) -> dict[str, float]:
|
124 |
+
"""Measure model size metrics."""
|
125 |
+
logger.info("📏 Measuring model size...")
|
126 |
+
|
127 |
+
size_metrics: dict[str, Any] = {}
|
128 |
+
|
129 |
+
# Disk size
|
130 |
+
try:
|
131 |
+
if Path(self.model_path).is_dir():
|
132 |
+
# Local directory - calculate size of model files only
|
133 |
+
model_extensions = {".safetensors", ".bin", ".json", ".txt", ".tokenizer"}
|
134 |
+
total_size = 0
|
135 |
+
model_dir = Path(self.model_path)
|
136 |
+
|
137 |
+
for file_path in model_dir.rglob("*"):
|
138 |
+
if file_path.is_file() and (
|
139 |
+
file_path.suffix.lower() in model_extensions or "tokenizer" in file_path.name.lower()
|
140 |
+
):
|
141 |
+
total_size += file_path.stat().st_size
|
142 |
+
|
143 |
+
size_metrics["disk_size_mb"] = total_size / (1024 * 1024)
|
144 |
+
# HuggingFace model - estimate based on model parameters
|
145 |
+
elif self.model is not None:
|
146 |
+
param_count = sum(p.numel() for p in self.model.parameters())
|
147 |
+
# Rough estimate: 4 bytes per parameter (float32)
|
148 |
+
estimated_size = param_count * 4
|
149 |
+
size_metrics["disk_size_mb"] = estimated_size / (1024 * 1024)
|
150 |
+
else:
|
151 |
+
size_metrics["disk_size_mb"] = 0.0
|
152 |
+
|
153 |
+
except Exception as e:
|
154 |
+
logger.warning(f"⚠️ Could not calculate disk size: {e}")
|
155 |
+
size_metrics["disk_size_mb"] = 0.0
|
156 |
+
|
157 |
+
# Memory size (if model is loaded)
|
158 |
+
if self.model is not None:
|
159 |
+
try:
|
160 |
+
# Parameter count
|
161 |
+
param_count = sum(p.numel() for p in self.model.parameters())
|
162 |
+
size_metrics["parameter_count"] = param_count
|
163 |
+
size_metrics["parameters_millions"] = param_count / 1e6
|
164 |
+
|
165 |
+
# Memory usage estimate
|
166 |
+
param_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
|
167 |
+
buffer_size = sum(b.numel() * b.element_size() for b in self.model.buffers())
|
168 |
+
size_metrics["memory_size_mb"] = (param_size + buffer_size) / (1024 * 1024)
|
169 |
+
size_metrics["ram_usage_mb"] = size_metrics["memory_size_mb"]
|
170 |
+
|
171 |
+
# GPU memory if using CUDA
|
172 |
+
if self.device == "cuda" and torch.cuda.is_available():
|
173 |
+
size_metrics["gpu_memory_mb"] = torch.cuda.memory_allocated() / (1024 * 1024)
|
174 |
+
size_metrics["gpu_name"] = torch.cuda.get_device_name(0)
|
175 |
+
|
176 |
+
# Embedding dimension if available
|
177 |
+
if hasattr(self.model, "get_sentence_embedding_dimension"):
|
178 |
+
size_metrics["embedding_dim"] = self.model.get_sentence_embedding_dimension()
|
179 |
+
|
180 |
+
except Exception as e:
|
181 |
+
logger.warning(f"⚠️ Could not calculate memory size: {e}")
|
182 |
+
|
183 |
+
# Update results
|
184 |
+
self.results["size_metrics"] = size_metrics
|
185 |
+
return size_metrics
|
186 |
+
|
187 |
+
def benchmark_inference_speed(self, batch_sizes: list[int] | None = None) -> dict[str, Any]:
|
188 |
+
"""Benchmark inference speed with different batch sizes."""
|
189 |
+
if batch_sizes is None:
|
190 |
+
batch_sizes = [1, 8, 16, 32]
|
191 |
+
|
192 |
+
logger.info(f"⚡ Benchmarking inference speed with batch sizes: {batch_sizes}")
|
193 |
+
|
194 |
+
if self.model is None:
|
195 |
+
self.load_model()
|
196 |
+
|
197 |
+
speed_results: dict[str, Any] = {"medium": {}}
|
198 |
+
|
199 |
+
# Use medium-length texts for speed testing
|
200 |
+
test_texts = BENCHMARK_TEXTS["medium"]
|
201 |
+
|
202 |
+
for batch_size in batch_sizes:
|
203 |
+
logger.info(f" 📊 Testing batch size: {batch_size}")
|
204 |
+
|
205 |
+
# Prepare batch
|
206 |
+
batch = (
|
207 |
+
test_texts[:batch_size]
|
208 |
+
if batch_size <= len(test_texts)
|
209 |
+
else test_texts * ((batch_size // len(test_texts)) + 1)
|
210 |
+
)
|
211 |
+
batch = batch[:batch_size]
|
212 |
+
|
213 |
+
# Warmup
|
214 |
+
if self.model is not None:
|
215 |
+
_ = self.model.encode(batch[: min(4, len(batch))], convert_to_tensor=False)
|
216 |
+
|
217 |
+
# Benchmark multiple runs
|
218 |
+
latencies = []
|
219 |
+
num_runs = max(3, 20 // batch_size) # More runs for smaller batches
|
220 |
+
|
221 |
+
for _ in range(num_runs):
|
222 |
+
start_time = time.time()
|
223 |
+
if self.model is not None:
|
224 |
+
_ = self.model.encode(batch, convert_to_tensor=False, normalize_embeddings=True)
|
225 |
+
end_time = time.time()
|
226 |
+
latencies.append(end_time - start_time)
|
227 |
+
|
228 |
+
# Calculate metrics
|
229 |
+
avg_latency = sum(latencies) / len(latencies)
|
230 |
+
throughput = batch_size / avg_latency
|
231 |
+
time_per_text_ms = (avg_latency / batch_size) * 1000
|
232 |
+
|
233 |
+
batch_key = f"batch_{batch_size}"
|
234 |
+
speed_results["medium"][batch_key] = {
|
235 |
+
"time_per_text_ms": time_per_text_ms,
|
236 |
+
"texts_per_second": throughput,
|
237 |
+
"tokens_per_second": throughput * 50, # Estimate 50 tokens per text
|
238 |
+
}
|
239 |
+
|
240 |
+
logger.info(f" ⚡ Latency: {avg_latency:.3f}s, Throughput: {throughput:.1f} texts/sec")
|
241 |
+
|
242 |
+
# Update results
|
243 |
+
self.results["speed_benchmarks"] = speed_results
|
244 |
+
return speed_results
|
245 |
+
|
246 |
+
def benchmark_memory_scaling(self, batch_sizes: list[int] | None = None) -> dict[str, Any]:
|
247 |
+
"""Benchmark memory usage scaling with batch size."""
|
248 |
+
if batch_sizes is None:
|
249 |
+
batch_sizes = [1, 8, 16, 32]
|
250 |
+
|
251 |
+
logger.info(f"🧠 Benchmarking memory scaling with batch sizes: {batch_sizes}")
|
252 |
+
|
253 |
+
if self.model is None:
|
254 |
+
self.load_model()
|
255 |
+
|
256 |
+
memory_results: dict[str, Any] = {}
|
257 |
+
test_texts = BENCHMARK_TEXTS["medium"]
|
258 |
+
|
259 |
+
for batch_size in batch_sizes:
|
260 |
+
logger.info(f" 📊 Testing memory with batch size: {batch_size}")
|
261 |
+
|
262 |
+
# Prepare batch
|
263 |
+
batch = (
|
264 |
+
test_texts[:batch_size]
|
265 |
+
if batch_size <= len(test_texts)
|
266 |
+
else test_texts * ((batch_size // len(test_texts)) + 1)
|
267 |
+
)
|
268 |
+
batch = batch[:batch_size]
|
269 |
+
|
270 |
+
# Clear GPU cache if using CUDA
|
271 |
+
if torch.cuda.is_available():
|
272 |
+
torch.cuda.empty_cache()
|
273 |
+
torch.cuda.reset_peak_memory_stats()
|
274 |
+
|
275 |
+
try:
|
276 |
+
# Run inference
|
277 |
+
if self.model is not None:
|
278 |
+
_ = self.model.encode(batch, convert_to_tensor=False)
|
279 |
+
|
280 |
+
# Measure peak memory
|
281 |
+
if torch.cuda.is_available():
|
282 |
+
peak_memory = torch.cuda.max_memory_allocated() / (1024 * 1024)
|
283 |
+
memory_per_text = peak_memory / batch_size
|
284 |
+
else:
|
285 |
+
# Use psutil for CPU memory (less accurate)
|
286 |
+
peak_memory = psutil.virtual_memory().used / (1024 * 1024)
|
287 |
+
memory_per_text = 0 # Can't accurately measure per-text on CPU
|
288 |
+
|
289 |
+
batch_key = f"batch_{batch_size}"
|
290 |
+
memory_results[batch_key] = {
|
291 |
+
"memory_used_mb": peak_memory,
|
292 |
+
"memory_per_text_mb": memory_per_text,
|
293 |
+
"oom": False,
|
294 |
+
}
|
295 |
+
|
296 |
+
logger.info(f" 🧠 Peak memory: {peak_memory:.1f}MB, Per text: {memory_per_text:.2f}MB")
|
297 |
+
|
298 |
+
except Exception as e:
|
299 |
+
logger.warning(f"⚠️ Memory benchmark failed for batch {batch_size}: {e}")
|
300 |
+
batch_key = f"batch_{batch_size}"
|
301 |
+
memory_results[batch_key] = {
|
302 |
+
"oom": True,
|
303 |
+
"error": str(e),
|
304 |
+
}
|
305 |
+
|
306 |
+
self.results["memory_benchmarks"] = memory_results
|
307 |
+
return memory_results
|
308 |
+
|
309 |
+
def benchmark_cpu_vs_gpu(self) -> dict[str, Any]:
|
310 |
+
"""Compare CPU vs GPU performance."""
|
311 |
+
logger.info("⚖️ Benchmarking CPU vs GPU performance")
|
312 |
+
|
313 |
+
if not torch.cuda.is_available():
|
314 |
+
logger.warning("⚠️ CUDA not available - skipping GPU benchmark")
|
315 |
+
return {}
|
316 |
+
|
317 |
+
comparison_results: dict[str, Any] = {}
|
318 |
+
test_texts = BENCHMARK_TEXTS["medium"][:16] # Use 16 texts for comparison
|
319 |
+
|
320 |
+
for device in ["cpu", "cuda"]:
|
321 |
+
logger.info(f" 📊 Testing on {device.upper()}")
|
322 |
+
|
323 |
+
try:
|
324 |
+
model = SentenceTransformer(self.model_path, device=device, trust_remote_code=True)
|
325 |
+
|
326 |
+
# Warmup
|
327 |
+
_ = model.encode(test_texts[:4], convert_to_tensor=False)
|
328 |
+
|
329 |
+
# Benchmark
|
330 |
+
start_time = time.time()
|
331 |
+
_ = model.encode(test_texts, convert_to_tensor=False, normalize_embeddings=True)
|
332 |
+
end_time = time.time()
|
333 |
+
|
334 |
+
latency = end_time - start_time
|
335 |
+
throughput = len(test_texts) / latency
|
336 |
+
|
337 |
+
comparison_results[device] = {
|
338 |
+
"texts_per_second": throughput,
|
339 |
+
}
|
340 |
+
|
341 |
+
logger.info(f" ⚡ {device.upper()}: {latency:.3f}s, {throughput:.1f} texts/sec")
|
342 |
+
|
343 |
+
# Clean up
|
344 |
+
del model
|
345 |
+
if device == "cuda":
|
346 |
+
torch.cuda.empty_cache()
|
347 |
+
|
348 |
+
except Exception as e:
|
349 |
+
logger.warning(f"⚠️ Failed to benchmark {device}: {e}")
|
350 |
+
comparison_results[device] = {"error": str(e)}
|
351 |
+
|
352 |
+
# Calculate speedup
|
353 |
+
if "cpu" in comparison_results and "cuda" in comparison_results:
|
354 |
+
cpu_throughput = comparison_results["cpu"].get("texts_per_second", 0)
|
355 |
+
gpu_throughput = comparison_results["cuda"].get("texts_per_second", 0)
|
356 |
+
if cpu_throughput > 0:
|
357 |
+
speedup = gpu_throughput / cpu_throughput
|
358 |
+
comparison_results["gpu_speedup"] = speedup
|
359 |
+
logger.info(f" 🚀 GPU Speedup: {speedup:.1f}x")
|
360 |
+
|
361 |
+
self.results["cpu_vs_gpu"] = comparison_results
|
362 |
+
return comparison_results
|
363 |
+
|
364 |
+
def run_comprehensive_benchmark(self) -> dict[str, Any]:
|
365 |
+
"""Run all benchmarks and return comprehensive results."""
|
366 |
+
logger.info(f"🏁 Starting comprehensive benchmark for {self.model_name}")
|
367 |
+
|
368 |
+
# Model information
|
369 |
+
self.results["model_name"] = self.model_name
|
370 |
+
self.results["model_path"] = self.model_path
|
371 |
+
self.results["timestamp"] = time.strftime("%Y-%m-%d %H:%M:%S")
|
372 |
+
|
373 |
+
# Run all benchmarks
|
374 |
+
try:
|
375 |
+
self.load_model()
|
376 |
+
self.measure_model_size()
|
377 |
+
self.benchmark_inference_speed([1, 8, 16, 32])
|
378 |
+
self.benchmark_memory_scaling([1, 8, 16, 32])
|
379 |
+
self.benchmark_cpu_vs_gpu()
|
380 |
+
|
381 |
+
logger.info(f"✅ Comprehensive benchmark completed for {self.model_name}")
|
382 |
+
|
383 |
+
except Exception:
|
384 |
+
logger.exception(f"❌ Benchmark failed for {self.model_name}")
|
385 |
+
self.results["error"] = traceback.format_exc()
|
386 |
+
|
387 |
+
return self.results
|
388 |
|
389 |
|
390 |
class CodeSearchNetEvaluator:
|
391 |
"""Evaluator for CodeSearchNet-style code search tasks."""
|
392 |
|
393 |
+
def __init__(self, model_path: str, model_name: str | None = None) -> None:
|
394 |
+
"""Initialize the evaluator with a model."""
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
self.model_path = model_path
|
396 |
self.model_name = model_name or Path(model_path).name
|
397 |
self.model: SentenceTransformer | None = None
|
|
|
|
|
398 |
self._load_model()
|
399 |
|
400 |
def _load_model(self) -> None:
|
401 |
+
"""Load the embedding model."""
|
402 |
logger.info(f"Loading model from {self.model_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
try:
|
404 |
self.model = SentenceTransformer(self.model_path, trust_remote_code=True)
|
405 |
logger.info(f"Successfully loaded model: {self.model_name}")
|
|
|
414 |
raise RuntimeError(msg)
|
415 |
|
416 |
embeddings = []
|
|
|
417 |
for i in tqdm(range(0, len(texts), BATCH_SIZE), desc=desc):
|
418 |
batch = texts[i : i + BATCH_SIZE]
|
419 |
batch_embeddings = self.model.encode(batch, convert_to_tensor=False, normalize_embeddings=True)
|
|
|
422 |
return np.vstack(embeddings)
|
423 |
|
424 |
def evaluate_language(self, language: str, max_queries: int = 1000) -> dict[str, Any]:
|
425 |
+
"""Evaluate on a specific programming language."""
|
426 |
logger.info(f"Evaluating on {language} language (max {max_queries} queries)")
|
427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
try:
|
429 |
# Load test split for the language
|
430 |
dataset = load_dataset(
|
431 |
+
codesearchnet_config.dataset_name,
|
432 |
language,
|
433 |
split="test",
|
434 |
trust_remote_code=True,
|
435 |
)
|
436 |
|
|
|
437 |
if not isinstance(dataset, Dataset):
|
438 |
logger.error(f"Unexpected dataset type for {language}: {type(dataset)}")
|
439 |
return {}
|
440 |
|
441 |
+
# Sample queries for evaluation
|
442 |
if len(dataset) > max_queries:
|
443 |
+
rng = np.random.default_rng(42)
|
444 |
indices = rng.choice(len(dataset), max_queries, replace=False)
|
445 |
dataset = dataset.select(indices)
|
446 |
|
|
|
464 |
logger.info(f"Found {len(queries)} valid query-code pairs for {language}")
|
465 |
|
466 |
# Encode queries and codes
|
467 |
+
start_time = time.time()
|
468 |
query_embeddings = self.encode_texts(queries, f"Encoding {language} queries")
|
469 |
+
code_embeddings = self.encode_texts(codes, f"Encoding {language} code")
|
470 |
+
encoding_time = time.time() - start_time
|
471 |
|
472 |
+
# Compute similarities and metrics
|
473 |
similarities = cosine_similarity(query_embeddings, code_embeddings)
|
|
|
|
|
474 |
metrics = self._compute_retrieval_metrics(similarities)
|
475 |
|
476 |
+
# Prepare results
|
477 |
+
results = {
|
478 |
"language": language,
|
479 |
+
"model_name": self.model_name,
|
480 |
"num_queries": len(queries),
|
481 |
+
"encoding_time_seconds": encoding_time,
|
482 |
"metrics": metrics,
|
|
|
483 |
}
|
484 |
|
485 |
+
logger.info(f"✅ {language} evaluation completed in {encoding_time:.2f}s")
|
486 |
+
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
|
488 |
except Exception:
|
489 |
+
logger.exception(f"❌ Failed to evaluate {language}")
|
490 |
return {}
|
491 |
|
492 |
def _compute_retrieval_metrics(self, similarities: np.ndarray) -> dict[str, float]:
|
493 |
+
"""Compute retrieval metrics from similarity matrix."""
|
494 |
+
n_queries = similarities.shape[0]
|
495 |
|
496 |
+
# For each query, the correct code is at the same index
|
497 |
+
correct_indices = np.arange(n_queries)
|
498 |
+
|
499 |
+
# Rank all codes for each query
|
500 |
+
ranked_indices = np.argsort(similarities, axis=1)[:, ::-1]
|
501 |
+
|
502 |
+
metrics = {}
|
503 |
+
|
504 |
+
# Compute metrics for different k values
|
505 |
+
for k in [1, 5, 10]:
|
506 |
+
if k <= similarities.shape[1]:
|
507 |
+
# Recall@k
|
508 |
+
recall_k = np.mean([correct_indices[i] in ranked_indices[i, :k] for i in range(n_queries)])
|
509 |
+
metrics[f"recall@{k}"] = recall_k
|
510 |
+
|
511 |
+
# NDCG@k
|
512 |
+
ndcg_k = np.mean(
|
513 |
+
[self._compute_ndcg(ranked_indices[i], correct_indices[i], k) for i in range(n_queries)]
|
514 |
+
)
|
515 |
+
metrics[f"ndcg@{k}"] = ndcg_k
|
516 |
+
|
517 |
+
# Mean Reciprocal Rank
|
518 |
reciprocal_ranks = []
|
519 |
+
for i in range(n_queries):
|
520 |
+
rank = np.where(ranked_indices[i] == correct_indices[i])[0]
|
521 |
+
if len(rank) > 0:
|
522 |
+
reciprocal_ranks.append(1.0 / (rank[0] + 1))
|
523 |
+
else:
|
524 |
+
reciprocal_ranks.append(0.0)
|
525 |
+
|
526 |
+
metrics["mrr"] = np.mean(reciprocal_ranks)
|
527 |
+
|
528 |
+
# Add mean rank and median rank
|
529 |
+
mean_ranks = []
|
530 |
+
for i in range(n_queries):
|
531 |
+
rank = np.where(ranked_indices[i] == correct_indices[i])[0]
|
532 |
+
if len(rank) > 0:
|
533 |
+
mean_ranks.append(rank[0] + 1) # 1-indexed
|
534 |
+
else:
|
535 |
+
mean_ranks.append(similarities.shape[1]) # Worst possible rank
|
536 |
+
|
537 |
+
metrics["mean_rank"] = np.mean(mean_ranks)
|
538 |
+
metrics["median_rank"] = np.median(mean_ranks)
|
539 |
+
|
540 |
+
# Ensure all values are float
|
541 |
+
return {k: float(v) for k, v in metrics.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
542 |
|
543 |
def _compute_ndcg(self, ranked_indices: np.ndarray, correct_idx: int, k: int) -> float:
|
544 |
"""Compute NDCG@k for a single query."""
|
545 |
+
if correct_idx in ranked_indices[:k]:
|
546 |
+
rank = np.where(ranked_indices[:k] == correct_idx)[0][0]
|
547 |
+
return 1.0 / np.log2(rank + 2)
|
|
|
|
|
|
|
|
|
|
|
548 |
return 0.0
|
549 |
|
550 |
def evaluate_all_languages(
|
551 |
self, max_queries_per_lang: int = 1000, languages: list[str] | None = None
|
552 |
) -> dict[str, Any]:
|
553 |
+
"""Evaluate on all specified languages."""
|
554 |
+
eval_languages = languages or languages_config.all
|
|
|
|
|
|
|
555 |
|
556 |
+
logger.info(f"🚀 Starting evaluation on {len(eval_languages)} languages")
|
557 |
+
logger.info(f"📊 Model: {self.model_name}")
|
558 |
+
logger.info(f"🔢 Max queries per language: {max_queries_per_lang}")
|
|
|
|
|
|
|
559 |
|
560 |
start_time = time.time()
|
561 |
+
results = {
|
|
|
562 |
"model_name": self.model_name,
|
563 |
"model_path": self.model_path,
|
|
|
564 |
"languages": {},
|
565 |
"overall": {},
|
566 |
+
"evaluation_time_seconds": 0,
|
567 |
}
|
568 |
+
languages_dict: dict[str, Any] = {}
|
569 |
|
570 |
+
# Evaluate each language
|
571 |
+
for language in eval_languages:
|
572 |
+
logger.info(f"\n{'=' * 50}")
|
573 |
+
logger.info(f"🔍 Evaluating {language}")
|
574 |
+
logger.info(f"{'=' * 50}")
|
575 |
|
|
|
|
|
576 |
lang_results = self.evaluate_language(language, max_queries_per_lang)
|
|
|
577 |
if lang_results:
|
578 |
+
languages_dict[language] = lang_results
|
|
|
|
|
|
|
579 |
|
580 |
+
results["languages"] = languages_dict
|
581 |
+
|
582 |
+
# Compute overall metrics
|
583 |
+
if languages_dict:
|
584 |
overall_metrics = {}
|
585 |
+
metric_names = list(next(iter(languages_dict.values()))["metrics"].keys())
|
586 |
+
|
587 |
+
for metric in metric_names:
|
588 |
+
values = [languages_dict[lang]["metrics"][metric] for lang in languages_dict]
|
589 |
+
overall_metrics[metric] = np.mean(values)
|
590 |
|
591 |
results["overall"] = overall_metrics
|
592 |
|
593 |
total_time = time.time() - start_time
|
594 |
results["evaluation_time_seconds"] = total_time
|
595 |
|
|
|
|
|
|
|
|
|
|
|
596 |
logger.info(f"Evaluation completed in {total_time:.2f} seconds")
|
597 |
return results
|
598 |
|
599 |
|
600 |
+
class ComprehensiveModelEvaluator:
|
601 |
+
"""Combined evaluator for both task performance and operational benchmarks."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
602 |
|
603 |
+
def __init__(self, model_path: str, model_name: str | None = None) -> None:
|
604 |
+
"""Initialize the comprehensive evaluator with a model."""
|
605 |
+
self.model_path = model_path
|
606 |
+
self.model_name = model_name or Path(model_path).name
|
607 |
+
|
608 |
+
# Initialize sub-evaluators
|
609 |
+
self.codesearch_evaluator = CodeSearchNetEvaluator(model_path, model_name)
|
610 |
+
self.performance_benchmarker = PerformanceBenchmark(model_path, model_name)
|
611 |
+
|
612 |
+
self.results: dict[str, Any] = {}
|
613 |
+
|
614 |
+
def run_comprehensive_evaluation(
|
615 |
+
self,
|
616 |
+
max_queries_per_lang: int = 1000,
|
617 |
+
languages: list[str] | None = None,
|
618 |
+
skip_benchmark: bool = False,
|
619 |
+
) -> dict[str, Any]:
|
620 |
+
"""Run both CodeSearchNet evaluation and performance benchmarking."""
|
621 |
+
logger.info(f"🚀 Starting comprehensive evaluation for {self.model_name}")
|
622 |
+
start_time = time.time()
|
623 |
+
|
624 |
+
# Initialize results structure
|
625 |
+
self.results = {
|
626 |
+
"model_name": self.model_name,
|
627 |
+
"model_path": self.model_path,
|
628 |
+
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
629 |
+
"evaluation_time_seconds": 0,
|
630 |
+
}
|
631 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
632 |
try:
|
633 |
+
# 1. Run CodeSearchNet evaluation
|
634 |
+
logger.info("🔍 Running CodeSearchNet task evaluation...")
|
635 |
+
codesearch_results = self.codesearch_evaluator.evaluate_all_languages(max_queries_per_lang, languages)
|
636 |
+
|
637 |
+
# Extract CodeSearchNet metrics
|
638 |
+
self.results.update(
|
639 |
+
{
|
640 |
+
"codesearch_languages": codesearch_results.get("languages", {}),
|
641 |
+
"codesearch_overall": codesearch_results.get("overall", {}),
|
642 |
+
}
|
643 |
+
)
|
644 |
+
|
645 |
+
# 2. Run performance benchmarking (unless skipped)
|
646 |
+
if not skip_benchmark:
|
647 |
+
logger.info("⚡ Running operational performance benchmarking...")
|
648 |
+
benchmark_results = self.performance_benchmarker.run_comprehensive_benchmark()
|
649 |
+
|
650 |
+
# Extract benchmark metrics
|
651 |
+
self.results.update(
|
652 |
+
{
|
653 |
+
"size_metrics": benchmark_results.get("size_metrics", {}),
|
654 |
+
"speed_benchmarks": benchmark_results.get("speed_benchmarks", {}),
|
655 |
+
"memory_benchmarks": benchmark_results.get("memory_benchmarks", {}),
|
656 |
+
"cpu_vs_gpu": benchmark_results.get("cpu_vs_gpu", {}),
|
657 |
+
}
|
658 |
+
)
|
659 |
+
else:
|
660 |
+
logger.info("⏭️ Skipping performance benchmarking")
|
661 |
+
self.results["benchmark_skipped"] = True
|
662 |
+
|
663 |
except Exception as e:
|
664 |
+
logger.exception(f"❌ Comprehensive evaluation failed for {self.model_name}")
|
665 |
+
self.results["error"] = str(e)
|
666 |
|
667 |
+
# Calculate total time
|
668 |
+
total_time = time.time() - start_time
|
669 |
+
self.results["evaluation_time_seconds"] = total_time
|
|
|
|
|
|
|
|
|
670 |
|
671 |
+
logger.info(f"✅ Comprehensive evaluation completed in {total_time:.2f} seconds")
|
672 |
+
return self.results
|
|
|
673 |
|
674 |
+
def print_summary(self) -> None:
|
675 |
+
"""Print a comprehensive summary of all results."""
|
676 |
+
logger.info(f"\n{'=' * 60}")
|
677 |
+
logger.info(f"📊 COMPREHENSIVE EVALUATION RESULTS: {self.model_name}")
|
678 |
+
logger.info(f"{'=' * 60}")
|
679 |
|
680 |
+
# CodeSearchNet results
|
681 |
+
overall = self.results.get("codesearch_overall", {})
|
682 |
+
if overall:
|
683 |
+
logger.info("🔍 CodeSearchNet Performance:")
|
684 |
+
for metric, value in overall.items():
|
685 |
+
logger.info(f" 🎯 {metric.upper()}: {value:.4f}")
|
686 |
+
|
687 |
+
# Benchmark results
|
688 |
+
if not self.results.get("benchmark_skipped", False):
|
689 |
+
size_metrics = self.results.get("size_metrics", {})
|
690 |
+
if size_metrics:
|
691 |
+
logger.info(f"\n📏 Model Size: {size_metrics.get('disk_size_mb', 0):.1f}MB")
|
692 |
+
if "parameters_millions" in size_metrics:
|
693 |
+
logger.info(f"🔢 Parameters: {size_metrics['parameters_millions']:.1f}M")
|
694 |
+
|
695 |
+
speed_benchmarks = self.results.get("speed_benchmarks", {})
|
696 |
+
if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
|
697 |
+
batch_32 = speed_benchmarks["medium"]["batch_32"]
|
698 |
+
logger.info(f"⚡ Throughput (batch 32): {batch_32.get('texts_per_second', 0):.1f} texts/sec")
|
699 |
+
|
700 |
+
cpu_vs_gpu = self.results.get("cpu_vs_gpu", {})
|
701 |
+
if "gpu_speedup" in cpu_vs_gpu:
|
702 |
+
speedup = cpu_vs_gpu["gpu_speedup"]
|
703 |
+
logger.info(f"🚀 GPU speedup: {speedup:.1f}x")
|
704 |
+
|
705 |
+
# Language breakdown
|
706 |
+
languages = self.results.get("codesearch_languages", {})
|
707 |
+
if languages:
|
708 |
+
logger.info("\n📋 Language Breakdown:")
|
709 |
+
for lang, lang_results in languages.items():
|
710 |
+
metrics = lang_results.get("metrics", {})
|
711 |
+
ndcg10 = metrics.get("ndcg@10", 0)
|
712 |
+
mrr = metrics.get("mrr", 0)
|
713 |
+
logger.info(f" {lang}: NDCG@10={ndcg10:.4f}, MRR={mrr:.4f}")
|
714 |
|
|
|
715 |
|
716 |
+
# =============================================================================
|
717 |
+
# UTILITY FUNCTIONS
|
718 |
+
# =============================================================================
|
719 |
|
|
|
|
|
|
|
|
|
720 |
|
721 |
+
def check_existing_results(model_name: str, local_dir: str = LOCAL_EVALUATION_DIR) -> dict[str, Any] | None:
|
722 |
+
"""Check if comprehensive evaluation results already exist for a model."""
|
723 |
+
local_path = Path(local_dir)
|
724 |
+
safe_model_name = get_safe_model_name(model_name)
|
725 |
+
|
726 |
+
# Check for new comprehensive format first
|
727 |
+
comprehensive_file = local_path / f"comprehensive_eval_{safe_model_name}.json"
|
728 |
+
if comprehensive_file.exists():
|
729 |
+
try:
|
730 |
+
with comprehensive_file.open("r") as f:
|
731 |
+
results = json.load(f)
|
732 |
+
logger.info(f"✅ Found existing comprehensive results for {model_name}")
|
733 |
+
return results
|
734 |
+
except Exception as e:
|
735 |
+
logger.warning(f"⚠️ Could not load existing comprehensive results for {model_name}: {e}")
|
736 |
+
|
737 |
+
# Fallback to legacy codesearchnet format for backward compatibility
|
738 |
+
legacy_file = local_path / f"codesearchnet_eval_{safe_model_name}.json"
|
739 |
+
if legacy_file.exists():
|
740 |
+
try:
|
741 |
+
with legacy_file.open("r") as f:
|
742 |
+
results = json.load(f)
|
743 |
+
logger.info(f"✅ Found existing legacy results for {model_name}")
|
744 |
+
return results
|
745 |
+
except Exception as e:
|
746 |
+
logger.warning(f"⚠️ Could not load existing legacy results for {model_name}: {e}")
|
747 |
+
|
748 |
+
return None
|
749 |
+
|
750 |
+
|
751 |
+
def save_evaluation_results(results: dict[str, Any], local_dir: str = LOCAL_EVALUATION_DIR) -> bool:
|
752 |
+
"""Save comprehensive evaluation results to local directory as a single JSON file."""
|
753 |
+
try:
|
754 |
+
local_path = Path(local_dir)
|
755 |
+
local_path.mkdir(parents=True, exist_ok=True)
|
756 |
+
|
757 |
+
model_name = results.get("model_name", "unknown")
|
758 |
+
safe_model_name = get_safe_model_name(model_name)
|
759 |
|
760 |
+
# Save single comprehensive results file (CodeSearchNet + Benchmark combined)
|
761 |
+
result_file = local_path / f"comprehensive_eval_{safe_model_name}.json"
|
762 |
+
with result_file.open("w") as f:
|
763 |
+
json.dump(results, f, indent=2, default=str)
|
764 |
+
|
765 |
+
logger.info(f"💾 Saved comprehensive evaluation results for {model_name}")
|
766 |
+
return True
|
767 |
+
|
768 |
+
except Exception:
|
769 |
+
logger.exception("❌ Error saving evaluation results")
|
770 |
+
return False
|
771 |
+
|
772 |
+
|
773 |
+
def discover_local_models(models_dir: str = LOCAL_MODELS_DIR) -> list[str]:
|
774 |
+
"""Discover models in the local models directory."""
|
775 |
+
models_path = Path(models_dir)
|
776 |
+
discovered_models = []
|
777 |
+
|
778 |
+
if models_path.exists():
|
779 |
+
for model_dir in models_path.iterdir():
|
780 |
+
if model_dir.is_dir() and (
|
781 |
+
any(model_dir.glob("*.json")) or any(model_dir.glob("*.bin")) or any(model_dir.glob("*.safetensors"))
|
782 |
+
):
|
783 |
+
discovered_models.append(str(model_dir))
|
784 |
+
logger.info(f"📁 Found local model: {model_dir.name}")
|
785 |
+
|
786 |
+
return discovered_models
|
787 |
+
|
788 |
+
|
789 |
+
def print_results_summary(results: dict[str, Any]) -> None:
|
790 |
+
"""Print a formatted summary of comprehensive evaluation results."""
|
791 |
+
logger.info(f"\n{'=' * 60}")
|
792 |
+
logger.info(f"📊 COMPREHENSIVE EVALUATION: {results.get('model_name', 'Unknown')}")
|
793 |
+
logger.info(f"{'=' * 60}")
|
794 |
+
|
795 |
+
# CodeSearchNet results
|
796 |
+
overall = results.get("codesearch_overall", {})
|
797 |
if overall:
|
798 |
+
logger.info("🔍 CodeSearchNet Performance:")
|
799 |
+
for metric, value in overall.items():
|
800 |
+
logger.info(f" 🎯 {metric.upper()}: {value:.4f}")
|
801 |
+
|
802 |
+
# Benchmark results
|
803 |
+
if not results.get("benchmark_skipped", False):
|
804 |
+
size_metrics = results.get("size_metrics", {})
|
805 |
+
if size_metrics:
|
806 |
+
logger.info(f"\n📏 Model Size: {size_metrics.get('disk_size_mb', 0):.1f}MB")
|
807 |
+
if "parameters_millions" in size_metrics:
|
808 |
+
logger.info(f"🔢 Parameters: {size_metrics['parameters_millions']:.1f}M")
|
809 |
+
|
810 |
+
speed_benchmarks = results.get("speed_benchmarks", {})
|
811 |
+
if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
|
812 |
+
batch_32 = speed_benchmarks["medium"]["batch_32"]
|
813 |
+
logger.info(f"⚡ Throughput (batch 32): {batch_32.get('texts_per_second', 0):.1f} texts/sec")
|
814 |
+
|
815 |
+
# Language breakdown
|
816 |
+
languages = results.get("codesearch_languages", {})
|
817 |
+
if languages:
|
818 |
+
logger.info("\n📋 Language Breakdown:")
|
819 |
+
for lang, lang_results in languages.items():
|
820 |
+
metrics = lang_results.get("metrics", {})
|
821 |
+
ndcg10 = metrics.get("ndcg@10", 0)
|
822 |
+
mrr = metrics.get("mrr", 0)
|
823 |
+
logger.info(f" {lang}: NDCG@10={ndcg10:.4f}, MRR={mrr:.4f}")
|
824 |
+
|
825 |
+
|
826 |
+
def create_comparison_report(all_results: list[dict[str, Any]], output_dir: str = LOCAL_EVALUATION_DIR) -> None:
|
827 |
+
"""Create a comprehensive comparison report with both CodeSearchNet and benchmark data."""
|
828 |
if not all_results:
|
829 |
return
|
830 |
|
831 |
+
logger.info("📊 Creating comprehensive comparison report...")
|
832 |
+
|
833 |
+
# Create evaluation comparison dataframe
|
834 |
+
evaluation_data = []
|
835 |
+
benchmark_data = []
|
836 |
+
|
837 |
+
for result in all_results:
|
838 |
+
model_name = result.get("model_name", "Unknown")
|
839 |
+
|
840 |
+
# CodeSearchNet data
|
841 |
+
overall = result.get("codesearch_overall", {})
|
842 |
+
eval_row = {"model_name": model_name}
|
843 |
+
eval_row.update(overall)
|
844 |
+
evaluation_data.append(eval_row)
|
845 |
+
|
846 |
+
# Benchmark data (if available)
|
847 |
+
if not result.get("benchmark_skipped", False):
|
848 |
+
benchmark_row = {"model_name": model_name}
|
849 |
+
size_metrics = result.get("size_metrics", {})
|
850 |
+
speed_benchmarks = result.get("speed_benchmarks", {})
|
851 |
+
|
852 |
+
benchmark_row.update(size_metrics)
|
853 |
+
if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
|
854 |
+
batch_32 = speed_benchmarks["medium"]["batch_32"]
|
855 |
+
benchmark_row["best_throughput"] = batch_32.get("texts_per_second", 0)
|
856 |
+
benchmark_data.append(benchmark_row)
|
857 |
+
|
858 |
+
# Save comparison results
|
859 |
output_path = Path(output_dir)
|
860 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
861 |
|
862 |
+
# Combined evaluation comparison CSV (includes both CodeSearchNet and key benchmark metrics)
|
863 |
+
if evaluation_data and benchmark_data:
|
864 |
+
# Merge evaluation and benchmark data
|
865 |
+
combined_data = []
|
866 |
+
benchmark_dict = {row["model_name"]: row for row in benchmark_data}
|
867 |
+
|
868 |
+
for eval_row in evaluation_data:
|
869 |
+
model_name = eval_row["model_name"]
|
870 |
+
combined_row = eval_row.copy()
|
871 |
+
|
872 |
+
# Add benchmark metrics if available
|
873 |
+
if model_name in benchmark_dict:
|
874 |
+
benchmark_row = benchmark_dict[model_name]
|
875 |
+
combined_row.update(
|
876 |
+
{
|
877 |
+
"disk_size_mb": benchmark_row.get("disk_size_mb", 0),
|
878 |
+
"parameters_millions": benchmark_row.get("parameters_millions", 0),
|
879 |
+
"best_throughput": benchmark_row.get("best_throughput", 0),
|
880 |
+
}
|
881 |
+
)
|
882 |
|
883 |
+
combined_data.append(combined_row)
|
|
|
884 |
|
885 |
+
combined_df = pd.DataFrame(combined_data)
|
886 |
+
combined_csv = output_path / "comprehensive_comparison.csv"
|
887 |
+
combined_df.to_csv(combined_csv, index=False)
|
888 |
+
logger.info(f"📄 Comprehensive comparison CSV saved: {combined_csv}")
|
889 |
|
890 |
+
# Detailed JSON export
|
891 |
+
json_path = output_path / "comprehensive_evaluation.json"
|
892 |
+
with json_path.open("w") as f:
|
893 |
+
json.dump(all_results, f, indent=2, default=str)
|
894 |
+
logger.info(f"📄 Comprehensive results JSON saved: {json_path}")
|
895 |
|
896 |
|
897 |
+
# =============================================================================
|
898 |
+
# MAIN EVALUATION FUNCTIONS
|
899 |
+
# =============================================================================
|
900 |
+
|
901 |
+
|
902 |
+
def run_evaluation(
|
903 |
models: list[str],
|
904 |
max_queries: int = 1000,
|
905 |
languages: list[str] | None = None,
|
906 |
+
use_beam: bool = False,
|
907 |
+
skip_benchmark: bool = False,
|
|
|
908 |
) -> list[dict[str, Any]]:
|
909 |
+
"""Main evaluation function that handles both local and Beam execution."""
|
910 |
+
logger.info(f"🚀 Starting comprehensive evaluation ({'Beam' if use_beam else 'Local'})")
|
911 |
+
logger.info(f"📊 Evaluating {len(models)} models on {len(languages or languages_config.all)} languages")
|
912 |
+
logger.info(f"⚡ Benchmarking: {'Disabled' if skip_benchmark else 'Enabled'}")
|
913 |
+
|
914 |
+
# Check for existing results and skip already evaluated models
|
915 |
+
models_to_evaluate = []
|
916 |
+
skipped_models = []
|
917 |
+
all_results = []
|
918 |
|
919 |
+
for model_path in models:
|
920 |
+
model_name = Path(model_path).name
|
921 |
+
existing_results = check_existing_results(model_name)
|
922 |
|
923 |
+
if existing_results:
|
924 |
+
logger.info(f"✅ Model {model_name} already evaluated, skipping")
|
925 |
+
all_results.append(existing_results)
|
926 |
+
skipped_models.append(model_name)
|
927 |
+
else:
|
928 |
+
models_to_evaluate.append(model_path)
|
929 |
|
930 |
+
if not models_to_evaluate:
|
931 |
+
logger.info("🎉 All models already evaluated!")
|
932 |
+
return all_results
|
933 |
|
934 |
+
logger.info(f"📊 Need to evaluate {len(models_to_evaluate)} models")
|
935 |
+
|
936 |
+
if use_beam:
|
937 |
+
# Run on Beam
|
938 |
+
new_results = _run_beam_evaluation(models_to_evaluate, max_queries, languages, skip_benchmark)
|
939 |
+
else:
|
940 |
+
# Run locally
|
941 |
+
new_results = _run_local_evaluation(models_to_evaluate, max_queries, languages, skip_benchmark)
|
942 |
+
|
943 |
+
all_results.extend(new_results)
|
944 |
|
945 |
+
# Create comparison report
|
946 |
+
if len(all_results) > 1:
|
947 |
+
create_comparison_report(all_results)
|
948 |
+
|
949 |
+
# Print summary
|
950 |
+
newly_evaluated = len(new_results)
|
951 |
+
logger.info(f"\n{'=' * 60}")
|
952 |
+
logger.info("📊 EVALUATION SUMMARY")
|
953 |
+
logger.info(f"{'=' * 60}")
|
954 |
+
logger.info(f"📊 Total models: {len(models)}")
|
955 |
+
logger.info(f"✅ Newly evaluated: {newly_evaluated}")
|
956 |
+
logger.info(f"⏭️ Skipped (already done): {len(skipped_models)}")
|
957 |
+
logger.info(f"🎯 Total results: {len(all_results)}")
|
958 |
+
logger.info(f"⚡ Benchmarking: {'Disabled' if skip_benchmark else 'Enabled'}")
|
959 |
+
|
960 |
+
return all_results
|
961 |
+
|
962 |
+
|
963 |
+
def _run_local_evaluation(
|
964 |
+
models: list[str],
|
965 |
+
max_queries: int = 1000,
|
966 |
+
languages: list[str] | None = None,
|
967 |
+
skip_benchmark: bool = False,
|
968 |
+
) -> list[dict[str, Any]]:
|
969 |
+
"""Run comprehensive evaluation locally."""
|
970 |
+
logger.info("🖥️ Running local comprehensive evaluation")
|
971 |
+
|
972 |
+
results = []
|
973 |
for model_path in models:
|
974 |
model_name = Path(model_path).name
|
975 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
976 |
logger.info(f"\n{'=' * 60}")
|
977 |
logger.info(f"🔍 Evaluating model: {model_name}")
|
|
|
978 |
logger.info(f"{'=' * 60}")
|
979 |
|
980 |
try:
|
981 |
+
evaluator = ComprehensiveModelEvaluator(model_path, model_name)
|
982 |
+
result = evaluator.run_comprehensive_evaluation(max_queries, languages, skip_benchmark)
|
|
|
|
|
983 |
|
984 |
+
# Save results locally
|
985 |
+
save_evaluation_results(result)
|
986 |
+
print_results_summary(result)
|
987 |
+
results.append(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
988 |
|
989 |
except Exception:
|
990 |
logger.exception(f"❌ Failed to evaluate {model_name}")
|
991 |
continue
|
992 |
|
993 |
+
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
994 |
|
995 |
|
996 |
@function(
|
997 |
gpu=GPU_NAME,
|
998 |
+
volumes=[Volume(name=VOLUME_CONFIG.name, mount_path=VOLUME_CONFIG.mount_path)],
|
999 |
image=IMAGE,
|
1000 |
secrets=["HF_ACCESS_TOKEN"],
|
1001 |
+
env=BEAM_ENV_SETTINGS,
|
1002 |
+
timeout=3600 * 8, # 8 hours for comprehensive evaluation
|
|
|
|
|
|
|
1003 |
)
|
1004 |
+
def _beam_evaluate_single_model(
|
1005 |
+
model_path: str,
|
1006 |
+
max_queries: int = 1000,
|
1007 |
+
languages: list[str] | None = None,
|
1008 |
+
skip_benchmark: bool = False,
|
1009 |
+
) -> dict[str, Any]:
|
1010 |
+
"""Beam function to comprehensively evaluate a single model."""
|
1011 |
+
model_name = Path(model_path).name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1012 |
|
1013 |
+
logger.info(f"🚀 Beam comprehensive evaluation starting for {model_name}")
|
|
|
|
|
1014 |
|
1015 |
+
try:
|
1016 |
+
evaluator = ComprehensiveModelEvaluator(model_path, model_name)
|
1017 |
+
results = evaluator.run_comprehensive_evaluation(max_queries, languages, skip_benchmark)
|
1018 |
|
1019 |
+
# Save to Beam volume as single comprehensive file
|
1020 |
+
volume_results_dir = Path(VOLUME_CONFIG.mount_path) / "evaluation_results"
|
1021 |
+
volume_results_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
1022 |
|
1023 |
+
safe_model_name = get_safe_model_name(model_name)
|
1024 |
+
result_file = volume_results_dir / f"comprehensive_eval_{safe_model_name}.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1025 |
|
1026 |
+
with result_file.open("w") as f:
|
1027 |
+
json.dump(results, f, indent=2, default=str)
|
1028 |
|
1029 |
+
logger.info(f"💾 Saved Beam comprehensive evaluation results for {model_name}")
|
1030 |
+
return results
|
1031 |
|
1032 |
+
except Exception:
|
1033 |
+
logger.exception(f"❌ Beam comprehensive evaluation failed for {model_name}")
|
1034 |
+
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1035 |
|
1036 |
|
1037 |
+
def _run_beam_evaluation(
|
1038 |
+
models: list[str],
|
1039 |
max_queries: int = 1000,
|
1040 |
languages: list[str] | None = None,
|
1041 |
+
skip_benchmark: bool = False,
|
1042 |
) -> list[dict[str, Any]]:
|
1043 |
+
"""Run comprehensive evaluation on Beam and download results."""
|
1044 |
+
logger.info("☁️ Running Beam comprehensive evaluation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1045 |
|
1046 |
+
results = []
|
1047 |
for model_path in models:
|
1048 |
model_name = Path(model_path).name
|
1049 |
|
1050 |
+
logger.info(f"🚀 Starting Beam comprehensive evaluation for {model_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1051 |
|
1052 |
try:
|
1053 |
+
# Run evaluation on Beam
|
1054 |
+
result = _beam_evaluate_single_model.remote(model_path, max_queries, languages, skip_benchmark)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1055 |
|
1056 |
+
if result:
|
1057 |
+
# Download the comprehensive result file from Beam
|
1058 |
+
success = download_specific_evaluation_file(
|
1059 |
+
VOLUME_CONFIG.name,
|
1060 |
+
model_name,
|
1061 |
+
"evaluation_results",
|
1062 |
+
LOCAL_EVALUATION_DIR,
|
1063 |
+
file_prefix="comprehensive_eval",
|
1064 |
+
)
|
1065 |
|
1066 |
+
if success:
|
1067 |
+
logger.info(f"📥 Downloaded comprehensive results for {model_name}")
|
1068 |
+
print_results_summary(result)
|
1069 |
+
results.append(result)
|
1070 |
+
else:
|
1071 |
+
logger.warning(f"⚠️ Could not download results for {model_name}")
|
1072 |
|
1073 |
except Exception:
|
1074 |
+
logger.exception(f"❌ Beam comprehensive evaluation failed for {model_name}")
|
1075 |
continue
|
1076 |
|
1077 |
+
return results
|
|
|
|
|
|
|
1078 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1079 |
|
1080 |
+
# =============================================================================
|
1081 |
+
# CLI INTERFACE
|
1082 |
+
# =============================================================================
|
1083 |
|
|
|
1084 |
|
1085 |
+
def main(
|
1086 |
+
use_beam: bool = typer.Option(default=False, help="Use Beam for evaluation"),
|
1087 |
+
skip_third_party: bool = typer.Option(default=False, help="Skip third-party models"),
|
1088 |
+
skip_benchmark: bool = typer.Option(default=False, help="Skip performance benchmarking"),
|
1089 |
+
max_queries: int = typer.Option(default=1000, help="Maximum queries per language"),
|
1090 |
+
) -> None:
|
1091 |
+
"""Main comprehensive evaluation function."""
|
1092 |
+
logger.info("🚀 Starting comprehensive model evaluation (CodeSearchNet + Performance)")
|
1093 |
|
1094 |
+
# Build model list
|
1095 |
+
models = []
|
|
|
|
|
|
|
|
|
|
|
1096 |
|
1097 |
+
# Add third-party models if not skipped
|
1098 |
+
if not skip_third_party:
|
1099 |
+
logger.info("📊 Including third-party peer models for comparison")
|
1100 |
+
models.extend(DEFAULT_EVALUATION_MODELS)
|
1101 |
+
else:
|
1102 |
+
logger.info("⏭️ Skipping third-party models")
|
1103 |
|
1104 |
+
# Discover local models from code_model2vec/final
|
1105 |
+
logger.info("🔍 Discovering local distillation models...")
|
1106 |
+
local_models = discover_local_models()
|
1107 |
|
1108 |
+
if local_models:
|
1109 |
+
logger.info(f"✅ Found {len(local_models)} local models:")
|
1110 |
+
for model_path in local_models:
|
1111 |
+
models.append(model_path)
|
1112 |
+
logger.info(f" 📁 {Path(model_path).name}")
|
1113 |
+
else:
|
1114 |
+
logger.warning("⚠️ No local distillation models found")
|
1115 |
+
if skip_third_party:
|
1116 |
+
logger.error("❌ No models to evaluate!")
|
1117 |
+
return
|
1118 |
+
|
1119 |
+
if not models:
|
1120 |
+
logger.error("❌ No models to evaluate!")
|
1121 |
+
return
|
1122 |
|
1123 |
+
logger.info(f"📊 Will evaluate {len(models)} models:")
|
1124 |
+
for i, model in enumerate(models, 1):
|
1125 |
+
logger.info(f" {i}. {Path(model).name}")
|
1126 |
+
|
1127 |
+
# Run evaluation
|
1128 |
+
results = run_evaluation(
|
1129 |
+
models=models,
|
1130 |
max_queries=max_queries,
|
1131 |
+
languages=languages_config.all,
|
1132 |
+
use_beam=use_beam,
|
1133 |
+
skip_benchmark=skip_benchmark,
|
1134 |
)
|
1135 |
|
1136 |
+
logger.info("🎉 Comprehensive evaluation workflow completed!")
|
1137 |
+
logger.info(f"📊 Successfully evaluated {len(results)} models")
|
1138 |
+
logger.info(f"💾 Results saved to: {LOCAL_EVALUATION_DIR}")
|
1139 |
+
logger.info("📄 Format: Single comprehensive JSON per model (CodeSearchNet + Benchmarks)")
|
1140 |
+
|
1141 |
|
1142 |
if __name__ == "__main__":
|
1143 |
+
typer.run(main)
|
src/distiller/patch_utils.py
CHANGED
@@ -67,6 +67,15 @@ def apply_patch_file(patch_file: Path, target_dir: Path) -> bool:
|
|
67 |
try:
|
68 |
logger.info(f"Applying patch: {patch_file.name}")
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
# Use patch command with the following options:
|
71 |
# -p1: strip 1 leading directory from paths
|
72 |
# -d: change to directory before applying
|
@@ -121,6 +130,9 @@ def apply_all_patches() -> int:
|
|
121 |
target_dir = get_site_packages_path()
|
122 |
logger.info(f"Applying patches to: {target_dir}")
|
123 |
|
|
|
|
|
|
|
124 |
success_count = 0
|
125 |
|
126 |
# Sort patch files for consistent ordering
|
@@ -132,6 +144,113 @@ def apply_all_patches() -> int:
|
|
132 |
return success_count
|
133 |
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
def main() -> None:
|
136 |
"""Main function for standalone execution."""
|
137 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
|
67 |
try:
|
68 |
logger.info(f"Applying patch: {patch_file.name}")
|
69 |
|
70 |
+
# Check if patch is already applied
|
71 |
+
if is_patch_already_applied(patch_file, target_dir):
|
72 |
+
logger.info(f"Patch {patch_file.name} already applied")
|
73 |
+
return True
|
74 |
+
|
75 |
+
# Clean any duplicate validation code before applying
|
76 |
+
if "model2vec.patch" in patch_file.name:
|
77 |
+
clean_duplicate_validation_code(target_dir)
|
78 |
+
|
79 |
# Use patch command with the following options:
|
80 |
# -p1: strip 1 leading directory from paths
|
81 |
# -d: change to directory before applying
|
|
|
130 |
target_dir = get_site_packages_path()
|
131 |
logger.info(f"Applying patches to: {target_dir}")
|
132 |
|
133 |
+
# Clean any existing duplicates first
|
134 |
+
clean_duplicate_validation_code(target_dir)
|
135 |
+
|
136 |
success_count = 0
|
137 |
|
138 |
# Sort patch files for consistent ordering
|
|
|
144 |
return success_count
|
145 |
|
146 |
|
147 |
+
def is_patch_already_applied(patch_file: Path, target_dir: Path) -> bool:
|
148 |
+
"""
|
149 |
+
Check if a patch has already been applied by looking for specific markers.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
patch_file: Path to the .patch file
|
153 |
+
target_dir: Target directory (usually site-packages)
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
True if patch appears to be already applied, False otherwise
|
157 |
+
"""
|
158 |
+
try:
|
159 |
+
# For model2vec.patch, check if the validation code is already present
|
160 |
+
if "model2vec.patch" in patch_file.name:
|
161 |
+
inference_file = target_dir / "model2vec" / "distill" / "inference.py"
|
162 |
+
if inference_file.exists():
|
163 |
+
inference_content = inference_file.read_text()
|
164 |
+
# Check for the specific validation code we're adding
|
165 |
+
if (
|
166 |
+
"Token-vector mismatch:" in inference_content
|
167 |
+
and "Truncating to prevent failure" in inference_content
|
168 |
+
):
|
169 |
+
# Also make sure it's in the right place (before return statement, not after)
|
170 |
+
lines = inference_content.split("\n")
|
171 |
+
for i, line in enumerate(lines):
|
172 |
+
if "return out_tokens, out_weights" in line:
|
173 |
+
# Check if validation code appears before this return
|
174 |
+
preceding_lines = lines[max(0, i - 10) : i]
|
175 |
+
if any("Token-vector mismatch:" in pline for pline in preceding_lines):
|
176 |
+
return True
|
177 |
+
break
|
178 |
+
|
179 |
+
return False
|
180 |
+
|
181 |
+
except Exception as e:
|
182 |
+
logger.warning(f"Error checking if patch {patch_file.name} is applied: {e}")
|
183 |
+
return False
|
184 |
+
|
185 |
+
|
186 |
+
def clean_duplicate_validation_code(target_dir: Path) -> bool:
|
187 |
+
"""
|
188 |
+
Clean up duplicate validation code that might have been added by multiple patch applications.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
target_dir: Target directory (usually site-packages)
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
True if cleanup was successful, False otherwise
|
195 |
+
"""
|
196 |
+
try:
|
197 |
+
inference_file = target_dir / "model2vec" / "distill" / "inference.py"
|
198 |
+
if not inference_file.exists():
|
199 |
+
return True
|
200 |
+
|
201 |
+
content = inference_file.read_text()
|
202 |
+
lines = content.split("\n")
|
203 |
+
|
204 |
+
# Find all instances of the validation code
|
205 |
+
validation_indices = []
|
206 |
+
for i, line in enumerate(lines):
|
207 |
+
if "Token-vector mismatch:" in line:
|
208 |
+
validation_indices.append(i)
|
209 |
+
|
210 |
+
if len(validation_indices) <= 1:
|
211 |
+
return True # No duplicates or no validation code
|
212 |
+
|
213 |
+
# Keep only the validation code that appears before a return statement
|
214 |
+
lines_to_keep = []
|
215 |
+
skip_until = -1
|
216 |
+
|
217 |
+
for i, line in enumerate(lines):
|
218 |
+
if i <= skip_until:
|
219 |
+
continue
|
220 |
+
|
221 |
+
# If this is validation code
|
222 |
+
if "Token-vector mismatch:" in line:
|
223 |
+
# Look ahead to see if there's a return statement nearby
|
224 |
+
has_return_after = False
|
225 |
+
for j in range(i, min(len(lines), i + 20)):
|
226 |
+
if "return out_tokens, out_weights" in lines[j]:
|
227 |
+
has_return_after = True
|
228 |
+
break
|
229 |
+
|
230 |
+
# Keep this validation block only if it's followed by a return
|
231 |
+
if has_return_after:
|
232 |
+
lines_to_keep.append(line)
|
233 |
+
else:
|
234 |
+
# Skip this validation block (it's a duplicate)
|
235 |
+
# Find the end of this validation block
|
236 |
+
for j in range(i + 1, len(lines)):
|
237 |
+
if lines[j].strip() == "" or not lines[j].startswith(" "):
|
238 |
+
skip_until = j - 1
|
239 |
+
break
|
240 |
+
else:
|
241 |
+
lines_to_keep.append(line)
|
242 |
+
|
243 |
+
# Write back the cleaned content
|
244 |
+
cleaned_content = "\n".join(lines_to_keep)
|
245 |
+
inference_file.write_text(cleaned_content)
|
246 |
+
logger.info("Cleaned duplicate validation code from inference.py")
|
247 |
+
return True
|
248 |
+
|
249 |
+
except Exception as e:
|
250 |
+
logger.warning(f"Error cleaning duplicate validation code: {e}")
|
251 |
+
return False
|
252 |
+
|
253 |
+
|
254 |
def main() -> None:
|
255 |
"""Main function for standalone execution."""
|
256 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
src/distiller/sync.py
DELETED
@@ -1,262 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Sync utility for downloading files from Beam volume to local directory.
|
3 |
-
|
4 |
-
This module provides functionality to download generated files from the Beam volume
|
5 |
-
back to the local filesystem, including:
|
6 |
-
- Final distilled model files (model.safetensors, tokenizer.json, etc.)
|
7 |
-
- Analysis reports and charts (README.md, comparison charts, etc.)
|
8 |
-
"""
|
9 |
-
|
10 |
-
import logging
|
11 |
-
import shutil
|
12 |
-
from pathlib import Path
|
13 |
-
|
14 |
-
from .beam_utils import create_beam_utilities
|
15 |
-
|
16 |
-
# Configure logging
|
17 |
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
18 |
-
logger = logging.getLogger(__name__)
|
19 |
-
|
20 |
-
# Beam volume configuration (must match distill.py)
|
21 |
-
VOLUME_NAME = "gte_qwen2_m2v_code"
|
22 |
-
VOLUME_PATH = "./gte_qwen2_m2v_code"
|
23 |
-
|
24 |
-
# Model files to sync
|
25 |
-
MODEL_FILES = [
|
26 |
-
"model.safetensors",
|
27 |
-
"tokenizer.json",
|
28 |
-
"modules.json",
|
29 |
-
"config.json",
|
30 |
-
"pytorch_model.bin", # Backup format
|
31 |
-
"vocab.txt", # If present
|
32 |
-
]
|
33 |
-
|
34 |
-
# Analysis directories and files
|
35 |
-
ANALYSIS_DIRS = [
|
36 |
-
"analysis_results/reports",
|
37 |
-
"analysis_results/charts",
|
38 |
-
"evaluation_results",
|
39 |
-
]
|
40 |
-
|
41 |
-
ANALYSIS_FILES = [
|
42 |
-
"analysis_results/reports/analysis_report.md",
|
43 |
-
"analysis_results/reports/README.md",
|
44 |
-
"analysis_results/charts/*.png",
|
45 |
-
"analysis_results/charts/*.html",
|
46 |
-
"evaluation_results/*.json",
|
47 |
-
"evaluation_results/comparisons/*.csv",
|
48 |
-
]
|
49 |
-
|
50 |
-
|
51 |
-
def sync_model_files(output_dir: str) -> bool:
|
52 |
-
"""Download final model files from Beam volume."""
|
53 |
-
logger.info("🔄 Syncing model files from Beam volume...")
|
54 |
-
|
55 |
-
output_path = Path(output_dir)
|
56 |
-
output_path.mkdir(parents=True, exist_ok=True)
|
57 |
-
|
58 |
-
# First, let's debug what's actually in the volume
|
59 |
-
volume_root = Path(VOLUME_PATH)
|
60 |
-
logger.info(f"🔍 Debugging volume contents at: {volume_root}")
|
61 |
-
|
62 |
-
if volume_root.exists():
|
63 |
-
logger.info("📁 Volume root directory contents:")
|
64 |
-
for item in volume_root.iterdir():
|
65 |
-
if item.is_file():
|
66 |
-
logger.info(f" 📄 {item.name} ({item.stat().st_size} bytes)")
|
67 |
-
elif item.is_dir():
|
68 |
-
logger.info(f" 📁 {item.name}/ (directory)")
|
69 |
-
# List files in important subdirectories
|
70 |
-
if item.name in ["models", "checkpoints", "gte_qwen2_m2v_code"]:
|
71 |
-
try:
|
72 |
-
logger.info(f" Contents of {item.name}/:")
|
73 |
-
for subitem in item.iterdir():
|
74 |
-
if subitem.is_file():
|
75 |
-
logger.info(f" 📄 {subitem.name} ({subitem.stat().st_size} bytes)")
|
76 |
-
else:
|
77 |
-
logger.info(f" 📁 {subitem.name}/")
|
78 |
-
# Check one level deeper for model files
|
79 |
-
if subitem.is_dir():
|
80 |
-
for subsubitem in subitem.iterdir():
|
81 |
-
if subsubitem.is_file() and subsubitem.name in MODEL_FILES:
|
82 |
-
logger.info(f" 🎯 FOUND MODEL FILE: {subsubitem}")
|
83 |
-
except Exception as e:
|
84 |
-
logger.warning(f" Error exploring {item.name}: {e}")
|
85 |
-
|
86 |
-
# Also check for model files directly in root
|
87 |
-
logger.info("🔍 Checking for model files directly in volume root:")
|
88 |
-
for model_file in MODEL_FILES:
|
89 |
-
root_file = volume_root / model_file
|
90 |
-
if root_file.exists():
|
91 |
-
logger.info(f" 🎯 FOUND: {model_file} in root ({root_file.stat().st_size} bytes)")
|
92 |
-
else:
|
93 |
-
logger.error(f"❌ Volume root does not exist: {volume_root}")
|
94 |
-
return False
|
95 |
-
|
96 |
-
# Since training completed successfully, look for model files in all possible locations
|
97 |
-
model_locations = [
|
98 |
-
Path(VOLUME_PATH), # Root of volume (where final model was saved)
|
99 |
-
Path(VOLUME_PATH) / "models" / "refined_model", # Refined model directory
|
100 |
-
]
|
101 |
-
|
102 |
-
synced_files = []
|
103 |
-
|
104 |
-
for location in model_locations:
|
105 |
-
logger.info(f"📂 Checking model location: {location}")
|
106 |
-
|
107 |
-
if not location.exists():
|
108 |
-
logger.info(f" ⚠️ Location does not exist: {location}")
|
109 |
-
continue
|
110 |
-
|
111 |
-
# Try to download each model file directly
|
112 |
-
for model_file in MODEL_FILES:
|
113 |
-
source_path = location / model_file
|
114 |
-
dest_path = output_path / model_file
|
115 |
-
|
116 |
-
if source_path.exists():
|
117 |
-
try:
|
118 |
-
shutil.copy2(source_path, dest_path)
|
119 |
-
synced_files.append(model_file)
|
120 |
-
logger.info(f"✅ Downloaded: {model_file}")
|
121 |
-
except Exception as e:
|
122 |
-
logger.warning(f"⚠️ Failed to copy {model_file}: {e}")
|
123 |
-
|
124 |
-
if synced_files:
|
125 |
-
logger.info(f"🎉 Successfully synced {len(synced_files)} model files:")
|
126 |
-
for file in synced_files:
|
127 |
-
logger.info(f" ✓ {file}")
|
128 |
-
return True
|
129 |
-
logger.error("❌ No model files found to sync")
|
130 |
-
return False
|
131 |
-
|
132 |
-
|
133 |
-
def sync_analysis_files(output_dir: str) -> bool:
|
134 |
-
"""Download analysis reports and charts from Beam volume."""
|
135 |
-
logger.info("🔄 Syncing analysis files from Beam volume...")
|
136 |
-
|
137 |
-
output_path = Path(output_dir)
|
138 |
-
output_path.mkdir(parents=True, exist_ok=True)
|
139 |
-
|
140 |
-
synced_files = []
|
141 |
-
|
142 |
-
# Sync analysis reports (including README.md)
|
143 |
-
analysis_reports_dir = Path(VOLUME_PATH) / "analysis_results" / "reports"
|
144 |
-
if analysis_reports_dir.exists():
|
145 |
-
for report_file in analysis_reports_dir.glob("*.md"):
|
146 |
-
dest_path = output_path / report_file.name
|
147 |
-
try:
|
148 |
-
shutil.copy2(report_file, dest_path)
|
149 |
-
synced_files.append(str(report_file.name))
|
150 |
-
logger.info(f"✅ Downloaded report: {report_file.name}")
|
151 |
-
|
152 |
-
# Special handling for README.md - copy to root
|
153 |
-
if report_file.name in {"analysis_report.md", "README.md"}:
|
154 |
-
root_readme = Path(output_dir) / "README.md"
|
155 |
-
shutil.copy2(report_file, root_readme)
|
156 |
-
logger.info("✅ Updated root README.md")
|
157 |
-
|
158 |
-
except Exception as e:
|
159 |
-
logger.warning(f"⚠️ Failed to copy {report_file.name}: {e}")
|
160 |
-
|
161 |
-
# Sync charts
|
162 |
-
charts_dir = Path(VOLUME_PATH) / "analysis_results" / "charts"
|
163 |
-
local_charts_dir = output_path / "charts"
|
164 |
-
if charts_dir.exists():
|
165 |
-
local_charts_dir.mkdir(exist_ok=True)
|
166 |
-
|
167 |
-
for chart_file in charts_dir.glob("*"):
|
168 |
-
if chart_file.is_file():
|
169 |
-
dest_path = local_charts_dir / chart_file.name
|
170 |
-
try:
|
171 |
-
shutil.copy2(chart_file, dest_path)
|
172 |
-
synced_files.append(f"charts/{chart_file.name}")
|
173 |
-
logger.info(f"✅ Downloaded chart: {chart_file.name}")
|
174 |
-
except Exception as e:
|
175 |
-
logger.warning(f"⚠️ Failed to copy chart {chart_file.name}: {e}")
|
176 |
-
|
177 |
-
# Sync evaluation results
|
178 |
-
eval_dir = Path(VOLUME_PATH) / "evaluation_results"
|
179 |
-
local_eval_dir = output_path / "evaluation_results"
|
180 |
-
if eval_dir.exists():
|
181 |
-
local_eval_dir.mkdir(exist_ok=True)
|
182 |
-
|
183 |
-
for eval_file in eval_dir.glob("*.json"):
|
184 |
-
dest_path = local_eval_dir / eval_file.name
|
185 |
-
try:
|
186 |
-
shutil.copy2(eval_file, dest_path)
|
187 |
-
synced_files.append(f"evaluation_results/{eval_file.name}")
|
188 |
-
logger.info(f"✅ Downloaded evaluation: {eval_file.name}")
|
189 |
-
except Exception as e:
|
190 |
-
logger.warning(f"⚠️ Failed to copy evaluation {eval_file.name}: {e}")
|
191 |
-
|
192 |
-
if synced_files:
|
193 |
-
logger.info(f"🎉 Successfully synced {len(synced_files)} analysis files:")
|
194 |
-
for file in synced_files[:10]: # Show first 10
|
195 |
-
logger.info(f" ✓ {file}")
|
196 |
-
if len(synced_files) > 10:
|
197 |
-
logger.info(f" ... and {len(synced_files) - 10} more files")
|
198 |
-
return True
|
199 |
-
logger.error("❌ No analysis files found to sync")
|
200 |
-
return False
|
201 |
-
|
202 |
-
|
203 |
-
def sync_files(
|
204 |
-
model_files: bool = False,
|
205 |
-
analysis_files: bool = False,
|
206 |
-
all_files: bool = False,
|
207 |
-
output_dir: str = ".",
|
208 |
-
) -> None:
|
209 |
-
"""Main sync function to download files from Beam volume."""
|
210 |
-
logger.info("🚀 Starting file sync from Beam volume")
|
211 |
-
logger.info(f"📁 Local output directory: {output_dir}")
|
212 |
-
|
213 |
-
# Initialize Beam utilities (read-only)
|
214 |
-
try:
|
215 |
-
volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(VOLUME_NAME, VOLUME_PATH)
|
216 |
-
logger.info(f"✅ Connected to Beam volume: {VOLUME_NAME}")
|
217 |
-
except Exception:
|
218 |
-
logger.exception("❌ Failed to connect to Beam volume")
|
219 |
-
logger.info("Make sure you have run the distillation/evaluation on Beam first")
|
220 |
-
return
|
221 |
-
|
222 |
-
# Check what files to sync
|
223 |
-
sync_model = model_files or all_files
|
224 |
-
sync_analysis = analysis_files or all_files
|
225 |
-
|
226 |
-
if not (sync_model or sync_analysis):
|
227 |
-
logger.error("❌ No file types specified. Use --model-files, --analysis-files, or --all")
|
228 |
-
return
|
229 |
-
|
230 |
-
success_count = 0
|
231 |
-
|
232 |
-
# Sync model files
|
233 |
-
if sync_model:
|
234 |
-
logger.info("\n" + "=" * 60) # noqa: G003
|
235 |
-
logger.info("MODEL FILES SYNC")
|
236 |
-
logger.info("=" * 60)
|
237 |
-
if sync_model_files(output_dir):
|
238 |
-
success_count += 1
|
239 |
-
|
240 |
-
# Sync analysis files
|
241 |
-
if sync_analysis:
|
242 |
-
logger.info("\n" + "=" * 60) # noqa: G003
|
243 |
-
logger.info("ANALYSIS FILES SYNC")
|
244 |
-
logger.info("=" * 60)
|
245 |
-
if sync_analysis_files(output_dir):
|
246 |
-
success_count += 1
|
247 |
-
|
248 |
-
# Summary
|
249 |
-
logger.info("\n" + "=" * 60) # noqa: G003
|
250 |
-
logger.info("SYNC SUMMARY")
|
251 |
-
logger.info("=" * 60)
|
252 |
-
|
253 |
-
total_requested = sum([sync_model, sync_analysis])
|
254 |
-
|
255 |
-
if success_count == total_requested:
|
256 |
-
logger.info("🎉 All requested files synced successfully!")
|
257 |
-
elif success_count > 0:
|
258 |
-
logger.info(f"⚠️ Partial sync: {success_count}/{total_requested} file types synced")
|
259 |
-
else:
|
260 |
-
logger.error("❌ No files were synced")
|
261 |
-
|
262 |
-
logger.info(f"📂 Files saved to: {Path(output_dir).absolute()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/distiller/utils.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Common utilities for the distiller package.
|
3 |
+
|
4 |
+
This module provides shared functionality used across multiple components
|
5 |
+
including model discovery, result management, and initialization helpers.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
from pathlib import Path
|
11 |
+
from types import TracebackType
|
12 |
+
from typing import Any
|
13 |
+
|
14 |
+
from .beam_utils import (
|
15 |
+
BeamCheckpointManager,
|
16 |
+
BeamEvaluationManager,
|
17 |
+
BeamModelManager,
|
18 |
+
BeamVolumeManager,
|
19 |
+
create_beam_utilities,
|
20 |
+
)
|
21 |
+
from .config import VolumeConfig, get_safe_model_name, get_volume_config, setup_logging
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
# =============================================================================
|
26 |
+
# BEAM UTILITIES MANAGEMENT
|
27 |
+
# =============================================================================
|
28 |
+
|
29 |
+
|
30 |
+
class BeamContext:
|
31 |
+
"""Context manager for Beam utilities with consistent initialization."""
|
32 |
+
|
33 |
+
def __init__(self, workflow: str, volume_config: VolumeConfig | None = None) -> None:
|
34 |
+
"""
|
35 |
+
Initialize Beam context.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
workflow: Workflow type (distill, evaluate, benchmark, etc.)
|
39 |
+
volume_config: Optional custom volume config, otherwise inferred from workflow
|
40 |
+
"""
|
41 |
+
self.workflow = workflow
|
42 |
+
self.volume_config = volume_config or get_volume_config()
|
43 |
+
self.volume_manager: BeamVolumeManager | None = None
|
44 |
+
self.checkpoint_manager: BeamCheckpointManager | None = None
|
45 |
+
self.model_manager: BeamModelManager | None = None
|
46 |
+
self.evaluation_manager: BeamEvaluationManager | None = None
|
47 |
+
|
48 |
+
def __enter__(self) -> tuple[BeamVolumeManager, BeamCheckpointManager, BeamModelManager, BeamEvaluationManager]:
|
49 |
+
"""Enter context and initialize utilities."""
|
50 |
+
logger.info(f"🚀 Initializing Beam utilities for {self.workflow}")
|
51 |
+
logger.info(f"📁 Volume: {self.volume_config.name} at {self.volume_config.mount_path}")
|
52 |
+
|
53 |
+
self.volume_manager, self.checkpoint_manager, self.model_manager, self.evaluation_manager = (
|
54 |
+
create_beam_utilities(self.volume_config.name, self.volume_config.mount_path)
|
55 |
+
)
|
56 |
+
|
57 |
+
return self.volume_manager, self.checkpoint_manager, self.model_manager, self.evaluation_manager
|
58 |
+
|
59 |
+
def __exit__(
|
60 |
+
self,
|
61 |
+
exc_type: type[BaseException] | None,
|
62 |
+
exc_val: BaseException | None,
|
63 |
+
exc_tb: TracebackType | None,
|
64 |
+
) -> None:
|
65 |
+
"""Exit context with cleanup if needed."""
|
66 |
+
if exc_type:
|
67 |
+
logger.error(f"❌ Error in Beam context for {self.workflow}: {exc_val}")
|
68 |
+
else:
|
69 |
+
logger.info(f"✅ Beam context for {self.workflow} completed successfully")
|
70 |
+
|
71 |
+
|
72 |
+
def get_beam_utilities() -> tuple[BeamVolumeManager, BeamCheckpointManager, BeamModelManager, BeamEvaluationManager]:
|
73 |
+
"""
|
74 |
+
Get Beam utilities for a specific workflow.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Tuple of (volume_manager, checkpoint_manager, model_manager, evaluation_manager)
|
78 |
+
"""
|
79 |
+
volume_config = get_volume_config()
|
80 |
+
return create_beam_utilities(volume_config.name, volume_config.mount_path)
|
81 |
+
|
82 |
+
|
83 |
+
# =============================================================================
|
84 |
+
# MODEL DISCOVERY
|
85 |
+
# =============================================================================
|
86 |
+
|
87 |
+
|
88 |
+
def discover_simplified_models(base_path: str | Path = ".") -> list[str]:
|
89 |
+
"""
|
90 |
+
Discover simplified distillation models in the specified directory.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
base_path: Base path to search for models
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
List of model paths sorted alphabetically
|
97 |
+
"""
|
98 |
+
base = Path(base_path)
|
99 |
+
|
100 |
+
# Look for models in common locations
|
101 |
+
search_patterns = [
|
102 |
+
"code_model2vec/final/**/",
|
103 |
+
"final/**/",
|
104 |
+
"code_model2vec_*/",
|
105 |
+
"*/config.json",
|
106 |
+
"*.safetensors",
|
107 |
+
]
|
108 |
+
|
109 |
+
discovered_models = []
|
110 |
+
|
111 |
+
for pattern in search_patterns:
|
112 |
+
matches = list(base.glob(pattern))
|
113 |
+
for match in matches:
|
114 |
+
if match.is_dir():
|
115 |
+
# Check if it's a valid model directory
|
116 |
+
if (match / "config.json").exists() or (match / "model.safetensors").exists():
|
117 |
+
discovered_models.append(str(match))
|
118 |
+
elif match.name == "config.json":
|
119 |
+
# Add parent directory if config.json found
|
120 |
+
discovered_models.append(str(match.parent))
|
121 |
+
|
122 |
+
# Remove duplicates and sort
|
123 |
+
unique_models = sorted(set(discovered_models))
|
124 |
+
|
125 |
+
logger.info(f"🔍 Discovered {len(unique_models)} models in {base_path}")
|
126 |
+
for model in unique_models:
|
127 |
+
logger.info(f" 📁 {model}")
|
128 |
+
|
129 |
+
return unique_models
|
130 |
+
|
131 |
+
|
132 |
+
def validate_model_path(model_path: str | Path, volume_manager: BeamVolumeManager | None = None) -> str | None:
|
133 |
+
"""
|
134 |
+
Validate and resolve model path, checking local filesystem and Beam volumes.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
model_path: Path to model (can be local path or HuggingFace model name)
|
138 |
+
volume_manager: Optional volume manager for Beam volume checks
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
Resolved model path or None if not found
|
142 |
+
"""
|
143 |
+
path = Path(model_path)
|
144 |
+
|
145 |
+
# Check if it's a HuggingFace model name
|
146 |
+
if "/" in str(model_path) and not path.exists() and not str(model_path).startswith("/"):
|
147 |
+
logger.info(f"📥 Treating as HuggingFace model: {model_path}")
|
148 |
+
return str(model_path)
|
149 |
+
|
150 |
+
# Check local filesystem
|
151 |
+
if path.exists():
|
152 |
+
logger.info(f"✅ Found local model: {model_path}")
|
153 |
+
return str(path)
|
154 |
+
|
155 |
+
# Check Beam volume if available
|
156 |
+
if volume_manager:
|
157 |
+
volume_path = Path(volume_manager.mount_path) / path.name
|
158 |
+
if volume_path.exists():
|
159 |
+
logger.info(f"✅ Found model in Beam volume: {volume_path}")
|
160 |
+
return str(volume_path)
|
161 |
+
|
162 |
+
# Check volume root
|
163 |
+
root_path = Path(volume_manager.mount_path)
|
164 |
+
if (root_path / "config.json").exists():
|
165 |
+
logger.info(f"✅ Found model in Beam volume root: {root_path}")
|
166 |
+
return str(root_path)
|
167 |
+
|
168 |
+
logger.warning(f"⚠️ Model not found: {model_path}")
|
169 |
+
return None
|
170 |
+
|
171 |
+
|
172 |
+
# =============================================================================
|
173 |
+
# RESULT MANAGEMENT
|
174 |
+
# =============================================================================
|
175 |
+
|
176 |
+
|
177 |
+
def save_results_with_backup(
|
178 |
+
results: dict[str, Any],
|
179 |
+
primary_path: str | Path,
|
180 |
+
model_name: str,
|
181 |
+
result_type: str = "evaluation",
|
182 |
+
volume_manager: BeamVolumeManager | None = None,
|
183 |
+
evaluation_manager: BeamEvaluationManager | None = None,
|
184 |
+
) -> bool:
|
185 |
+
"""
|
186 |
+
Save results with multiple backup strategies.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
results: Results dictionary to save
|
190 |
+
primary_path: Primary save location
|
191 |
+
model_name: Model name for filename generation
|
192 |
+
result_type: Type of results (evaluation, benchmark, etc.)
|
193 |
+
volume_manager: Optional volume manager for Beam storage
|
194 |
+
evaluation_manager: Optional evaluation manager for specialized storage
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
True if saved successfully to at least one location
|
198 |
+
"""
|
199 |
+
success_count = 0
|
200 |
+
safe_name = get_safe_model_name(model_name)
|
201 |
+
|
202 |
+
# Save to primary location
|
203 |
+
try:
|
204 |
+
primary = Path(primary_path)
|
205 |
+
primary.mkdir(parents=True, exist_ok=True)
|
206 |
+
filename = f"{result_type}_{safe_name}.json"
|
207 |
+
filepath = primary / filename
|
208 |
+
|
209 |
+
with filepath.open("w") as f:
|
210 |
+
json.dump(results, f, indent=2, default=str)
|
211 |
+
|
212 |
+
logger.info(f"💾 Saved {result_type} results to: {filepath}")
|
213 |
+
success_count += 1
|
214 |
+
except Exception as e:
|
215 |
+
logger.warning(f"⚠️ Failed to save to primary location: {e}")
|
216 |
+
|
217 |
+
# Save to Beam volume if available
|
218 |
+
if volume_manager:
|
219 |
+
try:
|
220 |
+
volume_path = Path(volume_manager.mount_path) / f"{result_type}_results"
|
221 |
+
volume_path.mkdir(parents=True, exist_ok=True)
|
222 |
+
filename = f"{result_type}_{safe_name}.json"
|
223 |
+
filepath = volume_path / filename
|
224 |
+
|
225 |
+
with filepath.open("w") as f:
|
226 |
+
json.dump(results, f, indent=2, default=str)
|
227 |
+
|
228 |
+
logger.info(f"💾 Saved {result_type} results to Beam volume: {filepath}")
|
229 |
+
success_count += 1
|
230 |
+
except Exception as e:
|
231 |
+
logger.warning(f"⚠️ Failed to save to Beam volume: {e}")
|
232 |
+
|
233 |
+
# Save via evaluation manager if available and appropriate
|
234 |
+
if evaluation_manager and result_type == "evaluation":
|
235 |
+
try:
|
236 |
+
success = evaluation_manager.save_evaluation_results(model_name, results)
|
237 |
+
if success:
|
238 |
+
logger.info(f"💾 Saved via evaluation manager for {model_name}")
|
239 |
+
success_count += 1
|
240 |
+
except Exception as e:
|
241 |
+
logger.warning(f"⚠️ Failed to save via evaluation manager: {e}")
|
242 |
+
|
243 |
+
return success_count > 0
|
244 |
+
|
245 |
+
|
246 |
+
def load_existing_results(
|
247 |
+
model_name: str,
|
248 |
+
result_type: str = "evaluation",
|
249 |
+
search_paths: list[str | Path] | None = None,
|
250 |
+
volume_manager: BeamVolumeManager | None = None,
|
251 |
+
evaluation_manager: BeamEvaluationManager | None = None,
|
252 |
+
) -> dict[str, Any] | None:
|
253 |
+
"""
|
254 |
+
Load existing results from multiple possible locations.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
model_name: Model name to search for
|
258 |
+
result_type: Type of results to load
|
259 |
+
search_paths: Additional paths to search
|
260 |
+
volume_manager: Optional volume manager
|
261 |
+
evaluation_manager: Optional evaluation manager
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
Results dictionary if found, None otherwise
|
265 |
+
"""
|
266 |
+
safe_name = get_safe_model_name(model_name)
|
267 |
+
filename = f"{result_type}_{safe_name}.json"
|
268 |
+
|
269 |
+
# Search in provided paths
|
270 |
+
if search_paths:
|
271 |
+
for search_path in search_paths:
|
272 |
+
filepath = Path(search_path) / filename
|
273 |
+
if filepath.exists():
|
274 |
+
try:
|
275 |
+
with filepath.open("r") as f:
|
276 |
+
results = json.load(f)
|
277 |
+
logger.info(f"📂 Loaded existing {result_type} results from: {filepath}")
|
278 |
+
return results
|
279 |
+
except Exception as e:
|
280 |
+
logger.warning(f"⚠️ Failed to load from {filepath}: {e}")
|
281 |
+
|
282 |
+
# Search in Beam volume
|
283 |
+
if volume_manager:
|
284 |
+
volume_path = Path(volume_manager.mount_path) / f"{result_type}_results" / filename
|
285 |
+
if volume_path.exists():
|
286 |
+
try:
|
287 |
+
with volume_path.open("r") as f:
|
288 |
+
results = json.load(f)
|
289 |
+
logger.info(f"📂 Loaded existing {result_type} results from Beam volume: {volume_path}")
|
290 |
+
return results
|
291 |
+
except Exception as e:
|
292 |
+
logger.warning(f"⚠️ Failed to load from Beam volume: {e}")
|
293 |
+
|
294 |
+
# Try evaluation manager
|
295 |
+
if evaluation_manager and result_type == "evaluation":
|
296 |
+
try:
|
297 |
+
results = evaluation_manager.load_evaluation_results(model_name)
|
298 |
+
if results:
|
299 |
+
logger.info(f"📂 Loaded existing {result_type} results via evaluation manager")
|
300 |
+
return results
|
301 |
+
except Exception as e:
|
302 |
+
logger.warning(f"⚠️ Failed to load via evaluation manager: {e}")
|
303 |
+
|
304 |
+
logger.info(f"ℹ️ No existing {result_type} results found for {model_name}")
|
305 |
+
return None
|
306 |
+
|
307 |
+
|
308 |
+
# =============================================================================
|
309 |
+
# WORKFLOW HELPERS
|
310 |
+
# =============================================================================
|
311 |
+
|
312 |
+
|
313 |
+
def print_workflow_summary(
|
314 |
+
workflow_name: str,
|
315 |
+
total_items: int,
|
316 |
+
processed_items: int,
|
317 |
+
skipped_items: int,
|
318 |
+
execution_time: float | None = None,
|
319 |
+
) -> None:
|
320 |
+
"""Print a standardized workflow summary."""
|
321 |
+
logger.info(f"\n✅ {workflow_name} complete!")
|
322 |
+
logger.info(f"📊 Total items: {total_items}")
|
323 |
+
logger.info(f"✨ Newly processed: {processed_items}")
|
324 |
+
logger.info(f"⏭️ Skipped (already done): {skipped_items}")
|
325 |
+
|
326 |
+
if execution_time:
|
327 |
+
logger.info(f"⏱️ Execution time: {execution_time:.2f} seconds")
|
328 |
+
|
329 |
+
|
330 |
+
def check_existing_results(
|
331 |
+
items: list[str],
|
332 |
+
result_type: str,
|
333 |
+
search_paths: list[str | Path] | None = None,
|
334 |
+
volume_manager: BeamVolumeManager | None = None,
|
335 |
+
) -> tuple[list[str], list[str]]:
|
336 |
+
"""
|
337 |
+
Check which items already have results and which need processing.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
items: List of items (model names, etc.) to check
|
341 |
+
result_type: Type of results to check for
|
342 |
+
search_paths: Paths to search for existing results
|
343 |
+
volume_manager: Optional volume manager
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
Tuple of (items_to_process, items_to_skip)
|
347 |
+
"""
|
348 |
+
to_process = []
|
349 |
+
to_skip = []
|
350 |
+
|
351 |
+
for item in items:
|
352 |
+
existing = load_existing_results(item, result_type, search_paths, volume_manager)
|
353 |
+
if existing:
|
354 |
+
to_skip.append(item)
|
355 |
+
else:
|
356 |
+
to_process.append(item)
|
357 |
+
|
358 |
+
return to_process, to_skip
|
359 |
+
|
360 |
+
|
361 |
+
# =============================================================================
|
362 |
+
# INITIALIZATION
|
363 |
+
# =============================================================================
|
364 |
+
|
365 |
+
|
366 |
+
def initialize_distiller_logging(level: int = logging.INFO) -> None:
|
367 |
+
"""Initialize logging for distiller package."""
|
368 |
+
setup_logging(level)
|
369 |
+
logger.info("🚀 Distiller package initialized")
|
370 |
+
|
371 |
+
|
372 |
+
# Ensure logging is set up when module is imported
|
373 |
+
initialize_distiller_logging()
|