Sarthak commited on
Commit
7837959
·
1 Parent(s): 37196da

chore: update dependencies and configuration for improved training

Browse files

This commit updates the model configuration in `.codemap.yml` to use a lighter version of the model. Additionally, it enhances the `pyproject.toml` and `uv.lock` files by adding new dependencies such as `jinja2`, `joblib`, `rich`, and `safetensors`, while also replacing `tokenlearn` with `tokenizers`. The report has been adjusted to reflect changes in model performance metrics and the dataset configuration has been improved to support optimized dataset usage during training.

.codemap.yml CHANGED
@@ -5,7 +5,7 @@
5
  # LLM Configuration - Controls which model is used for AI operations
6
  llm:
7
  # Format: "provider:model-name", e.g., "openai:gpt-4o", "anthropic:claude-3-opus"
8
- model: "google-gla:gemini-2.0-flash"
9
  temperature: 0.5 # Lower for more deterministic outputs, higher for creativity
10
  max_input_tokens: 1000000 # Maximum tokens in input
11
  max_output_tokens: 10000 # Maximum tokens in responses
 
5
  # LLM Configuration - Controls which model is used for AI operations
6
  llm:
7
  # Format: "provider:model-name", e.g., "openai:gpt-4o", "anthropic:claude-3-opus"
8
+ model: "google-gla:gemini-2.0-flash-lite"
9
  temperature: 0.5 # Lower for more deterministic outputs, higher for creativity
10
  max_input_tokens: 1000000 # Maximum tokens in input
11
  max_output_tokens: 10000 # Maximum tokens in responses
REPORT.md CHANGED
@@ -28,8 +28,8 @@ This report presents a comprehensive analysis of Model2Vec distillation experime
28
  | code_model2vec_all_MiniLM_L6_v2 | [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) | 0.7385 | 0.7049 | 0.7910 | 🥈 2nd |
29
  | code_model2vec_jina_embeddings_v2_base_code | [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code) | 0.7381 | 0.6996 | 0.8130 | 🥉 3rd |
30
  | code_model2vec_paraphrase_MiniLM_L6_v2 | [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) | 0.7013 | 0.6638 | 0.7665 | #4 |
31
- | code_model2vec_all_mpnet_base_v2_fine_tuned | [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) | 0.6906 | 0.6372 | 0.7917 | #5 |
32
- | code_model2vec_Reason_ModernColBERT | [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT) | 0.6598 | 0.6228 | 0.7260 | #6 |
33
  | code_model2vec_bge_m3 | [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | 0.4863 | 0.4439 | 0.5514 | #7 |
34
  | code_model2vec_jina_embeddings_v3 | [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) | 0.4755 | 0.4416 | 0.5456 | #8 |
35
  | code_model2vec_nomic_embed_text_v2_moe | [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) | 0.4532 | 0.4275 | 0.5094 | #9 |
@@ -50,8 +50,8 @@ Our distilled models exhibit consistent architectural characteristics across dif
50
  | all_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
51
  | jina_embeddings_v2_base_code | 61,053 | 15.6M | 256 | 29.8MB |
52
  | paraphrase_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
53
- | all_mpnet_base_v2_fine_tuned | 77,316 | 19.8M | 256 | 75.5MB |
54
  | Reason_ModernColBERT | 50,254 | 12.9M | 256 | 24.5MB |
 
55
  | bge_m3 | 249,999 | 64.0M | 256 | 122.1MB |
56
  | jina_embeddings_v3 | 249,999 | 64.0M | 256 | 122.1MB |
57
  | nomic_embed_text_v2_moe | 249,999 | 64.0M | 256 | 122.1MB |
@@ -69,9 +69,9 @@ Our distilled models exhibit consistent architectural characteristics across dif
69
  #### Key Insights from Model Specifications:
70
 
71
 
72
- - **Vocabulary Consistency**: All models use vocabulary sizes ranging from 29,525 to 249,999 tokens (avg: 104,501)
73
- - **Parameter Efficiency**: Models range from 7.6M to 64.0M parameters (avg: 26.8M)
74
- - **Storage Efficiency**: Disk usage ranges from 14.4MB to 122.1MB (avg: 53.7MB)
75
  - **Embedding Dimensions**: Consistent 256 dimensions across all models (optimized for efficiency)
76
 
77
 
@@ -81,85 +81,13 @@ Our distilled models exhibit consistent architectural characteristics across dif
81
  - **Best Teacher Model**: code_model2vec_all_mpnet_base_v2 (NDCG@10: 0.7387)
82
  - **Least Effective Teacher**: code_model2vec_codebert_base (NDCG@10: 0.2779)
83
  - **Performance Range**: 62.4% difference between best and worst
84
- - **Average Performance**: 0.5302 NDCG@10
85
 
86
 
87
  ## 🎯 Language Performance Radar Charts
88
 
89
  ### Best Model vs Peer Models Comparison
90
 
91
- ![Comparative Radar Chart](analysis_charts/comparative_radar.png)
92
-
93
- *Comparative view showing how the best simplified distillation model performs against top peer models across programming languages.*
94
-
95
- ### Individual Model Performance by Language
96
-
97
- #### code_model2vec_all_mpnet_base_v2 (Teacher: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)) - NDCG@10: 0.7387
98
-
99
- ![code_model2vec_all_mpnet_base_v2 Radar Chart](analysis_charts/radar_code_model2vec_all_mpnet_base_v2.png)
100
-
101
- #### code_model2vec_all_MiniLM_L6_v2 (Teacher: [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)) - NDCG@10: 0.7385
102
-
103
- ![code_model2vec_all_MiniLM_L6_v2 Radar Chart](analysis_charts/radar_code_model2vec_all_MiniLM_L6_v2.png)
104
-
105
- #### code_model2vec_jina_embeddings_v2_base_code (Teacher: [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code)) - NDCG@10: 0.7381
106
-
107
- ![code_model2vec_jina_embeddings_v2_base_code Radar Chart](analysis_charts/radar_code_model2vec_jina_embeddings_v2_base_code.png)
108
-
109
- #### code_model2vec_paraphrase_MiniLM_L6_v2 (Teacher: [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2)) - NDCG@10: 0.7013
110
-
111
- ![code_model2vec_paraphrase_MiniLM_L6_v2 Radar Chart](analysis_charts/radar_code_model2vec_paraphrase_MiniLM_L6_v2.png)
112
-
113
- #### code_model2vec_all_mpnet_base_v2_fine_tuned (Teacher: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)) - NDCG@10: 0.6906
114
-
115
- ![code_model2vec_all_mpnet_base_v2_fine_tuned Radar Chart](analysis_charts/radar_code_model2vec_all_mpnet_base_v2_fine_tuned.png)
116
-
117
- #### code_model2vec_Reason_ModernColBERT (Teacher: [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT)) - NDCG@10: 0.6598
118
-
119
- ![code_model2vec_Reason_ModernColBERT Radar Chart](analysis_charts/radar_code_model2vec_Reason_ModernColBERT.png)
120
-
121
- #### code_model2vec_bge_m3 (Teacher: [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3)) - NDCG@10: 0.4863
122
-
123
- ![code_model2vec_bge_m3 Radar Chart](analysis_charts/radar_code_model2vec_bge_m3.png)
124
-
125
- #### code_model2vec_jina_embeddings_v3 (Teacher: [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3)) - NDCG@10: 0.4755
126
-
127
- ![code_model2vec_jina_embeddings_v3 Radar Chart](analysis_charts/radar_code_model2vec_jina_embeddings_v3.png)
128
-
129
- #### code_model2vec_nomic_embed_text_v2_moe (Teacher: [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe)) - NDCG@10: 0.4532
130
-
131
- ![code_model2vec_nomic_embed_text_v2_moe Radar Chart](analysis_charts/radar_code_model2vec_nomic_embed_text_v2_moe.png)
132
-
133
- #### code_model2vec_gte_Qwen2_1.5B_instruct (Teacher: [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct)) - NDCG@10: 0.4238
134
-
135
- ![code_model2vec_gte_Qwen2_1.5B_instruct Radar Chart](analysis_charts/radar_code_model2vec_gte_Qwen2_15B_instruct.png)
136
-
137
- #### code_model2vec_Qodo_Embed_1_1.5B (Teacher: [Qodo/Qodo-Embed-1-1.5B](https://huggingface.co/Qodo/Qodo-Embed-1-1.5B)) - NDCG@10: 0.4101
138
-
139
- ![code_model2vec_Qodo_Embed_1_1.5B Radar Chart](analysis_charts/radar_code_model2vec_Qodo_Embed_1_15B.png)
140
-
141
- #### code_model2vec_graphcodebert_base (Teacher: [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base)) - NDCG@10: 0.3420
142
-
143
- ![code_model2vec_graphcodebert_base Radar Chart](analysis_charts/radar_code_model2vec_graphcodebert_base.png)
144
-
145
- #### code_model2vec_Linq_Embed_Mistral (Teacher: [Linq-AI-Research/Linq-Embed-Mistral](https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral)) - NDCG@10: 0.2868
146
-
147
- ![code_model2vec_Linq_Embed_Mistral Radar Chart](analysis_charts/radar_code_model2vec_Linq_Embed_Mistral.png)
148
-
149
- #### code_model2vec_codebert_base (Teacher: [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base)) - NDCG@10: 0.2779
150
-
151
- ![code_model2vec_codebert_base Radar Chart](analysis_charts/radar_code_model2vec_codebert_base.png)
152
-
153
-
154
-
155
- ## 🏆 Peer Model Comparison
156
-
157
- ![Peer Comparison](analysis_charts/peer_comparison.png)
158
-
159
- *Comparison with established code-specialized embedding models using actual evaluation results.*
160
-
161
- ### Complete Model Ranking
162
-
163
  | Rank | Model | Type | NDCG@10 | MRR | Recall@5 |
164
  |------|-------|------|---------|-----|----------|
165
  | 1 | Alibaba-NLP/gte-Qwen2-1.5B-instruct | General | 0.9729 | 0.9676 | 0.9825 |
@@ -180,10 +108,10 @@ Our distilled models exhibit consistent architectural characteristics across dif
180
  | 16 | code_model2vec_all_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7385 | 0.7049 | 0.7910 |
181
  | 17 | code_model2vec_jina_embeddings_v2_base_code | **🔥 Simplified Distillation** | 0.7381 | 0.6996 | 0.8130 |
182
  | 18 | code_model2vec_paraphrase_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7013 | 0.6638 | 0.7665 |
183
- | 19 | code_model2vec_all_mpnet_base_v2_fine_tuned | **🎓 Fine-tuned Distillation** | 0.6906 | 0.6372 | 0.7917 |
184
- | 20 | code_model2vec_Reason_ModernColBERT | **🔥 Simplified Distillation** | 0.6598 | 0.6228 | 0.7260 |
185
- | 21 | potion-multilingual-128M | Model2Vec | 0.6124 | 0.5683 | 0.7017 |
186
- | 22 | huggingface/CodeBERTa-small-v1 | Code-Specific | 0.5903 | 0.5350 | 0.6779 |
187
  | 23 | Salesforce/codet5-base | Code-Specific | 0.4872 | 0.4500 | 0.5742 |
188
  | 24 | code_model2vec_bge_m3 | **🔥 Simplified Distillation** | 0.4863 | 0.4439 | 0.5514 |
189
  | 25 | code_model2vec_jina_embeddings_v3 | **🔥 Simplified Distillation** | 0.4755 | 0.4416 | 0.5456 |
@@ -243,12 +171,12 @@ Our distilled models exhibit consistent architectural characteristics across dif
243
 
244
  | Language | Best Model Performance | Average Performance | Language Difficulty |
245
  |----------|------------------------|--------------------|--------------------|
246
- | Go | 0.9780 | 0.6978 | Easy |
247
- | Java | 0.9921 | 0.6618 | Easy |
248
- | Javascript | 0.9550 | 0.5877 | Easy |
249
- | Php | 1.0000 | 0.6355 | Easy |
250
- | Python | 1.0000 | 0.8615 | Easy |
251
- | Ruby | 0.9493 | 0.6398 | Easy |
252
 
253
 
254
  ## 🎯 Conclusions and Recommendations
@@ -302,5 +230,5 @@ Based on the evaluation results across all simplified distillation models:
302
 
303
  ---
304
 
305
- *Report generated on 2025-05-31 16:36:16 using automated analysis pipeline.*
306
  *For questions about methodology or results, please refer to the CodeSearchNet documentation.*
 
28
  | code_model2vec_all_MiniLM_L6_v2 | [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) | 0.7385 | 0.7049 | 0.7910 | 🥈 2nd |
29
  | code_model2vec_jina_embeddings_v2_base_code | [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code) | 0.7381 | 0.6996 | 0.8130 | 🥉 3rd |
30
  | code_model2vec_paraphrase_MiniLM_L6_v2 | [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) | 0.7013 | 0.6638 | 0.7665 | #4 |
31
+ | code_model2vec_Reason_ModernColBERT | [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT) | 0.6598 | 0.6228 | 0.7260 | #5 |
32
+ | code_model2vec_all_mpnet_base_v2_fine_tuned | [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) | 0.5347 | 0.4875 | 0.6200 | #6 |
33
  | code_model2vec_bge_m3 | [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | 0.4863 | 0.4439 | 0.5514 | #7 |
34
  | code_model2vec_jina_embeddings_v3 | [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) | 0.4755 | 0.4416 | 0.5456 | #8 |
35
  | code_model2vec_nomic_embed_text_v2_moe | [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) | 0.4532 | 0.4275 | 0.5094 | #9 |
 
50
  | all_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
51
  | jina_embeddings_v2_base_code | 61,053 | 15.6M | 256 | 29.8MB |
52
  | paraphrase_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
 
53
  | Reason_ModernColBERT | 50,254 | 12.9M | 256 | 24.5MB |
54
+ | all_mpnet_base_v2_fine_tuned | 29,528 | 7.6M | 256 | 28.8MB |
55
  | bge_m3 | 249,999 | 64.0M | 256 | 122.1MB |
56
  | jina_embeddings_v3 | 249,999 | 64.0M | 256 | 122.1MB |
57
  | nomic_embed_text_v2_moe | 249,999 | 64.0M | 256 | 122.1MB |
 
69
  #### Key Insights from Model Specifications:
70
 
71
 
72
+ - **Vocabulary Consistency**: All models use vocabulary sizes ranging from 29,525 to 249,999 tokens (avg: 101,087)
73
+ - **Parameter Efficiency**: Models range from 7.6M to 64.0M parameters (avg: 25.9M)
74
+ - **Storage Efficiency**: Disk usage ranges from 14.4MB to 122.1MB (avg: 50.4MB)
75
  - **Embedding Dimensions**: Consistent 256 dimensions across all models (optimized for efficiency)
76
 
77
 
 
81
  - **Best Teacher Model**: code_model2vec_all_mpnet_base_v2 (NDCG@10: 0.7387)
82
  - **Least Effective Teacher**: code_model2vec_codebert_base (NDCG@10: 0.2779)
83
  - **Performance Range**: 62.4% difference between best and worst
84
+ - **Average Performance**: 0.5190 NDCG@10
85
 
86
 
87
  ## 🎯 Language Performance Radar Charts
88
 
89
  ### Best Model vs Peer Models Comparison
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  | Rank | Model | Type | NDCG@10 | MRR | Recall@5 |
92
  |------|-------|------|---------|-----|----------|
93
  | 1 | Alibaba-NLP/gte-Qwen2-1.5B-instruct | General | 0.9729 | 0.9676 | 0.9825 |
 
108
  | 16 | code_model2vec_all_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7385 | 0.7049 | 0.7910 |
109
  | 17 | code_model2vec_jina_embeddings_v2_base_code | **🔥 Simplified Distillation** | 0.7381 | 0.6996 | 0.8130 |
110
  | 18 | code_model2vec_paraphrase_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7013 | 0.6638 | 0.7665 |
111
+ | 19 | code_model2vec_Reason_ModernColBERT | **🔥 Simplified Distillation** | 0.6598 | 0.6228 | 0.7260 |
112
+ | 20 | potion-multilingual-128M | Model2Vec | 0.6124 | 0.5683 | 0.7017 |
113
+ | 21 | huggingface/CodeBERTa-small-v1 | Code-Specific | 0.5903 | 0.5350 | 0.6779 |
114
+ | 22 | code_model2vec_all_mpnet_base_v2_fine_tuned | **🎓 Fine-tuned Distillation** | 0.5347 | 0.4875 | 0.6200 |
115
  | 23 | Salesforce/codet5-base | Code-Specific | 0.4872 | 0.4500 | 0.5742 |
116
  | 24 | code_model2vec_bge_m3 | **🔥 Simplified Distillation** | 0.4863 | 0.4439 | 0.5514 |
117
  | 25 | code_model2vec_jina_embeddings_v3 | **🔥 Simplified Distillation** | 0.4755 | 0.4416 | 0.5456 |
 
171
 
172
  | Language | Best Model Performance | Average Performance | Language Difficulty |
173
  |----------|------------------------|--------------------|--------------------|
174
+ | Go | 0.9780 | 0.6923 | Easy |
175
+ | Java | 0.9921 | 0.6545 | Easy |
176
+ | Javascript | 0.9550 | 0.5831 | Easy |
177
+ | Php | 1.0000 | 0.6325 | Easy |
178
+ | Python | 1.0000 | 0.8599 | Easy |
179
+ | Ruby | 0.9493 | 0.6333 | Easy |
180
 
181
 
182
  ## 🎯 Conclusions and Recommendations
 
230
 
231
  ---
232
 
233
+ *Report generated on 2025-05-31 21:07:06 using automated analysis pipeline.*
234
  *For questions about methodology or results, please refer to the CodeSearchNet documentation.*
patches/model2vec.patch DELETED
@@ -1,39 +0,0 @@
1
- --- a/model2vec/train/base.py
2
- +++ b/model2vec/train/base.py
3
- @@ -35,7 +35,7 @@ class FinetunableStaticModel(nn.Module):
4
- )
5
- self.vectors = vectors.float()
6
-
7
- - self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
8
- + self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=pad_id)
9
- self.head = self.construct_head()
10
- self.w = self.construct_weights()
11
- self.tokenizer = tokenizer
12
- --- a/model2vec/distill/distillation.py
13
- +++ b/model2vec/distill/distillation.py
14
- @@ -137,7 +137,10 @@ def distill_from_model(
15
- # Get the language from the model card.
16
- try:
17
- info = model_info(model_name)
18
- - language = info.cardData.get("language", None)
19
- + if info is not None and hasattr(info, 'cardData') and info.cardData is not None:
20
- + language = info.cardData.get("language", None)
21
- + else:
22
- + language = None
23
- except RepositoryNotFoundError:
24
- logger.info("No model info found for the model. Setting language to None.")
25
- language = None
26
- --- a/model2vec/distill/inference.py
27
- +++ b/model2vec/distill/inference.py
28
- @@ -109,5 +109,12 @@ def create_embeddings(
29
- out_tokens.extend([Token(x, False) for x in tokens])
30
- out_weights = np.stack(intermediate_weights)
31
-
32
- + # Validate token-vector consistency to prevent failures
33
- + if len(out_tokens) != out_weights.shape[0]:
34
- + logger.warning(f"Token-vector mismatch: {len(out_tokens)} tokens vs {out_weights.shape[0]} vectors. Truncating to prevent failure.")
35
- + min_count = min(len(out_tokens), out_weights.shape[0])
36
- + out_tokens = out_tokens[:min_count]
37
- + out_weights = out_weights[:min_count]
38
- +
39
- return out_tokens, out_weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
patches/tokenlearn.patch DELETED
@@ -1,25 +0,0 @@
1
- --- a/tokenlearn/pretrain.py
2
- +++ b/tokenlearn/pretrain.py
3
- @@ -38,7 +38,10 @@ class FinetunableStaticModel(nn.Module):
4
- """Run the model using input IDs."""
5
- input_ids = input_ids.view(-1)
6
- input_ids = input_ids[input_ids != self.pad_token_id]
7
- - w = self.w[input_ids]
8
- + # Fix for index out of bounds issue
9
- + # Clamp input_ids to valid range to prevent IndexError during training
10
- + valid_input_ids = torch.clamp(input_ids, 0, self.w.shape[0] - 1)
11
- + w = self.w[valid_input_ids]
12
- return self.sub_forward(w)
13
-
14
- def forward(self, x):
15
- @@ -46,7 +49,10 @@ class FinetunableStaticModel(nn.Module):
16
- # Add a small epsilon to avoid division by zero
17
- length = zeros.sum(1) + 1e-16
18
- - embedded = self.embeddings(input_ids)
19
- + # Fix for embedding index out of bounds issue
20
- + # Clamp input_ids to valid embedding range
21
- + valid_input_ids = torch.clamp(input_ids, 0, self.embeddings.num_embeddings - 1)
22
- + embedded = self.embeddings(valid_input_ids)
23
- # Zero out the padding
24
- embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
25
- # Simulate actual mean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml CHANGED
@@ -19,24 +19,31 @@ dependencies = [
19
  "flash-attn>=2.7.4.post1",
20
  "hatchling>=1.27.0",
21
  "iso639>=0.1.4",
 
 
22
  "kaleido==1.0.0rc13",
23
  "lightning>=2.5.1.post0",
24
  "matplotlib>=3.10.3",
25
- "model2vec[train]>=0.5.0",
26
  "mteb>=1.14.15",
27
  "numpy>=1.26.4",
28
  "plotly>=6.1.1",
29
  "psutil>=7.0.0",
30
  "pydantic>=2.11.5",
31
  "requests>=2.32.3",
 
 
32
  "scikit-learn>=1.6.1",
33
  "seaborn>=0.13.2",
34
  "sentence-transformers>=4.1.0",
35
  "setuptools>=80.8.0",
 
36
  "smart-open[s3]>=7.1.0",
37
  "statsmodels>=0.14.4",
38
- "tokenlearn>=0.2.0",
39
  "torch>=2.7.0",
 
 
40
  "typer>=0.16.0",
41
  ]
42
 
@@ -78,7 +85,9 @@ exclude = [
78
  "__pycache__",
79
  "build",
80
  "dist",
81
- "vendor"
 
 
82
  ]
83
 
84
  [tool.ruff.lint]
@@ -114,6 +123,8 @@ ignore = [
114
  "E501", # Line too long
115
  "PLR2004",
116
  "RUF001",
 
 
117
  ]
118
 
119
  [tool.ruff.lint.mccabe]
 
19
  "flash-attn>=2.7.4.post1",
20
  "hatchling>=1.27.0",
21
  "iso639>=0.1.4",
22
+ "jinja2>=3.0.0",
23
+ "joblib>=1.0.0",
24
  "kaleido==1.0.0rc13",
25
  "lightning>=2.5.1.post0",
26
  "matplotlib>=3.10.3",
27
+ "more-itertools>=10.5.0",
28
  "mteb>=1.14.15",
29
  "numpy>=1.26.4",
30
  "plotly>=6.1.1",
31
  "psutil>=7.0.0",
32
  "pydantic>=2.11.5",
33
  "requests>=2.32.3",
34
+ "rich>=10.0.0",
35
+ "safetensors>=0.3.0",
36
  "scikit-learn>=1.6.1",
37
  "seaborn>=0.13.2",
38
  "sentence-transformers>=4.1.0",
39
  "setuptools>=80.8.0",
40
+ "skops>=0.11.0",
41
  "smart-open[s3]>=7.1.0",
42
  "statsmodels>=0.14.4",
43
+ "tokenizers>=0.20",
44
  "torch>=2.7.0",
45
+ "transformers<=4.52.1",
46
+ "tqdm>=4.65.0",
47
  "typer>=0.16.0",
48
  ]
49
 
 
85
  "__pycache__",
86
  "build",
87
  "dist",
88
+ "vendor",
89
+ "src/distiller/model2vec",
90
+ "src/distiller/tokenlearn"
91
  ]
92
 
93
  [tool.ruff.lint]
 
123
  "E501", # Line too long
124
  "PLR2004",
125
  "RUF001",
126
+ "D100", # Missing docstring in public module
127
+ "D101", # Missing docstring in public class
128
  ]
129
 
130
  [tool.ruff.lint.mccabe]
src/distiller/__main__.py CHANGED
@@ -17,12 +17,41 @@ def distill(
17
  train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False,
18
  teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None,
19
  pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  ) -> None:
21
  """Run unified Model2Vec distillation with optional training."""
22
  from .distill import main as distill_main
23
 
24
- # Call the distill main function with arguments
25
- distill_main(use_beam, train, teacher_models, pca_dims)
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  @app.command()
@@ -53,5 +82,26 @@ def analyze(
53
  analyze_main(results_dir or "code_model2vec/evaluation_results", model_name, output, export_csv)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  if __name__ == "__main__":
57
  app()
 
17
  train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False,
18
  teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None,
19
  pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None,
20
+ clear_cache: Annotated[
21
+ bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation")
22
+ ] = False,
23
+ clear_checkpoints: Annotated[
24
+ bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
25
+ ] = False,
26
+ skip_ptr: Annotated[
27
+ bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
28
+ ] = False,
29
+ use_optimized_dataset: Annotated[
30
+ bool,
31
+ typer.Option(
32
+ "--use-optimized-dataset", help="Use the pre-created optimized dataset from code_model2vec/dataset"
33
+ ),
34
+ ] = False,
35
+ dataset_path: Annotated[
36
+ str | None,
37
+ typer.Option("--dataset-path", help="Path to custom dataset directory (defaults to code_model2vec/dataset)"),
38
+ ] = None,
39
  ) -> None:
40
  """Run unified Model2Vec distillation with optional training."""
41
  from .distill import main as distill_main
42
 
43
+ # Call the distill main function with all arguments
44
+ distill_main(
45
+ use_beam,
46
+ train,
47
+ teacher_models,
48
+ pca_dims,
49
+ clear_cache,
50
+ clear_checkpoints,
51
+ skip_ptr,
52
+ use_optimized_dataset,
53
+ dataset_path,
54
+ )
55
 
56
 
57
  @app.command()
 
82
  analyze_main(results_dir or "code_model2vec/evaluation_results", model_name, output, export_csv)
83
 
84
 
85
+ @app.command()
86
+ def dataset(
87
+ max_samples_per_lang: Annotated[int, typer.Option(help="Maximum samples per language")] = 50000,
88
+ min_doc_words: Annotated[int, typer.Option(help="Minimum words in documentation")] = 3,
89
+ max_doc_words: Annotated[int, typer.Option(help="Maximum words in documentation")] = 100,
90
+ min_code_chars: Annotated[int, typer.Option(help="Minimum characters in code")] = 50,
91
+ max_code_chars: Annotated[int, typer.Option(help="Maximum characters in code")] = 2000,
92
+ output_dir: Annotated[str | None, typer.Option(help="Output directory for dataset")] = None,
93
+ simple_format: Annotated[
94
+ bool, typer.Option(help="Create only simple format (not multiple training formats)")
95
+ ] = False,
96
+ ) -> None:
97
+ """Create optimized training dataset from CodeSearchNet for code search tasks."""
98
+ from .dataset import main as dataset_main
99
+
100
+ # Call the dataset main function with arguments
101
+ dataset_main(
102
+ max_samples_per_lang, min_doc_words, max_doc_words, min_code_chars, max_code_chars, output_dir, simple_format
103
+ )
104
+
105
+
106
  if __name__ == "__main__":
107
  app()
src/distiller/analyze.py CHANGED
@@ -510,7 +510,7 @@ class CodeSearchNetAnalyzer:
510
 
511
  try:
512
  # Try to load the model and get specifications
513
- from model2vec import StaticModel
514
 
515
  model = StaticModel.from_pretrained(str(model_dir))
516
 
 
510
 
511
  try:
512
  # Try to load the model and get specifications
513
+ from distiller.model2vec import StaticModel
514
 
515
  model = StaticModel.from_pretrained(str(model_dir))
516
 
src/distiller/config.py CHANGED
@@ -212,13 +212,19 @@ class DistillationConfig(BaseModel):
212
  # Tokenlearn-specific parameters (POTION approach)
213
  tokenlearn_dataset: str = "sentence-transformers/codesearchnet" # Dataset for tokenlearn featurization
214
  tokenlearn_dataset_name: str = "pair" # Use 'pair' configuration (only available config)
215
- tokenlearn_text_key: str = "code" # Text field to use from the dataset ('code' or 'comment')
 
 
216
  tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
217
  tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
218
 
219
  # Post-training configuration
220
  skip_post_training_regularization: bool = False # Skip PCA + SIF re-regularization step
221
 
 
 
 
 
222
 
223
  distillation_config = DistillationConfig()
224
 
 
212
  # Tokenlearn-specific parameters (POTION approach)
213
  tokenlearn_dataset: str = "sentence-transformers/codesearchnet" # Dataset for tokenlearn featurization
214
  tokenlearn_dataset_name: str = "pair" # Use 'pair' configuration (only available config)
215
+ tokenlearn_text_key: str = (
216
+ "combined_text" # Text field to use from the dataset ('combined_text' for doc-code pairs)
217
+ )
218
  tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
219
  tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
220
 
221
  # Post-training configuration
222
  skip_post_training_regularization: bool = False # Skip PCA + SIF re-regularization step
223
 
224
+ # Dataset configuration
225
+ use_optimized_dataset: bool = True # Use the pre-created optimized dataset from dataset.py
226
+ custom_dataset_path: str | None = "code_model2vec/dataset" # Path to custom dataset directory
227
+
228
 
229
  distillation_config = DistillationConfig()
230
 
src/distiller/dataset.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Dataset Generation for Code-Specialized Model Training.
3
+
4
+ This module creates optimized training datasets from CodeSearchNet that are specifically
5
+ designed to improve performance on code search evaluation tasks.
6
+
7
+ Features:
8
+ - High-quality doc-code pairs optimized for retrieval
9
+ - Balanced sampling across programming languages
10
+ - Multiple training formats (doc-only, code-only, combined)
11
+ - Quality filtering and data cleaning
12
+ - Train/test/eval splits with proper stratification
13
+ - Efficient parquet format output
14
+ """
15
+
16
+ import json
17
+ import logging
18
+ import time
19
+ from pathlib import Path
20
+ from typing import Annotated, Any
21
+
22
+ import pandas as pd
23
+ import typer
24
+ from datasets import load_dataset
25
+ from tqdm import tqdm
26
+
27
+ from .config import languages_config
28
+
29
+ # Set up logging
30
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # Dataset configuration
34
+ DATASET_OUTPUT_DIR = Path("code_model2vec/dataset")
35
+ DEFAULT_MAX_SAMPLES_PER_LANG = 50000
36
+ DEFAULT_MIN_DOC_WORDS = 3
37
+ DEFAULT_MAX_DOC_WORDS = 100
38
+ DEFAULT_MIN_CODE_CHARS = 50
39
+ DEFAULT_MAX_CODE_CHARS = 2000
40
+
41
+
42
+ def create_optimized_dataset(
43
+ max_samples_per_lang: int = DEFAULT_MAX_SAMPLES_PER_LANG,
44
+ min_doc_words: int = DEFAULT_MIN_DOC_WORDS,
45
+ max_doc_words: int = DEFAULT_MAX_DOC_WORDS,
46
+ min_code_chars: int = DEFAULT_MIN_CODE_CHARS,
47
+ max_code_chars: int = DEFAULT_MAX_CODE_CHARS,
48
+ output_dir: Path | None = None,
49
+ create_multiple_formats: bool = True,
50
+ ) -> dict[str, Any]:
51
+ """
52
+ Create optimized training dataset from CodeSearchNet for code search tasks.
53
+
54
+ Args:
55
+ max_samples_per_lang: Maximum samples per programming language
56
+ min_doc_words: Minimum words in documentation
57
+ max_doc_words: Maximum words in documentation
58
+ min_code_chars: Minimum characters in code
59
+ max_code_chars: Maximum characters in code
60
+ output_dir: Output directory for dataset
61
+ create_multiple_formats: Create multiple training formats
62
+
63
+ Returns:
64
+ Dictionary with dataset statistics and file paths
65
+ """
66
+ output_dir = DATASET_OUTPUT_DIR if output_dir is None else Path(output_dir)
67
+
68
+ output_dir.mkdir(parents=True, exist_ok=True)
69
+
70
+ logger.info("🚀 Starting optimized CodeSearchNet dataset creation...")
71
+ logger.info(f"📁 Output directory: {output_dir}")
72
+ logger.info(f"📊 Target: {max_samples_per_lang} samples per language")
73
+ logger.info(f"🔍 Languages: {', '.join(languages_config.all)}")
74
+
75
+ start_time = time.time()
76
+ all_samples = []
77
+ language_stats = {}
78
+
79
+ # Process each programming language
80
+ for language in languages_config.all:
81
+ logger.info(f"\n🔄 Processing {language}...")
82
+
83
+ try:
84
+ # Load CodeSearchNet dataset for this language
85
+ dataset = load_dataset("code_search_net", language, split="train", trust_remote_code=True)
86
+
87
+ language_samples = []
88
+ processed_count = 0
89
+ quality_filtered = 0
90
+
91
+ # Process examples with quality filtering
92
+ for example in tqdm(dataset, desc=f"Processing {language}", unit="examples"):
93
+ processed_count += 1
94
+
95
+ # Extract documentation and code
96
+ doc_string = example.get("func_documentation_string", "").strip()
97
+ code_string = example.get("func_code_string", "").strip()
98
+ func_name = example.get("func_name", "").strip()
99
+
100
+ # Quality filters
101
+ if not _passes_quality_filters(
102
+ doc_string, code_string, func_name, min_doc_words, max_doc_words, min_code_chars, max_code_chars
103
+ ):
104
+ continue
105
+
106
+ quality_filtered += 1
107
+
108
+ # Create optimized training samples
109
+ samples = _create_training_samples(
110
+ doc_string, code_string, func_name, language, create_multiple_formats
111
+ )
112
+ language_samples.extend(samples)
113
+
114
+ # Stop if we have enough samples
115
+ if len(language_samples) >= max_samples_per_lang:
116
+ break
117
+
118
+ # Truncate to exact target size
119
+ language_samples = language_samples[:max_samples_per_lang]
120
+ all_samples.extend(language_samples)
121
+
122
+ # Track statistics
123
+ language_stats[language] = {
124
+ "processed": processed_count,
125
+ "quality_filtered": quality_filtered,
126
+ "final_samples": len(language_samples),
127
+ "quality_rate": quality_filtered / processed_count if processed_count > 0 else 0,
128
+ }
129
+
130
+ logger.info(f"✅ {language}: {len(language_samples)} samples from {quality_filtered} quality examples")
131
+
132
+ except Exception:
133
+ logger.exception(f"❌ Failed to process {language}")
134
+ language_stats[language] = {
135
+ "processed": 0,
136
+ "quality_filtered": 0,
137
+ "final_samples": 0,
138
+ "quality_rate": 0.0,
139
+ }
140
+
141
+ # Create DataFrame
142
+ logger.info(f"\n📊 Creating dataset with {len(all_samples)} total samples...")
143
+ df = pd.DataFrame(all_samples)
144
+
145
+ # Create stratified splits
146
+ train_df, test_df = _create_stratified_splits(df)
147
+
148
+ # Save datasets
149
+ dataset_files = _save_datasets(output_dir, train_df, test_df)
150
+
151
+ # Save metadata
152
+ metadata = {
153
+ "creation_time": time.strftime("%Y-%m-%d %H:%M:%S"),
154
+ "total_samples": len(all_samples),
155
+ "train_samples": len(train_df),
156
+ "test_samples": len(test_df),
157
+ "languages": languages_config.all,
158
+ "language_stats": language_stats,
159
+ "quality_filters": {
160
+ "min_doc_words": min_doc_words,
161
+ "max_doc_words": max_doc_words,
162
+ "min_code_chars": min_code_chars,
163
+ "max_code_chars": max_code_chars,
164
+ },
165
+ "files": dataset_files,
166
+ "processing_time": time.time() - start_time,
167
+ }
168
+
169
+ metadata_file = output_dir / "metadata.json"
170
+ with metadata_file.open("w") as f:
171
+ json.dump(metadata, f, indent=2)
172
+
173
+ logger.info(f"\n🎉 Dataset creation completed in {metadata['processing_time']:.2f} seconds!")
174
+ logger.info("📊 Final statistics:")
175
+ logger.info(f" - Total samples: {metadata['total_samples']}")
176
+ logger.info(f" - Train: {metadata['train_samples']}")
177
+ logger.info(f" - Test: {metadata['test_samples']}")
178
+ logger.info(f"💾 Metadata saved to: {metadata_file}")
179
+
180
+ return metadata
181
+
182
+
183
+ def _passes_quality_filters(
184
+ doc_string: str,
185
+ code_string: str,
186
+ func_name: str,
187
+ min_doc_words: int,
188
+ max_doc_words: int,
189
+ min_code_chars: int,
190
+ max_code_chars: int,
191
+ ) -> bool:
192
+ """Apply quality filters optimized for code retrieval following RAG best practices."""
193
+ # Basic existence checks
194
+ if not doc_string or not code_string or not func_name:
195
+ return False
196
+
197
+ # Documentation quality filters for code retrieval
198
+ doc_words = len(doc_string.split())
199
+ if doc_words < min_doc_words or doc_words > max_doc_words:
200
+ return False
201
+
202
+ # Code quality filters
203
+ code_length = len(code_string)
204
+ if code_length < min_code_chars or code_length > max_code_chars:
205
+ return False
206
+
207
+ # Content quality filters for code retrieval
208
+ doc_lower = doc_string.lower()
209
+ code_string.lower()
210
+
211
+ # Skip low-quality documentation (expanded for code context)
212
+ skip_phrases = [
213
+ "todo",
214
+ "fixme",
215
+ "hack",
216
+ "temp",
217
+ "test",
218
+ "placeholder",
219
+ "not implemented",
220
+ "coming soon",
221
+ "tbd",
222
+ "xxx",
223
+ "broken",
224
+ "deprecated",
225
+ "legacy",
226
+ "old version",
227
+ "outdated",
228
+ ]
229
+ if any(phrase in doc_lower for phrase in skip_phrases):
230
+ return False
231
+
232
+ # Ensure meaningful documentation for code retrieval
233
+ if func_name.lower() in doc_lower and doc_words < 5:
234
+ return False
235
+
236
+ # Code structure validation (more comprehensive for retrieval)
237
+ has_function = any(
238
+ pattern in code_string for pattern in ["def ", "function ", "class ", "public ", "private ", "static "]
239
+ )
240
+ if not has_function:
241
+ return False
242
+
243
+ # Skip trivial or incomplete code
244
+ trivial_code_patterns = [
245
+ "pass",
246
+ "return None",
247
+ "return;",
248
+ "throw new Error",
249
+ "# TODO",
250
+ "// TODO",
251
+ "print(",
252
+ "console.log(",
253
+ ]
254
+ if any(pattern in code_string for pattern in trivial_code_patterns) and len(code_string) < 100:
255
+ return False
256
+
257
+ # Ensure documentation describes functionality (not just naming)
258
+ generic_docs = [
259
+ "returns a value",
260
+ "does something",
261
+ "helper function",
262
+ "utility method",
263
+ "this function",
264
+ "this method",
265
+ "returns the result",
266
+ "performs operation",
267
+ ]
268
+ if any(generic in doc_lower for generic in generic_docs):
269
+ return False
270
+
271
+ # Ensure documentation has descriptive content for retrieval
272
+ descriptive_words = [
273
+ "parse",
274
+ "convert",
275
+ "transform",
276
+ "calculate",
277
+ "validate",
278
+ "format",
279
+ "filter",
280
+ "sort",
281
+ "search",
282
+ "find",
283
+ "create",
284
+ "generate",
285
+ "process",
286
+ "handle",
287
+ "manage",
288
+ "update",
289
+ "modify",
290
+ "remove",
291
+ "delete",
292
+ "add",
293
+ ]
294
+ if not any(word in doc_lower for word in descriptive_words) and doc_words < 8:
295
+ return False
296
+
297
+ # Code-documentation alignment check (key for retrieval quality)
298
+ return _check_code_doc_alignment(doc_string, code_string, func_name)
299
+
300
+
301
+ def _check_code_doc_alignment(doc_string: str, code_string: str, func_name: str) -> bool:
302
+ """Check if documentation and code are well-aligned for retrieval tasks."""
303
+ doc_lower = doc_string.lower()
304
+ code_lower = code_string.lower()
305
+
306
+ # Function name should relate to documentation
307
+ func_base = func_name.lower().replace("_", " ").replace("-", " ")
308
+
309
+ # Check for obvious mismatches
310
+ doc_has_return = any(word in doc_lower for word in ["return", "returns", "gives", "outputs"])
311
+ code_has_return = "return " in code_lower
312
+
313
+ # If doc mentions returning something, code should have returns
314
+ if doc_has_return and not code_has_return and len(code_string.split("\n")) > 3:
315
+ return False
316
+
317
+ # Check for parameter mentions alignment
318
+ any(word in doc_lower for word in ["parameter", "param", "argument", "input"])
319
+ "(" in func_name and func_name.count("(") == 1
320
+
321
+ # Basic semantic alignment
322
+ action_words = ["sort", "parse", "convert", "validate", "format", "filter", "search", "calculate"]
323
+ doc_actions = [word for word in action_words if word in doc_lower]
324
+ [word for word in action_words if word in code_lower or word in func_base]
325
+
326
+ # If documentation mentions specific actions, code or function name should reflect them
327
+ return not (doc_actions and not any(action in code_lower or action in func_base for action in doc_actions))
328
+
329
+
330
+ def _create_training_samples(
331
+ doc_string: str,
332
+ code_string: str,
333
+ func_name: str,
334
+ language: str,
335
+ create_multiple_formats: bool,
336
+ ) -> list[dict[str, Any]]:
337
+ """Create optimized training samples for code retrieval with proper training schema."""
338
+ samples = []
339
+
340
+ if create_multiple_formats:
341
+ # Format 1: Documentation query → Code (direct evaluation format)
342
+ query_1 = doc_string
343
+ text_1 = _format_training_text(query_1, code_string, language)
344
+ samples.append(
345
+ {
346
+ "language": language,
347
+ "query": query_1,
348
+ "code": code_string,
349
+ "text": text_1,
350
+ }
351
+ )
352
+
353
+ # Format 2: How-to query (realistic developer search)
354
+ query_2 = _generate_how_to_query(doc_string, func_name, language)
355
+ text_2 = _format_training_text(query_2, code_string, language)
356
+ samples.append(
357
+ {
358
+ "language": language,
359
+ "query": query_2,
360
+ "code": code_string,
361
+ "text": text_2,
362
+ }
363
+ )
364
+
365
+ # Format 3: Functional requirement query
366
+ query_3 = _generate_functional_query(doc_string, func_name)
367
+ text_3 = _format_training_text(query_3, code_string, language)
368
+ samples.append(
369
+ {
370
+ "language": language,
371
+ "query": query_3,
372
+ "code": code_string,
373
+ "text": text_3,
374
+ }
375
+ )
376
+
377
+ # Format 4: Implementation-specific query
378
+ query_4 = _generate_implementation_query(doc_string, func_name, language)
379
+ text_4 = _format_training_text(query_4, code_string, language)
380
+ samples.append(
381
+ {
382
+ "language": language,
383
+ "query": query_4,
384
+ "code": code_string,
385
+ "text": text_4,
386
+ }
387
+ )
388
+
389
+ else:
390
+ # Simple format - direct documentation to code
391
+ query = doc_string
392
+ text = _format_training_text(query, code_string, language)
393
+ samples.append(
394
+ {
395
+ "language": language,
396
+ "query": query,
397
+ "code": code_string,
398
+ "text": text,
399
+ }
400
+ )
401
+
402
+ return samples
403
+
404
+
405
+ def _format_training_text(query: str, code: str, language: str) -> str:
406
+ """Format query and code into a single training text chunk with markdown-style code blocks."""
407
+ # Clean up query but preserve internal code formatting
408
+ query_clean = query.strip()
409
+ code_clean = code.strip()
410
+
411
+ # Create training text with proper markdown format and newline separation
412
+ # Structure: query + empty line + markdown code block with language
413
+ return f"{query_clean}\n\n```{language}\n{code_clean}\n```"
414
+
415
+
416
+ def _generate_how_to_query(doc_string: str, func_name: str, language: str) -> str:
417
+ """Generate realistic 'how to' queries that developers might actually search for."""
418
+ # Extract key action words from documentation
419
+ doc_lower = doc_string.lower()
420
+ func_lower = func_name.lower()
421
+
422
+ # Common developer query patterns
423
+ if "sort" in doc_lower or "sort" in func_lower:
424
+ return f"How to sort data in {language}"
425
+ if "parse" in doc_lower or "parse" in func_lower:
426
+ return f"How to parse data in {language}"
427
+ if "convert" in doc_lower or "transform" in doc_lower or "convert" in func_lower:
428
+ return f"How to convert data in {language}"
429
+ if "validate" in doc_lower or "check" in doc_lower or "validate" in func_lower:
430
+ return f"How to validate input in {language}"
431
+ if "calculate" in doc_lower or "compute" in doc_lower or "calc" in func_lower:
432
+ return f"How to calculate values in {language}"
433
+ if "format" in doc_lower or "format" in func_lower:
434
+ return f"How to format output in {language}"
435
+ if "filter" in doc_lower or "filter" in func_lower:
436
+ return f"How to filter data in {language}"
437
+ if "search" in doc_lower or "find" in doc_lower or "search" in func_lower or "find" in func_lower:
438
+ return f"How to search through data in {language}"
439
+ # Use function name for more specific queries
440
+ if func_name and len(func_name) > 2:
441
+ # Extract meaningful words from function name
442
+ func_words = func_name.replace("_", " ").replace("-", " ").strip()
443
+ if func_words:
444
+ return f"How to {func_words.lower()} in {language}"
445
+ # Fallback to more generic query
446
+ action = doc_string.split()[0] if doc_string.split() else "implement"
447
+ return f"How to {action.lower()} in {language}"
448
+
449
+
450
+ def _generate_functional_query(doc_string: str, func_name: str) -> str:
451
+ """Generate functional requirement queries focusing on what the code accomplishes."""
452
+ # Clean up documentation to create natural query
453
+ doc_clean = doc_string.strip().rstrip(".")
454
+
455
+ # Transform to question format
456
+ if doc_clean.startswith(("Returns", "Return")):
457
+ return f"Function that {doc_clean.lower()}"
458
+ if doc_clean.startswith(("Creates", "Create")):
459
+ return f"Code to {doc_clean.lower()}"
460
+ if doc_clean.startswith(("Checks", "Check")):
461
+ return f"Function to {doc_clean.lower()}"
462
+
463
+ # Use function name to enhance the query if available
464
+ if func_name and len(func_name) > 2:
465
+ func_words = func_name.replace("_", " ").replace("-", " ").strip()
466
+ if func_words and len(doc_clean) < 30: # Only for short docs
467
+ return f"Function named '{func_name}' that {doc_clean.lower()}"
468
+
469
+ return f"Implementation that {doc_clean.lower()}"
470
+
471
+
472
+ def _generate_implementation_query(doc_string: str, func_name: str, language: str) -> str:
473
+ """Generate implementation-specific queries with technical details."""
474
+ doc_lower = doc_string.lower()
475
+ func_lower = func_name.lower() if func_name else ""
476
+
477
+ # Add language-specific implementation details
478
+ if language == "python":
479
+ if "list" in doc_lower or "array" in doc_lower or "list" in func_lower:
480
+ return f"Python function to {doc_string.lower()} using lists"
481
+ if "dict" in doc_lower or "hash" in doc_lower or "dict" in func_lower:
482
+ return f"Python function to {doc_string.lower()} using dictionaries"
483
+ # Include function name for context if available
484
+ if func_name and len(func_name) > 2:
485
+ return f"Python implementation of {func_name}: {doc_string.lower()}"
486
+ return f"Python implementation: {doc_string.lower()}"
487
+ if language == "java":
488
+ func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
489
+ return f"Java method to {doc_string.lower()}{func_suffix}"
490
+ if language == "javascript":
491
+ func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
492
+ return f"JavaScript function to {doc_string.lower()}{func_suffix}"
493
+ if language == "php":
494
+ func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
495
+ return f"PHP function to {doc_string.lower()}{func_suffix}"
496
+ if language == "ruby":
497
+ func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
498
+ return f"Ruby method to {doc_string.lower()}{func_suffix}"
499
+ if language == "go":
500
+ func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
501
+ return f"Go function to {doc_string.lower()}{func_suffix}"
502
+ return f"{language} code to {doc_string.lower()}"
503
+
504
+
505
+ def _create_stratified_splits(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
506
+ """Create stratified train/test splits preserving language distribution."""
507
+ # Define split ratios
508
+ train_ratio = 0.9
509
+ # test_ratio = 0.1 (remainder)
510
+
511
+ train_dfs = []
512
+ test_dfs = []
513
+
514
+ # Split by language to ensure balanced representation
515
+ for language in df["language"].unique():
516
+ lang_df = df[df["language"] == language].copy()
517
+ n_samples = len(lang_df)
518
+
519
+ # Calculate split sizes
520
+ n_train = int(n_samples * train_ratio)
521
+ # Remainder goes to test
522
+
523
+ # Shuffle and split
524
+ lang_df = lang_df.sample(frac=1, random_state=42).reset_index(drop=True)
525
+
526
+ train_dfs.append(lang_df[:n_train])
527
+ test_dfs.append(lang_df[n_train:])
528
+
529
+ # Combine and shuffle again
530
+ train_df = pd.concat(train_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
531
+ test_df = pd.concat(test_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
532
+
533
+ logger.info("📊 Created stratified splits:")
534
+ logger.info(f" - Train: {len(train_df)} samples")
535
+ logger.info(f" - Test: {len(test_df)} samples")
536
+
537
+ return train_df, test_df
538
+
539
+
540
+ def _save_datasets(
541
+ output_dir: Path,
542
+ train_df: pd.DataFrame,
543
+ test_df: pd.DataFrame,
544
+ ) -> dict[str, str]:
545
+ """Save datasets in parquet format with compression."""
546
+ dataset_files = {}
547
+
548
+ # Save each split
549
+ for split_name, df in [("train", train_df), ("test", test_df)]:
550
+ filepath = output_dir / f"{split_name}.parquet"
551
+ df.to_parquet(
552
+ filepath,
553
+ compression="snappy",
554
+ index=False,
555
+ )
556
+ dataset_files[split_name] = str(filepath)
557
+ logger.info(f"💾 Saved {split_name}: {len(df)} samples → {filepath}")
558
+
559
+ # Also save a combined dataset for convenience
560
+ combined_df = pd.concat([train_df, test_df], ignore_index=True)
561
+ combined_filepath = output_dir / "combined.parquet"
562
+ combined_df.to_parquet(combined_filepath, compression="snappy", index=False)
563
+ dataset_files["combined"] = str(combined_filepath)
564
+ logger.info(f"💾 Saved combined: {len(combined_df)} samples → {combined_filepath}")
565
+
566
+ return dataset_files
567
+
568
+
569
+ def load_optimized_dataset(
570
+ output_dir: Path | None = None,
571
+ split: str = "train",
572
+ ) -> pd.DataFrame:
573
+ """
574
+ Load a previously created optimized dataset.
575
+
576
+ Args:
577
+ output_dir: Directory containing the dataset files
578
+ split: Which split to load ('train', 'test', 'combined')
579
+
580
+ Returns:
581
+ DataFrame with the requested dataset split
582
+ """
583
+ if output_dir is None:
584
+ output_dir = DATASET_OUTPUT_DIR
585
+
586
+ filepath = output_dir / f"{split}.parquet"
587
+
588
+ if not filepath.exists():
589
+ available_files = list(output_dir.glob("*.parquet"))
590
+ available_splits = [f.stem for f in available_files]
591
+ msg = f"Dataset split '{split}' not found at {filepath}. Available splits: {available_splits}"
592
+ raise FileNotFoundError(msg)
593
+
594
+ logger.info(f"📂 Loading {split} dataset from {filepath}")
595
+ df = pd.read_parquet(filepath)
596
+ logger.info(f"✅ Loaded {len(df)} samples")
597
+
598
+ return df
599
+
600
+
601
+ def main(
602
+ max_samples_per_lang: Annotated[
603
+ int, typer.Option(help="Maximum samples per language")
604
+ ] = DEFAULT_MAX_SAMPLES_PER_LANG,
605
+ min_doc_words: Annotated[int, typer.Option(help="Minimum words in documentation")] = DEFAULT_MIN_DOC_WORDS,
606
+ max_doc_words: Annotated[int, typer.Option(help="Maximum words in documentation")] = DEFAULT_MAX_DOC_WORDS,
607
+ min_code_chars: Annotated[int, typer.Option(help="Minimum characters in code")] = DEFAULT_MIN_CODE_CHARS,
608
+ max_code_chars: Annotated[int, typer.Option(help="Maximum characters in code")] = DEFAULT_MAX_CODE_CHARS,
609
+ output_dir: Annotated[str | None, typer.Option(help="Output directory for dataset")] = None,
610
+ simple_format: Annotated[
611
+ bool, typer.Option(help="Create only simple format (not multiple training formats)")
612
+ ] = False,
613
+ ) -> None:
614
+ """Create optimized training dataset from CodeSearchNet for code search tasks."""
615
+ logger.info("🚀 Starting optimized dataset creation command...")
616
+
617
+ # Convert output_dir to Path if provided
618
+ output_path = Path(output_dir) if output_dir else None
619
+
620
+ # Create the dataset
621
+ try:
622
+ metadata = create_optimized_dataset(
623
+ max_samples_per_lang=max_samples_per_lang,
624
+ min_doc_words=min_doc_words,
625
+ max_doc_words=max_doc_words,
626
+ min_code_chars=min_code_chars,
627
+ max_code_chars=max_code_chars,
628
+ output_dir=output_path,
629
+ create_multiple_formats=not simple_format,
630
+ )
631
+
632
+ logger.info("✅ Dataset creation completed successfully!")
633
+ logger.info(f"📁 Output directory: {metadata['files']['train']}")
634
+
635
+ # Print summary statistics
636
+ print("\n" + "=" * 60)
637
+ print("📊 DATASET CREATION SUMMARY")
638
+ print("=" * 60)
639
+ print(f"Total samples created: {metadata['total_samples']:,}")
640
+ print(f"Processing time: {metadata['processing_time']:.2f} seconds")
641
+ print("\nSplit distribution:")
642
+ print(f" • Train: {metadata['train_samples']:,} samples")
643
+ print(f" • Test: {metadata['test_samples']:,} samples")
644
+
645
+ print("\nLanguage distribution:")
646
+ for lang, stats in metadata["language_stats"].items():
647
+ if "error" not in stats:
648
+ print(f" • {lang}: {stats['final_samples']:,} samples ({stats['quality_rate']:.1%} quality rate)")
649
+
650
+ print(f"\nDataset files saved to: {output_path or DATASET_OUTPUT_DIR}")
651
+ print("=" * 60)
652
+
653
+ except Exception as e:
654
+ logger.exception("❌ Dataset creation failed")
655
+ raise typer.Exit(1) from e
656
+
657
+
658
+ if __name__ == "__main__":
659
+ typer.run(main)
src/distiller/distill.py CHANGED
@@ -28,13 +28,14 @@ import time
28
  from pathlib import Path
29
  from typing import Annotated, Any
30
 
 
31
  import torch
32
  import typer
33
  from beam import function
34
- from datasets import load_dataset
35
- from model2vec.distill import distill
36
  from sentence_transformers import SentenceTransformer
37
 
 
 
38
  # Try to import flash_attn to check if it's available
39
  from .beam_utils import (
40
  BeamCheckpointManager,
@@ -145,25 +146,6 @@ def load_model_with_flash_attention(model_path: str, device: str = "auto") -> Se
145
  # =============================================================================
146
 
147
 
148
- def apply_local_patches() -> bool:
149
- """Apply patches locally without requiring Beam utilities."""
150
- try:
151
- try:
152
- from .patch_utils import apply_all_patches
153
-
154
- patches_applied = apply_all_patches()
155
- logger.info(f"Successfully applied {patches_applied} patches via patch_utils")
156
- return True
157
- except ImportError:
158
- logger.warning("patch_utils not available, trying direct patching")
159
-
160
- return False
161
-
162
- except Exception as e:
163
- logger.warning(f"Failed to apply patches: {e}")
164
- return False
165
-
166
-
167
  def get_current_config_hash(enable_training: bool) -> str:
168
  """Generate a hash of current configuration parameters for checkpoint validation."""
169
  import hashlib
@@ -217,22 +199,22 @@ def check_existing_final_model(teacher_name: str, enable_training: bool = False)
217
  model_name = f"code_model2vec_{teacher_name}"
218
  if enable_training:
219
  model_name += "_fine_tuned"
220
- model_dir = final_dir / model_name
221
 
222
- if model_dir.exists():
223
  # Check for essential model files
224
- has_config = (model_dir / "config.json").exists()
225
  has_model_file = any(
226
  [
227
- (model_dir / "model.safetensors").exists(),
228
- (model_dir / "model.bin").exists(),
229
- (model_dir / "pytorch_model.bin").exists(),
230
  ]
231
  )
232
 
233
  if has_config and has_model_file:
234
  logger.info(f"✅ Found existing final model: {teacher_name}{'_fine_tuned' if enable_training else ''}")
235
- return str(model_dir)
236
 
237
  return None
238
 
@@ -427,11 +409,65 @@ def simple_distillation(
427
  return None
428
 
429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  def load_codesearchnet_dataset(
431
  max_samples: int = 50000,
432
  checkpoint_manager: BeamCheckpointManager | None = None,
433
  ) -> list[str]:
434
  """Load and format the CodeSearchNet dataset for token frequency computation."""
 
 
435
  logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
436
  logger.info(f"Limiting to {max_samples} samples for training efficiency")
437
  logger.info(f"Languages: {', '.join(languages_config.all)}")
@@ -482,6 +518,8 @@ def load_codesearchnet_dataset(
482
 
483
  try:
484
  # Load training split for the specific language (same format as evaluate.py)
 
 
485
  dataset = load_dataset(
486
  codesearchnet_config.dataset_name,
487
  language,
@@ -709,8 +747,33 @@ def compute_token_frequencies_for_sif(
709
  logger.info("📊 Computing token frequencies for SIF weighting...")
710
 
711
  try:
712
- # Load CodeSearchNet dataset to compute frequencies (limited sample for efficiency)
713
- dataset_texts = load_codesearchnet_dataset(max_samples=10000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
 
715
  logger.info(f"📊 Computing frequencies on {len(dataset_texts)} texts...")
716
 
@@ -763,7 +826,6 @@ def apply_post_training_regularization(
763
  """
764
  import json
765
 
766
- import numpy as np
767
  from sklearn.decomposition import PCA
768
 
769
  logger.info("🔧 Starting post-training re-regularization (POTION Step 4)")
@@ -836,7 +898,7 @@ def apply_post_training_regularization(
836
  final_embeddings = embeddings_pca.astype(np.float32)
837
 
838
  # Create new model with updated embeddings
839
- from model2vec.model import StaticModel
840
 
841
  # Save tokenizer and config from original model
842
  tokenizer = model.tokenizer
@@ -866,7 +928,6 @@ def tokenlearn_training(
866
  3. Tokenlearn training
867
  4. Post-training re-regularization (PCA + SIF weighting)
868
  """
869
- import subprocess
870
  from pathlib import Path
871
 
872
  logger.info("🧪 Starting tokenlearn training (POTION approach)...")
@@ -914,6 +975,9 @@ def tokenlearn_training(
914
 
915
  logger.info(f"📊 Using teacher model: {teacher_model_name}")
916
 
 
 
 
917
  # Check if featurization already completed (checkpoint detection)
918
  featurization_complete_marker = features_dir / ".featurization_complete"
919
  if featurization_complete_marker.exists() and verify_featurization_output(features_dir):
@@ -936,47 +1000,42 @@ def tokenlearn_training(
936
  logger.info(f"📊 Using teacher model: {teacher_model_name}")
937
 
938
  try:
939
- # Use configured dataset for code specialization
940
- featurize_cmd = [
941
- "python",
942
- "-m",
943
- "tokenlearn.featurize",
944
- "--model-name",
945
- str(teacher_model_name),
946
- "--output-dir",
947
- str(features_dir),
948
- "--dataset-path",
949
- str(distillation_config.tokenlearn_dataset),
950
- "--dataset-name",
951
- str(distillation_config.tokenlearn_dataset_name),
952
- "--dataset-split",
953
- "train",
954
- "--key",
955
- str(distillation_config.tokenlearn_text_key), # Use configured text field
956
- "--batch-size",
957
- "1024", # Optimized batch size for A100-40G
958
- ]
959
 
960
  logger.info("🔄 Running tokenlearn featurization...")
961
- logger.info(
962
- f"📊 Dataset: {distillation_config.tokenlearn_dataset} (config: {distillation_config.tokenlearn_dataset_name})"
963
- )
964
- logger.info(f"📝 Text field: {distillation_config.tokenlearn_text_key}")
965
- logger.info(f"Command: {' '.join(featurize_cmd)}")
966
- print(f"\n🔄 Executing: {' '.join(featurize_cmd)}\n")
967
-
968
- result = subprocess.run( # noqa: S603
969
- featurize_cmd,
970
- text=True,
971
- timeout=distillation_config.tokenlearn_timeout_featurize,
972
- check=False,
973
- )
974
 
975
- if result.returncode != 0:
976
- logger.error(f"❌ Featurization failed with return code: {result.returncode}")
977
- logger.error("💥 Tokenlearn featurization is required for training - cannot proceed")
978
- msg = f"Tokenlearn featurization failed with return code: {result.returncode}"
979
- raise RuntimeError(msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980
 
981
  logger.info("✅ Featurization completed successfully")
982
 
@@ -1025,65 +1084,74 @@ def tokenlearn_training(
1025
  logger.info("🔄 No valid training checkpoint found - starting training...")
1026
 
1027
  try:
1028
- train_cmd = [
1029
- "python",
1030
- "-m",
1031
- "tokenlearn.train",
1032
- "--model-name",
1033
- str(teacher_model_name),
1034
- "--data-path",
1035
- str(features_dir),
1036
- "--save-path",
1037
- str(trained_dir),
1038
- ]
1039
 
1040
- logger.info("🔄 Running tokenlearn training...")
1041
- logger.info(f"Command: {' '.join(train_cmd)}")
1042
- print(f"\n🎓 Executing: {' '.join(train_cmd)}\n")
1043
 
1044
- result = subprocess.run( # noqa: S603
1045
- train_cmd,
1046
- text=True,
1047
- capture_output=True, # Capture stdout and stderr
1048
- timeout=distillation_config.tokenlearn_timeout_train,
1049
- check=False,
1050
- )
 
 
 
 
 
 
 
 
 
1051
 
1052
- if result.returncode != 0:
1053
- logger.error(f"❌ Tokenlearn training failed with return code: {result.returncode}")
 
 
 
1054
 
1055
- # Log the actual error output for debugging
1056
- if result.stderr:
1057
- logger.error(f"stderr: {result.stderr}")
1058
- if result.stdout:
1059
- logger.info(f"stdout: {result.stdout}")
1060
 
1061
- # Check if it's the token-vector mismatch issue
1062
- error_output = str(result.stderr) + str(result.stdout)
1063
- if "Number of tokens" in error_output and "does not match number of vectors" in error_output:
1064
- logger.error("🔧 Token-vector mismatch detected in tokenlearn")
1065
- logger.error("💥 This is a known issue with tokenlearn/Model2Vec integration")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1066
 
1067
  # Create training marker to indicate we tried but failed
1068
  training_fallback_marker = trained_dir / ".training_fallback"
1069
  training_fallback_marker.touch()
1070
 
1071
- logger.error(" Tokenlearn training failed due to token-vector mismatch")
1072
- msg = f"Tokenlearn training failed with token-vector mismatch: {error_output}"
1073
- raise RuntimeError(msg)
1074
- logger.error("💥 Tokenlearn training failed with different error")
1075
- msg = f"Tokenlearn training failed with return code: {result.returncode}"
1076
- raise RuntimeError(msg)
1077
- logger.info("✅ Tokenlearn training completed successfully")
1078
-
1079
- # Create checkpoint marker to indicate training is complete
1080
- training_complete_marker.touch()
1081
- logger.info(f"💾 Created training checkpoint: {training_complete_marker}")
1082
 
1083
  except Exception as e:
1084
- logger.exception("💥 Tokenlearn training failed")
1085
- logger.exception("💥 Tokenlearn training is required - cannot proceed")
1086
- msg = f"Tokenlearn training failed: {e}"
1087
  raise RuntimeError(msg) from e
1088
 
1089
  # Step 4: Load the trained model and apply post-training re-regularization
@@ -1098,7 +1166,7 @@ def tokenlearn_training(
1098
  raise RuntimeError(msg)
1099
 
1100
  try:
1101
- from model2vec.model import StaticModel
1102
 
1103
  # Load the trained model from tokenlearn
1104
  trained_model_path = trained_dir / "model"
@@ -1213,12 +1281,13 @@ def distill_single_teacher(
1213
  existing_final = check_existing_final_model(teacher_name, enable_training)
1214
  if existing_final:
1215
  logger.info(f"✅ Final model already exists: {teacher_name}{'_fine_tuned' if enable_training else ''}")
 
1216
  return {
1217
  "teacher_model": teacher_model,
1218
  "teacher_name": teacher_name,
1219
  "status": "skipped_existing_final",
1220
  "final_path": existing_final,
1221
- "distillation_time": 0.0,
1222
  }
1223
 
1224
  # Step 1.5: Sync existing checkpoints from Beam if using Beam utilities
@@ -1236,7 +1305,7 @@ def distill_single_teacher(
1236
  logger.info(f"✅ Found existing base model: {teacher_name}")
1237
  if enable_training:
1238
  # Load base model for training
1239
- from model2vec.model import StaticModel
1240
 
1241
  base_model = StaticModel.from_pretrained(existing_base)
1242
  elif use_beam_utilities:
@@ -1244,7 +1313,7 @@ def distill_single_teacher(
1244
  if synced:
1245
  existing_base = str(base_dir)
1246
  if enable_training:
1247
- from model2vec.model import StaticModel
1248
 
1249
  base_model = StaticModel.from_pretrained(existing_base)
1250
 
@@ -1263,11 +1332,13 @@ def distill_single_teacher(
1263
  base_model = simple_distillation(teacher_model, str(base_dir), pca_dims)
1264
 
1265
  if base_model is None:
 
1266
  return {
1267
  "teacher_model": teacher_model,
1268
  "teacher_name": teacher_name,
1269
  "status": "failed_base_distillation",
1270
  "error": "Simple distillation failed",
 
1271
  }
1272
 
1273
  # Sync base model and checkpoints to Beam
@@ -1280,71 +1351,74 @@ def distill_single_teacher(
1280
 
1281
  existing_base = str(base_dir)
1282
 
1283
- # Step 3: Handle final model creation
1284
- if enable_training and base_model is not None:
 
 
 
 
 
 
 
 
1285
  # Perform tokenlearn training (POTION approach)
1286
- logger.info(f"🧪 Starting tokenlearn training for {teacher_name}")
 
 
 
 
 
1287
 
1288
- try:
1289
- # Load teacher model for training
1290
- device = "cuda" if torch.cuda.is_available() else "cpu"
1291
- teacher_st_model = load_model_with_flash_attention(teacher_model, device)
1292
-
1293
- # Perform tokenlearn training (POTION approach)
1294
- final_model = tokenlearn_training(
1295
- base_model,
1296
- teacher_st_model,
1297
- checkpoint_mgr,
1298
- skip_post_training_regularization=distillation_config.skip_post_training_regularization,
1299
- )
1300
 
1301
- # Save final model
1302
- final_dir.mkdir(parents=True, exist_ok=True)
1303
- final_model.save_pretrained(str(final_dir))
 
 
 
 
1304
 
1305
- # Sync final model and training checkpoints to Beam
1306
- if use_beam_utilities:
1307
- sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities)
1308
- if checkpoint_mgr:
1309
- sync_checkpoints_to_beam(
1310
- VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints
1311
- )
1312
 
 
 
 
 
 
 
1313
  del teacher_st_model
1314
- if torch.cuda.is_available():
1315
- torch.cuda.empty_cache()
1316
-
1317
- except RuntimeError as e:
1318
- # Training failed - clean up and return failure
1319
- logger.exception(f"❌ Training failed for {teacher_name}")
1320
-
1321
- # Clean up teacher model if it was loaded
1322
- if "teacher_st_model" in locals():
1323
- del teacher_st_model
1324
- if torch.cuda.is_available():
1325
- torch.cuda.empty_cache()
1326
-
1327
- return {
1328
- "teacher_model": teacher_model,
1329
- "teacher_name": teacher_name,
1330
- "status": "failed_training",
1331
- "error": f"Training failed: {e!s}",
1332
- "base_path": existing_base, # Base model was created successfully
1333
- }
1334
 
1335
- else:
1336
- # Copy base to final (no training)
1337
- logger.info(f"📁 Copying base to final for {teacher_name}")
1338
- if not copy_base_to_final(teacher_name, enable_training):
1339
- return {
1340
- "teacher_model": teacher_model,
1341
- "teacher_name": teacher_name,
1342
- "status": "failed_copy_to_final",
1343
- "error": "Failed to copy base to final",
1344
- }
1345
 
1346
- total_time = time.time() - start_time
 
 
 
 
 
 
 
 
 
 
 
1347
 
 
1348
  return {
1349
  "teacher_model": teacher_model,
1350
  "teacher_name": teacher_name,
@@ -1357,11 +1431,13 @@ def distill_single_teacher(
1357
 
1358
  except Exception as e:
1359
  logger.exception(f"❌ Failed to process {teacher_model}")
 
1360
  return {
1361
  "teacher_model": teacher_model,
1362
  "teacher_name": teacher_name,
1363
  "status": "failed",
1364
  "error": str(e),
 
1365
  }
1366
 
1367
 
@@ -1382,13 +1458,6 @@ def run_local_distillation(
1382
  if teacher_models is None:
1383
  teacher_models = DEFAULT_TEACHER_MODELS
1384
 
1385
- # Apply patches
1386
- patch_success = apply_local_patches()
1387
- if patch_success:
1388
- logger.info("✅ Successfully applied patches")
1389
- else:
1390
- logger.warning("⚠️ Failed to apply patches - some models may fail")
1391
-
1392
  results = {}
1393
  successful_models = []
1394
 
@@ -1468,13 +1537,6 @@ def _beam_distill_internal(
1468
  clear_cache: bool = False,
1469
  ) -> dict[str, Any]:
1470
  """Shared internal implementation for beam distillation."""
1471
- # Apply patches
1472
- patch_success = apply_local_patches()
1473
- if patch_success:
1474
- logger.info("✅ Successfully applied patches")
1475
- else:
1476
- logger.warning("⚠️ Failed to apply patches - some models may fail")
1477
-
1478
  if teacher_models is None:
1479
  teacher_models = DEFAULT_TEACHER_MODELS
1480
 
@@ -1647,6 +1709,16 @@ def main(
1647
  skip_ptr: Annotated[
1648
  bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
1649
  ] = False,
 
 
 
 
 
 
 
 
 
 
1650
  ) -> None:
1651
  """Unified distillation command with optional training."""
1652
  logger.info("🚀 Starting unified Model2Vec distillation workflow")
@@ -1656,6 +1728,13 @@ def main(
1656
  if skip_ptr and train:
1657
  logger.info("⏭️ Post-training re-regularization will be skipped (PCA + SIF weighting disabled)")
1658
 
 
 
 
 
 
 
 
1659
  logger.info(f"🎓 Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
1660
  logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
1661
 
@@ -1894,7 +1973,7 @@ def salesforce_model_distillation(
1894
  logger.info("✅ Successfully loaded with SentenceTransformer method")
1895
 
1896
  # Now use Model2Vec's distill_from_model function directly
1897
- from model2vec.distill.distillation import distill_from_model
1898
 
1899
  distilled_model = distill_from_model(
1900
  model=model,
@@ -2004,7 +2083,7 @@ def baai_bge_model_distillation(
2004
  return None
2005
 
2006
  # Now use Model2Vec's distill_from_model function directly
2007
- from model2vec.distill.distillation import distill_from_model
2008
 
2009
  distilled_model = distill_from_model(
2010
  model=model,
@@ -2090,5 +2169,77 @@ def verify_training_output(trained_dir: Path) -> bool:
2090
  return False
2091
 
2092
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2093
  if __name__ == "__main__":
2094
  typer.run(main)
 
28
  from pathlib import Path
29
  from typing import Annotated, Any
30
 
31
+ import numpy as np
32
  import torch
33
  import typer
34
  from beam import function
 
 
35
  from sentence_transformers import SentenceTransformer
36
 
37
+ from distiller.model2vec.distill import distill
38
+
39
  # Try to import flash_attn to check if it's available
40
  from .beam_utils import (
41
  BeamCheckpointManager,
 
146
  # =============================================================================
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def get_current_config_hash(enable_training: bool) -> str:
150
  """Generate a hash of current configuration parameters for checkpoint validation."""
151
  import hashlib
 
199
  model_name = f"code_model2vec_{teacher_name}"
200
  if enable_training:
201
  model_name += "_fine_tuned"
202
+ final_path = final_dir / model_name
203
 
204
+ if final_path.exists():
205
  # Check for essential model files
206
+ has_config = (final_path / "config.json").exists()
207
  has_model_file = any(
208
  [
209
+ (final_path / "model.safetensors").exists(),
210
+ (final_path / "model.bin").exists(),
211
+ (final_path / "pytorch_model.bin").exists(),
212
  ]
213
  )
214
 
215
  if has_config and has_model_file:
216
  logger.info(f"✅ Found existing final model: {teacher_name}{'_fine_tuned' if enable_training else ''}")
217
+ return str(final_path)
218
 
219
  return None
220
 
 
409
  return None
410
 
411
 
412
+ def load_optimized_dataset(
413
+ max_samples: int = 50000,
414
+ checkpoint_manager: BeamCheckpointManager | None = None,
415
+ dataset_path: str | None = None,
416
+ ) -> list[str]:
417
+ """Load our pre-created optimized dataset for tokenlearn training."""
418
+ from .dataset import DATASET_OUTPUT_DIR
419
+ from .dataset import load_optimized_dataset as load_dataset_func
420
+
421
+ # Use configuration if not provided as parameter
422
+ if dataset_path is None:
423
+ dataset_path = distillation_config.custom_dataset_path
424
+
425
+ dataset_dir = Path(dataset_path) if dataset_path else DATASET_OUTPUT_DIR
426
+
427
+ logger.info(f"🎯 Loading optimized dataset from {dataset_dir}")
428
+ logger.info(f"📊 Target samples: {max_samples}")
429
+
430
+ try:
431
+ # Load the training split of our optimized dataset
432
+ df = load_dataset_func(output_dir=dataset_dir, split="train")
433
+
434
+ # Extract the text column (which contains our formatted query + code)
435
+ texts = df["text"].tolist()
436
+
437
+ # Shuffle for better training distribution
438
+ import random
439
+
440
+ random.seed(42)
441
+ random.shuffle(texts)
442
+
443
+ # Limit to max_samples
444
+ if len(texts) > max_samples:
445
+ texts = texts[:max_samples]
446
+
447
+ logger.info(f"✅ Loaded {len(texts)} optimized training samples")
448
+
449
+ # Log language distribution
450
+ languages = df["language"].value_counts()
451
+ logger.info("📊 Language distribution:")
452
+ for lang, count in languages.items():
453
+ percentage = (count / len(df)) * 100
454
+ logger.info(f" {lang}: {count} samples ({percentage:.1f}%)")
455
+
456
+ return texts
457
+
458
+ except Exception as e:
459
+ logger.warning(f"⚠️ Failed to load optimized dataset: {e}")
460
+ logger.info("🔄 Falling back to original CodeSearchNet loading...")
461
+ return load_codesearchnet_dataset(max_samples, checkpoint_manager)
462
+
463
+
464
  def load_codesearchnet_dataset(
465
  max_samples: int = 50000,
466
  checkpoint_manager: BeamCheckpointManager | None = None,
467
  ) -> list[str]:
468
  """Load and format the CodeSearchNet dataset for token frequency computation."""
469
+ from datasets import load_dataset
470
+
471
  logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
472
  logger.info(f"Limiting to {max_samples} samples for training efficiency")
473
  logger.info(f"Languages: {', '.join(languages_config.all)}")
 
518
 
519
  try:
520
  # Load training split for the specific language (same format as evaluate.py)
521
+ from datasets import load_dataset
522
+
523
  dataset = load_dataset(
524
  codesearchnet_config.dataset_name,
525
  language,
 
747
  logger.info("📊 Computing token frequencies for SIF weighting...")
748
 
749
  try:
750
+ # Load dataset to compute frequencies (limited sample for efficiency)
751
+ if distillation_config.use_optimized_dataset:
752
+ # Use the custom optimized dataset
753
+ from .dataset import load_optimized_dataset as load_custom_dataset
754
+
755
+ custom_dataset_dir = (
756
+ Path(distillation_config.custom_dataset_path)
757
+ if distillation_config.custom_dataset_path
758
+ else Path("code_model2vec/dataset")
759
+ )
760
+
761
+ if custom_dataset_dir.exists() and (custom_dataset_dir / "train.parquet").exists():
762
+ train_df = load_custom_dataset(output_dir=custom_dataset_dir, split="train")
763
+ # Sample a subset for frequency computation
764
+ sample_size = min(10000, len(train_df))
765
+ train_df_sample = train_df.sample(n=sample_size, random_state=42)
766
+ dataset_texts = train_df_sample["text"].tolist()
767
+ logger.info(f"📊 Using {len(dataset_texts)} samples from custom optimized dataset")
768
+ else:
769
+ # Fallback to original dataset loading
770
+ dataset_texts = load_codesearchnet_dataset(max_samples=10000)
771
+ logger.info(
772
+ f"📊 Custom dataset not found, using original CodeSearchNet with {len(dataset_texts)} texts"
773
+ )
774
+ else:
775
+ dataset_texts = load_codesearchnet_dataset(max_samples=10000)
776
+ logger.info(f"📊 Using original CodeSearchNet with {len(dataset_texts)} texts")
777
 
778
  logger.info(f"📊 Computing frequencies on {len(dataset_texts)} texts...")
779
 
 
826
  """
827
  import json
828
 
 
829
  from sklearn.decomposition import PCA
830
 
831
  logger.info("🔧 Starting post-training re-regularization (POTION Step 4)")
 
898
  final_embeddings = embeddings_pca.astype(np.float32)
899
 
900
  # Create new model with updated embeddings
901
+ from distiller.model2vec.model import StaticModel
902
 
903
  # Save tokenizer and config from original model
904
  tokenizer = model.tokenizer
 
928
  3. Tokenlearn training
929
  4. Post-training re-regularization (PCA + SIF weighting)
930
  """
 
931
  from pathlib import Path
932
 
933
  logger.info("🧪 Starting tokenlearn training (POTION approach)...")
 
975
 
976
  logger.info(f"📊 Using teacher model: {teacher_model_name}")
977
 
978
+ # Prepare dataset for tokenlearn featurization
979
+ dataset_path, dataset_name, text_key = _prepare_tokenlearn_dataset(persistent_tokenlearn_dir)
980
+
981
  # Check if featurization already completed (checkpoint detection)
982
  featurization_complete_marker = features_dir / ".featurization_complete"
983
  if featurization_complete_marker.exists() and verify_featurization_output(features_dir):
 
1000
  logger.info(f"📊 Using teacher model: {teacher_model_name}")
1001
 
1002
  try:
1003
+ # Use direct function call instead of subprocess
1004
+ from datasets import load_dataset
1005
+
1006
+ from distiller.tokenlearn.featurize import featurize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1007
 
1008
  logger.info("🔄 Running tokenlearn featurization...")
1009
+ logger.info(f"📊 Dataset: {dataset_path} (config: {dataset_name})")
1010
+ logger.info(f"📝 Text field: {text_key}")
 
 
 
 
 
 
 
 
 
 
 
1011
 
1012
+ # Load the dataset
1013
+ if dataset_name is None:
1014
+ # For local JSON files, don't pass name parameter
1015
+ dataset = load_dataset(
1016
+ "json",
1017
+ data_files=dataset_path,
1018
+ split="train",
1019
+ streaming=True,
1020
+ )
1021
+ else:
1022
+ # For remote datasets with specific configurations
1023
+ dataset = load_dataset(
1024
+ dataset_path,
1025
+ name=dataset_name,
1026
+ split="train",
1027
+ streaming=True,
1028
+ )
1029
+
1030
+ # Call featurization function directly
1031
+ featurize(
1032
+ dataset=iter(dataset),
1033
+ model=teacher_model,
1034
+ output_dir=str(features_dir),
1035
+ max_means=50000, # IMPROVEMENT: Limit means to prevent overfitting
1036
+ batch_size=512, # IMPROVEMENT: Smaller batch for better gradients
1037
+ text_key=text_key,
1038
+ )
1039
 
1040
  logger.info("✅ Featurization completed successfully")
1041
 
 
1084
  logger.info("🔄 No valid training checkpoint found - starting training...")
1085
 
1086
  try:
1087
+ # Use direct function call instead of subprocess
1088
+ from distiller.tokenlearn.train import train_model
1089
+ from distiller.tokenlearn.utils import collect_means_and_texts
 
 
 
 
 
 
 
 
1090
 
1091
+ # IMPROVED APPROACH: Try optimized parameters first
1092
+ logger.info("🚀 Attempting IMPROVED tokenlearn training with optimized parameters...")
1093
+ logger.info("📊 Using smaller vocabulary and conservative PCA to prevent overfitting")
1094
 
1095
+ # Collect training data from features directory
1096
+ paths = sorted(features_dir.glob("*.json"))
1097
+ train_txt, train_vec = collect_means_and_texts(paths)
1098
+
1099
+ logger.info(f"📊 Collected {len(train_txt)} texts and {train_vec.shape[0]} vectors for training")
1100
+
1101
+ try:
1102
+ # Try improved parameters first
1103
+ trained_model = train_model(
1104
+ model_name=str(teacher_model_name),
1105
+ train_txt=train_txt,
1106
+ train_vec=train_vec,
1107
+ device="cuda" if torch.cuda.is_available() else "cpu",
1108
+ vocab_size=25000, # IMPROVEMENT: Smaller vocabulary to prevent overfitting
1109
+ pca_dims=256, # IMPROVEMENT: Conservative PCA dimensions
1110
+ )
1111
 
1112
+ # Save the trained model
1113
+ trained_model.save_pretrained(str(trained_dir))
1114
+ logger.info("✅ IMPROVED tokenlearn training completed successfully")
1115
+ training_complete_marker.touch()
1116
+ logger.info(f"💾 Created improved training checkpoint: {training_complete_marker}")
1117
 
1118
+ except Exception as e:
1119
+ logger.warning(f"⚠️ Improved training failed: {e}")
1120
+ logger.info("🔄 Falling back to CONSERVATIVE tokenlearn training...")
 
 
1121
 
1122
+ # FALLBACK: Ultra-conservative training approach
1123
+ try:
1124
+ trained_model = train_model(
1125
+ model_name=str(teacher_model_name),
1126
+ train_txt=train_txt,
1127
+ train_vec=train_vec,
1128
+ device="cuda" if torch.cuda.is_available() else "cpu",
1129
+ vocab_size=15000, # FALLBACK: Even smaller vocabulary
1130
+ pca_dims=128, # FALLBACK: Smaller PCA dimensions
1131
+ )
1132
+
1133
+ # Save the trained model
1134
+ trained_model.save_pretrained(str(trained_dir))
1135
+ logger.info("✅ Conservative tokenlearn training completed successfully")
1136
+ training_complete_marker.touch()
1137
+ logger.info(f"💾 Created conservative training checkpoint: {training_complete_marker}")
1138
+
1139
+ except Exception as e2:
1140
+ logger.exception("❌ Conservative tokenlearn training also failed")
1141
+ logger.exception("💥 All training approaches failed - check output above for details")
1142
 
1143
  # Create training marker to indicate we tried but failed
1144
  training_fallback_marker = trained_dir / ".training_fallback"
1145
  training_fallback_marker.touch()
1146
 
1147
+ logger.exception("💥 Tokenlearn training failed completely")
1148
+ msg = f"All tokenlearn training approaches failed: {e2}"
1149
+ raise RuntimeError(msg) from e2
 
 
 
 
 
 
 
 
1150
 
1151
  except Exception as e:
1152
+ logger.warning("💥 All tokenlearn training approaches failed")
1153
+ logger.exception("💥 All training approaches failed completely - cannot proceed")
1154
+ msg = f"All training approaches failed: {e}"
1155
  raise RuntimeError(msg) from e
1156
 
1157
  # Step 4: Load the trained model and apply post-training re-regularization
 
1166
  raise RuntimeError(msg)
1167
 
1168
  try:
1169
+ from distiller.model2vec.model import StaticModel
1170
 
1171
  # Load the trained model from tokenlearn
1172
  trained_model_path = trained_dir / "model"
 
1281
  existing_final = check_existing_final_model(teacher_name, enable_training)
1282
  if existing_final:
1283
  logger.info(f"✅ Final model already exists: {teacher_name}{'_fine_tuned' if enable_training else ''}")
1284
+ total_time = time.time() - start_time
1285
  return {
1286
  "teacher_model": teacher_model,
1287
  "teacher_name": teacher_name,
1288
  "status": "skipped_existing_final",
1289
  "final_path": existing_final,
1290
+ "distillation_time": total_time,
1291
  }
1292
 
1293
  # Step 1.5: Sync existing checkpoints from Beam if using Beam utilities
 
1305
  logger.info(f"✅ Found existing base model: {teacher_name}")
1306
  if enable_training:
1307
  # Load base model for training
1308
+ from distiller.model2vec.model import StaticModel
1309
 
1310
  base_model = StaticModel.from_pretrained(existing_base)
1311
  elif use_beam_utilities:
 
1313
  if synced:
1314
  existing_base = str(base_dir)
1315
  if enable_training:
1316
+ from distiller.model2vec.model import StaticModel
1317
 
1318
  base_model = StaticModel.from_pretrained(existing_base)
1319
 
 
1332
  base_model = simple_distillation(teacher_model, str(base_dir), pca_dims)
1333
 
1334
  if base_model is None:
1335
+ total_time = time.time() - start_time
1336
  return {
1337
  "teacher_model": teacher_model,
1338
  "teacher_name": teacher_name,
1339
  "status": "failed_base_distillation",
1340
  "error": "Simple distillation failed",
1341
+ "distillation_time": total_time,
1342
  }
1343
 
1344
  # Sync base model and checkpoints to Beam
 
1351
 
1352
  existing_base = str(base_dir)
1353
 
1354
+ # Step 3: Handle final model creation
1355
+ if enable_training and base_model is not None:
1356
+ # Perform tokenlearn training (POTION approach)
1357
+ logger.info(f"🧪 Starting tokenlearn training for {teacher_name}")
1358
+
1359
+ try:
1360
+ # Load teacher model for training
1361
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1362
+ teacher_st_model = load_model_with_flash_attention(teacher_model, device)
1363
+
1364
  # Perform tokenlearn training (POTION approach)
1365
+ final_model = tokenlearn_training(
1366
+ base_model,
1367
+ teacher_st_model,
1368
+ checkpoint_mgr,
1369
+ skip_post_training_regularization=distillation_config.skip_post_training_regularization,
1370
+ )
1371
 
1372
+ # Save final model
1373
+ final_dir.mkdir(parents=True, exist_ok=True)
1374
+ final_model.save_pretrained(str(final_dir))
 
 
 
 
 
 
 
 
 
1375
 
1376
+ # Sync final model and training checkpoints to Beam
1377
+ if use_beam_utilities:
1378
+ sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities)
1379
+ if checkpoint_mgr:
1380
+ sync_checkpoints_to_beam(
1381
+ VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints
1382
+ )
1383
 
1384
+ del teacher_st_model
1385
+ if torch.cuda.is_available():
1386
+ torch.cuda.empty_cache()
 
 
 
 
1387
 
1388
+ except RuntimeError as e:
1389
+ # Training failed - clean up and return failure
1390
+ logger.exception(f"❌ Training failed for {teacher_name}")
1391
+
1392
+ # Clean up teacher model if it was loaded
1393
+ if "teacher_st_model" in locals():
1394
  del teacher_st_model
1395
+ if torch.cuda.is_available():
1396
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1397
 
1398
+ total_time = time.time() - start_time
1399
+ return {
1400
+ "teacher_model": teacher_model,
1401
+ "teacher_name": teacher_name,
1402
+ "status": "failed_training",
1403
+ "error": f"Training failed: {e!s}",
1404
+ "base_path": existing_base, # Base model was created successfully
1405
+ "distillation_time": total_time,
1406
+ }
 
1407
 
1408
+ else:
1409
+ # Copy base to final (no training)
1410
+ logger.info(f"📁 Copying base to final for {teacher_name}")
1411
+ if not copy_base_to_final(teacher_name, enable_training):
1412
+ total_time = time.time() - start_time
1413
+ return {
1414
+ "teacher_model": teacher_model,
1415
+ "teacher_name": teacher_name,
1416
+ "status": "failed_copy_to_final",
1417
+ "error": "Failed to copy base to final",
1418
+ "distillation_time": total_time,
1419
+ }
1420
 
1421
+ total_time = time.time() - start_time
1422
  return {
1423
  "teacher_model": teacher_model,
1424
  "teacher_name": teacher_name,
 
1431
 
1432
  except Exception as e:
1433
  logger.exception(f"❌ Failed to process {teacher_model}")
1434
+ total_time = time.time() - start_time
1435
  return {
1436
  "teacher_model": teacher_model,
1437
  "teacher_name": teacher_name,
1438
  "status": "failed",
1439
  "error": str(e),
1440
+ "distillation_time": total_time,
1441
  }
1442
 
1443
 
 
1458
  if teacher_models is None:
1459
  teacher_models = DEFAULT_TEACHER_MODELS
1460
 
 
 
 
 
 
 
 
1461
  results = {}
1462
  successful_models = []
1463
 
 
1537
  clear_cache: bool = False,
1538
  ) -> dict[str, Any]:
1539
  """Shared internal implementation for beam distillation."""
 
 
 
 
 
 
 
1540
  if teacher_models is None:
1541
  teacher_models = DEFAULT_TEACHER_MODELS
1542
 
 
1709
  skip_ptr: Annotated[
1710
  bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
1711
  ] = False,
1712
+ use_optimized_dataset: Annotated[
1713
+ bool,
1714
+ typer.Option(
1715
+ "--use-optimized-dataset", help="Use the pre-created optimized dataset from code_model2vec/dataset"
1716
+ ),
1717
+ ] = False,
1718
+ dataset_path: Annotated[
1719
+ str | None,
1720
+ typer.Option("--dataset-path", help="Path to custom dataset directory (defaults to code_model2vec/dataset)"),
1721
+ ] = None,
1722
  ) -> None:
1723
  """Unified distillation command with optional training."""
1724
  logger.info("🚀 Starting unified Model2Vec distillation workflow")
 
1728
  if skip_ptr and train:
1729
  logger.info("⏭️ Post-training re-regularization will be skipped (PCA + SIF weighting disabled)")
1730
 
1731
+ # Set dataset configuration
1732
+ distillation_config.use_optimized_dataset = use_optimized_dataset
1733
+ distillation_config.custom_dataset_path = dataset_path
1734
+ if use_optimized_dataset and train:
1735
+ dataset_source = dataset_path or "code_model2vec/dataset"
1736
+ logger.info(f"🎯 Using optimized dataset from: {dataset_source}")
1737
+
1738
  logger.info(f"🎓 Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
1739
  logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
1740
 
 
1973
  logger.info("✅ Successfully loaded with SentenceTransformer method")
1974
 
1975
  # Now use Model2Vec's distill_from_model function directly
1976
+ from distiller.model2vec.distill.distillation import distill_from_model
1977
 
1978
  distilled_model = distill_from_model(
1979
  model=model,
 
2083
  return None
2084
 
2085
  # Now use Model2Vec's distill_from_model function directly
2086
+ from distiller.model2vec.distill.distillation import distill_from_model
2087
 
2088
  distilled_model = distill_from_model(
2089
  model=model,
 
2169
  return False
2170
 
2171
 
2172
+ def _prepare_tokenlearn_dataset(tokenlearn_dir: Path) -> tuple[str, str | None, str]:
2173
+ """
2174
+ Prepare dataset for tokenlearn featurization.
2175
+
2176
+ Returns:
2177
+ Tuple of (dataset_path, dataset_name, text_key) for tokenlearn
2178
+ """
2179
+ if distillation_config.use_optimized_dataset:
2180
+ return _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir)
2181
+ return _prepare_original_dataset_for_tokenlearn()
2182
+
2183
+
2184
+ def _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir: Path) -> tuple[str, str | None, str]:
2185
+ """Prepare custom optimized dataset for tokenlearn featurization."""
2186
+ logger.info("🎯 Preparing custom optimized dataset for tokenlearn...")
2187
+
2188
+ # Import the dataset module
2189
+ from .dataset import create_optimized_dataset, load_optimized_dataset
2190
+
2191
+ # Define paths
2192
+ custom_dataset_dir = (
2193
+ Path(distillation_config.custom_dataset_path)
2194
+ if distillation_config.custom_dataset_path
2195
+ else Path("code_model2vec/dataset")
2196
+ )
2197
+ tokenlearn_dataset_dir = tokenlearn_dir / "custom_dataset"
2198
+
2199
+ # Check if we need to create the custom dataset
2200
+ if not custom_dataset_dir.exists() or not (custom_dataset_dir / "train.parquet").exists():
2201
+ logger.info("📊 Custom dataset not found - creating optimized dataset...")
2202
+ create_optimized_dataset(
2203
+ max_samples_per_lang=10000, # Reasonable size for tokenlearn
2204
+ output_dir=custom_dataset_dir,
2205
+ create_multiple_formats=False, # Use simple format for tokenlearn
2206
+ )
2207
+
2208
+ # Load the custom dataset
2209
+ logger.info(f"📂 Loading custom dataset from {custom_dataset_dir}")
2210
+ train_df = load_optimized_dataset(output_dir=custom_dataset_dir, split="train")
2211
+
2212
+ # Prepare dataset for tokenlearn (save as JSON files that load_dataset can read)
2213
+ tokenlearn_dataset_dir.mkdir(parents=True, exist_ok=True)
2214
+
2215
+ # Save as JSON file that tokenlearn can load with load_dataset()
2216
+ train_json_path = tokenlearn_dataset_dir / "train.json"
2217
+
2218
+ # Create JSON lines format
2219
+ import json
2220
+
2221
+ with train_json_path.open("w") as f:
2222
+ for text in train_df["text"]:
2223
+ json.dump({"text": text}, f)
2224
+ f.write("\n")
2225
+
2226
+ logger.info(f"✅ Prepared custom dataset with {len(train_df)} samples for tokenlearn")
2227
+ logger.info(f"💾 Saved JSON dataset to {train_json_path}")
2228
+
2229
+ # Return the JSON file path directly (not directory) and no config name for JSON loading
2230
+ return str(train_json_path), None, "text"
2231
+
2232
+
2233
+ def _prepare_original_dataset_for_tokenlearn() -> tuple[str, str, str]:
2234
+ """Prepare original CodeSearchNet dataset for tokenlearn featurization."""
2235
+ logger.info("📊 Using original CodeSearchNet dataset for tokenlearn...")
2236
+
2237
+ return (
2238
+ str(distillation_config.tokenlearn_dataset), # "sentence-transformers/codesearchnet"
2239
+ str(distillation_config.tokenlearn_dataset_name), # "pair"
2240
+ str(distillation_config.tokenlearn_text_key), # "combined_text"
2241
+ )
2242
+
2243
+
2244
  if __name__ == "__main__":
2245
  typer.run(main)
src/distiller/patch_utils.py DELETED
@@ -1,276 +0,0 @@
1
- """
2
- Patch utilities for applying fixes to installed packages.
3
-
4
- This module provides functionality to automatically apply all patches
5
- from the patches directory to fix bugs in third-party libraries.
6
- """
7
-
8
- import logging
9
- import subprocess
10
- import sys
11
- from pathlib import Path
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- def find_patches_directory() -> Path:
17
- """Find the patches directory relative to the current script location."""
18
- # Go up from src/distiller/ to project root, then to patches/
19
- current_file = Path(__file__)
20
- project_root = current_file.parent.parent.parent # Go up 3 levels: distiller -> src -> project_root
21
- patches_dir = project_root / "patches"
22
-
23
- if not patches_dir.exists():
24
- # Alternative: try relative to current working directory
25
- patches_dir = Path("patches")
26
-
27
- return patches_dir
28
-
29
-
30
- def get_site_packages_path() -> Path:
31
- """Get the site-packages directory path."""
32
- import site
33
-
34
- # Try to get the site-packages from the current environment
35
- site_packages_dirs = site.getsitepackages()
36
-
37
- # Prefer the first site-packages directory
38
- if site_packages_dirs:
39
- return Path(site_packages_dirs[0])
40
-
41
- # Fallback: try to find it relative to Python executable
42
- python_path = Path(sys.executable)
43
- if python_path.name == "python" or python_path.name.startswith("python"):
44
- # Standard virtual environment structure
45
- venv_lib = python_path.parent.parent / "lib"
46
- for item in venv_lib.iterdir():
47
- if item.name.startswith("python"):
48
- site_packages = item / "site-packages"
49
- if site_packages.exists():
50
- return site_packages
51
-
52
- # Last resort: use current directory
53
- return Path()
54
-
55
-
56
- def apply_patch_file(patch_file: Path, target_dir: Path) -> bool:
57
- """
58
- Apply a single patch file to the target directory.
59
-
60
- Args:
61
- patch_file: Path to the .patch file
62
- target_dir: Target directory (usually site-packages)
63
-
64
- Returns:
65
- True if patch was applied successfully, False otherwise
66
- """
67
- try:
68
- logger.info(f"Applying patch: {patch_file.name}")
69
-
70
- # Check if patch is already applied
71
- if is_patch_already_applied(patch_file, target_dir):
72
- logger.info(f"Patch {patch_file.name} already applied")
73
- return True
74
-
75
- # Clean any duplicate validation code before applying
76
- if "model2vec.patch" in patch_file.name:
77
- clean_duplicate_validation_code(target_dir)
78
-
79
- # Use patch command with the following options:
80
- # -p1: strip 1 leading directory from paths
81
- # -d: change to directory before applying
82
- # -f: force (don't ask questions)
83
- # -N: don't reverse patches that appear to be already applied
84
- result = subprocess.run( # noqa: S603
85
- ["patch", "-p1", "-d", str(target_dir), "-f", "-N"], # noqa: S607
86
- input=patch_file.read_text(),
87
- text=True,
88
- capture_output=True,
89
- check=False, # Don't raise exception on non-zero exit
90
- )
91
-
92
- if result.returncode == 0:
93
- logger.info(f"Successfully applied patch: {patch_file.name}")
94
- return True
95
- if "already applied" in result.stderr.lower() or "reversed" in result.stderr.lower():
96
- logger.info(f"Patch {patch_file.name} already applied")
97
- return True
98
- logger.warning(f"Failed to apply patch {patch_file.name}: {result.stderr}")
99
- return False
100
-
101
- except FileNotFoundError:
102
- logger.exception("'patch' command not found. Please install patch utility.")
103
- return False
104
- except Exception:
105
- logger.exception(f"Error applying patch {patch_file.name}")
106
- return False
107
-
108
-
109
- def apply_all_patches() -> int:
110
- """
111
- Apply all patches from the patches directory.
112
-
113
- Returns:
114
- Number of patches successfully applied
115
- """
116
- patches_dir = find_patches_directory()
117
-
118
- if not patches_dir.exists():
119
- logger.warning(f"Patches directory not found: {patches_dir}")
120
- return 0
121
-
122
- # Find all .patch files
123
- patch_files = list(patches_dir.glob("*.patch"))
124
-
125
- if not patch_files:
126
- logger.info("No patch files found")
127
- return 0
128
-
129
- # Get target directory (site-packages)
130
- target_dir = get_site_packages_path()
131
- logger.info(f"Applying patches to: {target_dir}")
132
-
133
- # Clean any existing duplicates first
134
- clean_duplicate_validation_code(target_dir)
135
-
136
- success_count = 0
137
-
138
- # Sort patch files for consistent ordering
139
- for patch_file in sorted(patch_files):
140
- if apply_patch_file(patch_file, target_dir):
141
- success_count += 1
142
-
143
- logger.info(f"Applied {success_count}/{len(patch_files)} patches successfully")
144
- return success_count
145
-
146
-
147
- def is_patch_already_applied(patch_file: Path, target_dir: Path) -> bool:
148
- """
149
- Check if a patch has already been applied by looking for specific markers.
150
-
151
- Args:
152
- patch_file: Path to the .patch file
153
- target_dir: Target directory (usually site-packages)
154
-
155
- Returns:
156
- True if patch appears to be already applied, False otherwise
157
- """
158
- try:
159
- # For model2vec.patch, check if the validation code is already present
160
- if "model2vec.patch" in patch_file.name:
161
- inference_file = target_dir / "model2vec" / "distill" / "inference.py"
162
- if inference_file.exists():
163
- inference_content = inference_file.read_text()
164
- # Check for the specific validation code we're adding
165
- if (
166
- "Token-vector mismatch:" in inference_content
167
- and "Truncating to prevent failure" in inference_content
168
- ):
169
- # Also make sure it's in the right place (before return statement, not after)
170
- lines = inference_content.split("\n")
171
- for i, line in enumerate(lines):
172
- if "return out_tokens, out_weights" in line:
173
- # Check if validation code appears before this return
174
- preceding_lines = lines[max(0, i - 10) : i]
175
- if any("Token-vector mismatch:" in pline for pline in preceding_lines):
176
- return True
177
- break
178
-
179
- # For tokenlearn.patch, check if the indexing fix is already present
180
- if "tokenlearn.patch" in patch_file.name:
181
- pretrain_file = target_dir / "tokenlearn" / "pretrain.py"
182
- if pretrain_file.exists():
183
- pretrain_content = pretrain_file.read_text()
184
- # Check for the specific fix we're adding
185
- if (
186
- "Fix for index out of bounds issue" in pretrain_content
187
- and "torch.clamp(input_ids, 0, self.w.shape[0] - 1)" in pretrain_content
188
- ):
189
- return True
190
-
191
- return False
192
-
193
- except Exception as e:
194
- logger.warning(f"Error checking if patch {patch_file.name} is applied: {e}")
195
- return False
196
-
197
-
198
- def clean_duplicate_validation_code(target_dir: Path) -> bool:
199
- """
200
- Clean up duplicate validation code that might have been added by multiple patch applications.
201
-
202
- Args:
203
- target_dir: Target directory (usually site-packages)
204
-
205
- Returns:
206
- True if cleanup was successful, False otherwise
207
- """
208
- try:
209
- inference_file = target_dir / "model2vec" / "distill" / "inference.py"
210
- if not inference_file.exists():
211
- return True
212
-
213
- content = inference_file.read_text()
214
- lines = content.split("\n")
215
-
216
- # Find all instances of the validation code
217
- validation_indices = []
218
- for i, line in enumerate(lines):
219
- if "Token-vector mismatch:" in line:
220
- validation_indices.append(i)
221
-
222
- if len(validation_indices) <= 1:
223
- return True # No duplicates or no validation code
224
-
225
- # Keep only the validation code that appears before a return statement
226
- lines_to_keep = []
227
- skip_until = -1
228
-
229
- for i, line in enumerate(lines):
230
- if i <= skip_until:
231
- continue
232
-
233
- # If this is validation code
234
- if "Token-vector mismatch:" in line:
235
- # Look ahead to see if there's a return statement nearby
236
- has_return_after = False
237
- for j in range(i, min(len(lines), i + 20)):
238
- if "return out_tokens, out_weights" in lines[j]:
239
- has_return_after = True
240
- break
241
-
242
- # Keep this validation block only if it's followed by a return
243
- if has_return_after:
244
- lines_to_keep.append(line)
245
- else:
246
- # Skip this validation block (it's a duplicate)
247
- # Find the end of this validation block
248
- for j in range(i + 1, len(lines)):
249
- if lines[j].strip() == "" or not lines[j].startswith(" "):
250
- skip_until = j - 1
251
- break
252
- else:
253
- lines_to_keep.append(line)
254
-
255
- # Write back the cleaned content
256
- cleaned_content = "\n".join(lines_to_keep)
257
- inference_file.write_text(cleaned_content)
258
- logger.info("Cleaned duplicate validation code from inference.py")
259
- return True
260
-
261
- except Exception as e:
262
- logger.warning(f"Error cleaning duplicate validation code: {e}")
263
- return False
264
-
265
-
266
- def main() -> None:
267
- """Main function for standalone execution."""
268
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
269
-
270
- print("Applying all patches...")
271
- success_count = apply_all_patches()
272
- print(f"Done. Applied {success_count} patches.")
273
-
274
-
275
- if __name__ == "__main__":
276
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
uv.lock CHANGED
@@ -774,24 +774,31 @@ dependencies = [
774
  { name = "flash-attn" },
775
  { name = "hatchling" },
776
  { name = "iso639" },
 
 
777
  { name = "kaleido" },
778
  { name = "lightning" },
779
  { name = "matplotlib" },
780
- { name = "model2vec", extra = ["train"] },
781
  { name = "mteb" },
782
  { name = "numpy" },
783
  { name = "plotly" },
784
  { name = "psutil" },
785
  { name = "pydantic" },
786
  { name = "requests" },
 
 
787
  { name = "scikit-learn" },
788
  { name = "seaborn" },
789
  { name = "sentence-transformers" },
790
  { name = "setuptools" },
 
791
  { name = "smart-open", extra = ["s3"] },
792
  { name = "statsmodels" },
793
- { name = "tokenlearn" },
794
  { name = "torch" },
 
 
795
  { name = "typer" },
796
  ]
797
 
@@ -813,24 +820,31 @@ requires-dist = [
813
  { name = "flash-attn", specifier = ">=2.7.4.post1" },
814
  { name = "hatchling", specifier = ">=1.27.0" },
815
  { name = "iso639", specifier = ">=0.1.4" },
 
 
816
  { name = "kaleido", specifier = "==1.0.0rc13" },
817
  { name = "lightning", specifier = ">=2.5.1.post0" },
818
  { name = "matplotlib", specifier = ">=3.10.3" },
819
- { name = "model2vec", extras = ["train"], specifier = ">=0.5.0" },
820
  { name = "mteb", specifier = ">=1.14.15" },
821
  { name = "numpy", specifier = ">=1.26.4" },
822
  { name = "plotly", specifier = ">=6.1.1" },
823
  { name = "psutil", specifier = ">=7.0.0" },
824
  { name = "pydantic", specifier = ">=2.11.5" },
825
  { name = "requests", specifier = ">=2.32.3" },
 
 
826
  { name = "scikit-learn", specifier = ">=1.6.1" },
827
  { name = "seaborn", specifier = ">=0.13.2" },
828
  { name = "sentence-transformers", specifier = ">=4.1.0" },
829
  { name = "setuptools", specifier = ">=80.8.0" },
 
830
  { name = "smart-open", extras = ["s3"], specifier = ">=7.1.0" },
831
  { name = "statsmodels", specifier = ">=0.14.4" },
832
- { name = "tokenlearn", specifier = ">=0.2.0" },
833
  { name = "torch", specifier = ">=2.7.0" },
 
 
834
  { name = "typer", specifier = ">=0.16.0" },
835
  ]
836
 
@@ -1187,38 +1201,6 @@ wheels = [
1187
  { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
1188
  ]
1189
 
1190
- [[package]]
1191
- name = "model2vec"
1192
- version = "0.5.0"
1193
- source = { registry = "https://pypi.org/simple" }
1194
- dependencies = [
1195
- { name = "jinja2" },
1196
- { name = "joblib" },
1197
- { name = "numpy" },
1198
- { name = "rich" },
1199
- { name = "safetensors" },
1200
- { name = "setuptools" },
1201
- { name = "tokenizers" },
1202
- { name = "tqdm" },
1203
- ]
1204
- sdist = { url = "https://files.pythonhosted.org/packages/93/18/c546916657e47e52b6e25b231803903bcf4e7ef2497fe41e9869236d7dee/model2vec-0.5.0.tar.gz", hash = "sha256:0771fd99d5c58fac631a2faa233759a8cec7a3be6e9aeeeeeca2d5e7048d1c7b", size = 2665840 }
1205
- wheels = [
1206
- { url = "https://files.pythonhosted.org/packages/66/ab/5263bc4605e9960fece76b710c01fef33859dc6ae72832d5987db75eed63/model2vec-0.5.0-py3-none-any.whl", hash = "sha256:12f14a18556975c037961a836a702388876bfec1ff76176f056884d219735271", size = 44578 },
1207
- ]
1208
-
1209
- [package.optional-dependencies]
1210
- distill = [
1211
- { name = "scikit-learn" },
1212
- { name = "torch" },
1213
- { name = "transformers" },
1214
- ]
1215
- train = [
1216
- { name = "lightning" },
1217
- { name = "scikit-learn" },
1218
- { name = "skops" },
1219
- { name = "torch" },
1220
- ]
1221
-
1222
  [[package]]
1223
  name = "more-itertools"
1224
  version = "10.7.0"
@@ -2492,22 +2474,6 @@ wheels = [
2492
  { url = "https://files.pythonhosted.org/packages/e6/b6/072a8e053ae600dcc2ac0da81a23548e3b523301a442a6ca900e92ac35be/tokenizers-0.21.1-cp39-abi3-win_amd64.whl", hash = "sha256:0f0dcbcc9f6e13e675a66d7a5f2f225a736745ce484c1a4e07476a89ccdad382", size = 2435481 },
2493
  ]
2494
 
2495
- [[package]]
2496
- name = "tokenlearn"
2497
- version = "0.2.0"
2498
- source = { registry = "https://pypi.org/simple" }
2499
- dependencies = [
2500
- { name = "datasets" },
2501
- { name = "model2vec", extra = ["distill"] },
2502
- { name = "more-itertools" },
2503
- { name = "sentence-transformers" },
2504
- { name = "torch" },
2505
- ]
2506
- sdist = { url = "https://files.pythonhosted.org/packages/58/b6/f9587ea271a9a7464cd25025b65f471d49bbceb48cc90742a89ac085edfd/tokenlearn-0.2.0.tar.gz", hash = "sha256:7a8faa0f51a510d185a40bef197a88116464adb8ce85ffd12c1d6905369c2375", size = 149042 }
2507
- wheels = [
2508
- { url = "https://files.pythonhosted.org/packages/40/3d/1c2b2e80ffd929bb8e7930d6a48e3b4252676cdc6c0c38f13a6f0f374b9c/tokenlearn-0.2.0-py3-none-any.whl", hash = "sha256:7a05e2800420eb2914c30e7377adeb14822c63585a0b9ed018bc82735dae1f29", size = 11970 },
2509
- ]
2510
-
2511
  [[package]]
2512
  name = "torch"
2513
  version = "2.7.0"
@@ -2580,7 +2546,7 @@ wheels = [
2580
 
2581
  [[package]]
2582
  name = "transformers"
2583
- version = "4.52.3"
2584
  source = { registry = "https://pypi.org/simple" }
2585
  dependencies = [
2586
  { name = "filelock" },
@@ -2594,9 +2560,9 @@ dependencies = [
2594
  { name = "tokenizers" },
2595
  { name = "tqdm" },
2596
  ]
2597
- sdist = { url = "https://files.pythonhosted.org/packages/07/42/271bcf364788337ac24e7f200005ac7142aaf022206bd6119d2daca22c04/transformers-4.52.3.tar.gz", hash = "sha256:2e1de29374f27920aaf6d589d4e6339f33def2fb08809e1a1d792e040e9fbce7", size = 8951324 }
2598
  wheels = [
2599
- { url = "https://files.pythonhosted.org/packages/36/f8/1f086942bc6a044e4e68dacf6de761a45367795efd5f57ad356765691c79/transformers-4.52.3-py3-none-any.whl", hash = "sha256:cd04059da50e7cf2a617ce3143ba8beffbf119f8c25a0717c3454fd9d0f19609", size = 10460322 },
2600
  ]
2601
 
2602
  [[package]]
 
774
  { name = "flash-attn" },
775
  { name = "hatchling" },
776
  { name = "iso639" },
777
+ { name = "jinja2" },
778
+ { name = "joblib" },
779
  { name = "kaleido" },
780
  { name = "lightning" },
781
  { name = "matplotlib" },
782
+ { name = "more-itertools" },
783
  { name = "mteb" },
784
  { name = "numpy" },
785
  { name = "plotly" },
786
  { name = "psutil" },
787
  { name = "pydantic" },
788
  { name = "requests" },
789
+ { name = "rich" },
790
+ { name = "safetensors" },
791
  { name = "scikit-learn" },
792
  { name = "seaborn" },
793
  { name = "sentence-transformers" },
794
  { name = "setuptools" },
795
+ { name = "skops" },
796
  { name = "smart-open", extra = ["s3"] },
797
  { name = "statsmodels" },
798
+ { name = "tokenizers" },
799
  { name = "torch" },
800
+ { name = "tqdm" },
801
+ { name = "transformers" },
802
  { name = "typer" },
803
  ]
804
 
 
820
  { name = "flash-attn", specifier = ">=2.7.4.post1" },
821
  { name = "hatchling", specifier = ">=1.27.0" },
822
  { name = "iso639", specifier = ">=0.1.4" },
823
+ { name = "jinja2", specifier = ">=3.0.0" },
824
+ { name = "joblib", specifier = ">=1.0.0" },
825
  { name = "kaleido", specifier = "==1.0.0rc13" },
826
  { name = "lightning", specifier = ">=2.5.1.post0" },
827
  { name = "matplotlib", specifier = ">=3.10.3" },
828
+ { name = "more-itertools", specifier = ">=10.5.0" },
829
  { name = "mteb", specifier = ">=1.14.15" },
830
  { name = "numpy", specifier = ">=1.26.4" },
831
  { name = "plotly", specifier = ">=6.1.1" },
832
  { name = "psutil", specifier = ">=7.0.0" },
833
  { name = "pydantic", specifier = ">=2.11.5" },
834
  { name = "requests", specifier = ">=2.32.3" },
835
+ { name = "rich", specifier = ">=10.0.0" },
836
+ { name = "safetensors", specifier = ">=0.3.0" },
837
  { name = "scikit-learn", specifier = ">=1.6.1" },
838
  { name = "seaborn", specifier = ">=0.13.2" },
839
  { name = "sentence-transformers", specifier = ">=4.1.0" },
840
  { name = "setuptools", specifier = ">=80.8.0" },
841
+ { name = "skops", specifier = ">=0.11.0" },
842
  { name = "smart-open", extras = ["s3"], specifier = ">=7.1.0" },
843
  { name = "statsmodels", specifier = ">=0.14.4" },
844
+ { name = "tokenizers", specifier = ">=0.20" },
845
  { name = "torch", specifier = ">=2.7.0" },
846
+ { name = "tqdm", specifier = ">=4.65.0" },
847
+ { name = "transformers", specifier = "<=4.52.1" },
848
  { name = "typer", specifier = ">=0.16.0" },
849
  ]
850
 
 
1201
  { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
1202
  ]
1203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1204
  [[package]]
1205
  name = "more-itertools"
1206
  version = "10.7.0"
 
2474
  { url = "https://files.pythonhosted.org/packages/e6/b6/072a8e053ae600dcc2ac0da81a23548e3b523301a442a6ca900e92ac35be/tokenizers-0.21.1-cp39-abi3-win_amd64.whl", hash = "sha256:0f0dcbcc9f6e13e675a66d7a5f2f225a736745ce484c1a4e07476a89ccdad382", size = 2435481 },
2475
  ]
2476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2477
  [[package]]
2478
  name = "torch"
2479
  version = "2.7.0"
 
2546
 
2547
  [[package]]
2548
  name = "transformers"
2549
+ version = "4.52.1"
2550
  source = { registry = "https://pypi.org/simple" }
2551
  dependencies = [
2552
  { name = "filelock" },
 
2560
  { name = "tokenizers" },
2561
  { name = "tqdm" },
2562
  ]
2563
+ sdist = { url = "https://files.pythonhosted.org/packages/4a/de/f3f3a0649dc522aeff55a5739e06e132c875c53701307a2ddd7ce7528ec5/transformers-4.52.1.tar.gz", hash = "sha256:c380d583ed9c7ebe3e30ca5e55ec1249db39eb9ee277f8e74dab1abc6a03c938", size = 8944009 }
2564
  wheels = [
2565
+ { url = "https://files.pythonhosted.org/packages/b8/1e/2b00e5021c3545d4a0ae32f3d332ae29e62a6259092f1468976e7b9d4adb/transformers-4.52.1-py3-none-any.whl", hash = "sha256:604b2bb357c480dc5883b7944e8562c967f6b06f63dfb6a1c4665d13d067148f", size = 10459023 },
2566
  ]
2567
 
2568
  [[package]]