Sarthak commited on
Commit
ecfceb8
·
1 Parent(s): 7bb46ce

initial commit

Browse files
.gitattributes CHANGED
@@ -1,37 +1,6 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
37
  *.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
2
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
3
  *.png filter=lfs diff=lfs merge=lfs -text
4
+ evaluation/** filter=lfs diff=lfs merge=lfs -text
5
+ *.skops* filter=lfs diff=lfs merge=lfs -text
6
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ # Cache
13
+ .ruff_cache
14
+ .mypy_cache
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
MTEB_evaluate.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ MTEB Evaluation Script for Distilled Model - Code-Focused Tasks.
4
+
5
+ This script evaluates the distilled gte-Qwen2-7B-instruct model using MTEB
6
+ (Massive Text Embedding Benchmark) with a focus on tasks relevant for code:
7
+
8
+ - Classification: Tests ability to distinguish between different categories (e.g., programming languages)
9
+ - Clustering: Tests ability to group similar code by functionality
10
+ - STS: Tests semantic similarity understanding between code snippets
11
+ - Retrieval: Tests code search and duplicate detection capabilities
12
+
13
+ Features:
14
+ - Incremental evaluation: Skips tasks that already have results in mteb_results/
15
+ - Combines existing and new results automatically
16
+ - Saves results in multiple formats for analysis
17
+
18
+ Usage:
19
+ python MTEB_evaluate.py
20
+
21
+ Configuration:
22
+ - Set EVAL_ALL_TASKS = False to use only CODE_SPECIFIC_TASKS
23
+ - Modify CODE_SPECIFIC_TASKS for granular task selection
24
+ """
25
+
26
+ import json
27
+ import logging
28
+ import sys
29
+ import time
30
+ from pathlib import Path
31
+
32
+ import mteb
33
+ from model2vec import StaticModel
34
+ from mteb import ModelMeta
35
+
36
+ from evaluation import (
37
+ CustomMTEB,
38
+ get_tasks,
39
+ make_leaderboard,
40
+ parse_mteb_results,
41
+ summarize_results,
42
+ )
43
+
44
+ # =============================================================================
45
+ # CONFIGURATION CONSTANTS
46
+ # =============================================================================
47
+
48
+ # Model Configuration
49
+ MODEL_PATH = "." # Path to the distilled model directory
50
+ MODEL_NAME = "gte-Qwen2-7B-instruct-M2V-Distilled" # Name for the model in results
51
+
52
+ # Evaluation Configuration
53
+ OUTPUT_DIR = "mteb_results" # Directory to save evaluation results
54
+
55
+ EVAL_ALL_TASKS = True
56
+
57
+ # Specific tasks most relevant for code evaluation (focused selection)
58
+ CODE_SPECIFIC_TASKS = [
59
+ # Classification - Programming language/category classification
60
+ "Banking77Classification", # Fine-grained classification (77 classes)
61
+ # Clustering - Code grouping by functionality
62
+ "StackExchangeClustering.v2", # Technical Q&A clustering (most relevant)
63
+ # STS - Code similarity understanding
64
+ "STSBenchmark", # Standard semantic similarity benchmark
65
+ # Retrieval - Code search capabilities
66
+ "CQADupstackProgrammersRetrieval", # Programming Q&A retrieval
67
+ # PairClassification - Duplicate/similar code detection
68
+ "SprintDuplicateQuestions", # Duplicate question detection
69
+ ]
70
+
71
+ # Evaluation settings
72
+ EVAL_SPLITS = ["test"] # Dataset splits to evaluate on
73
+ VERBOSITY = 2 # MTEB verbosity level
74
+
75
+ # =============================================================================
76
+
77
+ # Configure logging
78
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
79
+ logger = logging.getLogger(__name__)
80
+
81
+
82
+ def check_existing_results(output_path: Path, tasks: list) -> list:
83
+ """Check for existing task results and filter out completed tasks."""
84
+ remaining_tasks = []
85
+ completed_tasks = []
86
+
87
+ for task in tasks:
88
+ task_name = task.metadata.name
89
+ # MTEB saves results as {model_name}__{task_name}.json
90
+ result_file = output_path / MODEL_NAME / f"{task_name}.json"
91
+
92
+ if result_file.exists():
93
+ completed_tasks.append(task_name)
94
+ logger.info(f"Skipping {task_name} - results already exist")
95
+ else:
96
+ remaining_tasks.append(task)
97
+
98
+ if completed_tasks:
99
+ logger.info(f"Found existing results for {len(completed_tasks)} tasks: {completed_tasks}")
100
+
101
+ return remaining_tasks
102
+
103
+
104
+ def load_existing_parsed_results(output_path: Path) -> dict:
105
+ """Load existing parsed results if they exist."""
106
+ parsed_results_file = output_path / "mteb_parsed_results.json"
107
+ if parsed_results_file.exists():
108
+ try:
109
+ with parsed_results_file.open("r") as f:
110
+ return json.load(f)
111
+ except (json.JSONDecodeError, OSError) as e:
112
+ logger.warning(f"Could not load existing parsed results: {e}")
113
+ return {}
114
+
115
+
116
+ def load_and_display_existing_results(output_path: Path) -> None:
117
+ """Load and display existing MTEB results."""
118
+ summary_file = output_path / "mteb_summary.json"
119
+ if summary_file.exists():
120
+ with summary_file.open("r") as f:
121
+ summary = json.load(f)
122
+
123
+ logger.info("=" * 80)
124
+ logger.info("EXISTING MTEB EVALUATION RESULTS:")
125
+ logger.info("=" * 80)
126
+
127
+ stats = summary.get("summary_stats")
128
+ if stats:
129
+ logger.info(f"Total Datasets: {stats.get('total_datasets', 'N/A')}")
130
+ logger.info(f"Average Score: {stats.get('average_score', 0):.4f}")
131
+ logger.info(f"Median Score: {stats.get('median_score', 0):.4f}")
132
+
133
+ logger.info("=" * 80)
134
+ else:
135
+ logger.info("No existing summary found. Individual task results may still exist.")
136
+
137
+
138
+ def run_mteb_evaluation() -> None:
139
+ """Run MTEB evaluation using the evaluation package."""
140
+ output_path = Path(OUTPUT_DIR)
141
+ output_path.mkdir(parents=True, exist_ok=True)
142
+
143
+ logger.info(f"Loading model from {MODEL_PATH}")
144
+ model = StaticModel.from_pretrained(MODEL_PATH)
145
+ logger.info("Model loaded successfully")
146
+
147
+ # Set up model metadata for MTEB
148
+ model.mteb_model_meta = ModelMeta( # type: ignore[attr-defined]
149
+ name=MODEL_NAME, revision="distilled", release_date=None, languages=["eng"]
150
+ )
151
+
152
+ # Get specific code-relevant tasks (focused selection)
153
+ logger.info("Getting focused code-relevant MTEB tasks")
154
+ logger.info(f"Selected specific tasks: {CODE_SPECIFIC_TASKS}")
155
+
156
+ if EVAL_ALL_TASKS:
157
+ all_tasks = get_tasks()
158
+ else:
159
+ all_tasks = [mteb.get_task(task_name, languages=["eng"]) for task_name in CODE_SPECIFIC_TASKS]
160
+
161
+ logger.info(f"Found {len(all_tasks)} total tasks")
162
+
163
+ # Check for existing results and filter out completed tasks
164
+ tasks = check_existing_results(output_path, all_tasks)
165
+ logger.info(f"Will evaluate {len(tasks)} remaining tasks")
166
+
167
+ if not tasks:
168
+ logger.info("No new tasks to evaluate - all tasks already completed!")
169
+
170
+ # Load and display existing results
171
+ logger.info("Loading existing results...")
172
+ try:
173
+ load_and_display_existing_results(output_path)
174
+ except (json.JSONDecodeError, OSError, KeyError) as e:
175
+ logger.warning(f"Could not load existing results: {e}")
176
+ return
177
+
178
+ # Define the CustomMTEB object with the specified tasks
179
+ evaluation = CustomMTEB(tasks=tasks)
180
+
181
+ # Run the evaluation
182
+ logger.info("Starting MTEB evaluation...")
183
+ start_time = time.time()
184
+
185
+ results = evaluation.run(model, eval_splits=EVAL_SPLITS, output_folder=str(output_path), verbosity=VERBOSITY)
186
+
187
+ end_time = time.time()
188
+ evaluation_time = end_time - start_time
189
+ logger.info(f"Evaluation completed in {evaluation_time:.2f} seconds")
190
+
191
+ # Parse the results and summarize them
192
+ logger.info("Parsing and summarizing results...")
193
+ parsed_results = parse_mteb_results(mteb_results=results, model_name=MODEL_NAME)
194
+
195
+ # Load existing results if any and combine them
196
+ existing_results = load_existing_parsed_results(output_path)
197
+ if existing_results:
198
+ logger.info("Combining with existing results...")
199
+ # Convert to dict for merging
200
+ parsed_dict = dict(parsed_results) if hasattr(parsed_results, "items") else {}
201
+ # Simple merge - existing results take precedence to avoid overwriting
202
+ for key, value in existing_results.items():
203
+ if key not in parsed_dict:
204
+ parsed_dict[key] = value
205
+ parsed_results = parsed_dict
206
+
207
+ task_scores = summarize_results(parsed_results)
208
+
209
+ # Save results in different formats
210
+ save_results(output_path, results, parsed_results, task_scores, evaluation_time)
211
+
212
+ # Print the results in a leaderboard format
213
+ logger.info("MTEB Evaluation Results:")
214
+ logger.info("=" * 80)
215
+ leaderboard = make_leaderboard(task_scores) # type: ignore[arg-type]
216
+ logger.info(leaderboard.to_string(index=False))
217
+ logger.info("=" * 80)
218
+
219
+ logger.info(f"Evaluation completed successfully. Results saved to {OUTPUT_DIR}")
220
+
221
+
222
+ def save_results(
223
+ output_path: Path, raw_results: list, parsed_results: dict, task_scores: dict, evaluation_time: float
224
+ ) -> None:
225
+ """Save evaluation results in multiple formats."""
226
+ # Save raw results
227
+ raw_results_file = output_path / "mteb_raw_results.json"
228
+ with raw_results_file.open("w") as f:
229
+ json.dump(raw_results, f, indent=2, default=str)
230
+ logger.info(f"Raw results saved to {raw_results_file}")
231
+
232
+ # Save parsed results
233
+ parsed_results_file = output_path / "mteb_parsed_results.json"
234
+ with parsed_results_file.open("w") as f:
235
+ json.dump(parsed_results, f, indent=2, default=str)
236
+ logger.info(f"Parsed results saved to {parsed_results_file}")
237
+
238
+ # Generate summary statistics
239
+ summary_stats = generate_summary_stats(task_scores)
240
+
241
+ # Save task scores summary
242
+ summary = {
243
+ "model_name": MODEL_NAME,
244
+ "evaluation_time_seconds": evaluation_time,
245
+ "task_scores": task_scores,
246
+ "summary_stats": summary_stats,
247
+ }
248
+
249
+ summary_file = output_path / "mteb_summary.json"
250
+ with summary_file.open("w") as f:
251
+ json.dump(summary, f, indent=2, default=str)
252
+ logger.info(f"Summary saved to {summary_file}")
253
+
254
+ # Save human-readable report
255
+ report_file = output_path / "mteb_report.txt"
256
+ generate_report(output_path, task_scores, summary_stats, evaluation_time)
257
+ logger.info(f"Report saved to {report_file}")
258
+
259
+
260
+ def generate_summary_stats(task_scores: dict) -> dict:
261
+ """Generate summary statistics from task scores."""
262
+ if not task_scores:
263
+ return {}
264
+
265
+ # Extract all individual dataset scores
266
+ all_scores = []
267
+ for model_data in task_scores.values():
268
+ if isinstance(model_data, dict) and "dataset_scores" in model_data:
269
+ dataset_scores = model_data["dataset_scores"]
270
+ if isinstance(dataset_scores, dict):
271
+ all_scores.extend(
272
+ [
273
+ float(score)
274
+ for score in dataset_scores.values()
275
+ if isinstance(score, int | float) and str(score).lower() != "nan"
276
+ ]
277
+ )
278
+
279
+ if not all_scores:
280
+ return {}
281
+
282
+ import numpy as np
283
+
284
+ return {
285
+ "total_datasets": len(all_scores),
286
+ "average_score": float(np.mean(all_scores)),
287
+ "median_score": float(np.median(all_scores)),
288
+ "std_dev": float(np.std(all_scores)),
289
+ "min_score": float(np.min(all_scores)),
290
+ "max_score": float(np.max(all_scores)),
291
+ }
292
+
293
+
294
+ def generate_report(output_path: Path, task_scores: dict, summary_stats: dict, evaluation_time: float) -> None:
295
+ """Generate human-readable evaluation report."""
296
+ report_file = output_path / "mteb_report.txt"
297
+
298
+ with report_file.open("w") as f:
299
+ f.write("=" * 80 + "\n")
300
+ f.write("MTEB Evaluation Report\n")
301
+ f.write("=" * 80 + "\n\n")
302
+ f.write(f"Model: {MODEL_NAME}\n")
303
+ f.write(f"Model Path: {MODEL_PATH}\n")
304
+ f.write(f"Evaluation Time: {evaluation_time:.2f} seconds\n")
305
+
306
+ # Write summary stats
307
+ if summary_stats:
308
+ f.write(f"Total Datasets: {summary_stats['total_datasets']}\n\n")
309
+ f.write("Summary Statistics:\n")
310
+ f.write(f" Average Score: {summary_stats['average_score']:.4f}\n")
311
+ f.write(f" Median Score: {summary_stats['median_score']:.4f}\n")
312
+ f.write(f" Standard Deviation: {summary_stats['std_dev']:.4f}\n")
313
+ f.write(f" Score Range: {summary_stats['min_score']:.4f} - {summary_stats['max_score']:.4f}\n\n")
314
+ else:
315
+ f.write("Summary Statistics: No valid results found\n\n")
316
+
317
+ # Write leaderboard
318
+ f.write("Detailed Results:\n")
319
+ f.write("-" * 50 + "\n")
320
+ if task_scores:
321
+ leaderboard = make_leaderboard(task_scores) # type: ignore[arg-type]
322
+ f.write(leaderboard.to_string(index=False))
323
+ else:
324
+ f.write("No results available\n")
325
+
326
+ f.write("\n\n" + "=" * 80 + "\n")
327
+
328
+
329
+ def main() -> None:
330
+ """Main evaluation function."""
331
+ logger.info(f"Starting MTEB evaluation for {MODEL_NAME}")
332
+ logger.info(f"Model path: {MODEL_PATH}")
333
+ logger.info(f"Output directory: {OUTPUT_DIR}")
334
+ logger.info("Running focused MTEB evaluation on code-relevant tasks:")
335
+ logger.info(" - Classification: Programming language classification")
336
+ logger.info(" - Clustering: Code clustering by functionality")
337
+ logger.info(" - STS: Semantic similarity between code snippets")
338
+ logger.info(" - Retrieval: Code search and retrieval")
339
+
340
+ try:
341
+ run_mteb_evaluation()
342
+ logger.info("Evaluation pipeline completed successfully!")
343
+
344
+ except Exception:
345
+ logger.exception("Evaluation failed")
346
+ sys.exit(1)
347
+
348
+
349
+ if __name__ == "__main__":
350
+ main()
README.md CHANGED
@@ -1,3 +1,3 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "model2vec",
3
+ "architectures": [
4
+ "StaticModel"
5
+ ],
6
+ "tokenizer_name": "Alibaba-NLP/gte-Qwen2-7B-instruct",
7
+ "apply_pca": 256,
8
+ "apply_zipf": null,
9
+ "sif_coefficient": 0.0001,
10
+ "hidden_dim": 256,
11
+ "seq_length": 1000000,
12
+ "normalize": true
13
+ }
distill.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Script to distill Alibaba-NLP/gte-Qwen2-7B-instruct using Model2Vec.
4
+
5
+ This script performs the following operations:
6
+ 1. Downloads the Alibaba-NLP/gte-Qwen2-7B-instruct model
7
+ 2. Distills it using Model2Vec to create a smaller, faster static model
8
+ 3. Saves the distilled model for further use
9
+ """
10
+
11
+ import logging
12
+ import shutil
13
+ import time
14
+ from pathlib import Path
15
+
16
+ from model2vec.distill import distill
17
+
18
+ # =============================================================================
19
+ # CONFIGURATION CONSTANTS
20
+ # =============================================================================
21
+
22
+ # Model Configuration
23
+ MODEL_NAME = "Alibaba-NLP/gte-Qwen2-7B-instruct" # Model name or path for the source model
24
+ OUTPUT_DIR = "." # Directory to save the distilled model (current directory)
25
+ PCA_DIMS = 256 # Dimensions for PCA reduction (smaller = faster but less accurate)
26
+
27
+ # Hub Configuration
28
+ SAVE_TO_HUB = False # Whether to push the model to HuggingFace Hub
29
+ HUB_MODEL_ID = None # Model ID for HuggingFace Hub (if saving to hub)
30
+
31
+ # Generation Configuration
32
+ SKIP_README = True # Skip generating the README file
33
+
34
+ # =============================================================================
35
+
36
+ # Configure logging
37
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ def main() -> None:
42
+ """Run the distillation process for Alibaba-NLP/gte-Qwen2-7B-instruct."""
43
+ # Create output directory if it doesn't exist
44
+ output_dir = Path(OUTPUT_DIR)
45
+ output_dir.mkdir(parents=True, exist_ok=True)
46
+
47
+ logger.info(f"Starting distillation of {MODEL_NAME}")
48
+ logger.info(f"Distilled model will be saved to {output_dir}")
49
+ logger.info(f"Using PCA dimensions: {PCA_DIMS}")
50
+ logger.info(f"Skipping README generation: {SKIP_README}")
51
+
52
+ # Record start time for benchmarking
53
+ start_time = time.time()
54
+
55
+ # Run the distillation
56
+ try:
57
+ logger.info("Starting Model2Vec distillation...")
58
+ m2v_model = distill(
59
+ model_name=MODEL_NAME,
60
+ pca_dims=PCA_DIMS,
61
+ )
62
+
63
+ distill_time = time.time() - start_time
64
+ logger.info(f"Distillation completed in {distill_time:.2f} seconds")
65
+
66
+ # Save the distilled model
67
+ m2v_model.save_pretrained(OUTPUT_DIR)
68
+ logger.info(f"Model saved to {OUTPUT_DIR}")
69
+
70
+ # Remove README.md if it was created and we want to skip it
71
+ if SKIP_README and (output_dir / "README.md").exists():
72
+ (output_dir / "README.md").unlink()
73
+ logger.info("Removed auto-generated README.md")
74
+
75
+ # Get model size information
76
+ model_size_mb = sum(
77
+ f.stat().st_size for f in output_dir.glob("**/*") if f.is_file() and f.name != "README.md"
78
+ ) / (1024 * 1024)
79
+ logger.info(f"Distilled model size: {model_size_mb:.2f} MB")
80
+
81
+ # Push to hub if requested
82
+ if SAVE_TO_HUB:
83
+ if HUB_MODEL_ID:
84
+ logger.info(f"Pushing model to HuggingFace Hub as {HUB_MODEL_ID}")
85
+
86
+ # Create a temporary README for Hub upload if needed
87
+ readme_path = output_dir / "README.md"
88
+ had_readme = readme_path.exists()
89
+
90
+ if SKIP_README and had_readme:
91
+ # Backup the README
92
+ shutil.move(readme_path, output_dir / "README.md.bak")
93
+
94
+ # Push to Hub
95
+ m2v_model.push_to_hub(HUB_MODEL_ID)
96
+
97
+ # Restore state
98
+ if SKIP_README:
99
+ if had_readme:
100
+ # Restore the backup
101
+ shutil.move(output_dir / "README.md.bak", readme_path)
102
+ elif (output_dir / "README.md").exists():
103
+ # Remove README created during push_to_hub
104
+ (output_dir / "README.md").unlink()
105
+ else:
106
+ logger.error("HUB_MODEL_ID must be specified when SAVE_TO_HUB is True")
107
+
108
+ logger.info("Distillation process completed successfully!")
109
+
110
+ except Exception:
111
+ logger.exception("Error during distillation")
112
+ raise
113
+
114
+
115
+ if __name__ == "__main__":
116
+ main()
evaluate.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Script to evaluate the performance of the distilled Qodo-Embed model.
4
+
5
+ This script performs the following:
6
+ 1. Loads both the original Qodo-Embed-1-1.5B model and the distilled version
7
+ 2. Compares them on:
8
+ - Embedding similarity
9
+ - Inference speed
10
+ - Memory usage
11
+ 3. Outputs a comprehensive evaluation report
12
+ """
13
+
14
+ import argparse
15
+ import gc
16
+ import logging
17
+ import os
18
+ import time
19
+ from pathlib import Path
20
+ from typing import Any, cast
21
+
22
+ import matplotlib.pyplot as plt
23
+ import numpy as np
24
+ import psutil # type: ignore [import]
25
+ import torch
26
+ from model2vec import StaticModel
27
+ from sentence_transformers import SentenceTransformer
28
+ from sklearn.metrics.pairwise import cosine_similarity # type: ignore [import]
29
+
30
+ # For transformer models
31
+ from transformers import AutoModel, AutoTokenizer
32
+ from transformers.modeling_utils import PreTrainedModel
33
+
34
+ # =============================================================================
35
+ # CONFIGURATION CONSTANTS
36
+ # =============================================================================
37
+
38
+ # Model Configuration
39
+ ORIGINAL_MODEL = "Alibaba-NLP/gte-Qwen2-7B-instruct" # Original model name or path
40
+ DISTILLED_MODEL = "." # Path to the distilled model (current directory)
41
+ OUTPUT_DIR = "evaluation" # Directory to save evaluation results
42
+
43
+ # =============================================================================
44
+
45
+ # Constants
46
+ BYTES_PER_KB = 1024.0
47
+ TEXT_TRUNCATE_LENGTH = 20
48
+
49
+ # Configure logging
50
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
51
+ logger = logging.getLogger(__name__)
52
+
53
+ # Sample texts for evaluation
54
+ SAMPLE_TEXTS = [
55
+ "def process_data_stream(source_iterator):",
56
+ "implement binary search tree",
57
+ "how to handle memory efficient data streaming",
58
+ """class LazyLoader:
59
+ def __init__(self, source):
60
+ self.generator = iter(source)
61
+ self._cache = []""",
62
+ """def dfs_traversal(root):
63
+ if not root:
64
+ return []
65
+ visited = []
66
+ stack = [root]
67
+ while stack:
68
+ node = stack.pop()
69
+ visited.append(node.val)
70
+ if node.right:
71
+ stack.append(node.right)
72
+ if node.left:
73
+ stack.append(node.left)
74
+ return visited""",
75
+ ]
76
+
77
+
78
+ def load_models(
79
+ original_model_name: str, distilled_model_path: str
80
+ ) -> tuple[tuple[SentenceTransformer | PreTrainedModel, str], StaticModel]:
81
+ """Load both the original and distilled models."""
82
+ logger.info(f"Loading original model: {original_model_name}")
83
+
84
+ try:
85
+ # Try to load as a sentence transformer first
86
+ original_model = SentenceTransformer(original_model_name)
87
+ model_type = "sentence_transformer"
88
+ except (ValueError, OSError, ImportError) as e:
89
+ # If that fails, try loading as a Hugging Face transformer
90
+ logger.info(f"Failed to load as SentenceTransformer: {e}")
91
+ AutoTokenizer.from_pretrained(original_model_name)
92
+ original_model = AutoModel.from_pretrained(original_model_name)
93
+ model_type = "huggingface"
94
+
95
+ logger.info(f"Loading distilled model from: {distilled_model_path}")
96
+ distilled_model = StaticModel.from_pretrained(distilled_model_path)
97
+
98
+ return (original_model, model_type), distilled_model
99
+
100
+
101
+ def measure_memory_usage(model: SentenceTransformer | PreTrainedModel | StaticModel) -> float:
102
+ """Measure memory usage of a model in MB."""
103
+ gc.collect()
104
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
105
+
106
+ process = psutil.Process(os.getpid())
107
+ memory_before = process.memory_info().rss / (1024 * 1024) # MB
108
+
109
+ # Force model to allocate memory if it hasn't already
110
+ if isinstance(model, StaticModel | SentenceTransformer):
111
+ _ = model.encode(["Test"])
112
+ else:
113
+ # For HF models, we need to handle differently
114
+ pass
115
+
116
+ gc.collect()
117
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
118
+
119
+ process = psutil.Process(os.getpid())
120
+ memory_after = process.memory_info().rss / (1024 * 1024) # MB
121
+
122
+ return memory_after - memory_before
123
+
124
+
125
+ def compute_embeddings(
126
+ original_model: SentenceTransformer | PreTrainedModel,
127
+ original_model_type: str,
128
+ distilled_model: StaticModel,
129
+ texts: list[str],
130
+ original_model_name: str = "unknown",
131
+ ) -> tuple[np.ndarray, np.ndarray]:
132
+ """Compute embeddings using both models."""
133
+ # Original model embeddings
134
+ if original_model_type == "sentence_transformer":
135
+ # Type narrowing: we know it's a SentenceTransformer here
136
+ sentence_model = cast("SentenceTransformer", original_model)
137
+ original_embeddings = sentence_model.encode(texts)
138
+ else:
139
+ # Type narrowing: we know it's a PreTrainedModel here
140
+ auto_model = original_model # AutoModel.from_pretrained returns a PreTrainedModel instance
141
+
142
+ # For HF models, we need more custom code
143
+ # Simple mean pooling function for HF models
144
+ def mean_pooling(model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
145
+ token_embeddings = model_output
146
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
147
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
148
+ input_mask_expanded.sum(1), min=1e-9
149
+ )
150
+
151
+ # Get model name for tokenizer
152
+ model_name = getattr(auto_model.config, "name_or_path", original_model_name)
153
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
154
+ encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
155
+
156
+ with torch.no_grad():
157
+ model_output = auto_model(**encoded_input)
158
+ original_embeddings = mean_pooling(model_output.last_hidden_state, encoded_input["attention_mask"]).numpy()
159
+
160
+ # Distilled model embeddings
161
+ distilled_embeddings = distilled_model.encode(texts)
162
+
163
+ return original_embeddings, distilled_embeddings
164
+
165
+
166
+ def measure_inference_speed(
167
+ model: SentenceTransformer | PreTrainedModel | StaticModel, model_type: str, texts: list[str], n_runs: int = 5
168
+ ) -> float:
169
+ """Measure inference speed in texts/second."""
170
+ # Warmup
171
+ if model_type in {"sentence_transformer", "static_model"}:
172
+ # Type narrowing: we know it has encode method here
173
+ encode_model = cast("SentenceTransformer | StaticModel", model)
174
+ _ = encode_model.encode(texts[:1])
175
+ else:
176
+ # Type narrowing: we know it's a PreTrainedModel here
177
+ auto_model = cast("PreTrainedModel", model)
178
+ # Warmup for HF models
179
+ model_name = getattr(auto_model.config, "name_or_path", "unknown")
180
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
181
+ encoded_input = tokenizer(texts[:1], padding=True, truncation=True, return_tensors="pt")
182
+ with torch.no_grad():
183
+ _ = auto_model(**encoded_input)
184
+
185
+ # Measure speed
186
+ start_time = time.time()
187
+
188
+ if model_type in {"sentence_transformer", "static_model"}:
189
+ # Type narrowing: we know it has encode method here
190
+ encode_model = cast("SentenceTransformer | StaticModel", model)
191
+ for _ in range(n_runs):
192
+ _ = encode_model.encode(texts)
193
+ else:
194
+ # Type narrowing: we know it's a PreTrainedModel here
195
+ auto_model = cast("PreTrainedModel", model)
196
+ # For HF models
197
+ model_name = getattr(auto_model.config, "name_or_path", "unknown")
198
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
199
+ for _ in range(n_runs):
200
+ encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
201
+ with torch.no_grad():
202
+ _ = auto_model(**encoded_input)
203
+
204
+ total_time = time.time() - start_time
205
+ return (len(texts) * n_runs) / total_time
206
+
207
+
208
+ def compute_cosine_similarity(embeddings1: np.ndarray, embeddings2: np.ndarray) -> np.ndarray:
209
+ """Compute cosine similarity between embeddings, handling different dimensions.
210
+
211
+ For embeddings with different dimensions, we compute similarity by comparing
212
+ how they rank the same texts (semantically equivalent).
213
+ """
214
+ # Ensure embeddings1 and embeddings2 are 2D arrays with shapes (n_samples, n_features)
215
+ if embeddings1.ndim == 1:
216
+ embeddings1 = embeddings1.reshape(1, -1)
217
+ if embeddings2.ndim == 1:
218
+ embeddings2 = embeddings2.reshape(1, -1)
219
+
220
+ # Check and transpose if needed to ensure samples are in rows
221
+ if embeddings2.shape[0] != len(SAMPLE_TEXTS) and embeddings2.shape[1] == len(SAMPLE_TEXTS):
222
+ embeddings2 = embeddings2.T
223
+
224
+ logger.info(f"Embeddings shapes: original={embeddings1.shape}, distilled={embeddings2.shape}")
225
+
226
+ # If dimensions differ, we compute similarity matrix based on how each model ranks text pairs
227
+ # This is a form of semantic similarity evaluation rather than direct vector comparison
228
+ similarity_matrix = np.zeros((len(SAMPLE_TEXTS), len(SAMPLE_TEXTS)))
229
+
230
+ # Compute similarity matrices within each embedding space
231
+ sim1 = cosine_similarity(embeddings1)
232
+ sim2 = cosine_similarity(embeddings2)
233
+
234
+ # The similarity between samples i and j is the correlation between how they rank other samples
235
+ for i in range(len(SAMPLE_TEXTS)):
236
+ for j in range(len(SAMPLE_TEXTS)):
237
+ # For diagonal elements (same sample), use a direct measure of how similar
238
+ # the two models rank that sample against all others
239
+ if i == j:
240
+ # Pearson correlation between the rankings (excluding self-comparison)
241
+ rankings1 = np.delete(sim1[i], i)
242
+ rankings2 = np.delete(sim2[i], i)
243
+ # Higher correlation means the models agree on the semantic similarity
244
+ similarity_matrix[i, j] = np.corrcoef(rankings1, rankings2)[0, 1]
245
+ else:
246
+ # For off-diagonal elements, compare how similarly both models relate samples i and j
247
+ similarity_matrix[i, j] = 1 - abs(sim1[i, j] - sim2[i, j])
248
+
249
+ return similarity_matrix
250
+
251
+
252
+ def format_size(size_bytes: float) -> str:
253
+ """Format size in bytes to human-readable format."""
254
+ for unit in ["B", "KB", "MB", "GB"]:
255
+ if size_bytes < BYTES_PER_KB:
256
+ return f"{size_bytes:.2f} {unit}"
257
+ size_bytes /= BYTES_PER_KB
258
+ return f"{size_bytes:.2f} TB"
259
+
260
+
261
+ def plot_comparison(results: dict[str, Any], output_dir: str) -> None:
262
+ """Generate comparison plots and save them."""
263
+ output_path = Path(output_dir)
264
+ output_path.mkdir(exist_ok=True, parents=True)
265
+
266
+ # Speed comparison
267
+ plt.figure(figsize=(10, 6))
268
+ models = ["Original", "Distilled"]
269
+ speeds = [results["original_speed"], results["distilled_speed"]]
270
+ plt.bar(models, speeds, color=["#1f77b4", "#ff7f0e"])
271
+ plt.ylabel("Texts per second")
272
+ plt.title("Inference Speed Comparison")
273
+ plt.savefig(output_path / "speed_comparison.png", dpi=300, bbox_inches="tight")
274
+
275
+ # Memory comparison
276
+ plt.figure(figsize=(10, 6))
277
+ memories = [results["original_memory"], results["distilled_memory"]]
278
+ plt.bar(models, memories, color=["#1f77b4", "#ff7f0e"])
279
+ plt.ylabel("Memory Usage (MB)")
280
+ plt.title("Memory Usage Comparison")
281
+ plt.savefig(output_path / "memory_comparison.png", dpi=300, bbox_inches="tight")
282
+
283
+ # Size comparison
284
+ plt.figure(figsize=(10, 6))
285
+ sizes = [results["original_size"], results["distilled_size"]]
286
+ plt.bar(models, sizes, color=["#1f77b4", "#ff7f0e"])
287
+ plt.ylabel("Model Size (MB)")
288
+ plt.title("Model Size Comparison")
289
+ plt.savefig(output_path / "size_comparison.png", dpi=300, bbox_inches="tight")
290
+
291
+ # Similarity matrix heatmap
292
+ plt.figure(figsize=(8, 6))
293
+ plt.imshow(results["similarity_matrix"], cmap="viridis", interpolation="nearest")
294
+ plt.colorbar(label="Cosine Similarity")
295
+ plt.title("Embedding Similarity Between Original and Distilled Models")
296
+ plt.xticks([])
297
+ plt.yticks(
298
+ range(len(SAMPLE_TEXTS)),
299
+ [t[:TEXT_TRUNCATE_LENGTH] + "..." if len(t) > TEXT_TRUNCATE_LENGTH else t for t in SAMPLE_TEXTS],
300
+ )
301
+ plt.savefig(output_path / "similarity_matrix.png", dpi=300, bbox_inches="tight")
302
+
303
+
304
+ def evaluate_models(original_model_name: str, distilled_model_path: str, output_dir: str) -> dict[str, Any]:
305
+ """Evaluate the original and distilled models."""
306
+ # Load models
307
+ (original_model, original_model_type), distilled_model = load_models(original_model_name, distilled_model_path)
308
+
309
+ # Measure model sizes
310
+ if isinstance(original_model, SentenceTransformer):
311
+ # For SentenceTransformer, get parameters from all modules
312
+ total_params = 0
313
+ for module in original_model.modules():
314
+ if hasattr(module, "parameters"):
315
+ for param in module.parameters():
316
+ total_params += param.numel()
317
+ original_model_size = total_params * 4 / (1024 * 1024) # MB (assuming float32)
318
+ else:
319
+ # For PreTrainedModel
320
+ auto_model = original_model # AutoModel.from_pretrained returns a PreTrainedModel instance
321
+ original_model_size = sum(p.numel() * 4 for p in auto_model.parameters()) / (
322
+ 1024 * 1024
323
+ ) # MB (assuming float32)
324
+
325
+ # Calculate distilled model size - only count actual model files
326
+ model_files = ["model.safetensors", "config.json", "modules.json", "tokenizer.json"]
327
+ distilled_model_size = 0.0
328
+ for file_name in model_files:
329
+ file_path = Path(distilled_model_path) / file_name
330
+ if file_path.exists():
331
+ distilled_model_size += file_path.stat().st_size
332
+ distilled_model_size = distilled_model_size / (1024 * 1024) # Convert to MB
333
+
334
+ # Measure memory usage
335
+ original_memory = measure_memory_usage(original_model)
336
+ distilled_memory = measure_memory_usage(distilled_model)
337
+
338
+ # Compute embeddings
339
+ original_embeddings, distilled_embeddings = compute_embeddings(
340
+ original_model, original_model_type, distilled_model, SAMPLE_TEXTS, original_model_name
341
+ )
342
+
343
+ # Compute similarity between embeddings
344
+ similarity_matrix = compute_cosine_similarity(original_embeddings, distilled_embeddings)
345
+ similarity_diagonal = np.diag(similarity_matrix)
346
+ avg_similarity = np.mean(similarity_diagonal)
347
+
348
+ # Measure inference speed
349
+ original_speed = measure_inference_speed(original_model, original_model_type, SAMPLE_TEXTS, n_runs=5)
350
+ distilled_speed = measure_inference_speed(distilled_model, "static_model", SAMPLE_TEXTS, n_runs=5)
351
+
352
+ # Collect results
353
+ results = {
354
+ "original_size": original_model_size,
355
+ "distilled_size": distilled_model_size,
356
+ "original_memory": original_memory,
357
+ "distilled_memory": distilled_memory,
358
+ "similarity_matrix": similarity_matrix,
359
+ "avg_similarity": avg_similarity,
360
+ "original_speed": original_speed,
361
+ "distilled_speed": distilled_speed,
362
+ "speed_improvement": distilled_speed / original_speed if original_speed > 0 else float("inf"),
363
+ "size_reduction": original_model_size / distilled_model_size if distilled_model_size > 0 else float("inf"),
364
+ "memory_reduction": original_memory / distilled_memory if distilled_memory > 0 else float("inf"),
365
+ }
366
+
367
+ # Generate plots
368
+ plot_comparison(results, output_dir)
369
+
370
+ # Print results
371
+ separator = "=" * 50
372
+ logger.info("\n%s", separator)
373
+ logger.info("Model Evaluation Results")
374
+ logger.info("%s", separator)
375
+ logger.info(f"Original Model Size: {results['original_size']:.2f} MB")
376
+ logger.info(f"Distilled Model Size: {results['distilled_size']:.2f} MB")
377
+ logger.info(f"Size Reduction Factor: {results['size_reduction']:.2f}x")
378
+ logger.info("\n")
379
+ logger.info(f"Original Model Memory: {results['original_memory']:.2f} MB")
380
+ logger.info(f"Distilled Model Memory: {results['distilled_memory']:.2f} MB")
381
+ logger.info(f"Memory Reduction Factor: {results['memory_reduction']:.2f}x")
382
+ logger.info("\n")
383
+ logger.info(f"Original Model Speed: {results['original_speed']:.2f} texts/second")
384
+ logger.info(f"Distilled Model Speed: {results['distilled_speed']:.2f} texts/second")
385
+ logger.info(f"Speed Improvement Factor: {results['speed_improvement']:.2f}x")
386
+ logger.info("\n")
387
+ logger.info(f"Average Embedding Similarity: {results['avg_similarity']:.4f}")
388
+ logger.info("%s", separator)
389
+
390
+ return results
391
+
392
+
393
+ def main() -> None:
394
+ """Run the evaluation process."""
395
+ parser = argparse.ArgumentParser(description="Evaluate the distilled model against the original")
396
+ parser.add_argument("--original_model", default=ORIGINAL_MODEL, help="Original model name or path")
397
+ parser.add_argument("--distilled_model", default=DISTILLED_MODEL, help="Path to the distilled model")
398
+ parser.add_argument("--output_dir", default=OUTPUT_DIR, help="Directory to save evaluation results")
399
+
400
+ args = parser.parse_args()
401
+
402
+ # Validate configuration
403
+ if not args.distilled_model:
404
+ logger.error("Distilled model path must be provided")
405
+ logger.error("Use --distilled_model to specify the path or set DISTILLED_MODEL constant")
406
+ return
407
+
408
+ # Create output directory
409
+ output_dir = Path(args.output_dir)
410
+ output_dir.mkdir(parents=True, exist_ok=True)
411
+
412
+ # Run evaluation
413
+ try:
414
+ evaluate_models(args.original_model, args.distilled_model, args.output_dir)
415
+ logger.info(f"Evaluation completed. Results saved to {args.output_dir}")
416
+ except Exception:
417
+ logger.exception("Error during evaluation")
418
+ raise
419
+
420
+
421
+ if __name__ == "__main__":
422
+ main()
evaluation/memory_comparison.png ADDED

Git LFS Details

  • SHA256: 3e1258d250bc8fb9a5f8deb0a5b91eea2d7de8aa99b8c1353383e30a203874e4
  • Pointer size: 130 Bytes
  • Size of remote file: 72.3 kB
evaluation/similarity_matrix.png ADDED

Git LFS Details

  • SHA256: c9e597292d3f120e347b0382d80a7cf5da00cce234167276bc6acc219602a7ef
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
evaluation/size_comparison.png ADDED

Git LFS Details

  • SHA256: 0e8356904ec54d4837361f88af9dfe7cb2161acbb3bb9b08ac1d11e718804d77
  • Pointer size: 130 Bytes
  • Size of remote file: 77.2 kB
evaluation/speed_comparison.png ADDED

Git LFS Details

  • SHA256: e8edf611a573a36223262ccfdcfa6251e14e1d1982af70144d21560684d69cba
  • Pointer size: 130 Bytes
  • Size of remote file: 80 kB
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dc253eb31caa17834057d92999b03869cda542d5f70ecca3d4a5f03b3563b3f
3
+ size 155283544
modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": ".",
6
+ "type": "sentence_transformers.models.StaticEmbedding"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Normalize",
12
+ "type": "sentence_transformers.models.Normalize"
13
+ }
14
+ ]
mteb_results/gte-Qwen2-7B-instruct-M2V-Distilled/distilled/AmazonCounterfactualClassification.json ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "e8379541af4e31359cca9fbcf4b00f2671dba205",
3
+ "evaluation_time": 8.737873554229736,
4
+ "kg_co2_emissions": null,
5
+ "mteb_version": "1.14.15",
6
+ "scores": {
7
+ "test": [
8
+ {
9
+ "accuracy": 0.5690404797601201,
10
+ "ap": 0.13918928297805203,
11
+ "ap_weighted": 0.13918928297805203,
12
+ "f1": 0.47354721284407075,
13
+ "f1_weighted": 0.6484719142466673,
14
+ "hf_subset": "en-ext",
15
+ "languages": [
16
+ "eng-Latn"
17
+ ],
18
+ "main_score": 0.5690404797601201,
19
+ "scores_per_experiment": [
20
+ {
21
+ "accuracy": 0.6041979010494752,
22
+ "ap": 0.1403174548244783,
23
+ "ap_weighted": 0.1403174548244783,
24
+ "f1": 0.49333755341034974,
25
+ "f1_weighted": 0.6809473724919468
26
+ },
27
+ {
28
+ "accuracy": 0.6319340329835083,
29
+ "ap": 0.15145229144021116,
30
+ "ap_weighted": 0.15145229144021116,
31
+ "f1": 0.5150100219839455,
32
+ "f1_weighted": 0.7035165191069046
33
+ },
34
+ {
35
+ "accuracy": 0.6124437781109445,
36
+ "ap": 0.15370328203750555,
37
+ "ap_weighted": 0.15370328203750555,
38
+ "f1": 0.5069738581294719,
39
+ "f1_weighted": 0.6874863954073245
40
+ },
41
+ {
42
+ "accuracy": 0.5712143928035982,
43
+ "ap": 0.13285280504159222,
44
+ "ap_weighted": 0.13285280504159222,
45
+ "f1": 0.471264367816092,
46
+ "f1_weighted": 0.6532423443450689
47
+ },
48
+ {
49
+ "accuracy": 0.6101949025487257,
50
+ "ap": 0.1382528418572316,
51
+ "ap_weighted": 0.1382528418572316,
52
+ "f1": 0.49459093982420554,
53
+ "f1_weighted": 0.6859354298509973
54
+ },
55
+ {
56
+ "accuracy": 0.5194902548725637,
57
+ "ap": 0.12777013417285304,
58
+ "ap_weighted": 0.12777013417285304,
59
+ "f1": 0.4405866978944166,
60
+ "f1_weighted": 0.6068983868543434
61
+ },
62
+ {
63
+ "accuracy": 0.43778110944527737,
64
+ "ap": 0.11875450153550213,
65
+ "ap_weighted": 0.11875450153550213,
66
+ "f1": 0.3875609684433214,
67
+ "f1_weighted": 0.5263894210560583
68
+ },
69
+ {
70
+ "accuracy": 0.6176911544227887,
71
+ "ap": 0.14128018744097307,
72
+ "ap_weighted": 0.14128018744097307,
73
+ "f1": 0.5005725863284003,
74
+ "f1_weighted": 0.6920233725631899
75
+ },
76
+ {
77
+ "accuracy": 0.49025487256371814,
78
+ "ap": 0.14466447368521512,
79
+ "ap_weighted": 0.14466447368521512,
80
+ "f1": 0.43475703375805064,
81
+ "f1_weighted": 0.5749621002144737
82
+ },
83
+ {
84
+ "accuracy": 0.5952023988005997,
85
+ "ap": 0.14284485774495798,
86
+ "ap_weighted": 0.14284485774495798,
87
+ "f1": 0.4908181008524535,
88
+ "f1_weighted": 0.6733178005763648
89
+ }
90
+ ]
91
+ },
92
+ {
93
+ "accuracy": 0.62,
94
+ "ap": 0.26415963699316264,
95
+ "ap_weighted": 0.26415963699316264,
96
+ "f1": 0.5644640290850564,
97
+ "f1_weighted": 0.6579491434972964,
98
+ "hf_subset": "en",
99
+ "languages": [
100
+ "eng-Latn"
101
+ ],
102
+ "main_score": 0.62,
103
+ "scores_per_experiment": [
104
+ {
105
+ "accuracy": 0.5955223880597015,
106
+ "ap": 0.25283011702254965,
107
+ "ap_weighted": 0.25283011702254965,
108
+ "f1": 0.5461419440632507,
109
+ "f1_weighted": 0.637305840672083
110
+ },
111
+ {
112
+ "accuracy": 0.5835820895522388,
113
+ "ap": 0.26075921450734807,
114
+ "ap_weighted": 0.26075921450734807,
115
+ "f1": 0.5441743417924314,
116
+ "f1_weighted": 0.6257903879142659
117
+ },
118
+ {
119
+ "accuracy": 0.6029850746268657,
120
+ "ap": 0.24791359505097144,
121
+ "ap_weighted": 0.24791359505097144,
122
+ "f1": 0.5467492700989818,
123
+ "f1_weighted": 0.643970491486171
124
+ },
125
+ {
126
+ "accuracy": 0.6029850746268657,
127
+ "ap": 0.26571020411740476,
128
+ "ap_weighted": 0.26571020411740476,
129
+ "f1": 0.5578808446455505,
130
+ "f1_weighted": 0.6438739560117962
131
+ },
132
+ {
133
+ "accuracy": 0.6761194029850747,
134
+ "ap": 0.2747168088064611,
135
+ "ap_weighted": 0.2747168088064611,
136
+ "f1": 0.5995328480020714,
137
+ "f1_weighted": 0.7061789723470043
138
+ },
139
+ {
140
+ "accuracy": 0.6537313432835821,
141
+ "ap": 0.2721721103504757,
142
+ "ap_weighted": 0.2721721103504757,
143
+ "f1": 0.587775408670931,
144
+ "f1_weighted": 0.6881859359916834
145
+ },
146
+ {
147
+ "accuracy": 0.6597014925373135,
148
+ "ap": 0.26444025941241256,
149
+ "ap_weighted": 0.26444025941241256,
150
+ "f1": 0.5851663570893213,
151
+ "f1_weighted": 0.692245002380803
152
+ },
153
+ {
154
+ "accuracy": 0.6149253731343284,
155
+ "ap": 0.23113683661630094,
156
+ "ap_weighted": 0.23113683661630094,
157
+ "f1": 0.538936721825689,
158
+ "f1_weighted": 0.6529196987886481
159
+ },
160
+ {
161
+ "accuracy": 0.5791044776119403,
162
+ "ap": 0.2786207978292612,
163
+ "ap_weighted": 0.2786207978292612,
164
+ "f1": 0.5501428571428572,
165
+ "f1_weighted": 0.6196507462686567
166
+ },
167
+ {
168
+ "accuracy": 0.6313432835820896,
169
+ "ap": 0.29329642621844076,
170
+ "ap_weighted": 0.29329642621844076,
171
+ "f1": 0.5881396975194806,
172
+ "f1_weighted": 0.6693704031118514
173
+ }
174
+ ]
175
+ }
176
+ ]
177
+ },
178
+ "task_name": "AmazonCounterfactualClassification"
179
+ }
mteb_results/gte-Qwen2-7B-instruct-M2V-Distilled/distilled/Banking77Classification.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "0fd18e25b25c072e09e0d92ab615fda904d66300",
3
+ "evaluation_time": 6.451777696609497,
4
+ "kg_co2_emissions": null,
5
+ "mteb_version": "1.14.15",
6
+ "scores": {
7
+ "test": [
8
+ {
9
+ "accuracy": 0.4396103896103896,
10
+ "f1": 0.4142711532114576,
11
+ "f1_weighted": 0.4142711532114576,
12
+ "hf_subset": "default",
13
+ "languages": [
14
+ "eng-Latn"
15
+ ],
16
+ "main_score": 0.4396103896103896,
17
+ "scores_per_experiment": [
18
+ {
19
+ "accuracy": 0.4279220779220779,
20
+ "f1": 0.4030476288783657,
21
+ "f1_weighted": 0.4030476288783656
22
+ },
23
+ {
24
+ "accuracy": 0.4211038961038961,
25
+ "f1": 0.39776168133611584,
26
+ "f1_weighted": 0.39776168133611584
27
+ },
28
+ {
29
+ "accuracy": 0.45064935064935063,
30
+ "f1": 0.42872843564828145,
31
+ "f1_weighted": 0.42872843564828145
32
+ },
33
+ {
34
+ "accuracy": 0.4448051948051948,
35
+ "f1": 0.420756828398419,
36
+ "f1_weighted": 0.42075682839841905
37
+ },
38
+ {
39
+ "accuracy": 0.44675324675324674,
40
+ "f1": 0.42100682221185654,
41
+ "f1_weighted": 0.42100682221185654
42
+ },
43
+ {
44
+ "accuracy": 0.45324675324675323,
45
+ "f1": 0.4392342490231314,
46
+ "f1_weighted": 0.4392342490231314
47
+ },
48
+ {
49
+ "accuracy": 0.437012987012987,
50
+ "f1": 0.4056017558988273,
51
+ "f1_weighted": 0.40560175589882724
52
+ },
53
+ {
54
+ "accuracy": 0.42337662337662335,
55
+ "f1": 0.39123709562594644,
56
+ "f1_weighted": 0.39123709562594655
57
+ },
58
+ {
59
+ "accuracy": 0.44512987012987015,
60
+ "f1": 0.41578171494860966,
61
+ "f1_weighted": 0.41578171494860966
62
+ },
63
+ {
64
+ "accuracy": 0.4461038961038961,
65
+ "f1": 0.4195553201450221,
66
+ "f1_weighted": 0.419555320145022
67
+ }
68
+ ]
69
+ }
70
+ ]
71
+ },
72
+ "task_name": "Banking77Classification"
73
+ }
mteb_results/gte-Qwen2-7B-instruct-M2V-Distilled/distilled/CQADupstackProgrammersRetrieval.json ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "6184bc1440d2dbc7612be22b50686b8826d22b32",
3
+ "evaluation_time": 99.69791841506958,
4
+ "kg_co2_emissions": null,
5
+ "mteb_version": "1.14.15",
6
+ "scores": {
7
+ "test": [
8
+ {
9
+ "hf_subset": "default",
10
+ "languages": [
11
+ "eng-Latn"
12
+ ],
13
+ "main_score": 0.0501,
14
+ "map_at_1": 0.02467,
15
+ "map_at_10": 0.03898,
16
+ "map_at_100": 0.04261,
17
+ "map_at_1000": 0.04333,
18
+ "map_at_20": 0.04068,
19
+ "map_at_3": 0.03388,
20
+ "map_at_5": 0.03693,
21
+ "mrr_at_1": 0.030821917808219176,
22
+ "mrr_at_10": 0.04904462926723201,
23
+ "mrr_at_100": 0.05339942610218758,
24
+ "mrr_at_1000": 0.05413492750157237,
25
+ "mrr_at_20": 0.05126402659708249,
26
+ "mrr_at_3": 0.04280821917808219,
27
+ "mrr_at_5": 0.04634703196347032,
28
+ "nauc_map_at_1000_diff1": 0.03644747951501248,
29
+ "nauc_map_at_1000_max": 0.2240572170754659,
30
+ "nauc_map_at_1000_std": -0.17708810912472517,
31
+ "nauc_map_at_100_diff1": 0.03759221625144172,
32
+ "nauc_map_at_100_max": 0.22324901446317413,
33
+ "nauc_map_at_100_std": -0.17630470695891512,
34
+ "nauc_map_at_10_diff1": 0.03906418656483989,
35
+ "nauc_map_at_10_max": 0.22061594321968936,
36
+ "nauc_map_at_10_std": -0.17777470317814356,
37
+ "nauc_map_at_1_diff1": 0.1731091343679673,
38
+ "nauc_map_at_1_max": 0.33459947679728974,
39
+ "nauc_map_at_1_std": -0.23115450977179597,
40
+ "nauc_map_at_20_diff1": 0.03795725531499195,
41
+ "nauc_map_at_20_max": 0.22396003211648763,
42
+ "nauc_map_at_20_std": -0.17867373725662639,
43
+ "nauc_map_at_3_diff1": 0.06042780588964212,
44
+ "nauc_map_at_3_max": 0.2486807528974488,
45
+ "nauc_map_at_3_std": -0.18512855007450404,
46
+ "nauc_map_at_5_diff1": 0.04407217741234605,
47
+ "nauc_map_at_5_max": 0.22647048266105405,
48
+ "nauc_map_at_5_std": -0.18107585673560017,
49
+ "nauc_mrr_at_1000_diff1": 0.033601872249839834,
50
+ "nauc_mrr_at_1000_max": 0.2523936325136619,
51
+ "nauc_mrr_at_1000_std": -0.19078164353963076,
52
+ "nauc_mrr_at_100_diff1": 0.03435870935950355,
53
+ "nauc_mrr_at_100_max": 0.2523932973431928,
54
+ "nauc_mrr_at_100_std": -0.1900913512193067,
55
+ "nauc_mrr_at_10_diff1": 0.03361519179733555,
56
+ "nauc_mrr_at_10_max": 0.25392922716866984,
57
+ "nauc_mrr_at_10_std": -0.1935061134919541,
58
+ "nauc_mrr_at_1_diff1": 0.1772995319079407,
59
+ "nauc_mrr_at_1_max": 0.35182174117717013,
60
+ "nauc_mrr_at_1_std": -0.24426280067522707,
61
+ "nauc_mrr_at_20_diff1": 0.03479828151019169,
62
+ "nauc_mrr_at_20_max": 0.25624951214228564,
63
+ "nauc_mrr_at_20_std": -0.19212268093923462,
64
+ "nauc_mrr_at_3_diff1": 0.06173430027850725,
65
+ "nauc_mrr_at_3_max": 0.26889485727748363,
66
+ "nauc_mrr_at_3_std": -0.19153801111553947,
67
+ "nauc_mrr_at_5_diff1": 0.036743759763164886,
68
+ "nauc_mrr_at_5_max": 0.253857849052297,
69
+ "nauc_mrr_at_5_std": -0.19604549670316734,
70
+ "nauc_ndcg_at_1000_diff1": -0.010372586628261796,
71
+ "nauc_ndcg_at_1000_max": 0.20925878430027478,
72
+ "nauc_ndcg_at_1000_std": -0.1717044268161809,
73
+ "nauc_ndcg_at_100_diff1": 0.0023309149151885546,
74
+ "nauc_ndcg_at_100_max": 0.20125970115134734,
75
+ "nauc_ndcg_at_100_std": -0.15865628929382014,
76
+ "nauc_ndcg_at_10_diff1": 0.0026192804576363727,
77
+ "nauc_ndcg_at_10_max": 0.19884193622357532,
78
+ "nauc_ndcg_at_10_std": -0.16919003671988075,
79
+ "nauc_ndcg_at_1_diff1": 0.1772995319079407,
80
+ "nauc_ndcg_at_1_max": 0.35182174117717013,
81
+ "nauc_ndcg_at_1_std": -0.24426280067522707,
82
+ "nauc_ndcg_at_20_diff1": 0.0031543394811079034,
83
+ "nauc_ndcg_at_20_max": 0.20925361343315524,
84
+ "nauc_ndcg_at_20_std": -0.17106125631597793,
85
+ "nauc_ndcg_at_3_diff1": 0.03670154146101528,
86
+ "nauc_ndcg_at_3_max": 0.23212930749840155,
87
+ "nauc_ndcg_at_3_std": -0.1728371812831961,
88
+ "nauc_ndcg_at_5_diff1": 0.0107566708693031,
89
+ "nauc_ndcg_at_5_max": 0.20474332948099355,
90
+ "nauc_ndcg_at_5_std": -0.1734952739301359,
91
+ "nauc_precision_at_1000_diff1": -0.07195606207962846,
92
+ "nauc_precision_at_1000_max": 0.2542912736794115,
93
+ "nauc_precision_at_1000_std": -0.1881459402790264,
94
+ "nauc_precision_at_100_diff1": -0.04518222914182943,
95
+ "nauc_precision_at_100_max": 0.22138981394024387,
96
+ "nauc_precision_at_100_std": -0.13384472263037697,
97
+ "nauc_precision_at_10_diff1": -0.052513811685878764,
98
+ "nauc_precision_at_10_max": 0.18962064467698705,
99
+ "nauc_precision_at_10_std": -0.14827004787357115,
100
+ "nauc_precision_at_1_diff1": 0.1772995319079407,
101
+ "nauc_precision_at_1_max": 0.35182174117717013,
102
+ "nauc_precision_at_1_std": -0.24426280067522707,
103
+ "nauc_precision_at_20_diff1": -0.040789324913047875,
104
+ "nauc_precision_at_20_max": 0.22086458009752882,
105
+ "nauc_precision_at_20_std": -0.14430508663959002,
106
+ "nauc_precision_at_3_diff1": -0.013044619440245884,
107
+ "nauc_precision_at_3_max": 0.21285488271783465,
108
+ "nauc_precision_at_3_std": -0.1483164417030193,
109
+ "nauc_precision_at_5_diff1": -0.05113181393685194,
110
+ "nauc_precision_at_5_max": 0.1756649379589832,
111
+ "nauc_precision_at_5_std": -0.15632134056178232,
112
+ "nauc_recall_at_1000_diff1": -0.047075752528689695,
113
+ "nauc_recall_at_1000_max": 0.16414155669676642,
114
+ "nauc_recall_at_1000_std": -0.1513320281746568,
115
+ "nauc_recall_at_100_diff1": -0.023004658252697183,
116
+ "nauc_recall_at_100_max": 0.14861973646512244,
117
+ "nauc_recall_at_100_std": -0.12240747671934184,
118
+ "nauc_recall_at_10_diff1": -0.051375323084735164,
119
+ "nauc_recall_at_10_max": 0.1384336247044034,
120
+ "nauc_recall_at_10_std": -0.14737738059263306,
121
+ "nauc_recall_at_1_diff1": 0.1731091343679673,
122
+ "nauc_recall_at_1_max": 0.33459947679728974,
123
+ "nauc_recall_at_1_std": -0.23115450977179597,
124
+ "nauc_recall_at_20_diff1": -0.03578815918976938,
125
+ "nauc_recall_at_20_max": 0.16386688869593355,
126
+ "nauc_recall_at_20_std": -0.1528456365862212,
127
+ "nauc_recall_at_3_diff1": -0.021696811828998432,
128
+ "nauc_recall_at_3_max": 0.1864107664448688,
129
+ "nauc_recall_at_3_std": -0.14586036842324565,
130
+ "nauc_recall_at_5_diff1": -0.0538517948884412,
131
+ "nauc_recall_at_5_max": 0.1453135254521713,
132
+ "nauc_recall_at_5_std": -0.1531619473747777,
133
+ "ndcg_at_1": 0.03082,
134
+ "ndcg_at_10": 0.0501,
135
+ "ndcg_at_100": 0.07072,
136
+ "ndcg_at_1000": 0.09327,
137
+ "ndcg_at_20": 0.05662,
138
+ "ndcg_at_3": 0.03989,
139
+ "ndcg_at_5": 0.04484,
140
+ "precision_at_1": 0.03082,
141
+ "precision_at_10": 0.00993,
142
+ "precision_at_100": 0.00241,
143
+ "precision_at_1000": 0.00052,
144
+ "precision_at_20": 0.00685,
145
+ "precision_at_3": 0.02017,
146
+ "precision_at_5": 0.0153,
147
+ "recall_at_1": 0.02467,
148
+ "recall_at_10": 0.07499,
149
+ "recall_at_100": 0.16969,
150
+ "recall_at_1000": 0.33718,
151
+ "recall_at_20": 0.09901,
152
+ "recall_at_3": 0.04648,
153
+ "recall_at_5": 0.05869
154
+ }
155
+ ]
156
+ },
157
+ "task_name": "CQADupstackProgrammersRetrieval"
158
+ }
mteb_results/gte-Qwen2-7B-instruct-M2V-Distilled/distilled/STSBenchmark.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "b0fddb56ed78048fa8b90373c8a3cfc37b684831",
3
+ "evaluation_time": 0.12331175804138184,
4
+ "kg_co2_emissions": null,
5
+ "mteb_version": "1.14.15",
6
+ "scores": {
7
+ "test": [
8
+ {
9
+ "cosine_pearson": 0.34632056143460516,
10
+ "cosine_spearman": 0.42973159111999676,
11
+ "euclidean_pearson": 0.4043313982401531,
12
+ "euclidean_spearman": 0.42973159111999676,
13
+ "hf_subset": "default",
14
+ "languages": [
15
+ "eng-Latn"
16
+ ],
17
+ "main_score": 0.42973159111999676,
18
+ "manhattan_pearson": 0.511950240807258,
19
+ "manhattan_spearman": 0.5019330550880601,
20
+ "pearson": 0.34632056143460516,
21
+ "spearman": 0.42973159111999676
22
+ }
23
+ ]
24
+ },
25
+ "task_name": "STSBenchmark"
26
+ }
mteb_results/gte-Qwen2-7B-instruct-M2V-Distilled/distilled/SprintDuplicateQuestions.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "d66bd1f72af766a5cc4b0ca5e00c162f89e8cc46",
3
+ "evaluation_time": 1.9629368782043457,
4
+ "kg_co2_emissions": null,
5
+ "mteb_version": "1.14.15",
6
+ "scores": {
7
+ "test": [
8
+ {
9
+ "cosine_accuracy": 0.9926237623762376,
10
+ "cosine_accuracy_threshold": 0.9106360077857971,
11
+ "cosine_ap": 0.4700755863552174,
12
+ "cosine_f1": 0.4925187032418952,
13
+ "cosine_f1_threshold": 0.8986777067184448,
14
+ "cosine_precision": 0.6539735099337748,
15
+ "cosine_recall": 0.395,
16
+ "dot_accuracy": 0.9926237623762376,
17
+ "dot_accuracy_threshold": 0.9106361269950867,
18
+ "dot_ap": 0.47007548398718707,
19
+ "dot_f1": 0.4925187032418952,
20
+ "dot_f1_threshold": 0.8986777663230896,
21
+ "dot_precision": 0.6539735099337748,
22
+ "dot_recall": 0.395,
23
+ "euclidean_accuracy": 0.9926237623762376,
24
+ "euclidean_accuracy_threshold": 0.42276236414909363,
25
+ "euclidean_ap": 0.47007558217981027,
26
+ "euclidean_f1": 0.4925187032418952,
27
+ "euclidean_f1_threshold": 0.4501606225967407,
28
+ "euclidean_precision": 0.6539735099337748,
29
+ "euclidean_recall": 0.395,
30
+ "hf_subset": "default",
31
+ "languages": [
32
+ "eng-Latn"
33
+ ],
34
+ "main_score": 0.6386707007383838,
35
+ "manhattan_accuracy": 0.9939207920792079,
36
+ "manhattan_accuracy_threshold": 4.824772834777832,
37
+ "manhattan_ap": 0.6386707007383838,
38
+ "manhattan_f1": 0.6293103448275862,
39
+ "manhattan_f1_threshold": 5.194998741149902,
40
+ "manhattan_precision": 0.6822429906542056,
41
+ "manhattan_recall": 0.584,
42
+ "max_accuracy": 0.9939207920792079,
43
+ "max_ap": 0.6386707007383838,
44
+ "max_f1": 0.6293103448275862,
45
+ "max_precision": 0.6822429906542056,
46
+ "max_recall": 0.584,
47
+ "similarity_accuracy": 0.9926237623762376,
48
+ "similarity_accuracy_threshold": 0.9106360077857971,
49
+ "similarity_ap": 0.4700755863552174,
50
+ "similarity_f1": 0.4925187032418952,
51
+ "similarity_f1_threshold": 0.8986777067184448,
52
+ "similarity_precision": 0.6539735099337748,
53
+ "similarity_recall": 0.395
54
+ }
55
+ ]
56
+ },
57
+ "task_name": "SprintDuplicateQuestions"
58
+ }
mteb_results/gte-Qwen2-7B-instruct-M2V-Distilled/distilled/StackExchangeClustering.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "6cbc1f7b2bc0622f2e39d2c77fa502909748c259",
3
+ "evaluation_time": 1075.5739603042603,
4
+ "kg_co2_emissions": null,
5
+ "mteb_version": "1.14.15",
6
+ "scores": {
7
+ "test": [
8
+ {
9
+ "hf_subset": "default",
10
+ "languages": [
11
+ "eng-Latn"
12
+ ],
13
+ "main_score": 0.2747977935355363,
14
+ "v_measure": 0.2747977935355363,
15
+ "v_measure_std": 0.04408138950391278,
16
+ "v_measures": [
17
+ 0.2671568735697825,
18
+ 0.35324106044655595,
19
+ 0.2134334295678833,
20
+ 0.26069561242914296,
21
+ 0.2360037867112385,
22
+ 0.18352010080864292,
23
+ 0.21227539957559294,
24
+ 0.22564157353303899,
25
+ 0.31014309699664405,
26
+ 0.2792317143409387,
27
+ 0.30736400840236383,
28
+ 0.33654065468328326,
29
+ 0.3375811203083562,
30
+ 0.23635769205347795,
31
+ 0.2889733490218442,
32
+ 0.2628972368553193,
33
+ 0.2892573063858698,
34
+ 0.3093369539018476,
35
+ 0.2778955236652676,
36
+ 0.29489160764728006,
37
+ 0.3092126928451642,
38
+ 0.22100223054084894,
39
+ 0.23711645754707986,
40
+ 0.3264131545037563,
41
+ 0.2937622020471872
42
+ ]
43
+ }
44
+ ]
45
+ },
46
+ "task_name": "StackExchangeClustering"
47
+ }
mteb_results/gte-Qwen2-7B-instruct-M2V-Distilled/distilled/model_meta.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"name": "gte-Qwen2-7B-instruct-M2V-Distilled", "revision": "distilled", "release_date": null, "languages": ["eng"], "n_parameters": null, "memory_usage": null, "max_tokens": null, "embed_dim": null, "license": null, "open_source": null, "similarity_fn_name": null, "framework": [], "loader": null}
mteb_results/mteb_parsed_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "gte-Qwen2-7B-instruct-M2V-Distilled": "ResultSet(datasets={'Banking77Classification': DatasetResult(scores=[0.4396103896103896], time=6.451777696609497), 'StackExchangeClustering': DatasetResult(scores=[0.2747977935355363], time=1075.5739603042603), 'STSBenchmark': DatasetResult(scores=[0.42973159111999676], time=0.12331175804138184), 'CQADupstackProgrammersRetrieval': DatasetResult(scores=[0.0501], time=99.69791841506958), 'SprintDuplicateQuestions': DatasetResult(scores=[0.6386707007383838], time=1.9629368782043457)})"
3
+ }
mteb_results/mteb_raw_results.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [
2
+ "dataset_revision='0fd18e25b25c072e09e0d92ab615fda904d66300' task_name='Banking77Classification' mteb_version='1.14.15' scores={'test': [{'accuracy': 0.4396103896103896, 'f1': 0.4142711532114576, 'f1_weighted': 0.4142711532114576, 'scores_per_experiment': [{'accuracy': 0.4279220779220779, 'f1': 0.4030476288783657, 'f1_weighted': 0.4030476288783656}, {'accuracy': 0.4211038961038961, 'f1': 0.39776168133611584, 'f1_weighted': 0.39776168133611584}, {'accuracy': 0.45064935064935063, 'f1': 0.42872843564828145, 'f1_weighted': 0.42872843564828145}, {'accuracy': 0.4448051948051948, 'f1': 0.420756828398419, 'f1_weighted': 0.42075682839841905}, {'accuracy': 0.44675324675324674, 'f1': 0.42100682221185654, 'f1_weighted': 0.42100682221185654}, {'accuracy': 0.45324675324675323, 'f1': 0.4392342490231314, 'f1_weighted': 0.4392342490231314}, {'accuracy': 0.437012987012987, 'f1': 0.4056017558988273, 'f1_weighted': 0.40560175589882724}, {'accuracy': 0.42337662337662335, 'f1': 0.39123709562594644, 'f1_weighted': 0.39123709562594655}, {'accuracy': 0.44512987012987015, 'f1': 0.41578171494860966, 'f1_weighted': 0.41578171494860966}, {'accuracy': 0.4461038961038961, 'f1': 0.4195553201450221, 'f1_weighted': 0.419555320145022}], 'main_score': 0.4396103896103896, 'hf_subset': 'default', 'languages': ['eng-Latn']}]} evaluation_time=6.451777696609497 kg_co2_emissions=None",
3
+ "dataset_revision='6cbc1f7b2bc0622f2e39d2c77fa502909748c259' task_name='StackExchangeClustering' mteb_version='1.14.15' scores={'test': [{'v_measure': 0.2747977935355363, 'v_measure_std': 0.04408138950391278, 'v_measures': [0.2671568735697825, 0.35324106044655595, 0.2134334295678833, 0.26069561242914296, 0.2360037867112385, 0.18352010080864292, 0.21227539957559294, 0.22564157353303899, 0.31014309699664405, 0.2792317143409387, 0.30736400840236383, 0.33654065468328326, 0.3375811203083562, 0.23635769205347795, 0.2889733490218442, 0.2628972368553193, 0.2892573063858698, 0.3093369539018476, 0.2778955236652676, 0.29489160764728006, 0.3092126928451642, 0.22100223054084894, 0.23711645754707986, 0.3264131545037563, 0.2937622020471872], 'main_score': 0.2747977935355363, 'hf_subset': 'default', 'languages': ['eng-Latn']}]} evaluation_time=1075.5739603042603 kg_co2_emissions=None",
4
+ "dataset_revision='b0fddb56ed78048fa8b90373c8a3cfc37b684831' task_name='STSBenchmark' mteb_version='1.14.15' scores={'test': [{'pearson': 0.34632056143460516, 'spearman': 0.42973159111999676, 'cosine_pearson': 0.34632056143460516, 'cosine_spearman': 0.42973159111999676, 'manhattan_pearson': 0.511950240807258, 'manhattan_spearman': 0.5019330550880601, 'euclidean_pearson': 0.4043313982401531, 'euclidean_spearman': 0.42973159111999676, 'main_score': 0.42973159111999676, 'hf_subset': 'default', 'languages': ['eng-Latn']}]} evaluation_time=0.12331175804138184 kg_co2_emissions=None",
5
+ "dataset_revision='6184bc1440d2dbc7612be22b50686b8826d22b32' task_name='CQADupstackProgrammersRetrieval' mteb_version='1.14.15' scores={'test': [{'ndcg_at_1': 0.03082, 'ndcg_at_3': 0.03989, 'ndcg_at_5': 0.04484, 'ndcg_at_10': 0.0501, 'ndcg_at_20': 0.05662, 'ndcg_at_100': 0.07072, 'ndcg_at_1000': 0.09327, 'map_at_1': 0.02467, 'map_at_3': 0.03388, 'map_at_5': 0.03693, 'map_at_10': 0.03898, 'map_at_20': 0.04068, 'map_at_100': 0.04261, 'map_at_1000': 0.04333, 'recall_at_1': 0.02467, 'recall_at_3': 0.04648, 'recall_at_5': 0.05869, 'recall_at_10': 0.07499, 'recall_at_20': 0.09901, 'recall_at_100': 0.16969, 'recall_at_1000': 0.33718, 'precision_at_1': 0.03082, 'precision_at_3': 0.02017, 'precision_at_5': 0.0153, 'precision_at_10': 0.00993, 'precision_at_20': 0.00685, 'precision_at_100': 0.00241, 'precision_at_1000': 0.00052, 'mrr_at_1': 0.030821917808219176, 'mrr_at_3': 0.04280821917808219, 'mrr_at_5': 0.04634703196347032, 'mrr_at_10': 0.04904462926723201, 'mrr_at_20': 0.05126402659708249, 'mrr_at_100': 0.05339942610218758, 'mrr_at_1000': 0.05413492750157237, 'nauc_ndcg_at_1_max': 0.35182174117717013, 'nauc_ndcg_at_1_std': -0.24426280067522707, 'nauc_ndcg_at_1_diff1': 0.1772995319079407, 'nauc_ndcg_at_3_max': 0.23212930749840155, 'nauc_ndcg_at_3_std': -0.1728371812831961, 'nauc_ndcg_at_3_diff1': 0.03670154146101528, 'nauc_ndcg_at_5_max': 0.20474332948099355, 'nauc_ndcg_at_5_std': -0.1734952739301359, 'nauc_ndcg_at_5_diff1': 0.0107566708693031, 'nauc_ndcg_at_10_max': 0.19884193622357532, 'nauc_ndcg_at_10_std': -0.16919003671988075, 'nauc_ndcg_at_10_diff1': 0.0026192804576363727, 'nauc_ndcg_at_20_max': 0.20925361343315524, 'nauc_ndcg_at_20_std': -0.17106125631597793, 'nauc_ndcg_at_20_diff1': 0.0031543394811079034, 'nauc_ndcg_at_100_max': 0.20125970115134734, 'nauc_ndcg_at_100_std': -0.15865628929382014, 'nauc_ndcg_at_100_diff1': 0.0023309149151885546, 'nauc_ndcg_at_1000_max': 0.20925878430027478, 'nauc_ndcg_at_1000_std': -0.1717044268161809, 'nauc_ndcg_at_1000_diff1': -0.010372586628261796, 'nauc_map_at_1_max': 0.33459947679728974, 'nauc_map_at_1_std': -0.23115450977179597, 'nauc_map_at_1_diff1': 0.1731091343679673, 'nauc_map_at_3_max': 0.2486807528974488, 'nauc_map_at_3_std': -0.18512855007450404, 'nauc_map_at_3_diff1': 0.06042780588964212, 'nauc_map_at_5_max': 0.22647048266105405, 'nauc_map_at_5_std': -0.18107585673560017, 'nauc_map_at_5_diff1': 0.04407217741234605, 'nauc_map_at_10_max': 0.22061594321968936, 'nauc_map_at_10_std': -0.17777470317814356, 'nauc_map_at_10_diff1': 0.03906418656483989, 'nauc_map_at_20_max': 0.22396003211648763, 'nauc_map_at_20_std': -0.17867373725662639, 'nauc_map_at_20_diff1': 0.03795725531499195, 'nauc_map_at_100_max': 0.22324901446317413, 'nauc_map_at_100_std': -0.17630470695891512, 'nauc_map_at_100_diff1': 0.03759221625144172, 'nauc_map_at_1000_max': 0.2240572170754659, 'nauc_map_at_1000_std': -0.17708810912472517, 'nauc_map_at_1000_diff1': 0.03644747951501248, 'nauc_recall_at_1_max': 0.33459947679728974, 'nauc_recall_at_1_std': -0.23115450977179597, 'nauc_recall_at_1_diff1': 0.1731091343679673, 'nauc_recall_at_3_max': 0.1864107664448688, 'nauc_recall_at_3_std': -0.14586036842324565, 'nauc_recall_at_3_diff1': -0.021696811828998432, 'nauc_recall_at_5_max': 0.1453135254521713, 'nauc_recall_at_5_std': -0.1531619473747777, 'nauc_recall_at_5_diff1': -0.0538517948884412, 'nauc_recall_at_10_max': 0.1384336247044034, 'nauc_recall_at_10_std': -0.14737738059263306, 'nauc_recall_at_10_diff1': -0.051375323084735164, 'nauc_recall_at_20_max': 0.16386688869593355, 'nauc_recall_at_20_std': -0.1528456365862212, 'nauc_recall_at_20_diff1': -0.03578815918976938, 'nauc_recall_at_100_max': 0.14861973646512244, 'nauc_recall_at_100_std': -0.12240747671934184, 'nauc_recall_at_100_diff1': -0.023004658252697183, 'nauc_recall_at_1000_max': 0.16414155669676642, 'nauc_recall_at_1000_std': -0.1513320281746568, 'nauc_recall_at_1000_diff1': -0.047075752528689695, 'nauc_precision_at_1_max': 0.35182174117717013, 'nauc_precision_at_1_std': -0.24426280067522707, 'nauc_precision_at_1_diff1': 0.1772995319079407, 'nauc_precision_at_3_max': 0.21285488271783465, 'nauc_precision_at_3_std': -0.1483164417030193, 'nauc_precision_at_3_diff1': -0.013044619440245884, 'nauc_precision_at_5_max': 0.1756649379589832, 'nauc_precision_at_5_std': -0.15632134056178232, 'nauc_precision_at_5_diff1': -0.05113181393685194, 'nauc_precision_at_10_max': 0.18962064467698705, 'nauc_precision_at_10_std': -0.14827004787357115, 'nauc_precision_at_10_diff1': -0.052513811685878764, 'nauc_precision_at_20_max': 0.22086458009752882, 'nauc_precision_at_20_std': -0.14430508663959002, 'nauc_precision_at_20_diff1': -0.040789324913047875, 'nauc_precision_at_100_max': 0.22138981394024387, 'nauc_precision_at_100_std': -0.13384472263037697, 'nauc_precision_at_100_diff1': -0.04518222914182943, 'nauc_precision_at_1000_max': 0.2542912736794115, 'nauc_precision_at_1000_std': -0.1881459402790264, 'nauc_precision_at_1000_diff1': -0.07195606207962846, 'nauc_mrr_at_1_max': 0.35182174117717013, 'nauc_mrr_at_1_std': -0.24426280067522707, 'nauc_mrr_at_1_diff1': 0.1772995319079407, 'nauc_mrr_at_3_max': 0.26889485727748363, 'nauc_mrr_at_3_std': -0.19153801111553947, 'nauc_mrr_at_3_diff1': 0.06173430027850725, 'nauc_mrr_at_5_max': 0.253857849052297, 'nauc_mrr_at_5_std': -0.19604549670316734, 'nauc_mrr_at_5_diff1': 0.036743759763164886, 'nauc_mrr_at_10_max': 0.25392922716866984, 'nauc_mrr_at_10_std': -0.1935061134919541, 'nauc_mrr_at_10_diff1': 0.03361519179733555, 'nauc_mrr_at_20_max': 0.25624951214228564, 'nauc_mrr_at_20_std': -0.19212268093923462, 'nauc_mrr_at_20_diff1': 0.03479828151019169, 'nauc_mrr_at_100_max': 0.2523932973431928, 'nauc_mrr_at_100_std': -0.1900913512193067, 'nauc_mrr_at_100_diff1': 0.03435870935950355, 'nauc_mrr_at_1000_max': 0.2523936325136619, 'nauc_mrr_at_1000_std': -0.19078164353963076, 'nauc_mrr_at_1000_diff1': 0.033601872249839834, 'main_score': 0.0501, 'hf_subset': 'default', 'languages': ['eng-Latn']}]} evaluation_time=99.69791841506958 kg_co2_emissions=None",
6
+ "dataset_revision='d66bd1f72af766a5cc4b0ca5e00c162f89e8cc46' task_name='SprintDuplicateQuestions' mteb_version='1.14.15' scores={'test': [{'similarity_accuracy': 0.9926237623762376, 'similarity_accuracy_threshold': 0.9106360077857971, 'similarity_f1': 0.4925187032418952, 'similarity_f1_threshold': 0.8986777067184448, 'similarity_precision': 0.6539735099337748, 'similarity_recall': 0.395, 'similarity_ap': 0.4700755863552174, 'cosine_accuracy': 0.9926237623762376, 'cosine_accuracy_threshold': 0.9106360077857971, 'cosine_f1': 0.4925187032418952, 'cosine_f1_threshold': 0.8986777067184448, 'cosine_precision': 0.6539735099337748, 'cosine_recall': 0.395, 'cosine_ap': 0.4700755863552174, 'manhattan_accuracy': 0.9939207920792079, 'manhattan_accuracy_threshold': 4.824772834777832, 'manhattan_f1': 0.6293103448275862, 'manhattan_f1_threshold': 5.194998741149902, 'manhattan_precision': 0.6822429906542056, 'manhattan_recall': 0.584, 'manhattan_ap': 0.6386707007383838, 'euclidean_accuracy': 0.9926237623762376, 'euclidean_accuracy_threshold': 0.42276236414909363, 'euclidean_f1': 0.4925187032418952, 'euclidean_f1_threshold': 0.4501606225967407, 'euclidean_precision': 0.6539735099337748, 'euclidean_recall': 0.395, 'euclidean_ap': 0.47007558217981027, 'dot_accuracy': 0.9926237623762376, 'dot_accuracy_threshold': 0.9106361269950867, 'dot_f1': 0.4925187032418952, 'dot_f1_threshold': 0.8986777663230896, 'dot_precision': 0.6539735099337748, 'dot_recall': 0.395, 'dot_ap': 0.47007548398718707, 'max_accuracy': 0.9939207920792079, 'max_f1': 0.6293103448275862, 'max_precision': 0.6822429906542056, 'max_recall': 0.584, 'max_ap': 0.6386707007383838, 'main_score': 0.6386707007383838, 'hf_subset': 'default', 'languages': ['eng-Latn']}]} evaluation_time=1.9629368782043457 kg_co2_emissions=None"
7
+ ]
mteb_results/mteb_report.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ================================================================================
2
+ MTEB Evaluation Report
3
+ ================================================================================
4
+
5
+ Model: gte-Qwen2-7B-instruct-M2V-Distilled
6
+ Model Path: .
7
+ Evaluation Time: 1235.71 seconds
8
+ Total Datasets: 1
9
+
10
+ Summary Statistics:
11
+ Average Score: 0.0501
12
+ Median Score: 0.0501
13
+ Standard Deviation: 0.0000
14
+ Score Range: 0.0501 - 0.0501
15
+
16
+ Detailed Results:
17
+ --------------------------------------------------
18
+ Model Average (All) Average (MTEB) Classification Clustering PairClassification Reranking Retrieval STS Summarization PEARL WordSim
19
+ gte-Qwen2-7B-instruct-M2V-Distilled nan nan nan nan nan nan 5.01 nan nan nan nan
20
+
21
+ ================================================================================
mteb_results/mteb_summary.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "gte-Qwen2-7B-instruct-M2V-Distilled",
3
+ "evaluation_time_seconds": 1235.7057559490204,
4
+ "task_scores": {
5
+ "gte-Qwen2-7B-instruct-M2V-Distilled": {
6
+ "task_means": "Classification NaN\nClustering NaN\nPairClassification NaN\nReranking NaN\nRetrieval 0.0501\nSTS NaN\nSummarization NaN\nPEARL NaN\nWordSim NaN\ndtype: float64",
7
+ "dataset_scores": {
8
+ "CQADupstack": 0.0501
9
+ }
10
+ }
11
+ },
12
+ "summary_stats": {
13
+ "total_datasets": 1,
14
+ "average_score": 0.0501,
15
+ "median_score": 0.0501,
16
+ "std_dev": 0.0,
17
+ "min_score": 0.0501,
18
+ "max_score": 0.0501
19
+ }
20
+ }
pipeline.skops ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd33bcb8eee34a4df1a0d5e7d22b1e2b241ea683750204be74f78055882c76c3
3
+ size 3843639
pyproject.toml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "gte-qwen2-7b-instruct-m2v"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "datasets>=3.6.0",
9
+ "evaluation",
10
+ "iso639>=0.1.4",
11
+ "lightning>=2.5.1.post0",
12
+ "matplotlib>=3.10.3",
13
+ "model2vec[train]>=0.5.0",
14
+ "mteb>=1.14.15",
15
+ "psutil>=7.0.0",
16
+ "scikit-learn>=1.6.1",
17
+ "sentence-transformers>=4.1.0",
18
+ "torch>=2.7.0",
19
+ ]
20
+
21
+ [dependency-groups]
22
+ dev = [
23
+ "mypy>=1.15.0",
24
+ "ruff>=0.11.6",
25
+ ]
26
+
27
+ [tool.mypy]
28
+ exclude = [
29
+ ".git",
30
+ ".ruff_cache",
31
+ ".venv",
32
+ "venv",
33
+ "__pycache__",
34
+ "build",
35
+ "dist",
36
+ "vendor",
37
+ ]
38
+ follow_untyped_imports = true
39
+
40
+ [tool.ruff]
41
+ line-length = 120
42
+ target-version = "py312"
43
+
44
+ # Exclude files/directories
45
+ exclude = [
46
+ ".git",
47
+ ".ruff_cache",
48
+ ".venv",
49
+ "venv",
50
+ "__pycache__",
51
+ "build",
52
+ "dist",
53
+ "vendor"
54
+ ]
55
+
56
+ [tool.ruff.lint]
57
+ # Enable all rules by default, then selectively disable
58
+ select = ["ALL"]
59
+ ignore = [
60
+ # Rules that conflict with other tools/preferences
61
+ "D203", # one-blank-line-before-class
62
+ "D212", # multi-line-summary-first-line
63
+ "FBT001", # Boolean positional arg in function definition (required for typer)
64
+ "FBT002", # Boolean default value in function definition (required for typer)
65
+ "C901", # function too complex
66
+ "PLR0911", # too many return statements
67
+ "PLR0912", # too many branches
68
+ "PLR0913", # too many arguments in function definition
69
+ "PLR0915", # too many statements
70
+ "TRY300", # Consider moving this statement to an `else` block
71
+ "COM812", # Use a constant for the message in a raise statement
72
+ "TC001", # Move application import into a type-checking block
73
+ "ERA001", # Found commented-out code
74
+ "G004", # Logging statement uses f-string
75
+ "TD003", # Missing link in to-do
76
+ "TRY301", # Abstract raise to an inner function
77
+ # Disable rules that conflict with tab indentation
78
+ "E101", # Indentation contains mixed spaces and tabs
79
+ "W191", # indentation contains tabs
80
+ "D206", # indent with spaces, not tabs
81
+ ]
82
+
83
+ [tool.ruff.lint.mccabe]
84
+ max-complexity = 10
85
+
86
+ [tool.ruff.lint.pylint]
87
+ max-args = 5
88
+ max-branches = 12
89
+ max-statements = 50
90
+
91
+ [tool.ruff.lint.pydocstyle]
92
+ convention = "google"
93
+
94
+ [tool.ruff.format]
95
+ quote-style = "double"
96
+ indent-style = "tab"
97
+ skip-magic-trailing-comma = false
98
+ line-ending = "auto"
99
+
100
+ [tool.uv.sources]
101
+ evaluation = { git = "https://github.com/MinishLab/evaluation.git", rev = "main" }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e76d728582b9955c7afa6e5757b5b7825b3d40ef49d935b3cb7b148ad556dce4
3
+ size 11418179
train_code_classification.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Script to train a code classification model using CodeSearchNet dataset with Model2Vec.
4
+
5
+ This script performs the following operations:
6
+ 1. Downloads the Alibaba-NLP/gte-Qwen2-7B-instruct model
7
+ 2. Optionally distills it using Model2Vec to create a smaller, faster static model
8
+ 3. Trains a programming language classifier on CodeSearchNet dataset
9
+ 4. Evaluates the classifier and saves the trained model
10
+
11
+ Based on the official CodeSearchNet dataset: https://github.com/github/CodeSearchNet
12
+ """
13
+
14
+ import json
15
+ import logging
16
+ import re
17
+ import time
18
+ from pathlib import Path
19
+ from time import perf_counter
20
+ from typing import Any, cast
21
+
22
+ from datasets import Dataset, DatasetDict, load_dataset # type: ignore [import]
23
+ from model2vec.distill import distill
24
+ from model2vec.train import StaticModelForClassification
25
+
26
+ # =============================================================================
27
+ # CONFIGURATION CONSTANTS
28
+ # =============================================================================
29
+
30
+ # Model Configuration
31
+ MODEL_NAME = "Alibaba-NLP/gte-Qwen2-7B-instruct" # Source model to distill
32
+ OUTPUT_DIR = "." # Directory to save the trained model
33
+
34
+ # Distillation Configuration
35
+ SKIP_DISTILLATION = False # Set to True to skip distillation and use existing model
36
+ DISTILLED_MODEL_PATH = "." # Path to existing distilled model (if skipping distillation)
37
+ PCA_DIMS = 256 # Dimensions for PCA reduction (smaller = faster but less accurate)
38
+
39
+ # Dataset Configuration
40
+ DATASET_NAME = "code-search-net/code_search_net" # CodeSearchNet dataset
41
+ CLASSIFICATION_TASK = "language" # Task: classify programming language
42
+ MAX_SAMPLES_PER_LANGUAGE = 5000 # Limit samples per language for balanced training
43
+ MIN_CODE_LENGTH = 50 # Minimum code length in characters
44
+ MAX_CODE_LENGTH = 2000 # Maximum code length in characters (for memory efficiency)
45
+
46
+ # Text processing constants
47
+ MAX_COMMENT_LENGTH = 200 # Maximum length for comment lines before truncation
48
+
49
+ # Training Configuration
50
+ MAX_EPOCHS = 30 # Maximum number of training epochs
51
+ PATIENCE = 5 # Early stopping patience
52
+ BATCH_SIZE = 32 # Training batch size
53
+ LEARNING_RATE = 1e-3 # Learning rate
54
+
55
+ # Saving Configuration
56
+ SAVE_PIPELINE = True # Save as scikit-learn compatible pipeline
57
+ SAVE_TO_HUB = False # Whether to push the model to HuggingFace Hub
58
+ HUB_MODEL_ID = None # Model ID for HuggingFace Hub (if saving to hub)
59
+
60
+ # =============================================================================
61
+
62
+ # Configure logging
63
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
64
+ logger = logging.getLogger(__name__)
65
+
66
+
67
+ def clean_code_text(code: str) -> str:
68
+ """Clean and normalize code text for better classification."""
69
+ if not code:
70
+ return ""
71
+
72
+ # Remove excessive whitespace while preserving structure
73
+ code = re.sub(r"\n\s*\n\s*\n", "\n\n", code) # Remove multiple empty lines
74
+ code = re.sub(r" +", " ", code) # Replace multiple spaces with single space
75
+
76
+ # Remove very long comments that might bias classification
77
+ lines = code.split("\n")
78
+ cleaned_lines = []
79
+ for original_line in lines:
80
+ line = original_line
81
+ # Keep comment lines but limit their length
82
+ if line.strip().startswith(("#", "//", "/*", "*", "--")) and len(line) > MAX_COMMENT_LENGTH:
83
+ line = line[:MAX_COMMENT_LENGTH] + "..."
84
+ cleaned_lines.append(line)
85
+
86
+ return "\n".join(cleaned_lines)
87
+
88
+
89
+ def load_codesearchnet_dataset() -> tuple[Dataset, Dataset, str, str]:
90
+ """Load and preprocess the CodeSearchNet dataset for programming language classification."""
91
+ logger.info("Loading CodeSearchNet dataset...")
92
+
93
+ try:
94
+ # Load the dataset with trust_remote_code=True
95
+ logger.info("Downloading and loading CodeSearchNet data...")
96
+ ds = cast(
97
+ "DatasetDict",
98
+ load_dataset(
99
+ DATASET_NAME,
100
+ trust_remote_code=True,
101
+ # Load a reasonable sample for training
102
+ ),
103
+ )
104
+
105
+ logger.info(f"Available splits: {list(ds.keys())}")
106
+
107
+ # Use train/test splits if available, otherwise split the data
108
+ if "train" in ds and "test" in ds:
109
+ train_dataset = ds["train"]
110
+ test_dataset = ds["test"]
111
+ elif "train" in ds:
112
+ # Split the train set
113
+ split_ds = ds["train"].train_test_split(test_size=0.2, seed=42)
114
+ train_dataset = split_ds["train"]
115
+ test_dataset = split_ds["test"]
116
+ else:
117
+ # Use all data and split
118
+ all_data = ds[next(iter(ds.keys()))]
119
+ split_ds = all_data.train_test_split(test_size=0.2, seed=42)
120
+ train_dataset = split_ds["train"]
121
+ test_dataset = split_ds["test"]
122
+
123
+ logger.info(f"Raw dataset sizes - Train: {len(train_dataset)}, Test: {len(test_dataset)}")
124
+
125
+ # Filter and preprocess the data
126
+ def filter_and_clean(dataset: Dataset) -> Dataset:
127
+ # Filter examples with valid code and language
128
+ filtered = dataset.filter(
129
+ lambda x: (
130
+ x["func_code_string"] is not None
131
+ and x["language"] is not None
132
+ and len(x["func_code_string"]) >= MIN_CODE_LENGTH
133
+ and len(x["func_code_string"]) <= MAX_CODE_LENGTH
134
+ and x["language"] in ["python", "java", "javascript", "go", "php", "ruby"]
135
+ )
136
+ )
137
+
138
+ # Balance the dataset by limiting samples per language
139
+ if len(filtered) > MAX_SAMPLES_PER_LANGUAGE * 6: # 6 languages
140
+ # Group by language and sample
141
+ language_samples: dict[str, list[dict[str, Any]]] = {}
142
+ for example in filtered:
143
+ lang = example["language"]
144
+ if lang not in language_samples:
145
+ language_samples[lang] = []
146
+ if len(language_samples[lang]) < MAX_SAMPLES_PER_LANGUAGE:
147
+ language_samples[lang].append(example)
148
+
149
+ # Combine all samples
150
+ balanced_examples = []
151
+ for lang_examples in language_samples.values():
152
+ balanced_examples.extend(lang_examples)
153
+
154
+ # Convert back to dataset format
155
+ if balanced_examples:
156
+ filtered = Dataset.from_list(balanced_examples)
157
+
158
+ # Clean the code text
159
+ def clean_example(example: dict[str, Any]) -> dict[str, Any]:
160
+ example["func_code_string"] = clean_code_text(example["func_code_string"])
161
+ return example
162
+
163
+ return filtered.map(clean_example)
164
+
165
+ train_dataset = filter_and_clean(train_dataset)
166
+ test_dataset = filter_and_clean(test_dataset)
167
+
168
+ logger.info(f"Filtered dataset sizes - Train: {len(train_dataset)}, Test: {len(test_dataset)}")
169
+
170
+ # Show language distribution
171
+ if len(train_dataset) > 0:
172
+ from collections import Counter
173
+
174
+ train_lang_dist = Counter(train_dataset["language"])
175
+ test_lang_dist = Counter(test_dataset["language"])
176
+ logger.info(f"Training language distribution: {dict(train_lang_dist)}")
177
+ logger.info(f"Test language distribution: {dict(test_lang_dist)}")
178
+
179
+ return train_dataset, test_dataset, "func_code_string", "language"
180
+
181
+ except Exception:
182
+ logger.exception("Error loading CodeSearchNet dataset")
183
+ raise
184
+
185
+
186
+ def main() -> None:
187
+ """Run the code classification training pipeline."""
188
+ # Create output directory if it doesn't exist
189
+ output_dir = Path(OUTPUT_DIR)
190
+ output_dir.mkdir(parents=True, exist_ok=True)
191
+
192
+ logger.info(f"Starting CodeSearchNet code classification pipeline for {MODEL_NAME}")
193
+ logger.info(f"Classification task: {CLASSIFICATION_TASK}")
194
+ logger.info(f"Trained model will be saved to {output_dir}")
195
+
196
+ # Record start time for benchmarking
197
+ total_start_time = time.time()
198
+
199
+ try:
200
+ # Step 1: Get the static model (either distill or load existing)
201
+ static_model = None
202
+
203
+ if SKIP_DISTILLATION:
204
+ if DISTILLED_MODEL_PATH:
205
+ logger.info(f"Loading existing distilled model from {DISTILLED_MODEL_PATH}")
206
+ # Note: We'll create the classifier from pretrained instead
207
+ else:
208
+ logger.error("DISTILLED_MODEL_PATH must be specified when SKIP_DISTILLATION is True")
209
+ return
210
+ else:
211
+ logger.info("Starting Model2Vec distillation...")
212
+ distill_start_time = time.time()
213
+
214
+ static_model = distill(
215
+ model_name=MODEL_NAME,
216
+ pca_dims=PCA_DIMS,
217
+ )
218
+
219
+ distill_time = time.time() - distill_start_time
220
+ logger.info(f"Distillation completed in {distill_time:.2f} seconds")
221
+
222
+ # Step 2: Create the classifier
223
+ logger.info("Creating classifier...")
224
+
225
+ if static_model is not None:
226
+ # From a distilled model
227
+ classifier = StaticModelForClassification.from_static_model(model=static_model)
228
+ else:
229
+ # From a pre-trained model path
230
+ classifier = StaticModelForClassification.from_pretrained(model_name=DISTILLED_MODEL_PATH)
231
+
232
+ # Step 3: Load the CodeSearchNet dataset
233
+ train_dataset, test_dataset, text_column, label_column = load_codesearchnet_dataset()
234
+
235
+ if len(train_dataset) == 0 or len(test_dataset) == 0:
236
+ logger.error("No valid data found after filtering. Please check dataset configuration.")
237
+ return
238
+
239
+ logger.info(f"Training dataset size: {len(train_dataset)}")
240
+ logger.info(f"Test dataset size: {len(test_dataset)}")
241
+
242
+ # Get unique languages for reference
243
+ unique_languages = sorted(set(train_dataset[label_column]))
244
+ logger.info(f"Programming languages to classify: {unique_languages}")
245
+
246
+ # Step 4: Train the classifier
247
+ logger.info("Starting training...")
248
+ train_start_time = perf_counter()
249
+
250
+ classifier = classifier.fit(
251
+ train_dataset[text_column],
252
+ train_dataset[label_column],
253
+ max_epochs=MAX_EPOCHS,
254
+ batch_size=BATCH_SIZE,
255
+ learning_rate=LEARNING_RATE,
256
+ early_stopping_patience=PATIENCE,
257
+ )
258
+
259
+ train_time = perf_counter() - train_start_time
260
+ logger.info(f"Training completed in {int(train_time)} seconds")
261
+
262
+ # Step 5: Evaluate the classifier
263
+ logger.info("Evaluating classifier...")
264
+ eval_start_time = perf_counter()
265
+
266
+ classification_report = classifier.evaluate(test_dataset[text_column], test_dataset[label_column])
267
+
268
+ eval_time = perf_counter() - eval_start_time
269
+ logger.info(f"Evaluation completed in {int(eval_time * 1000)} milliseconds")
270
+ logger.info(f"Classification results:\n{classification_report}")
271
+
272
+ # Step 6: Test with some examples
273
+ logger.info("Testing with sample code snippets...")
274
+
275
+ # Test examples for different languages
276
+ test_examples = [
277
+ 'def hello_world():\n print("Hello, World!")\n return True', # Python
278
+ (
279
+ "public class HelloWorld {\n"
280
+ " public static void main(String[] args) {\n"
281
+ ' System.out.println("Hello, World!");\n'
282
+ " }\n"
283
+ "}"
284
+ ), # Java
285
+ 'function helloWorld() {\n console.log("Hello, World!");\n return true;\n}', # JavaScript
286
+ 'package main\n\nimport "fmt"\n\nfunc main() {\n fmt.Println("Hello, World!")\n}', # Go
287
+ '<?php\nfunction hello_world() {\n echo "Hello, World!";\n return true;\n}\n?>', # PHP
288
+ 'def hello_world\n puts "Hello, World!"\n true\nend', # Ruby
289
+ ]
290
+
291
+ predictions = classifier.predict(test_examples)
292
+ for i, (code, pred) in enumerate(zip(test_examples, predictions, strict=False)):
293
+ logger.info(f"Example {i + 1}: {pred}")
294
+ logger.info(f"Code snippet: {code[:100]}...")
295
+
296
+ # Step 7: Benchmark inference speed
297
+ logger.info("Benchmarking inference speed...")
298
+ inference_start_time = perf_counter()
299
+ _ = classifier.predict(test_dataset[text_column][:100]) # Test on first 100 samples
300
+ inference_time = perf_counter() - inference_start_time
301
+ logger.info(f"Inference took {int(inference_time * 1000)} milliseconds for 100 code snippets on CPU")
302
+
303
+ # Step 8: Save the model
304
+ if SAVE_PIPELINE:
305
+ logger.info("Converting to scikit-learn pipeline...")
306
+ pipeline = classifier.to_pipeline()
307
+
308
+ # Save locally
309
+ pipeline_path = output_dir / "pipeline"
310
+ pipeline.save_pretrained(str(pipeline_path))
311
+ logger.info(f"Pipeline saved to {pipeline_path}")
312
+
313
+ # Save additional metadata
314
+ metadata = {
315
+ "model_name": MODEL_NAME,
316
+ "dataset": DATASET_NAME,
317
+ "task": "programming_language_classification",
318
+ "languages": unique_languages,
319
+ "pca_dims": PCA_DIMS,
320
+ "train_samples": len(train_dataset),
321
+ "test_samples": len(test_dataset),
322
+ }
323
+
324
+ metadata_path = output_dir / "metadata.json"
325
+ with metadata_path.open("w") as f:
326
+ json.dump(metadata, f, indent=2)
327
+ logger.info("Metadata saved to metadata.json")
328
+
329
+ # Push to hub if requested
330
+ if SAVE_TO_HUB and HUB_MODEL_ID:
331
+ logger.info(f"Pushing pipeline to HuggingFace Hub as {HUB_MODEL_ID}")
332
+ pipeline.push_to_hub(HUB_MODEL_ID)
333
+ else:
334
+ # Save the classifier directly
335
+ classifier_path = output_dir / "classifier"
336
+ classifier_path.mkdir(exist_ok=True)
337
+
338
+ # Note: StaticModelForClassification might not have save_pretrained
339
+ # We'll save the underlying static model and create instructions
340
+ if static_model is not None:
341
+ static_model.save_pretrained(str(classifier_path / "static_model"))
342
+
343
+ logger.info(f"Classifier components saved to {classifier_path}")
344
+
345
+ # Summary
346
+ total_time = time.time() - total_start_time
347
+ logger.info("=" * 60)
348
+ logger.info("CODE CLASSIFICATION TRAINING COMPLETED SUCCESSFULLY!")
349
+ logger.info(f"Total time: {total_time:.2f} seconds")
350
+ if not SKIP_DISTILLATION:
351
+ logger.info(f"Distillation time: {distill_time:.2f} seconds")
352
+ logger.info(f"Training time: {int(train_time)} seconds")
353
+ logger.info(f"Dataset: {DATASET_NAME}")
354
+ logger.info("Task: Programming Language Classification")
355
+ logger.info(f"Languages: {', '.join(unique_languages)}")
356
+ logger.info(f"Model saved to: {output_dir}")
357
+ logger.info("=" * 60)
358
+
359
+ except Exception:
360
+ logger.exception("Error during code classification training pipeline")
361
+ raise
362
+
363
+
364
+ if __name__ == "__main__":
365
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff