Sarthak
commited on
Commit
·
7837959
1
Parent(s):
37196da
chore: update dependencies and configuration for improved training
Browse filesThis commit updates the model configuration in `.codemap.yml` to use a lighter version of the model. Additionally, it enhances the `pyproject.toml` and `uv.lock` files by adding new dependencies such as `jinja2`, `joblib`, `rich`, and `safetensors`, while also replacing `tokenlearn` with `tokenizers`. The report has been adjusted to reflect changes in model performance metrics and the dataset configuration has been improved to support optimized dataset usage during training.
- .codemap.yml +1 -1
- REPORT.md +18 -90
- patches/model2vec.patch +0 -39
- patches/tokenlearn.patch +0 -25
- pyproject.toml +14 -3
- src/distiller/__main__.py +52 -2
- src/distiller/analyze.py +1 -1
- src/distiller/config.py +7 -1
- src/distiller/dataset.py +659 -0
- src/distiller/distill.py +345 -194
- src/distiller/patch_utils.py +0 -276
- uv.lock +21 -55
.codemap.yml
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
# LLM Configuration - Controls which model is used for AI operations
|
6 |
llm:
|
7 |
# Format: "provider:model-name", e.g., "openai:gpt-4o", "anthropic:claude-3-opus"
|
8 |
-
model: "google-gla:gemini-2.0-flash"
|
9 |
temperature: 0.5 # Lower for more deterministic outputs, higher for creativity
|
10 |
max_input_tokens: 1000000 # Maximum tokens in input
|
11 |
max_output_tokens: 10000 # Maximum tokens in responses
|
|
|
5 |
# LLM Configuration - Controls which model is used for AI operations
|
6 |
llm:
|
7 |
# Format: "provider:model-name", e.g., "openai:gpt-4o", "anthropic:claude-3-opus"
|
8 |
+
model: "google-gla:gemini-2.0-flash-lite"
|
9 |
temperature: 0.5 # Lower for more deterministic outputs, higher for creativity
|
10 |
max_input_tokens: 1000000 # Maximum tokens in input
|
11 |
max_output_tokens: 10000 # Maximum tokens in responses
|
REPORT.md
CHANGED
@@ -28,8 +28,8 @@ This report presents a comprehensive analysis of Model2Vec distillation experime
|
|
28 |
| code_model2vec_all_MiniLM_L6_v2 | [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) | 0.7385 | 0.7049 | 0.7910 | 🥈 2nd |
|
29 |
| code_model2vec_jina_embeddings_v2_base_code | [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code) | 0.7381 | 0.6996 | 0.8130 | 🥉 3rd |
|
30 |
| code_model2vec_paraphrase_MiniLM_L6_v2 | [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) | 0.7013 | 0.6638 | 0.7665 | #4 |
|
31 |
-
|
|
32 |
-
|
|
33 |
| code_model2vec_bge_m3 | [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | 0.4863 | 0.4439 | 0.5514 | #7 |
|
34 |
| code_model2vec_jina_embeddings_v3 | [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) | 0.4755 | 0.4416 | 0.5456 | #8 |
|
35 |
| code_model2vec_nomic_embed_text_v2_moe | [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) | 0.4532 | 0.4275 | 0.5094 | #9 |
|
@@ -50,8 +50,8 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
50 |
| all_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
|
51 |
| jina_embeddings_v2_base_code | 61,053 | 15.6M | 256 | 29.8MB |
|
52 |
| paraphrase_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
|
53 |
-
| all_mpnet_base_v2_fine_tuned | 77,316 | 19.8M | 256 | 75.5MB |
|
54 |
| Reason_ModernColBERT | 50,254 | 12.9M | 256 | 24.5MB |
|
|
|
55 |
| bge_m3 | 249,999 | 64.0M | 256 | 122.1MB |
|
56 |
| jina_embeddings_v3 | 249,999 | 64.0M | 256 | 122.1MB |
|
57 |
| nomic_embed_text_v2_moe | 249,999 | 64.0M | 256 | 122.1MB |
|
@@ -69,9 +69,9 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
69 |
#### Key Insights from Model Specifications:
|
70 |
|
71 |
|
72 |
-
- **Vocabulary Consistency**: All models use vocabulary sizes ranging from 29,525 to 249,999 tokens (avg:
|
73 |
-
- **Parameter Efficiency**: Models range from 7.6M to 64.0M parameters (avg:
|
74 |
-
- **Storage Efficiency**: Disk usage ranges from 14.4MB to 122.1MB (avg:
|
75 |
- **Embedding Dimensions**: Consistent 256 dimensions across all models (optimized for efficiency)
|
76 |
|
77 |
|
@@ -81,85 +81,13 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
81 |
- **Best Teacher Model**: code_model2vec_all_mpnet_base_v2 (NDCG@10: 0.7387)
|
82 |
- **Least Effective Teacher**: code_model2vec_codebert_base (NDCG@10: 0.2779)
|
83 |
- **Performance Range**: 62.4% difference between best and worst
|
84 |
-
- **Average Performance**: 0.
|
85 |
|
86 |
|
87 |
## 🎯 Language Performance Radar Charts
|
88 |
|
89 |
### Best Model vs Peer Models Comparison
|
90 |
|
91 |
-

|
92 |
-
|
93 |
-
*Comparative view showing how the best simplified distillation model performs against top peer models across programming languages.*
|
94 |
-
|
95 |
-
### Individual Model Performance by Language
|
96 |
-
|
97 |
-
#### code_model2vec_all_mpnet_base_v2 (Teacher: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)) - NDCG@10: 0.7387
|
98 |
-
|
99 |
-

|
100 |
-
|
101 |
-
#### code_model2vec_all_MiniLM_L6_v2 (Teacher: [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)) - NDCG@10: 0.7385
|
102 |
-
|
103 |
-

|
104 |
-
|
105 |
-
#### code_model2vec_jina_embeddings_v2_base_code (Teacher: [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code)) - NDCG@10: 0.7381
|
106 |
-
|
107 |
-

|
108 |
-
|
109 |
-
#### code_model2vec_paraphrase_MiniLM_L6_v2 (Teacher: [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2)) - NDCG@10: 0.7013
|
110 |
-
|
111 |
-

|
112 |
-
|
113 |
-
#### code_model2vec_all_mpnet_base_v2_fine_tuned (Teacher: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)) - NDCG@10: 0.6906
|
114 |
-
|
115 |
-

|
116 |
-
|
117 |
-
#### code_model2vec_Reason_ModernColBERT (Teacher: [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT)) - NDCG@10: 0.6598
|
118 |
-
|
119 |
-

|
120 |
-
|
121 |
-
#### code_model2vec_bge_m3 (Teacher: [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3)) - NDCG@10: 0.4863
|
122 |
-
|
123 |
-

|
124 |
-
|
125 |
-
#### code_model2vec_jina_embeddings_v3 (Teacher: [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3)) - NDCG@10: 0.4755
|
126 |
-
|
127 |
-

|
128 |
-
|
129 |
-
#### code_model2vec_nomic_embed_text_v2_moe (Teacher: [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe)) - NDCG@10: 0.4532
|
130 |
-
|
131 |
-

|
132 |
-
|
133 |
-
#### code_model2vec_gte_Qwen2_1.5B_instruct (Teacher: [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct)) - NDCG@10: 0.4238
|
134 |
-
|
135 |
-

|
136 |
-
|
137 |
-
#### code_model2vec_Qodo_Embed_1_1.5B (Teacher: [Qodo/Qodo-Embed-1-1.5B](https://huggingface.co/Qodo/Qodo-Embed-1-1.5B)) - NDCG@10: 0.4101
|
138 |
-
|
139 |
-

|
140 |
-
|
141 |
-
#### code_model2vec_graphcodebert_base (Teacher: [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base)) - NDCG@10: 0.3420
|
142 |
-
|
143 |
-

|
144 |
-
|
145 |
-
#### code_model2vec_Linq_Embed_Mistral (Teacher: [Linq-AI-Research/Linq-Embed-Mistral](https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral)) - NDCG@10: 0.2868
|
146 |
-
|
147 |
-

|
148 |
-
|
149 |
-
#### code_model2vec_codebert_base (Teacher: [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base)) - NDCG@10: 0.2779
|
150 |
-
|
151 |
-

|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
## 🏆 Peer Model Comparison
|
156 |
-
|
157 |
-

|
158 |
-
|
159 |
-
*Comparison with established code-specialized embedding models using actual evaluation results.*
|
160 |
-
|
161 |
-
### Complete Model Ranking
|
162 |
-
|
163 |
| Rank | Model | Type | NDCG@10 | MRR | Recall@5 |
|
164 |
|------|-------|------|---------|-----|----------|
|
165 |
| 1 | Alibaba-NLP/gte-Qwen2-1.5B-instruct | General | 0.9729 | 0.9676 | 0.9825 |
|
@@ -180,10 +108,10 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
180 |
| 16 | code_model2vec_all_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7385 | 0.7049 | 0.7910 |
|
181 |
| 17 | code_model2vec_jina_embeddings_v2_base_code | **🔥 Simplified Distillation** | 0.7381 | 0.6996 | 0.8130 |
|
182 |
| 18 | code_model2vec_paraphrase_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7013 | 0.6638 | 0.7665 |
|
183 |
-
| 19 |
|
184 |
-
| 20 |
|
185 |
-
| 21 |
|
186 |
-
| 22 |
|
187 |
| 23 | Salesforce/codet5-base | Code-Specific | 0.4872 | 0.4500 | 0.5742 |
|
188 |
| 24 | code_model2vec_bge_m3 | **🔥 Simplified Distillation** | 0.4863 | 0.4439 | 0.5514 |
|
189 |
| 25 | code_model2vec_jina_embeddings_v3 | **🔥 Simplified Distillation** | 0.4755 | 0.4416 | 0.5456 |
|
@@ -243,12 +171,12 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
243 |
|
244 |
| Language | Best Model Performance | Average Performance | Language Difficulty |
|
245 |
|----------|------------------------|--------------------|--------------------|
|
246 |
-
| Go | 0.9780 | 0.
|
247 |
-
| Java | 0.9921 | 0.
|
248 |
-
| Javascript | 0.9550 | 0.
|
249 |
-
| Php | 1.0000 | 0.
|
250 |
-
| Python | 1.0000 | 0.
|
251 |
-
| Ruby | 0.9493 | 0.
|
252 |
|
253 |
|
254 |
## 🎯 Conclusions and Recommendations
|
@@ -302,5 +230,5 @@ Based on the evaluation results across all simplified distillation models:
|
|
302 |
|
303 |
---
|
304 |
|
305 |
-
*Report generated on 2025-05-31
|
306 |
*For questions about methodology or results, please refer to the CodeSearchNet documentation.*
|
|
|
28 |
| code_model2vec_all_MiniLM_L6_v2 | [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) | 0.7385 | 0.7049 | 0.7910 | 🥈 2nd |
|
29 |
| code_model2vec_jina_embeddings_v2_base_code | [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code) | 0.7381 | 0.6996 | 0.8130 | 🥉 3rd |
|
30 |
| code_model2vec_paraphrase_MiniLM_L6_v2 | [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) | 0.7013 | 0.6638 | 0.7665 | #4 |
|
31 |
+
| code_model2vec_Reason_ModernColBERT | [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT) | 0.6598 | 0.6228 | 0.7260 | #5 |
|
32 |
+
| code_model2vec_all_mpnet_base_v2_fine_tuned | [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) | 0.5347 | 0.4875 | 0.6200 | #6 |
|
33 |
| code_model2vec_bge_m3 | [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | 0.4863 | 0.4439 | 0.5514 | #7 |
|
34 |
| code_model2vec_jina_embeddings_v3 | [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) | 0.4755 | 0.4416 | 0.5456 | #8 |
|
35 |
| code_model2vec_nomic_embed_text_v2_moe | [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) | 0.4532 | 0.4275 | 0.5094 | #9 |
|
|
|
50 |
| all_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
|
51 |
| jina_embeddings_v2_base_code | 61,053 | 15.6M | 256 | 29.8MB |
|
52 |
| paraphrase_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
|
|
|
53 |
| Reason_ModernColBERT | 50,254 | 12.9M | 256 | 24.5MB |
|
54 |
+
| all_mpnet_base_v2_fine_tuned | 29,528 | 7.6M | 256 | 28.8MB |
|
55 |
| bge_m3 | 249,999 | 64.0M | 256 | 122.1MB |
|
56 |
| jina_embeddings_v3 | 249,999 | 64.0M | 256 | 122.1MB |
|
57 |
| nomic_embed_text_v2_moe | 249,999 | 64.0M | 256 | 122.1MB |
|
|
|
69 |
#### Key Insights from Model Specifications:
|
70 |
|
71 |
|
72 |
+
- **Vocabulary Consistency**: All models use vocabulary sizes ranging from 29,525 to 249,999 tokens (avg: 101,087)
|
73 |
+
- **Parameter Efficiency**: Models range from 7.6M to 64.0M parameters (avg: 25.9M)
|
74 |
+
- **Storage Efficiency**: Disk usage ranges from 14.4MB to 122.1MB (avg: 50.4MB)
|
75 |
- **Embedding Dimensions**: Consistent 256 dimensions across all models (optimized for efficiency)
|
76 |
|
77 |
|
|
|
81 |
- **Best Teacher Model**: code_model2vec_all_mpnet_base_v2 (NDCG@10: 0.7387)
|
82 |
- **Least Effective Teacher**: code_model2vec_codebert_base (NDCG@10: 0.2779)
|
83 |
- **Performance Range**: 62.4% difference between best and worst
|
84 |
+
- **Average Performance**: 0.5190 NDCG@10
|
85 |
|
86 |
|
87 |
## 🎯 Language Performance Radar Charts
|
88 |
|
89 |
### Best Model vs Peer Models Comparison
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
| Rank | Model | Type | NDCG@10 | MRR | Recall@5 |
|
92 |
|------|-------|------|---------|-----|----------|
|
93 |
| 1 | Alibaba-NLP/gte-Qwen2-1.5B-instruct | General | 0.9729 | 0.9676 | 0.9825 |
|
|
|
108 |
| 16 | code_model2vec_all_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7385 | 0.7049 | 0.7910 |
|
109 |
| 17 | code_model2vec_jina_embeddings_v2_base_code | **🔥 Simplified Distillation** | 0.7381 | 0.6996 | 0.8130 |
|
110 |
| 18 | code_model2vec_paraphrase_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7013 | 0.6638 | 0.7665 |
|
111 |
+
| 19 | code_model2vec_Reason_ModernColBERT | **🔥 Simplified Distillation** | 0.6598 | 0.6228 | 0.7260 |
|
112 |
+
| 20 | potion-multilingual-128M | Model2Vec | 0.6124 | 0.5683 | 0.7017 |
|
113 |
+
| 21 | huggingface/CodeBERTa-small-v1 | Code-Specific | 0.5903 | 0.5350 | 0.6779 |
|
114 |
+
| 22 | code_model2vec_all_mpnet_base_v2_fine_tuned | **🎓 Fine-tuned Distillation** | 0.5347 | 0.4875 | 0.6200 |
|
115 |
| 23 | Salesforce/codet5-base | Code-Specific | 0.4872 | 0.4500 | 0.5742 |
|
116 |
| 24 | code_model2vec_bge_m3 | **🔥 Simplified Distillation** | 0.4863 | 0.4439 | 0.5514 |
|
117 |
| 25 | code_model2vec_jina_embeddings_v3 | **🔥 Simplified Distillation** | 0.4755 | 0.4416 | 0.5456 |
|
|
|
171 |
|
172 |
| Language | Best Model Performance | Average Performance | Language Difficulty |
|
173 |
|----------|------------------------|--------------------|--------------------|
|
174 |
+
| Go | 0.9780 | 0.6923 | Easy |
|
175 |
+
| Java | 0.9921 | 0.6545 | Easy |
|
176 |
+
| Javascript | 0.9550 | 0.5831 | Easy |
|
177 |
+
| Php | 1.0000 | 0.6325 | Easy |
|
178 |
+
| Python | 1.0000 | 0.8599 | Easy |
|
179 |
+
| Ruby | 0.9493 | 0.6333 | Easy |
|
180 |
|
181 |
|
182 |
## 🎯 Conclusions and Recommendations
|
|
|
230 |
|
231 |
---
|
232 |
|
233 |
+
*Report generated on 2025-05-31 21:07:06 using automated analysis pipeline.*
|
234 |
*For questions about methodology or results, please refer to the CodeSearchNet documentation.*
|
patches/model2vec.patch
DELETED
@@ -1,39 +0,0 @@
|
|
1 |
-
--- a/model2vec/train/base.py
|
2 |
-
+++ b/model2vec/train/base.py
|
3 |
-
@@ -35,7 +35,7 @@ class FinetunableStaticModel(nn.Module):
|
4 |
-
)
|
5 |
-
self.vectors = vectors.float()
|
6 |
-
|
7 |
-
- self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
|
8 |
-
+ self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=pad_id)
|
9 |
-
self.head = self.construct_head()
|
10 |
-
self.w = self.construct_weights()
|
11 |
-
self.tokenizer = tokenizer
|
12 |
-
--- a/model2vec/distill/distillation.py
|
13 |
-
+++ b/model2vec/distill/distillation.py
|
14 |
-
@@ -137,7 +137,10 @@ def distill_from_model(
|
15 |
-
# Get the language from the model card.
|
16 |
-
try:
|
17 |
-
info = model_info(model_name)
|
18 |
-
- language = info.cardData.get("language", None)
|
19 |
-
+ if info is not None and hasattr(info, 'cardData') and info.cardData is not None:
|
20 |
-
+ language = info.cardData.get("language", None)
|
21 |
-
+ else:
|
22 |
-
+ language = None
|
23 |
-
except RepositoryNotFoundError:
|
24 |
-
logger.info("No model info found for the model. Setting language to None.")
|
25 |
-
language = None
|
26 |
-
--- a/model2vec/distill/inference.py
|
27 |
-
+++ b/model2vec/distill/inference.py
|
28 |
-
@@ -109,5 +109,12 @@ def create_embeddings(
|
29 |
-
out_tokens.extend([Token(x, False) for x in tokens])
|
30 |
-
out_weights = np.stack(intermediate_weights)
|
31 |
-
|
32 |
-
+ # Validate token-vector consistency to prevent failures
|
33 |
-
+ if len(out_tokens) != out_weights.shape[0]:
|
34 |
-
+ logger.warning(f"Token-vector mismatch: {len(out_tokens)} tokens vs {out_weights.shape[0]} vectors. Truncating to prevent failure.")
|
35 |
-
+ min_count = min(len(out_tokens), out_weights.shape[0])
|
36 |
-
+ out_tokens = out_tokens[:min_count]
|
37 |
-
+ out_weights = out_weights[:min_count]
|
38 |
-
+
|
39 |
-
return out_tokens, out_weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
patches/tokenlearn.patch
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
--- a/tokenlearn/pretrain.py
|
2 |
-
+++ b/tokenlearn/pretrain.py
|
3 |
-
@@ -38,7 +38,10 @@ class FinetunableStaticModel(nn.Module):
|
4 |
-
"""Run the model using input IDs."""
|
5 |
-
input_ids = input_ids.view(-1)
|
6 |
-
input_ids = input_ids[input_ids != self.pad_token_id]
|
7 |
-
- w = self.w[input_ids]
|
8 |
-
+ # Fix for index out of bounds issue
|
9 |
-
+ # Clamp input_ids to valid range to prevent IndexError during training
|
10 |
-
+ valid_input_ids = torch.clamp(input_ids, 0, self.w.shape[0] - 1)
|
11 |
-
+ w = self.w[valid_input_ids]
|
12 |
-
return self.sub_forward(w)
|
13 |
-
|
14 |
-
def forward(self, x):
|
15 |
-
@@ -46,7 +49,10 @@ class FinetunableStaticModel(nn.Module):
|
16 |
-
# Add a small epsilon to avoid division by zero
|
17 |
-
length = zeros.sum(1) + 1e-16
|
18 |
-
- embedded = self.embeddings(input_ids)
|
19 |
-
+ # Fix for embedding index out of bounds issue
|
20 |
-
+ # Clamp input_ids to valid embedding range
|
21 |
-
+ valid_input_ids = torch.clamp(input_ids, 0, self.embeddings.num_embeddings - 1)
|
22 |
-
+ embedded = self.embeddings(valid_input_ids)
|
23 |
-
# Zero out the padding
|
24 |
-
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
|
25 |
-
# Simulate actual mean
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
CHANGED
@@ -19,24 +19,31 @@ dependencies = [
|
|
19 |
"flash-attn>=2.7.4.post1",
|
20 |
"hatchling>=1.27.0",
|
21 |
"iso639>=0.1.4",
|
|
|
|
|
22 |
"kaleido==1.0.0rc13",
|
23 |
"lightning>=2.5.1.post0",
|
24 |
"matplotlib>=3.10.3",
|
25 |
-
"
|
26 |
"mteb>=1.14.15",
|
27 |
"numpy>=1.26.4",
|
28 |
"plotly>=6.1.1",
|
29 |
"psutil>=7.0.0",
|
30 |
"pydantic>=2.11.5",
|
31 |
"requests>=2.32.3",
|
|
|
|
|
32 |
"scikit-learn>=1.6.1",
|
33 |
"seaborn>=0.13.2",
|
34 |
"sentence-transformers>=4.1.0",
|
35 |
"setuptools>=80.8.0",
|
|
|
36 |
"smart-open[s3]>=7.1.0",
|
37 |
"statsmodels>=0.14.4",
|
38 |
-
"
|
39 |
"torch>=2.7.0",
|
|
|
|
|
40 |
"typer>=0.16.0",
|
41 |
]
|
42 |
|
@@ -78,7 +85,9 @@ exclude = [
|
|
78 |
"__pycache__",
|
79 |
"build",
|
80 |
"dist",
|
81 |
-
"vendor"
|
|
|
|
|
82 |
]
|
83 |
|
84 |
[tool.ruff.lint]
|
@@ -114,6 +123,8 @@ ignore = [
|
|
114 |
"E501", # Line too long
|
115 |
"PLR2004",
|
116 |
"RUF001",
|
|
|
|
|
117 |
]
|
118 |
|
119 |
[tool.ruff.lint.mccabe]
|
|
|
19 |
"flash-attn>=2.7.4.post1",
|
20 |
"hatchling>=1.27.0",
|
21 |
"iso639>=0.1.4",
|
22 |
+
"jinja2>=3.0.0",
|
23 |
+
"joblib>=1.0.0",
|
24 |
"kaleido==1.0.0rc13",
|
25 |
"lightning>=2.5.1.post0",
|
26 |
"matplotlib>=3.10.3",
|
27 |
+
"more-itertools>=10.5.0",
|
28 |
"mteb>=1.14.15",
|
29 |
"numpy>=1.26.4",
|
30 |
"plotly>=6.1.1",
|
31 |
"psutil>=7.0.0",
|
32 |
"pydantic>=2.11.5",
|
33 |
"requests>=2.32.3",
|
34 |
+
"rich>=10.0.0",
|
35 |
+
"safetensors>=0.3.0",
|
36 |
"scikit-learn>=1.6.1",
|
37 |
"seaborn>=0.13.2",
|
38 |
"sentence-transformers>=4.1.0",
|
39 |
"setuptools>=80.8.0",
|
40 |
+
"skops>=0.11.0",
|
41 |
"smart-open[s3]>=7.1.0",
|
42 |
"statsmodels>=0.14.4",
|
43 |
+
"tokenizers>=0.20",
|
44 |
"torch>=2.7.0",
|
45 |
+
"transformers<=4.52.1",
|
46 |
+
"tqdm>=4.65.0",
|
47 |
"typer>=0.16.0",
|
48 |
]
|
49 |
|
|
|
85 |
"__pycache__",
|
86 |
"build",
|
87 |
"dist",
|
88 |
+
"vendor",
|
89 |
+
"src/distiller/model2vec",
|
90 |
+
"src/distiller/tokenlearn"
|
91 |
]
|
92 |
|
93 |
[tool.ruff.lint]
|
|
|
123 |
"E501", # Line too long
|
124 |
"PLR2004",
|
125 |
"RUF001",
|
126 |
+
"D100", # Missing docstring in public module
|
127 |
+
"D101", # Missing docstring in public class
|
128 |
]
|
129 |
|
130 |
[tool.ruff.lint.mccabe]
|
src/distiller/__main__.py
CHANGED
@@ -17,12 +17,41 @@ def distill(
|
|
17 |
train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False,
|
18 |
teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None,
|
19 |
pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
) -> None:
|
21 |
"""Run unified Model2Vec distillation with optional training."""
|
22 |
from .distill import main as distill_main
|
23 |
|
24 |
-
# Call the distill main function with arguments
|
25 |
-
distill_main(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
@app.command()
|
@@ -53,5 +82,26 @@ def analyze(
|
|
53 |
analyze_main(results_dir or "code_model2vec/evaluation_results", model_name, output, export_csv)
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
if __name__ == "__main__":
|
57 |
app()
|
|
|
17 |
train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False,
|
18 |
teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None,
|
19 |
pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None,
|
20 |
+
clear_cache: Annotated[
|
21 |
+
bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation")
|
22 |
+
] = False,
|
23 |
+
clear_checkpoints: Annotated[
|
24 |
+
bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
|
25 |
+
] = False,
|
26 |
+
skip_ptr: Annotated[
|
27 |
+
bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
|
28 |
+
] = False,
|
29 |
+
use_optimized_dataset: Annotated[
|
30 |
+
bool,
|
31 |
+
typer.Option(
|
32 |
+
"--use-optimized-dataset", help="Use the pre-created optimized dataset from code_model2vec/dataset"
|
33 |
+
),
|
34 |
+
] = False,
|
35 |
+
dataset_path: Annotated[
|
36 |
+
str | None,
|
37 |
+
typer.Option("--dataset-path", help="Path to custom dataset directory (defaults to code_model2vec/dataset)"),
|
38 |
+
] = None,
|
39 |
) -> None:
|
40 |
"""Run unified Model2Vec distillation with optional training."""
|
41 |
from .distill import main as distill_main
|
42 |
|
43 |
+
# Call the distill main function with all arguments
|
44 |
+
distill_main(
|
45 |
+
use_beam,
|
46 |
+
train,
|
47 |
+
teacher_models,
|
48 |
+
pca_dims,
|
49 |
+
clear_cache,
|
50 |
+
clear_checkpoints,
|
51 |
+
skip_ptr,
|
52 |
+
use_optimized_dataset,
|
53 |
+
dataset_path,
|
54 |
+
)
|
55 |
|
56 |
|
57 |
@app.command()
|
|
|
82 |
analyze_main(results_dir or "code_model2vec/evaluation_results", model_name, output, export_csv)
|
83 |
|
84 |
|
85 |
+
@app.command()
|
86 |
+
def dataset(
|
87 |
+
max_samples_per_lang: Annotated[int, typer.Option(help="Maximum samples per language")] = 50000,
|
88 |
+
min_doc_words: Annotated[int, typer.Option(help="Minimum words in documentation")] = 3,
|
89 |
+
max_doc_words: Annotated[int, typer.Option(help="Maximum words in documentation")] = 100,
|
90 |
+
min_code_chars: Annotated[int, typer.Option(help="Minimum characters in code")] = 50,
|
91 |
+
max_code_chars: Annotated[int, typer.Option(help="Maximum characters in code")] = 2000,
|
92 |
+
output_dir: Annotated[str | None, typer.Option(help="Output directory for dataset")] = None,
|
93 |
+
simple_format: Annotated[
|
94 |
+
bool, typer.Option(help="Create only simple format (not multiple training formats)")
|
95 |
+
] = False,
|
96 |
+
) -> None:
|
97 |
+
"""Create optimized training dataset from CodeSearchNet for code search tasks."""
|
98 |
+
from .dataset import main as dataset_main
|
99 |
+
|
100 |
+
# Call the dataset main function with arguments
|
101 |
+
dataset_main(
|
102 |
+
max_samples_per_lang, min_doc_words, max_doc_words, min_code_chars, max_code_chars, output_dir, simple_format
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
if __name__ == "__main__":
|
107 |
app()
|
src/distiller/analyze.py
CHANGED
@@ -510,7 +510,7 @@ class CodeSearchNetAnalyzer:
|
|
510 |
|
511 |
try:
|
512 |
# Try to load the model and get specifications
|
513 |
-
from model2vec import StaticModel
|
514 |
|
515 |
model = StaticModel.from_pretrained(str(model_dir))
|
516 |
|
|
|
510 |
|
511 |
try:
|
512 |
# Try to load the model and get specifications
|
513 |
+
from distiller.model2vec import StaticModel
|
514 |
|
515 |
model = StaticModel.from_pretrained(str(model_dir))
|
516 |
|
src/distiller/config.py
CHANGED
@@ -212,13 +212,19 @@ class DistillationConfig(BaseModel):
|
|
212 |
# Tokenlearn-specific parameters (POTION approach)
|
213 |
tokenlearn_dataset: str = "sentence-transformers/codesearchnet" # Dataset for tokenlearn featurization
|
214 |
tokenlearn_dataset_name: str = "pair" # Use 'pair' configuration (only available config)
|
215 |
-
tokenlearn_text_key: str =
|
|
|
|
|
216 |
tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
|
217 |
tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
|
218 |
|
219 |
# Post-training configuration
|
220 |
skip_post_training_regularization: bool = False # Skip PCA + SIF re-regularization step
|
221 |
|
|
|
|
|
|
|
|
|
222 |
|
223 |
distillation_config = DistillationConfig()
|
224 |
|
|
|
212 |
# Tokenlearn-specific parameters (POTION approach)
|
213 |
tokenlearn_dataset: str = "sentence-transformers/codesearchnet" # Dataset for tokenlearn featurization
|
214 |
tokenlearn_dataset_name: str = "pair" # Use 'pair' configuration (only available config)
|
215 |
+
tokenlearn_text_key: str = (
|
216 |
+
"combined_text" # Text field to use from the dataset ('combined_text' for doc-code pairs)
|
217 |
+
)
|
218 |
tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
|
219 |
tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
|
220 |
|
221 |
# Post-training configuration
|
222 |
skip_post_training_regularization: bool = False # Skip PCA + SIF re-regularization step
|
223 |
|
224 |
+
# Dataset configuration
|
225 |
+
use_optimized_dataset: bool = True # Use the pre-created optimized dataset from dataset.py
|
226 |
+
custom_dataset_path: str | None = "code_model2vec/dataset" # Path to custom dataset directory
|
227 |
+
|
228 |
|
229 |
distillation_config = DistillationConfig()
|
230 |
|
src/distiller/dataset.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Custom Dataset Generation for Code-Specialized Model Training.
|
3 |
+
|
4 |
+
This module creates optimized training datasets from CodeSearchNet that are specifically
|
5 |
+
designed to improve performance on code search evaluation tasks.
|
6 |
+
|
7 |
+
Features:
|
8 |
+
- High-quality doc-code pairs optimized for retrieval
|
9 |
+
- Balanced sampling across programming languages
|
10 |
+
- Multiple training formats (doc-only, code-only, combined)
|
11 |
+
- Quality filtering and data cleaning
|
12 |
+
- Train/test/eval splits with proper stratification
|
13 |
+
- Efficient parquet format output
|
14 |
+
"""
|
15 |
+
|
16 |
+
import json
|
17 |
+
import logging
|
18 |
+
import time
|
19 |
+
from pathlib import Path
|
20 |
+
from typing import Annotated, Any
|
21 |
+
|
22 |
+
import pandas as pd
|
23 |
+
import typer
|
24 |
+
from datasets import load_dataset
|
25 |
+
from tqdm import tqdm
|
26 |
+
|
27 |
+
from .config import languages_config
|
28 |
+
|
29 |
+
# Set up logging
|
30 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
# Dataset configuration
|
34 |
+
DATASET_OUTPUT_DIR = Path("code_model2vec/dataset")
|
35 |
+
DEFAULT_MAX_SAMPLES_PER_LANG = 50000
|
36 |
+
DEFAULT_MIN_DOC_WORDS = 3
|
37 |
+
DEFAULT_MAX_DOC_WORDS = 100
|
38 |
+
DEFAULT_MIN_CODE_CHARS = 50
|
39 |
+
DEFAULT_MAX_CODE_CHARS = 2000
|
40 |
+
|
41 |
+
|
42 |
+
def create_optimized_dataset(
|
43 |
+
max_samples_per_lang: int = DEFAULT_MAX_SAMPLES_PER_LANG,
|
44 |
+
min_doc_words: int = DEFAULT_MIN_DOC_WORDS,
|
45 |
+
max_doc_words: int = DEFAULT_MAX_DOC_WORDS,
|
46 |
+
min_code_chars: int = DEFAULT_MIN_CODE_CHARS,
|
47 |
+
max_code_chars: int = DEFAULT_MAX_CODE_CHARS,
|
48 |
+
output_dir: Path | None = None,
|
49 |
+
create_multiple_formats: bool = True,
|
50 |
+
) -> dict[str, Any]:
|
51 |
+
"""
|
52 |
+
Create optimized training dataset from CodeSearchNet for code search tasks.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
max_samples_per_lang: Maximum samples per programming language
|
56 |
+
min_doc_words: Minimum words in documentation
|
57 |
+
max_doc_words: Maximum words in documentation
|
58 |
+
min_code_chars: Minimum characters in code
|
59 |
+
max_code_chars: Maximum characters in code
|
60 |
+
output_dir: Output directory for dataset
|
61 |
+
create_multiple_formats: Create multiple training formats
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Dictionary with dataset statistics and file paths
|
65 |
+
"""
|
66 |
+
output_dir = DATASET_OUTPUT_DIR if output_dir is None else Path(output_dir)
|
67 |
+
|
68 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
69 |
+
|
70 |
+
logger.info("🚀 Starting optimized CodeSearchNet dataset creation...")
|
71 |
+
logger.info(f"📁 Output directory: {output_dir}")
|
72 |
+
logger.info(f"📊 Target: {max_samples_per_lang} samples per language")
|
73 |
+
logger.info(f"🔍 Languages: {', '.join(languages_config.all)}")
|
74 |
+
|
75 |
+
start_time = time.time()
|
76 |
+
all_samples = []
|
77 |
+
language_stats = {}
|
78 |
+
|
79 |
+
# Process each programming language
|
80 |
+
for language in languages_config.all:
|
81 |
+
logger.info(f"\n🔄 Processing {language}...")
|
82 |
+
|
83 |
+
try:
|
84 |
+
# Load CodeSearchNet dataset for this language
|
85 |
+
dataset = load_dataset("code_search_net", language, split="train", trust_remote_code=True)
|
86 |
+
|
87 |
+
language_samples = []
|
88 |
+
processed_count = 0
|
89 |
+
quality_filtered = 0
|
90 |
+
|
91 |
+
# Process examples with quality filtering
|
92 |
+
for example in tqdm(dataset, desc=f"Processing {language}", unit="examples"):
|
93 |
+
processed_count += 1
|
94 |
+
|
95 |
+
# Extract documentation and code
|
96 |
+
doc_string = example.get("func_documentation_string", "").strip()
|
97 |
+
code_string = example.get("func_code_string", "").strip()
|
98 |
+
func_name = example.get("func_name", "").strip()
|
99 |
+
|
100 |
+
# Quality filters
|
101 |
+
if not _passes_quality_filters(
|
102 |
+
doc_string, code_string, func_name, min_doc_words, max_doc_words, min_code_chars, max_code_chars
|
103 |
+
):
|
104 |
+
continue
|
105 |
+
|
106 |
+
quality_filtered += 1
|
107 |
+
|
108 |
+
# Create optimized training samples
|
109 |
+
samples = _create_training_samples(
|
110 |
+
doc_string, code_string, func_name, language, create_multiple_formats
|
111 |
+
)
|
112 |
+
language_samples.extend(samples)
|
113 |
+
|
114 |
+
# Stop if we have enough samples
|
115 |
+
if len(language_samples) >= max_samples_per_lang:
|
116 |
+
break
|
117 |
+
|
118 |
+
# Truncate to exact target size
|
119 |
+
language_samples = language_samples[:max_samples_per_lang]
|
120 |
+
all_samples.extend(language_samples)
|
121 |
+
|
122 |
+
# Track statistics
|
123 |
+
language_stats[language] = {
|
124 |
+
"processed": processed_count,
|
125 |
+
"quality_filtered": quality_filtered,
|
126 |
+
"final_samples": len(language_samples),
|
127 |
+
"quality_rate": quality_filtered / processed_count if processed_count > 0 else 0,
|
128 |
+
}
|
129 |
+
|
130 |
+
logger.info(f"✅ {language}: {len(language_samples)} samples from {quality_filtered} quality examples")
|
131 |
+
|
132 |
+
except Exception:
|
133 |
+
logger.exception(f"❌ Failed to process {language}")
|
134 |
+
language_stats[language] = {
|
135 |
+
"processed": 0,
|
136 |
+
"quality_filtered": 0,
|
137 |
+
"final_samples": 0,
|
138 |
+
"quality_rate": 0.0,
|
139 |
+
}
|
140 |
+
|
141 |
+
# Create DataFrame
|
142 |
+
logger.info(f"\n📊 Creating dataset with {len(all_samples)} total samples...")
|
143 |
+
df = pd.DataFrame(all_samples)
|
144 |
+
|
145 |
+
# Create stratified splits
|
146 |
+
train_df, test_df = _create_stratified_splits(df)
|
147 |
+
|
148 |
+
# Save datasets
|
149 |
+
dataset_files = _save_datasets(output_dir, train_df, test_df)
|
150 |
+
|
151 |
+
# Save metadata
|
152 |
+
metadata = {
|
153 |
+
"creation_time": time.strftime("%Y-%m-%d %H:%M:%S"),
|
154 |
+
"total_samples": len(all_samples),
|
155 |
+
"train_samples": len(train_df),
|
156 |
+
"test_samples": len(test_df),
|
157 |
+
"languages": languages_config.all,
|
158 |
+
"language_stats": language_stats,
|
159 |
+
"quality_filters": {
|
160 |
+
"min_doc_words": min_doc_words,
|
161 |
+
"max_doc_words": max_doc_words,
|
162 |
+
"min_code_chars": min_code_chars,
|
163 |
+
"max_code_chars": max_code_chars,
|
164 |
+
},
|
165 |
+
"files": dataset_files,
|
166 |
+
"processing_time": time.time() - start_time,
|
167 |
+
}
|
168 |
+
|
169 |
+
metadata_file = output_dir / "metadata.json"
|
170 |
+
with metadata_file.open("w") as f:
|
171 |
+
json.dump(metadata, f, indent=2)
|
172 |
+
|
173 |
+
logger.info(f"\n🎉 Dataset creation completed in {metadata['processing_time']:.2f} seconds!")
|
174 |
+
logger.info("📊 Final statistics:")
|
175 |
+
logger.info(f" - Total samples: {metadata['total_samples']}")
|
176 |
+
logger.info(f" - Train: {metadata['train_samples']}")
|
177 |
+
logger.info(f" - Test: {metadata['test_samples']}")
|
178 |
+
logger.info(f"💾 Metadata saved to: {metadata_file}")
|
179 |
+
|
180 |
+
return metadata
|
181 |
+
|
182 |
+
|
183 |
+
def _passes_quality_filters(
|
184 |
+
doc_string: str,
|
185 |
+
code_string: str,
|
186 |
+
func_name: str,
|
187 |
+
min_doc_words: int,
|
188 |
+
max_doc_words: int,
|
189 |
+
min_code_chars: int,
|
190 |
+
max_code_chars: int,
|
191 |
+
) -> bool:
|
192 |
+
"""Apply quality filters optimized for code retrieval following RAG best practices."""
|
193 |
+
# Basic existence checks
|
194 |
+
if not doc_string or not code_string or not func_name:
|
195 |
+
return False
|
196 |
+
|
197 |
+
# Documentation quality filters for code retrieval
|
198 |
+
doc_words = len(doc_string.split())
|
199 |
+
if doc_words < min_doc_words or doc_words > max_doc_words:
|
200 |
+
return False
|
201 |
+
|
202 |
+
# Code quality filters
|
203 |
+
code_length = len(code_string)
|
204 |
+
if code_length < min_code_chars or code_length > max_code_chars:
|
205 |
+
return False
|
206 |
+
|
207 |
+
# Content quality filters for code retrieval
|
208 |
+
doc_lower = doc_string.lower()
|
209 |
+
code_string.lower()
|
210 |
+
|
211 |
+
# Skip low-quality documentation (expanded for code context)
|
212 |
+
skip_phrases = [
|
213 |
+
"todo",
|
214 |
+
"fixme",
|
215 |
+
"hack",
|
216 |
+
"temp",
|
217 |
+
"test",
|
218 |
+
"placeholder",
|
219 |
+
"not implemented",
|
220 |
+
"coming soon",
|
221 |
+
"tbd",
|
222 |
+
"xxx",
|
223 |
+
"broken",
|
224 |
+
"deprecated",
|
225 |
+
"legacy",
|
226 |
+
"old version",
|
227 |
+
"outdated",
|
228 |
+
]
|
229 |
+
if any(phrase in doc_lower for phrase in skip_phrases):
|
230 |
+
return False
|
231 |
+
|
232 |
+
# Ensure meaningful documentation for code retrieval
|
233 |
+
if func_name.lower() in doc_lower and doc_words < 5:
|
234 |
+
return False
|
235 |
+
|
236 |
+
# Code structure validation (more comprehensive for retrieval)
|
237 |
+
has_function = any(
|
238 |
+
pattern in code_string for pattern in ["def ", "function ", "class ", "public ", "private ", "static "]
|
239 |
+
)
|
240 |
+
if not has_function:
|
241 |
+
return False
|
242 |
+
|
243 |
+
# Skip trivial or incomplete code
|
244 |
+
trivial_code_patterns = [
|
245 |
+
"pass",
|
246 |
+
"return None",
|
247 |
+
"return;",
|
248 |
+
"throw new Error",
|
249 |
+
"# TODO",
|
250 |
+
"// TODO",
|
251 |
+
"print(",
|
252 |
+
"console.log(",
|
253 |
+
]
|
254 |
+
if any(pattern in code_string for pattern in trivial_code_patterns) and len(code_string) < 100:
|
255 |
+
return False
|
256 |
+
|
257 |
+
# Ensure documentation describes functionality (not just naming)
|
258 |
+
generic_docs = [
|
259 |
+
"returns a value",
|
260 |
+
"does something",
|
261 |
+
"helper function",
|
262 |
+
"utility method",
|
263 |
+
"this function",
|
264 |
+
"this method",
|
265 |
+
"returns the result",
|
266 |
+
"performs operation",
|
267 |
+
]
|
268 |
+
if any(generic in doc_lower for generic in generic_docs):
|
269 |
+
return False
|
270 |
+
|
271 |
+
# Ensure documentation has descriptive content for retrieval
|
272 |
+
descriptive_words = [
|
273 |
+
"parse",
|
274 |
+
"convert",
|
275 |
+
"transform",
|
276 |
+
"calculate",
|
277 |
+
"validate",
|
278 |
+
"format",
|
279 |
+
"filter",
|
280 |
+
"sort",
|
281 |
+
"search",
|
282 |
+
"find",
|
283 |
+
"create",
|
284 |
+
"generate",
|
285 |
+
"process",
|
286 |
+
"handle",
|
287 |
+
"manage",
|
288 |
+
"update",
|
289 |
+
"modify",
|
290 |
+
"remove",
|
291 |
+
"delete",
|
292 |
+
"add",
|
293 |
+
]
|
294 |
+
if not any(word in doc_lower for word in descriptive_words) and doc_words < 8:
|
295 |
+
return False
|
296 |
+
|
297 |
+
# Code-documentation alignment check (key for retrieval quality)
|
298 |
+
return _check_code_doc_alignment(doc_string, code_string, func_name)
|
299 |
+
|
300 |
+
|
301 |
+
def _check_code_doc_alignment(doc_string: str, code_string: str, func_name: str) -> bool:
|
302 |
+
"""Check if documentation and code are well-aligned for retrieval tasks."""
|
303 |
+
doc_lower = doc_string.lower()
|
304 |
+
code_lower = code_string.lower()
|
305 |
+
|
306 |
+
# Function name should relate to documentation
|
307 |
+
func_base = func_name.lower().replace("_", " ").replace("-", " ")
|
308 |
+
|
309 |
+
# Check for obvious mismatches
|
310 |
+
doc_has_return = any(word in doc_lower for word in ["return", "returns", "gives", "outputs"])
|
311 |
+
code_has_return = "return " in code_lower
|
312 |
+
|
313 |
+
# If doc mentions returning something, code should have returns
|
314 |
+
if doc_has_return and not code_has_return and len(code_string.split("\n")) > 3:
|
315 |
+
return False
|
316 |
+
|
317 |
+
# Check for parameter mentions alignment
|
318 |
+
any(word in doc_lower for word in ["parameter", "param", "argument", "input"])
|
319 |
+
"(" in func_name and func_name.count("(") == 1
|
320 |
+
|
321 |
+
# Basic semantic alignment
|
322 |
+
action_words = ["sort", "parse", "convert", "validate", "format", "filter", "search", "calculate"]
|
323 |
+
doc_actions = [word for word in action_words if word in doc_lower]
|
324 |
+
[word for word in action_words if word in code_lower or word in func_base]
|
325 |
+
|
326 |
+
# If documentation mentions specific actions, code or function name should reflect them
|
327 |
+
return not (doc_actions and not any(action in code_lower or action in func_base for action in doc_actions))
|
328 |
+
|
329 |
+
|
330 |
+
def _create_training_samples(
|
331 |
+
doc_string: str,
|
332 |
+
code_string: str,
|
333 |
+
func_name: str,
|
334 |
+
language: str,
|
335 |
+
create_multiple_formats: bool,
|
336 |
+
) -> list[dict[str, Any]]:
|
337 |
+
"""Create optimized training samples for code retrieval with proper training schema."""
|
338 |
+
samples = []
|
339 |
+
|
340 |
+
if create_multiple_formats:
|
341 |
+
# Format 1: Documentation query → Code (direct evaluation format)
|
342 |
+
query_1 = doc_string
|
343 |
+
text_1 = _format_training_text(query_1, code_string, language)
|
344 |
+
samples.append(
|
345 |
+
{
|
346 |
+
"language": language,
|
347 |
+
"query": query_1,
|
348 |
+
"code": code_string,
|
349 |
+
"text": text_1,
|
350 |
+
}
|
351 |
+
)
|
352 |
+
|
353 |
+
# Format 2: How-to query (realistic developer search)
|
354 |
+
query_2 = _generate_how_to_query(doc_string, func_name, language)
|
355 |
+
text_2 = _format_training_text(query_2, code_string, language)
|
356 |
+
samples.append(
|
357 |
+
{
|
358 |
+
"language": language,
|
359 |
+
"query": query_2,
|
360 |
+
"code": code_string,
|
361 |
+
"text": text_2,
|
362 |
+
}
|
363 |
+
)
|
364 |
+
|
365 |
+
# Format 3: Functional requirement query
|
366 |
+
query_3 = _generate_functional_query(doc_string, func_name)
|
367 |
+
text_3 = _format_training_text(query_3, code_string, language)
|
368 |
+
samples.append(
|
369 |
+
{
|
370 |
+
"language": language,
|
371 |
+
"query": query_3,
|
372 |
+
"code": code_string,
|
373 |
+
"text": text_3,
|
374 |
+
}
|
375 |
+
)
|
376 |
+
|
377 |
+
# Format 4: Implementation-specific query
|
378 |
+
query_4 = _generate_implementation_query(doc_string, func_name, language)
|
379 |
+
text_4 = _format_training_text(query_4, code_string, language)
|
380 |
+
samples.append(
|
381 |
+
{
|
382 |
+
"language": language,
|
383 |
+
"query": query_4,
|
384 |
+
"code": code_string,
|
385 |
+
"text": text_4,
|
386 |
+
}
|
387 |
+
)
|
388 |
+
|
389 |
+
else:
|
390 |
+
# Simple format - direct documentation to code
|
391 |
+
query = doc_string
|
392 |
+
text = _format_training_text(query, code_string, language)
|
393 |
+
samples.append(
|
394 |
+
{
|
395 |
+
"language": language,
|
396 |
+
"query": query,
|
397 |
+
"code": code_string,
|
398 |
+
"text": text,
|
399 |
+
}
|
400 |
+
)
|
401 |
+
|
402 |
+
return samples
|
403 |
+
|
404 |
+
|
405 |
+
def _format_training_text(query: str, code: str, language: str) -> str:
|
406 |
+
"""Format query and code into a single training text chunk with markdown-style code blocks."""
|
407 |
+
# Clean up query but preserve internal code formatting
|
408 |
+
query_clean = query.strip()
|
409 |
+
code_clean = code.strip()
|
410 |
+
|
411 |
+
# Create training text with proper markdown format and newline separation
|
412 |
+
# Structure: query + empty line + markdown code block with language
|
413 |
+
return f"{query_clean}\n\n```{language}\n{code_clean}\n```"
|
414 |
+
|
415 |
+
|
416 |
+
def _generate_how_to_query(doc_string: str, func_name: str, language: str) -> str:
|
417 |
+
"""Generate realistic 'how to' queries that developers might actually search for."""
|
418 |
+
# Extract key action words from documentation
|
419 |
+
doc_lower = doc_string.lower()
|
420 |
+
func_lower = func_name.lower()
|
421 |
+
|
422 |
+
# Common developer query patterns
|
423 |
+
if "sort" in doc_lower or "sort" in func_lower:
|
424 |
+
return f"How to sort data in {language}"
|
425 |
+
if "parse" in doc_lower or "parse" in func_lower:
|
426 |
+
return f"How to parse data in {language}"
|
427 |
+
if "convert" in doc_lower or "transform" in doc_lower or "convert" in func_lower:
|
428 |
+
return f"How to convert data in {language}"
|
429 |
+
if "validate" in doc_lower or "check" in doc_lower or "validate" in func_lower:
|
430 |
+
return f"How to validate input in {language}"
|
431 |
+
if "calculate" in doc_lower or "compute" in doc_lower or "calc" in func_lower:
|
432 |
+
return f"How to calculate values in {language}"
|
433 |
+
if "format" in doc_lower or "format" in func_lower:
|
434 |
+
return f"How to format output in {language}"
|
435 |
+
if "filter" in doc_lower or "filter" in func_lower:
|
436 |
+
return f"How to filter data in {language}"
|
437 |
+
if "search" in doc_lower or "find" in doc_lower or "search" in func_lower or "find" in func_lower:
|
438 |
+
return f"How to search through data in {language}"
|
439 |
+
# Use function name for more specific queries
|
440 |
+
if func_name and len(func_name) > 2:
|
441 |
+
# Extract meaningful words from function name
|
442 |
+
func_words = func_name.replace("_", " ").replace("-", " ").strip()
|
443 |
+
if func_words:
|
444 |
+
return f"How to {func_words.lower()} in {language}"
|
445 |
+
# Fallback to more generic query
|
446 |
+
action = doc_string.split()[0] if doc_string.split() else "implement"
|
447 |
+
return f"How to {action.lower()} in {language}"
|
448 |
+
|
449 |
+
|
450 |
+
def _generate_functional_query(doc_string: str, func_name: str) -> str:
|
451 |
+
"""Generate functional requirement queries focusing on what the code accomplishes."""
|
452 |
+
# Clean up documentation to create natural query
|
453 |
+
doc_clean = doc_string.strip().rstrip(".")
|
454 |
+
|
455 |
+
# Transform to question format
|
456 |
+
if doc_clean.startswith(("Returns", "Return")):
|
457 |
+
return f"Function that {doc_clean.lower()}"
|
458 |
+
if doc_clean.startswith(("Creates", "Create")):
|
459 |
+
return f"Code to {doc_clean.lower()}"
|
460 |
+
if doc_clean.startswith(("Checks", "Check")):
|
461 |
+
return f"Function to {doc_clean.lower()}"
|
462 |
+
|
463 |
+
# Use function name to enhance the query if available
|
464 |
+
if func_name and len(func_name) > 2:
|
465 |
+
func_words = func_name.replace("_", " ").replace("-", " ").strip()
|
466 |
+
if func_words and len(doc_clean) < 30: # Only for short docs
|
467 |
+
return f"Function named '{func_name}' that {doc_clean.lower()}"
|
468 |
+
|
469 |
+
return f"Implementation that {doc_clean.lower()}"
|
470 |
+
|
471 |
+
|
472 |
+
def _generate_implementation_query(doc_string: str, func_name: str, language: str) -> str:
|
473 |
+
"""Generate implementation-specific queries with technical details."""
|
474 |
+
doc_lower = doc_string.lower()
|
475 |
+
func_lower = func_name.lower() if func_name else ""
|
476 |
+
|
477 |
+
# Add language-specific implementation details
|
478 |
+
if language == "python":
|
479 |
+
if "list" in doc_lower or "array" in doc_lower or "list" in func_lower:
|
480 |
+
return f"Python function to {doc_string.lower()} using lists"
|
481 |
+
if "dict" in doc_lower or "hash" in doc_lower or "dict" in func_lower:
|
482 |
+
return f"Python function to {doc_string.lower()} using dictionaries"
|
483 |
+
# Include function name for context if available
|
484 |
+
if func_name and len(func_name) > 2:
|
485 |
+
return f"Python implementation of {func_name}: {doc_string.lower()}"
|
486 |
+
return f"Python implementation: {doc_string.lower()}"
|
487 |
+
if language == "java":
|
488 |
+
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
|
489 |
+
return f"Java method to {doc_string.lower()}{func_suffix}"
|
490 |
+
if language == "javascript":
|
491 |
+
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
|
492 |
+
return f"JavaScript function to {doc_string.lower()}{func_suffix}"
|
493 |
+
if language == "php":
|
494 |
+
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
|
495 |
+
return f"PHP function to {doc_string.lower()}{func_suffix}"
|
496 |
+
if language == "ruby":
|
497 |
+
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
|
498 |
+
return f"Ruby method to {doc_string.lower()}{func_suffix}"
|
499 |
+
if language == "go":
|
500 |
+
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
|
501 |
+
return f"Go function to {doc_string.lower()}{func_suffix}"
|
502 |
+
return f"{language} code to {doc_string.lower()}"
|
503 |
+
|
504 |
+
|
505 |
+
def _create_stratified_splits(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
|
506 |
+
"""Create stratified train/test splits preserving language distribution."""
|
507 |
+
# Define split ratios
|
508 |
+
train_ratio = 0.9
|
509 |
+
# test_ratio = 0.1 (remainder)
|
510 |
+
|
511 |
+
train_dfs = []
|
512 |
+
test_dfs = []
|
513 |
+
|
514 |
+
# Split by language to ensure balanced representation
|
515 |
+
for language in df["language"].unique():
|
516 |
+
lang_df = df[df["language"] == language].copy()
|
517 |
+
n_samples = len(lang_df)
|
518 |
+
|
519 |
+
# Calculate split sizes
|
520 |
+
n_train = int(n_samples * train_ratio)
|
521 |
+
# Remainder goes to test
|
522 |
+
|
523 |
+
# Shuffle and split
|
524 |
+
lang_df = lang_df.sample(frac=1, random_state=42).reset_index(drop=True)
|
525 |
+
|
526 |
+
train_dfs.append(lang_df[:n_train])
|
527 |
+
test_dfs.append(lang_df[n_train:])
|
528 |
+
|
529 |
+
# Combine and shuffle again
|
530 |
+
train_df = pd.concat(train_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
|
531 |
+
test_df = pd.concat(test_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
|
532 |
+
|
533 |
+
logger.info("📊 Created stratified splits:")
|
534 |
+
logger.info(f" - Train: {len(train_df)} samples")
|
535 |
+
logger.info(f" - Test: {len(test_df)} samples")
|
536 |
+
|
537 |
+
return train_df, test_df
|
538 |
+
|
539 |
+
|
540 |
+
def _save_datasets(
|
541 |
+
output_dir: Path,
|
542 |
+
train_df: pd.DataFrame,
|
543 |
+
test_df: pd.DataFrame,
|
544 |
+
) -> dict[str, str]:
|
545 |
+
"""Save datasets in parquet format with compression."""
|
546 |
+
dataset_files = {}
|
547 |
+
|
548 |
+
# Save each split
|
549 |
+
for split_name, df in [("train", train_df), ("test", test_df)]:
|
550 |
+
filepath = output_dir / f"{split_name}.parquet"
|
551 |
+
df.to_parquet(
|
552 |
+
filepath,
|
553 |
+
compression="snappy",
|
554 |
+
index=False,
|
555 |
+
)
|
556 |
+
dataset_files[split_name] = str(filepath)
|
557 |
+
logger.info(f"💾 Saved {split_name}: {len(df)} samples → {filepath}")
|
558 |
+
|
559 |
+
# Also save a combined dataset for convenience
|
560 |
+
combined_df = pd.concat([train_df, test_df], ignore_index=True)
|
561 |
+
combined_filepath = output_dir / "combined.parquet"
|
562 |
+
combined_df.to_parquet(combined_filepath, compression="snappy", index=False)
|
563 |
+
dataset_files["combined"] = str(combined_filepath)
|
564 |
+
logger.info(f"💾 Saved combined: {len(combined_df)} samples → {combined_filepath}")
|
565 |
+
|
566 |
+
return dataset_files
|
567 |
+
|
568 |
+
|
569 |
+
def load_optimized_dataset(
|
570 |
+
output_dir: Path | None = None,
|
571 |
+
split: str = "train",
|
572 |
+
) -> pd.DataFrame:
|
573 |
+
"""
|
574 |
+
Load a previously created optimized dataset.
|
575 |
+
|
576 |
+
Args:
|
577 |
+
output_dir: Directory containing the dataset files
|
578 |
+
split: Which split to load ('train', 'test', 'combined')
|
579 |
+
|
580 |
+
Returns:
|
581 |
+
DataFrame with the requested dataset split
|
582 |
+
"""
|
583 |
+
if output_dir is None:
|
584 |
+
output_dir = DATASET_OUTPUT_DIR
|
585 |
+
|
586 |
+
filepath = output_dir / f"{split}.parquet"
|
587 |
+
|
588 |
+
if not filepath.exists():
|
589 |
+
available_files = list(output_dir.glob("*.parquet"))
|
590 |
+
available_splits = [f.stem for f in available_files]
|
591 |
+
msg = f"Dataset split '{split}' not found at {filepath}. Available splits: {available_splits}"
|
592 |
+
raise FileNotFoundError(msg)
|
593 |
+
|
594 |
+
logger.info(f"📂 Loading {split} dataset from {filepath}")
|
595 |
+
df = pd.read_parquet(filepath)
|
596 |
+
logger.info(f"✅ Loaded {len(df)} samples")
|
597 |
+
|
598 |
+
return df
|
599 |
+
|
600 |
+
|
601 |
+
def main(
|
602 |
+
max_samples_per_lang: Annotated[
|
603 |
+
int, typer.Option(help="Maximum samples per language")
|
604 |
+
] = DEFAULT_MAX_SAMPLES_PER_LANG,
|
605 |
+
min_doc_words: Annotated[int, typer.Option(help="Minimum words in documentation")] = DEFAULT_MIN_DOC_WORDS,
|
606 |
+
max_doc_words: Annotated[int, typer.Option(help="Maximum words in documentation")] = DEFAULT_MAX_DOC_WORDS,
|
607 |
+
min_code_chars: Annotated[int, typer.Option(help="Minimum characters in code")] = DEFAULT_MIN_CODE_CHARS,
|
608 |
+
max_code_chars: Annotated[int, typer.Option(help="Maximum characters in code")] = DEFAULT_MAX_CODE_CHARS,
|
609 |
+
output_dir: Annotated[str | None, typer.Option(help="Output directory for dataset")] = None,
|
610 |
+
simple_format: Annotated[
|
611 |
+
bool, typer.Option(help="Create only simple format (not multiple training formats)")
|
612 |
+
] = False,
|
613 |
+
) -> None:
|
614 |
+
"""Create optimized training dataset from CodeSearchNet for code search tasks."""
|
615 |
+
logger.info("🚀 Starting optimized dataset creation command...")
|
616 |
+
|
617 |
+
# Convert output_dir to Path if provided
|
618 |
+
output_path = Path(output_dir) if output_dir else None
|
619 |
+
|
620 |
+
# Create the dataset
|
621 |
+
try:
|
622 |
+
metadata = create_optimized_dataset(
|
623 |
+
max_samples_per_lang=max_samples_per_lang,
|
624 |
+
min_doc_words=min_doc_words,
|
625 |
+
max_doc_words=max_doc_words,
|
626 |
+
min_code_chars=min_code_chars,
|
627 |
+
max_code_chars=max_code_chars,
|
628 |
+
output_dir=output_path,
|
629 |
+
create_multiple_formats=not simple_format,
|
630 |
+
)
|
631 |
+
|
632 |
+
logger.info("✅ Dataset creation completed successfully!")
|
633 |
+
logger.info(f"📁 Output directory: {metadata['files']['train']}")
|
634 |
+
|
635 |
+
# Print summary statistics
|
636 |
+
print("\n" + "=" * 60)
|
637 |
+
print("📊 DATASET CREATION SUMMARY")
|
638 |
+
print("=" * 60)
|
639 |
+
print(f"Total samples created: {metadata['total_samples']:,}")
|
640 |
+
print(f"Processing time: {metadata['processing_time']:.2f} seconds")
|
641 |
+
print("\nSplit distribution:")
|
642 |
+
print(f" • Train: {metadata['train_samples']:,} samples")
|
643 |
+
print(f" • Test: {metadata['test_samples']:,} samples")
|
644 |
+
|
645 |
+
print("\nLanguage distribution:")
|
646 |
+
for lang, stats in metadata["language_stats"].items():
|
647 |
+
if "error" not in stats:
|
648 |
+
print(f" • {lang}: {stats['final_samples']:,} samples ({stats['quality_rate']:.1%} quality rate)")
|
649 |
+
|
650 |
+
print(f"\nDataset files saved to: {output_path or DATASET_OUTPUT_DIR}")
|
651 |
+
print("=" * 60)
|
652 |
+
|
653 |
+
except Exception as e:
|
654 |
+
logger.exception("❌ Dataset creation failed")
|
655 |
+
raise typer.Exit(1) from e
|
656 |
+
|
657 |
+
|
658 |
+
if __name__ == "__main__":
|
659 |
+
typer.run(main)
|
src/distiller/distill.py
CHANGED
@@ -28,13 +28,14 @@ import time
|
|
28 |
from pathlib import Path
|
29 |
from typing import Annotated, Any
|
30 |
|
|
|
31 |
import torch
|
32 |
import typer
|
33 |
from beam import function
|
34 |
-
from datasets import load_dataset
|
35 |
-
from model2vec.distill import distill
|
36 |
from sentence_transformers import SentenceTransformer
|
37 |
|
|
|
|
|
38 |
# Try to import flash_attn to check if it's available
|
39 |
from .beam_utils import (
|
40 |
BeamCheckpointManager,
|
@@ -145,25 +146,6 @@ def load_model_with_flash_attention(model_path: str, device: str = "auto") -> Se
|
|
145 |
# =============================================================================
|
146 |
|
147 |
|
148 |
-
def apply_local_patches() -> bool:
|
149 |
-
"""Apply patches locally without requiring Beam utilities."""
|
150 |
-
try:
|
151 |
-
try:
|
152 |
-
from .patch_utils import apply_all_patches
|
153 |
-
|
154 |
-
patches_applied = apply_all_patches()
|
155 |
-
logger.info(f"Successfully applied {patches_applied} patches via patch_utils")
|
156 |
-
return True
|
157 |
-
except ImportError:
|
158 |
-
logger.warning("patch_utils not available, trying direct patching")
|
159 |
-
|
160 |
-
return False
|
161 |
-
|
162 |
-
except Exception as e:
|
163 |
-
logger.warning(f"Failed to apply patches: {e}")
|
164 |
-
return False
|
165 |
-
|
166 |
-
|
167 |
def get_current_config_hash(enable_training: bool) -> str:
|
168 |
"""Generate a hash of current configuration parameters for checkpoint validation."""
|
169 |
import hashlib
|
@@ -217,22 +199,22 @@ def check_existing_final_model(teacher_name: str, enable_training: bool = False)
|
|
217 |
model_name = f"code_model2vec_{teacher_name}"
|
218 |
if enable_training:
|
219 |
model_name += "_fine_tuned"
|
220 |
-
|
221 |
|
222 |
-
if
|
223 |
# Check for essential model files
|
224 |
-
has_config = (
|
225 |
has_model_file = any(
|
226 |
[
|
227 |
-
(
|
228 |
-
(
|
229 |
-
(
|
230 |
]
|
231 |
)
|
232 |
|
233 |
if has_config and has_model_file:
|
234 |
logger.info(f"✅ Found existing final model: {teacher_name}{'_fine_tuned' if enable_training else ''}")
|
235 |
-
return str(
|
236 |
|
237 |
return None
|
238 |
|
@@ -427,11 +409,65 @@ def simple_distillation(
|
|
427 |
return None
|
428 |
|
429 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
def load_codesearchnet_dataset(
|
431 |
max_samples: int = 50000,
|
432 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
433 |
) -> list[str]:
|
434 |
"""Load and format the CodeSearchNet dataset for token frequency computation."""
|
|
|
|
|
435 |
logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
|
436 |
logger.info(f"Limiting to {max_samples} samples for training efficiency")
|
437 |
logger.info(f"Languages: {', '.join(languages_config.all)}")
|
@@ -482,6 +518,8 @@ def load_codesearchnet_dataset(
|
|
482 |
|
483 |
try:
|
484 |
# Load training split for the specific language (same format as evaluate.py)
|
|
|
|
|
485 |
dataset = load_dataset(
|
486 |
codesearchnet_config.dataset_name,
|
487 |
language,
|
@@ -709,8 +747,33 @@ def compute_token_frequencies_for_sif(
|
|
709 |
logger.info("📊 Computing token frequencies for SIF weighting...")
|
710 |
|
711 |
try:
|
712 |
-
# Load
|
713 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
714 |
|
715 |
logger.info(f"📊 Computing frequencies on {len(dataset_texts)} texts...")
|
716 |
|
@@ -763,7 +826,6 @@ def apply_post_training_regularization(
|
|
763 |
"""
|
764 |
import json
|
765 |
|
766 |
-
import numpy as np
|
767 |
from sklearn.decomposition import PCA
|
768 |
|
769 |
logger.info("🔧 Starting post-training re-regularization (POTION Step 4)")
|
@@ -836,7 +898,7 @@ def apply_post_training_regularization(
|
|
836 |
final_embeddings = embeddings_pca.astype(np.float32)
|
837 |
|
838 |
# Create new model with updated embeddings
|
839 |
-
from model2vec.model import StaticModel
|
840 |
|
841 |
# Save tokenizer and config from original model
|
842 |
tokenizer = model.tokenizer
|
@@ -866,7 +928,6 @@ def tokenlearn_training(
|
|
866 |
3. Tokenlearn training
|
867 |
4. Post-training re-regularization (PCA + SIF weighting)
|
868 |
"""
|
869 |
-
import subprocess
|
870 |
from pathlib import Path
|
871 |
|
872 |
logger.info("🧪 Starting tokenlearn training (POTION approach)...")
|
@@ -914,6 +975,9 @@ def tokenlearn_training(
|
|
914 |
|
915 |
logger.info(f"📊 Using teacher model: {teacher_model_name}")
|
916 |
|
|
|
|
|
|
|
917 |
# Check if featurization already completed (checkpoint detection)
|
918 |
featurization_complete_marker = features_dir / ".featurization_complete"
|
919 |
if featurization_complete_marker.exists() and verify_featurization_output(features_dir):
|
@@ -936,47 +1000,42 @@ def tokenlearn_training(
|
|
936 |
logger.info(f"📊 Using teacher model: {teacher_model_name}")
|
937 |
|
938 |
try:
|
939 |
-
# Use
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
"tokenlearn.featurize",
|
944 |
-
"--model-name",
|
945 |
-
str(teacher_model_name),
|
946 |
-
"--output-dir",
|
947 |
-
str(features_dir),
|
948 |
-
"--dataset-path",
|
949 |
-
str(distillation_config.tokenlearn_dataset),
|
950 |
-
"--dataset-name",
|
951 |
-
str(distillation_config.tokenlearn_dataset_name),
|
952 |
-
"--dataset-split",
|
953 |
-
"train",
|
954 |
-
"--key",
|
955 |
-
str(distillation_config.tokenlearn_text_key), # Use configured text field
|
956 |
-
"--batch-size",
|
957 |
-
"1024", # Optimized batch size for A100-40G
|
958 |
-
]
|
959 |
|
960 |
logger.info("🔄 Running tokenlearn featurization...")
|
961 |
-
logger.info(
|
962 |
-
|
963 |
-
)
|
964 |
-
logger.info(f"📝 Text field: {distillation_config.tokenlearn_text_key}")
|
965 |
-
logger.info(f"Command: {' '.join(featurize_cmd)}")
|
966 |
-
print(f"\n🔄 Executing: {' '.join(featurize_cmd)}\n")
|
967 |
-
|
968 |
-
result = subprocess.run( # noqa: S603
|
969 |
-
featurize_cmd,
|
970 |
-
text=True,
|
971 |
-
timeout=distillation_config.tokenlearn_timeout_featurize,
|
972 |
-
check=False,
|
973 |
-
)
|
974 |
|
975 |
-
|
976 |
-
|
977 |
-
|
978 |
-
|
979 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
980 |
|
981 |
logger.info("✅ Featurization completed successfully")
|
982 |
|
@@ -1025,65 +1084,74 @@ def tokenlearn_training(
|
|
1025 |
logger.info("🔄 No valid training checkpoint found - starting training...")
|
1026 |
|
1027 |
try:
|
1028 |
-
|
1029 |
-
|
1030 |
-
|
1031 |
-
"tokenlearn.train",
|
1032 |
-
"--model-name",
|
1033 |
-
str(teacher_model_name),
|
1034 |
-
"--data-path",
|
1035 |
-
str(features_dir),
|
1036 |
-
"--save-path",
|
1037 |
-
str(trained_dir),
|
1038 |
-
]
|
1039 |
|
1040 |
-
|
1041 |
-
logger.info(
|
1042 |
-
|
1043 |
|
1044 |
-
|
1045 |
-
|
1046 |
-
|
1047 |
-
|
1048 |
-
|
1049 |
-
|
1050 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1051 |
|
1052 |
-
|
1053 |
-
|
|
|
|
|
|
|
1054 |
|
1055 |
-
|
1056 |
-
|
1057 |
-
|
1058 |
-
if result.stdout:
|
1059 |
-
logger.info(f"stdout: {result.stdout}")
|
1060 |
|
1061 |
-
#
|
1062 |
-
|
1063 |
-
|
1064 |
-
|
1065 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1066 |
|
1067 |
# Create training marker to indicate we tried but failed
|
1068 |
training_fallback_marker = trained_dir / ".training_fallback"
|
1069 |
training_fallback_marker.touch()
|
1070 |
|
1071 |
-
logger.
|
1072 |
-
msg = f"
|
1073 |
-
raise RuntimeError(msg)
|
1074 |
-
logger.error("💥 Tokenlearn training failed with different error")
|
1075 |
-
msg = f"Tokenlearn training failed with return code: {result.returncode}"
|
1076 |
-
raise RuntimeError(msg)
|
1077 |
-
logger.info("✅ Tokenlearn training completed successfully")
|
1078 |
-
|
1079 |
-
# Create checkpoint marker to indicate training is complete
|
1080 |
-
training_complete_marker.touch()
|
1081 |
-
logger.info(f"💾 Created training checkpoint: {training_complete_marker}")
|
1082 |
|
1083 |
except Exception as e:
|
1084 |
-
logger.
|
1085 |
-
logger.exception("💥
|
1086 |
-
msg = f"
|
1087 |
raise RuntimeError(msg) from e
|
1088 |
|
1089 |
# Step 4: Load the trained model and apply post-training re-regularization
|
@@ -1098,7 +1166,7 @@ def tokenlearn_training(
|
|
1098 |
raise RuntimeError(msg)
|
1099 |
|
1100 |
try:
|
1101 |
-
from model2vec.model import StaticModel
|
1102 |
|
1103 |
# Load the trained model from tokenlearn
|
1104 |
trained_model_path = trained_dir / "model"
|
@@ -1213,12 +1281,13 @@ def distill_single_teacher(
|
|
1213 |
existing_final = check_existing_final_model(teacher_name, enable_training)
|
1214 |
if existing_final:
|
1215 |
logger.info(f"✅ Final model already exists: {teacher_name}{'_fine_tuned' if enable_training else ''}")
|
|
|
1216 |
return {
|
1217 |
"teacher_model": teacher_model,
|
1218 |
"teacher_name": teacher_name,
|
1219 |
"status": "skipped_existing_final",
|
1220 |
"final_path": existing_final,
|
1221 |
-
"distillation_time":
|
1222 |
}
|
1223 |
|
1224 |
# Step 1.5: Sync existing checkpoints from Beam if using Beam utilities
|
@@ -1236,7 +1305,7 @@ def distill_single_teacher(
|
|
1236 |
logger.info(f"✅ Found existing base model: {teacher_name}")
|
1237 |
if enable_training:
|
1238 |
# Load base model for training
|
1239 |
-
from model2vec.model import StaticModel
|
1240 |
|
1241 |
base_model = StaticModel.from_pretrained(existing_base)
|
1242 |
elif use_beam_utilities:
|
@@ -1244,7 +1313,7 @@ def distill_single_teacher(
|
|
1244 |
if synced:
|
1245 |
existing_base = str(base_dir)
|
1246 |
if enable_training:
|
1247 |
-
from model2vec.model import StaticModel
|
1248 |
|
1249 |
base_model = StaticModel.from_pretrained(existing_base)
|
1250 |
|
@@ -1263,11 +1332,13 @@ def distill_single_teacher(
|
|
1263 |
base_model = simple_distillation(teacher_model, str(base_dir), pca_dims)
|
1264 |
|
1265 |
if base_model is None:
|
|
|
1266 |
return {
|
1267 |
"teacher_model": teacher_model,
|
1268 |
"teacher_name": teacher_name,
|
1269 |
"status": "failed_base_distillation",
|
1270 |
"error": "Simple distillation failed",
|
|
|
1271 |
}
|
1272 |
|
1273 |
# Sync base model and checkpoints to Beam
|
@@ -1280,71 +1351,74 @@ def distill_single_teacher(
|
|
1280 |
|
1281 |
existing_base = str(base_dir)
|
1282 |
|
1283 |
-
|
1284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1285 |
# Perform tokenlearn training (POTION approach)
|
1286 |
-
|
|
|
|
|
|
|
|
|
|
|
1287 |
|
1288 |
-
|
1289 |
-
|
1290 |
-
|
1291 |
-
teacher_st_model = load_model_with_flash_attention(teacher_model, device)
|
1292 |
-
|
1293 |
-
# Perform tokenlearn training (POTION approach)
|
1294 |
-
final_model = tokenlearn_training(
|
1295 |
-
base_model,
|
1296 |
-
teacher_st_model,
|
1297 |
-
checkpoint_mgr,
|
1298 |
-
skip_post_training_regularization=distillation_config.skip_post_training_regularization,
|
1299 |
-
)
|
1300 |
|
1301 |
-
|
1302 |
-
|
1303 |
-
|
|
|
|
|
|
|
|
|
1304 |
|
1305 |
-
|
1306 |
-
|
1307 |
-
|
1308 |
-
if checkpoint_mgr:
|
1309 |
-
sync_checkpoints_to_beam(
|
1310 |
-
VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints
|
1311 |
-
)
|
1312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1313 |
del teacher_st_model
|
1314 |
-
|
1315 |
-
|
1316 |
-
|
1317 |
-
except RuntimeError as e:
|
1318 |
-
# Training failed - clean up and return failure
|
1319 |
-
logger.exception(f"❌ Training failed for {teacher_name}")
|
1320 |
-
|
1321 |
-
# Clean up teacher model if it was loaded
|
1322 |
-
if "teacher_st_model" in locals():
|
1323 |
-
del teacher_st_model
|
1324 |
-
if torch.cuda.is_available():
|
1325 |
-
torch.cuda.empty_cache()
|
1326 |
-
|
1327 |
-
return {
|
1328 |
-
"teacher_model": teacher_model,
|
1329 |
-
"teacher_name": teacher_name,
|
1330 |
-
"status": "failed_training",
|
1331 |
-
"error": f"Training failed: {e!s}",
|
1332 |
-
"base_path": existing_base, # Base model was created successfully
|
1333 |
-
}
|
1334 |
|
1335 |
-
|
1336 |
-
|
1337 |
-
|
1338 |
-
|
1339 |
-
|
1340 |
-
|
1341 |
-
|
1342 |
-
|
1343 |
-
|
1344 |
-
}
|
1345 |
|
1346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1347 |
|
|
|
1348 |
return {
|
1349 |
"teacher_model": teacher_model,
|
1350 |
"teacher_name": teacher_name,
|
@@ -1357,11 +1431,13 @@ def distill_single_teacher(
|
|
1357 |
|
1358 |
except Exception as e:
|
1359 |
logger.exception(f"❌ Failed to process {teacher_model}")
|
|
|
1360 |
return {
|
1361 |
"teacher_model": teacher_model,
|
1362 |
"teacher_name": teacher_name,
|
1363 |
"status": "failed",
|
1364 |
"error": str(e),
|
|
|
1365 |
}
|
1366 |
|
1367 |
|
@@ -1382,13 +1458,6 @@ def run_local_distillation(
|
|
1382 |
if teacher_models is None:
|
1383 |
teacher_models = DEFAULT_TEACHER_MODELS
|
1384 |
|
1385 |
-
# Apply patches
|
1386 |
-
patch_success = apply_local_patches()
|
1387 |
-
if patch_success:
|
1388 |
-
logger.info("✅ Successfully applied patches")
|
1389 |
-
else:
|
1390 |
-
logger.warning("⚠️ Failed to apply patches - some models may fail")
|
1391 |
-
|
1392 |
results = {}
|
1393 |
successful_models = []
|
1394 |
|
@@ -1468,13 +1537,6 @@ def _beam_distill_internal(
|
|
1468 |
clear_cache: bool = False,
|
1469 |
) -> dict[str, Any]:
|
1470 |
"""Shared internal implementation for beam distillation."""
|
1471 |
-
# Apply patches
|
1472 |
-
patch_success = apply_local_patches()
|
1473 |
-
if patch_success:
|
1474 |
-
logger.info("✅ Successfully applied patches")
|
1475 |
-
else:
|
1476 |
-
logger.warning("⚠️ Failed to apply patches - some models may fail")
|
1477 |
-
|
1478 |
if teacher_models is None:
|
1479 |
teacher_models = DEFAULT_TEACHER_MODELS
|
1480 |
|
@@ -1647,6 +1709,16 @@ def main(
|
|
1647 |
skip_ptr: Annotated[
|
1648 |
bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
|
1649 |
] = False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1650 |
) -> None:
|
1651 |
"""Unified distillation command with optional training."""
|
1652 |
logger.info("🚀 Starting unified Model2Vec distillation workflow")
|
@@ -1656,6 +1728,13 @@ def main(
|
|
1656 |
if skip_ptr and train:
|
1657 |
logger.info("⏭️ Post-training re-regularization will be skipped (PCA + SIF weighting disabled)")
|
1658 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1659 |
logger.info(f"🎓 Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
|
1660 |
logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
|
1661 |
|
@@ -1894,7 +1973,7 @@ def salesforce_model_distillation(
|
|
1894 |
logger.info("✅ Successfully loaded with SentenceTransformer method")
|
1895 |
|
1896 |
# Now use Model2Vec's distill_from_model function directly
|
1897 |
-
from model2vec.distill.distillation import distill_from_model
|
1898 |
|
1899 |
distilled_model = distill_from_model(
|
1900 |
model=model,
|
@@ -2004,7 +2083,7 @@ def baai_bge_model_distillation(
|
|
2004 |
return None
|
2005 |
|
2006 |
# Now use Model2Vec's distill_from_model function directly
|
2007 |
-
from model2vec.distill.distillation import distill_from_model
|
2008 |
|
2009 |
distilled_model = distill_from_model(
|
2010 |
model=model,
|
@@ -2090,5 +2169,77 @@ def verify_training_output(trained_dir: Path) -> bool:
|
|
2090 |
return False
|
2091 |
|
2092 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2093 |
if __name__ == "__main__":
|
2094 |
typer.run(main)
|
|
|
28 |
from pathlib import Path
|
29 |
from typing import Annotated, Any
|
30 |
|
31 |
+
import numpy as np
|
32 |
import torch
|
33 |
import typer
|
34 |
from beam import function
|
|
|
|
|
35 |
from sentence_transformers import SentenceTransformer
|
36 |
|
37 |
+
from distiller.model2vec.distill import distill
|
38 |
+
|
39 |
# Try to import flash_attn to check if it's available
|
40 |
from .beam_utils import (
|
41 |
BeamCheckpointManager,
|
|
|
146 |
# =============================================================================
|
147 |
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
def get_current_config_hash(enable_training: bool) -> str:
|
150 |
"""Generate a hash of current configuration parameters for checkpoint validation."""
|
151 |
import hashlib
|
|
|
199 |
model_name = f"code_model2vec_{teacher_name}"
|
200 |
if enable_training:
|
201 |
model_name += "_fine_tuned"
|
202 |
+
final_path = final_dir / model_name
|
203 |
|
204 |
+
if final_path.exists():
|
205 |
# Check for essential model files
|
206 |
+
has_config = (final_path / "config.json").exists()
|
207 |
has_model_file = any(
|
208 |
[
|
209 |
+
(final_path / "model.safetensors").exists(),
|
210 |
+
(final_path / "model.bin").exists(),
|
211 |
+
(final_path / "pytorch_model.bin").exists(),
|
212 |
]
|
213 |
)
|
214 |
|
215 |
if has_config and has_model_file:
|
216 |
logger.info(f"✅ Found existing final model: {teacher_name}{'_fine_tuned' if enable_training else ''}")
|
217 |
+
return str(final_path)
|
218 |
|
219 |
return None
|
220 |
|
|
|
409 |
return None
|
410 |
|
411 |
|
412 |
+
def load_optimized_dataset(
|
413 |
+
max_samples: int = 50000,
|
414 |
+
checkpoint_manager: BeamCheckpointManager | None = None,
|
415 |
+
dataset_path: str | None = None,
|
416 |
+
) -> list[str]:
|
417 |
+
"""Load our pre-created optimized dataset for tokenlearn training."""
|
418 |
+
from .dataset import DATASET_OUTPUT_DIR
|
419 |
+
from .dataset import load_optimized_dataset as load_dataset_func
|
420 |
+
|
421 |
+
# Use configuration if not provided as parameter
|
422 |
+
if dataset_path is None:
|
423 |
+
dataset_path = distillation_config.custom_dataset_path
|
424 |
+
|
425 |
+
dataset_dir = Path(dataset_path) if dataset_path else DATASET_OUTPUT_DIR
|
426 |
+
|
427 |
+
logger.info(f"🎯 Loading optimized dataset from {dataset_dir}")
|
428 |
+
logger.info(f"📊 Target samples: {max_samples}")
|
429 |
+
|
430 |
+
try:
|
431 |
+
# Load the training split of our optimized dataset
|
432 |
+
df = load_dataset_func(output_dir=dataset_dir, split="train")
|
433 |
+
|
434 |
+
# Extract the text column (which contains our formatted query + code)
|
435 |
+
texts = df["text"].tolist()
|
436 |
+
|
437 |
+
# Shuffle for better training distribution
|
438 |
+
import random
|
439 |
+
|
440 |
+
random.seed(42)
|
441 |
+
random.shuffle(texts)
|
442 |
+
|
443 |
+
# Limit to max_samples
|
444 |
+
if len(texts) > max_samples:
|
445 |
+
texts = texts[:max_samples]
|
446 |
+
|
447 |
+
logger.info(f"✅ Loaded {len(texts)} optimized training samples")
|
448 |
+
|
449 |
+
# Log language distribution
|
450 |
+
languages = df["language"].value_counts()
|
451 |
+
logger.info("📊 Language distribution:")
|
452 |
+
for lang, count in languages.items():
|
453 |
+
percentage = (count / len(df)) * 100
|
454 |
+
logger.info(f" {lang}: {count} samples ({percentage:.1f}%)")
|
455 |
+
|
456 |
+
return texts
|
457 |
+
|
458 |
+
except Exception as e:
|
459 |
+
logger.warning(f"⚠️ Failed to load optimized dataset: {e}")
|
460 |
+
logger.info("🔄 Falling back to original CodeSearchNet loading...")
|
461 |
+
return load_codesearchnet_dataset(max_samples, checkpoint_manager)
|
462 |
+
|
463 |
+
|
464 |
def load_codesearchnet_dataset(
|
465 |
max_samples: int = 50000,
|
466 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
467 |
) -> list[str]:
|
468 |
"""Load and format the CodeSearchNet dataset for token frequency computation."""
|
469 |
+
from datasets import load_dataset
|
470 |
+
|
471 |
logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
|
472 |
logger.info(f"Limiting to {max_samples} samples for training efficiency")
|
473 |
logger.info(f"Languages: {', '.join(languages_config.all)}")
|
|
|
518 |
|
519 |
try:
|
520 |
# Load training split for the specific language (same format as evaluate.py)
|
521 |
+
from datasets import load_dataset
|
522 |
+
|
523 |
dataset = load_dataset(
|
524 |
codesearchnet_config.dataset_name,
|
525 |
language,
|
|
|
747 |
logger.info("📊 Computing token frequencies for SIF weighting...")
|
748 |
|
749 |
try:
|
750 |
+
# Load dataset to compute frequencies (limited sample for efficiency)
|
751 |
+
if distillation_config.use_optimized_dataset:
|
752 |
+
# Use the custom optimized dataset
|
753 |
+
from .dataset import load_optimized_dataset as load_custom_dataset
|
754 |
+
|
755 |
+
custom_dataset_dir = (
|
756 |
+
Path(distillation_config.custom_dataset_path)
|
757 |
+
if distillation_config.custom_dataset_path
|
758 |
+
else Path("code_model2vec/dataset")
|
759 |
+
)
|
760 |
+
|
761 |
+
if custom_dataset_dir.exists() and (custom_dataset_dir / "train.parquet").exists():
|
762 |
+
train_df = load_custom_dataset(output_dir=custom_dataset_dir, split="train")
|
763 |
+
# Sample a subset for frequency computation
|
764 |
+
sample_size = min(10000, len(train_df))
|
765 |
+
train_df_sample = train_df.sample(n=sample_size, random_state=42)
|
766 |
+
dataset_texts = train_df_sample["text"].tolist()
|
767 |
+
logger.info(f"📊 Using {len(dataset_texts)} samples from custom optimized dataset")
|
768 |
+
else:
|
769 |
+
# Fallback to original dataset loading
|
770 |
+
dataset_texts = load_codesearchnet_dataset(max_samples=10000)
|
771 |
+
logger.info(
|
772 |
+
f"📊 Custom dataset not found, using original CodeSearchNet with {len(dataset_texts)} texts"
|
773 |
+
)
|
774 |
+
else:
|
775 |
+
dataset_texts = load_codesearchnet_dataset(max_samples=10000)
|
776 |
+
logger.info(f"📊 Using original CodeSearchNet with {len(dataset_texts)} texts")
|
777 |
|
778 |
logger.info(f"📊 Computing frequencies on {len(dataset_texts)} texts...")
|
779 |
|
|
|
826 |
"""
|
827 |
import json
|
828 |
|
|
|
829 |
from sklearn.decomposition import PCA
|
830 |
|
831 |
logger.info("🔧 Starting post-training re-regularization (POTION Step 4)")
|
|
|
898 |
final_embeddings = embeddings_pca.astype(np.float32)
|
899 |
|
900 |
# Create new model with updated embeddings
|
901 |
+
from distiller.model2vec.model import StaticModel
|
902 |
|
903 |
# Save tokenizer and config from original model
|
904 |
tokenizer = model.tokenizer
|
|
|
928 |
3. Tokenlearn training
|
929 |
4. Post-training re-regularization (PCA + SIF weighting)
|
930 |
"""
|
|
|
931 |
from pathlib import Path
|
932 |
|
933 |
logger.info("🧪 Starting tokenlearn training (POTION approach)...")
|
|
|
975 |
|
976 |
logger.info(f"📊 Using teacher model: {teacher_model_name}")
|
977 |
|
978 |
+
# Prepare dataset for tokenlearn featurization
|
979 |
+
dataset_path, dataset_name, text_key = _prepare_tokenlearn_dataset(persistent_tokenlearn_dir)
|
980 |
+
|
981 |
# Check if featurization already completed (checkpoint detection)
|
982 |
featurization_complete_marker = features_dir / ".featurization_complete"
|
983 |
if featurization_complete_marker.exists() and verify_featurization_output(features_dir):
|
|
|
1000 |
logger.info(f"📊 Using teacher model: {teacher_model_name}")
|
1001 |
|
1002 |
try:
|
1003 |
+
# Use direct function call instead of subprocess
|
1004 |
+
from datasets import load_dataset
|
1005 |
+
|
1006 |
+
from distiller.tokenlearn.featurize import featurize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1007 |
|
1008 |
logger.info("🔄 Running tokenlearn featurization...")
|
1009 |
+
logger.info(f"📊 Dataset: {dataset_path} (config: {dataset_name})")
|
1010 |
+
logger.info(f"📝 Text field: {text_key}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1011 |
|
1012 |
+
# Load the dataset
|
1013 |
+
if dataset_name is None:
|
1014 |
+
# For local JSON files, don't pass name parameter
|
1015 |
+
dataset = load_dataset(
|
1016 |
+
"json",
|
1017 |
+
data_files=dataset_path,
|
1018 |
+
split="train",
|
1019 |
+
streaming=True,
|
1020 |
+
)
|
1021 |
+
else:
|
1022 |
+
# For remote datasets with specific configurations
|
1023 |
+
dataset = load_dataset(
|
1024 |
+
dataset_path,
|
1025 |
+
name=dataset_name,
|
1026 |
+
split="train",
|
1027 |
+
streaming=True,
|
1028 |
+
)
|
1029 |
+
|
1030 |
+
# Call featurization function directly
|
1031 |
+
featurize(
|
1032 |
+
dataset=iter(dataset),
|
1033 |
+
model=teacher_model,
|
1034 |
+
output_dir=str(features_dir),
|
1035 |
+
max_means=50000, # IMPROVEMENT: Limit means to prevent overfitting
|
1036 |
+
batch_size=512, # IMPROVEMENT: Smaller batch for better gradients
|
1037 |
+
text_key=text_key,
|
1038 |
+
)
|
1039 |
|
1040 |
logger.info("✅ Featurization completed successfully")
|
1041 |
|
|
|
1084 |
logger.info("🔄 No valid training checkpoint found - starting training...")
|
1085 |
|
1086 |
try:
|
1087 |
+
# Use direct function call instead of subprocess
|
1088 |
+
from distiller.tokenlearn.train import train_model
|
1089 |
+
from distiller.tokenlearn.utils import collect_means_and_texts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1090 |
|
1091 |
+
# IMPROVED APPROACH: Try optimized parameters first
|
1092 |
+
logger.info("🚀 Attempting IMPROVED tokenlearn training with optimized parameters...")
|
1093 |
+
logger.info("📊 Using smaller vocabulary and conservative PCA to prevent overfitting")
|
1094 |
|
1095 |
+
# Collect training data from features directory
|
1096 |
+
paths = sorted(features_dir.glob("*.json"))
|
1097 |
+
train_txt, train_vec = collect_means_and_texts(paths)
|
1098 |
+
|
1099 |
+
logger.info(f"📊 Collected {len(train_txt)} texts and {train_vec.shape[0]} vectors for training")
|
1100 |
+
|
1101 |
+
try:
|
1102 |
+
# Try improved parameters first
|
1103 |
+
trained_model = train_model(
|
1104 |
+
model_name=str(teacher_model_name),
|
1105 |
+
train_txt=train_txt,
|
1106 |
+
train_vec=train_vec,
|
1107 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
1108 |
+
vocab_size=25000, # IMPROVEMENT: Smaller vocabulary to prevent overfitting
|
1109 |
+
pca_dims=256, # IMPROVEMENT: Conservative PCA dimensions
|
1110 |
+
)
|
1111 |
|
1112 |
+
# Save the trained model
|
1113 |
+
trained_model.save_pretrained(str(trained_dir))
|
1114 |
+
logger.info("✅ IMPROVED tokenlearn training completed successfully")
|
1115 |
+
training_complete_marker.touch()
|
1116 |
+
logger.info(f"💾 Created improved training checkpoint: {training_complete_marker}")
|
1117 |
|
1118 |
+
except Exception as e:
|
1119 |
+
logger.warning(f"⚠️ Improved training failed: {e}")
|
1120 |
+
logger.info("🔄 Falling back to CONSERVATIVE tokenlearn training...")
|
|
|
|
|
1121 |
|
1122 |
+
# FALLBACK: Ultra-conservative training approach
|
1123 |
+
try:
|
1124 |
+
trained_model = train_model(
|
1125 |
+
model_name=str(teacher_model_name),
|
1126 |
+
train_txt=train_txt,
|
1127 |
+
train_vec=train_vec,
|
1128 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
1129 |
+
vocab_size=15000, # FALLBACK: Even smaller vocabulary
|
1130 |
+
pca_dims=128, # FALLBACK: Smaller PCA dimensions
|
1131 |
+
)
|
1132 |
+
|
1133 |
+
# Save the trained model
|
1134 |
+
trained_model.save_pretrained(str(trained_dir))
|
1135 |
+
logger.info("✅ Conservative tokenlearn training completed successfully")
|
1136 |
+
training_complete_marker.touch()
|
1137 |
+
logger.info(f"💾 Created conservative training checkpoint: {training_complete_marker}")
|
1138 |
+
|
1139 |
+
except Exception as e2:
|
1140 |
+
logger.exception("❌ Conservative tokenlearn training also failed")
|
1141 |
+
logger.exception("💥 All training approaches failed - check output above for details")
|
1142 |
|
1143 |
# Create training marker to indicate we tried but failed
|
1144 |
training_fallback_marker = trained_dir / ".training_fallback"
|
1145 |
training_fallback_marker.touch()
|
1146 |
|
1147 |
+
logger.exception("💥 Tokenlearn training failed completely")
|
1148 |
+
msg = f"All tokenlearn training approaches failed: {e2}"
|
1149 |
+
raise RuntimeError(msg) from e2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1150 |
|
1151 |
except Exception as e:
|
1152 |
+
logger.warning("💥 All tokenlearn training approaches failed")
|
1153 |
+
logger.exception("💥 All training approaches failed completely - cannot proceed")
|
1154 |
+
msg = f"All training approaches failed: {e}"
|
1155 |
raise RuntimeError(msg) from e
|
1156 |
|
1157 |
# Step 4: Load the trained model and apply post-training re-regularization
|
|
|
1166 |
raise RuntimeError(msg)
|
1167 |
|
1168 |
try:
|
1169 |
+
from distiller.model2vec.model import StaticModel
|
1170 |
|
1171 |
# Load the trained model from tokenlearn
|
1172 |
trained_model_path = trained_dir / "model"
|
|
|
1281 |
existing_final = check_existing_final_model(teacher_name, enable_training)
|
1282 |
if existing_final:
|
1283 |
logger.info(f"✅ Final model already exists: {teacher_name}{'_fine_tuned' if enable_training else ''}")
|
1284 |
+
total_time = time.time() - start_time
|
1285 |
return {
|
1286 |
"teacher_model": teacher_model,
|
1287 |
"teacher_name": teacher_name,
|
1288 |
"status": "skipped_existing_final",
|
1289 |
"final_path": existing_final,
|
1290 |
+
"distillation_time": total_time,
|
1291 |
}
|
1292 |
|
1293 |
# Step 1.5: Sync existing checkpoints from Beam if using Beam utilities
|
|
|
1305 |
logger.info(f"✅ Found existing base model: {teacher_name}")
|
1306 |
if enable_training:
|
1307 |
# Load base model for training
|
1308 |
+
from distiller.model2vec.model import StaticModel
|
1309 |
|
1310 |
base_model = StaticModel.from_pretrained(existing_base)
|
1311 |
elif use_beam_utilities:
|
|
|
1313 |
if synced:
|
1314 |
existing_base = str(base_dir)
|
1315 |
if enable_training:
|
1316 |
+
from distiller.model2vec.model import StaticModel
|
1317 |
|
1318 |
base_model = StaticModel.from_pretrained(existing_base)
|
1319 |
|
|
|
1332 |
base_model = simple_distillation(teacher_model, str(base_dir), pca_dims)
|
1333 |
|
1334 |
if base_model is None:
|
1335 |
+
total_time = time.time() - start_time
|
1336 |
return {
|
1337 |
"teacher_model": teacher_model,
|
1338 |
"teacher_name": teacher_name,
|
1339 |
"status": "failed_base_distillation",
|
1340 |
"error": "Simple distillation failed",
|
1341 |
+
"distillation_time": total_time,
|
1342 |
}
|
1343 |
|
1344 |
# Sync base model and checkpoints to Beam
|
|
|
1351 |
|
1352 |
existing_base = str(base_dir)
|
1353 |
|
1354 |
+
# Step 3: Handle final model creation
|
1355 |
+
if enable_training and base_model is not None:
|
1356 |
+
# Perform tokenlearn training (POTION approach)
|
1357 |
+
logger.info(f"🧪 Starting tokenlearn training for {teacher_name}")
|
1358 |
+
|
1359 |
+
try:
|
1360 |
+
# Load teacher model for training
|
1361 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
1362 |
+
teacher_st_model = load_model_with_flash_attention(teacher_model, device)
|
1363 |
+
|
1364 |
# Perform tokenlearn training (POTION approach)
|
1365 |
+
final_model = tokenlearn_training(
|
1366 |
+
base_model,
|
1367 |
+
teacher_st_model,
|
1368 |
+
checkpoint_mgr,
|
1369 |
+
skip_post_training_regularization=distillation_config.skip_post_training_regularization,
|
1370 |
+
)
|
1371 |
|
1372 |
+
# Save final model
|
1373 |
+
final_dir.mkdir(parents=True, exist_ok=True)
|
1374 |
+
final_model.save_pretrained(str(final_dir))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1375 |
|
1376 |
+
# Sync final model and training checkpoints to Beam
|
1377 |
+
if use_beam_utilities:
|
1378 |
+
sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities)
|
1379 |
+
if checkpoint_mgr:
|
1380 |
+
sync_checkpoints_to_beam(
|
1381 |
+
VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints
|
1382 |
+
)
|
1383 |
|
1384 |
+
del teacher_st_model
|
1385 |
+
if torch.cuda.is_available():
|
1386 |
+
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
1387 |
|
1388 |
+
except RuntimeError as e:
|
1389 |
+
# Training failed - clean up and return failure
|
1390 |
+
logger.exception(f"❌ Training failed for {teacher_name}")
|
1391 |
+
|
1392 |
+
# Clean up teacher model if it was loaded
|
1393 |
+
if "teacher_st_model" in locals():
|
1394 |
del teacher_st_model
|
1395 |
+
if torch.cuda.is_available():
|
1396 |
+
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1397 |
|
1398 |
+
total_time = time.time() - start_time
|
1399 |
+
return {
|
1400 |
+
"teacher_model": teacher_model,
|
1401 |
+
"teacher_name": teacher_name,
|
1402 |
+
"status": "failed_training",
|
1403 |
+
"error": f"Training failed: {e!s}",
|
1404 |
+
"base_path": existing_base, # Base model was created successfully
|
1405 |
+
"distillation_time": total_time,
|
1406 |
+
}
|
|
|
1407 |
|
1408 |
+
else:
|
1409 |
+
# Copy base to final (no training)
|
1410 |
+
logger.info(f"📁 Copying base to final for {teacher_name}")
|
1411 |
+
if not copy_base_to_final(teacher_name, enable_training):
|
1412 |
+
total_time = time.time() - start_time
|
1413 |
+
return {
|
1414 |
+
"teacher_model": teacher_model,
|
1415 |
+
"teacher_name": teacher_name,
|
1416 |
+
"status": "failed_copy_to_final",
|
1417 |
+
"error": "Failed to copy base to final",
|
1418 |
+
"distillation_time": total_time,
|
1419 |
+
}
|
1420 |
|
1421 |
+
total_time = time.time() - start_time
|
1422 |
return {
|
1423 |
"teacher_model": teacher_model,
|
1424 |
"teacher_name": teacher_name,
|
|
|
1431 |
|
1432 |
except Exception as e:
|
1433 |
logger.exception(f"❌ Failed to process {teacher_model}")
|
1434 |
+
total_time = time.time() - start_time
|
1435 |
return {
|
1436 |
"teacher_model": teacher_model,
|
1437 |
"teacher_name": teacher_name,
|
1438 |
"status": "failed",
|
1439 |
"error": str(e),
|
1440 |
+
"distillation_time": total_time,
|
1441 |
}
|
1442 |
|
1443 |
|
|
|
1458 |
if teacher_models is None:
|
1459 |
teacher_models = DEFAULT_TEACHER_MODELS
|
1460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1461 |
results = {}
|
1462 |
successful_models = []
|
1463 |
|
|
|
1537 |
clear_cache: bool = False,
|
1538 |
) -> dict[str, Any]:
|
1539 |
"""Shared internal implementation for beam distillation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1540 |
if teacher_models is None:
|
1541 |
teacher_models = DEFAULT_TEACHER_MODELS
|
1542 |
|
|
|
1709 |
skip_ptr: Annotated[
|
1710 |
bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
|
1711 |
] = False,
|
1712 |
+
use_optimized_dataset: Annotated[
|
1713 |
+
bool,
|
1714 |
+
typer.Option(
|
1715 |
+
"--use-optimized-dataset", help="Use the pre-created optimized dataset from code_model2vec/dataset"
|
1716 |
+
),
|
1717 |
+
] = False,
|
1718 |
+
dataset_path: Annotated[
|
1719 |
+
str | None,
|
1720 |
+
typer.Option("--dataset-path", help="Path to custom dataset directory (defaults to code_model2vec/dataset)"),
|
1721 |
+
] = None,
|
1722 |
) -> None:
|
1723 |
"""Unified distillation command with optional training."""
|
1724 |
logger.info("🚀 Starting unified Model2Vec distillation workflow")
|
|
|
1728 |
if skip_ptr and train:
|
1729 |
logger.info("⏭️ Post-training re-regularization will be skipped (PCA + SIF weighting disabled)")
|
1730 |
|
1731 |
+
# Set dataset configuration
|
1732 |
+
distillation_config.use_optimized_dataset = use_optimized_dataset
|
1733 |
+
distillation_config.custom_dataset_path = dataset_path
|
1734 |
+
if use_optimized_dataset and train:
|
1735 |
+
dataset_source = dataset_path or "code_model2vec/dataset"
|
1736 |
+
logger.info(f"🎯 Using optimized dataset from: {dataset_source}")
|
1737 |
+
|
1738 |
logger.info(f"🎓 Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
|
1739 |
logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
|
1740 |
|
|
|
1973 |
logger.info("✅ Successfully loaded with SentenceTransformer method")
|
1974 |
|
1975 |
# Now use Model2Vec's distill_from_model function directly
|
1976 |
+
from distiller.model2vec.distill.distillation import distill_from_model
|
1977 |
|
1978 |
distilled_model = distill_from_model(
|
1979 |
model=model,
|
|
|
2083 |
return None
|
2084 |
|
2085 |
# Now use Model2Vec's distill_from_model function directly
|
2086 |
+
from distiller.model2vec.distill.distillation import distill_from_model
|
2087 |
|
2088 |
distilled_model = distill_from_model(
|
2089 |
model=model,
|
|
|
2169 |
return False
|
2170 |
|
2171 |
|
2172 |
+
def _prepare_tokenlearn_dataset(tokenlearn_dir: Path) -> tuple[str, str | None, str]:
|
2173 |
+
"""
|
2174 |
+
Prepare dataset for tokenlearn featurization.
|
2175 |
+
|
2176 |
+
Returns:
|
2177 |
+
Tuple of (dataset_path, dataset_name, text_key) for tokenlearn
|
2178 |
+
"""
|
2179 |
+
if distillation_config.use_optimized_dataset:
|
2180 |
+
return _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir)
|
2181 |
+
return _prepare_original_dataset_for_tokenlearn()
|
2182 |
+
|
2183 |
+
|
2184 |
+
def _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir: Path) -> tuple[str, str | None, str]:
|
2185 |
+
"""Prepare custom optimized dataset for tokenlearn featurization."""
|
2186 |
+
logger.info("🎯 Preparing custom optimized dataset for tokenlearn...")
|
2187 |
+
|
2188 |
+
# Import the dataset module
|
2189 |
+
from .dataset import create_optimized_dataset, load_optimized_dataset
|
2190 |
+
|
2191 |
+
# Define paths
|
2192 |
+
custom_dataset_dir = (
|
2193 |
+
Path(distillation_config.custom_dataset_path)
|
2194 |
+
if distillation_config.custom_dataset_path
|
2195 |
+
else Path("code_model2vec/dataset")
|
2196 |
+
)
|
2197 |
+
tokenlearn_dataset_dir = tokenlearn_dir / "custom_dataset"
|
2198 |
+
|
2199 |
+
# Check if we need to create the custom dataset
|
2200 |
+
if not custom_dataset_dir.exists() or not (custom_dataset_dir / "train.parquet").exists():
|
2201 |
+
logger.info("📊 Custom dataset not found - creating optimized dataset...")
|
2202 |
+
create_optimized_dataset(
|
2203 |
+
max_samples_per_lang=10000, # Reasonable size for tokenlearn
|
2204 |
+
output_dir=custom_dataset_dir,
|
2205 |
+
create_multiple_formats=False, # Use simple format for tokenlearn
|
2206 |
+
)
|
2207 |
+
|
2208 |
+
# Load the custom dataset
|
2209 |
+
logger.info(f"📂 Loading custom dataset from {custom_dataset_dir}")
|
2210 |
+
train_df = load_optimized_dataset(output_dir=custom_dataset_dir, split="train")
|
2211 |
+
|
2212 |
+
# Prepare dataset for tokenlearn (save as JSON files that load_dataset can read)
|
2213 |
+
tokenlearn_dataset_dir.mkdir(parents=True, exist_ok=True)
|
2214 |
+
|
2215 |
+
# Save as JSON file that tokenlearn can load with load_dataset()
|
2216 |
+
train_json_path = tokenlearn_dataset_dir / "train.json"
|
2217 |
+
|
2218 |
+
# Create JSON lines format
|
2219 |
+
import json
|
2220 |
+
|
2221 |
+
with train_json_path.open("w") as f:
|
2222 |
+
for text in train_df["text"]:
|
2223 |
+
json.dump({"text": text}, f)
|
2224 |
+
f.write("\n")
|
2225 |
+
|
2226 |
+
logger.info(f"✅ Prepared custom dataset with {len(train_df)} samples for tokenlearn")
|
2227 |
+
logger.info(f"💾 Saved JSON dataset to {train_json_path}")
|
2228 |
+
|
2229 |
+
# Return the JSON file path directly (not directory) and no config name for JSON loading
|
2230 |
+
return str(train_json_path), None, "text"
|
2231 |
+
|
2232 |
+
|
2233 |
+
def _prepare_original_dataset_for_tokenlearn() -> tuple[str, str, str]:
|
2234 |
+
"""Prepare original CodeSearchNet dataset for tokenlearn featurization."""
|
2235 |
+
logger.info("📊 Using original CodeSearchNet dataset for tokenlearn...")
|
2236 |
+
|
2237 |
+
return (
|
2238 |
+
str(distillation_config.tokenlearn_dataset), # "sentence-transformers/codesearchnet"
|
2239 |
+
str(distillation_config.tokenlearn_dataset_name), # "pair"
|
2240 |
+
str(distillation_config.tokenlearn_text_key), # "combined_text"
|
2241 |
+
)
|
2242 |
+
|
2243 |
+
|
2244 |
if __name__ == "__main__":
|
2245 |
typer.run(main)
|
src/distiller/patch_utils.py
DELETED
@@ -1,276 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Patch utilities for applying fixes to installed packages.
|
3 |
-
|
4 |
-
This module provides functionality to automatically apply all patches
|
5 |
-
from the patches directory to fix bugs in third-party libraries.
|
6 |
-
"""
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import subprocess
|
10 |
-
import sys
|
11 |
-
from pathlib import Path
|
12 |
-
|
13 |
-
logger = logging.getLogger(__name__)
|
14 |
-
|
15 |
-
|
16 |
-
def find_patches_directory() -> Path:
|
17 |
-
"""Find the patches directory relative to the current script location."""
|
18 |
-
# Go up from src/distiller/ to project root, then to patches/
|
19 |
-
current_file = Path(__file__)
|
20 |
-
project_root = current_file.parent.parent.parent # Go up 3 levels: distiller -> src -> project_root
|
21 |
-
patches_dir = project_root / "patches"
|
22 |
-
|
23 |
-
if not patches_dir.exists():
|
24 |
-
# Alternative: try relative to current working directory
|
25 |
-
patches_dir = Path("patches")
|
26 |
-
|
27 |
-
return patches_dir
|
28 |
-
|
29 |
-
|
30 |
-
def get_site_packages_path() -> Path:
|
31 |
-
"""Get the site-packages directory path."""
|
32 |
-
import site
|
33 |
-
|
34 |
-
# Try to get the site-packages from the current environment
|
35 |
-
site_packages_dirs = site.getsitepackages()
|
36 |
-
|
37 |
-
# Prefer the first site-packages directory
|
38 |
-
if site_packages_dirs:
|
39 |
-
return Path(site_packages_dirs[0])
|
40 |
-
|
41 |
-
# Fallback: try to find it relative to Python executable
|
42 |
-
python_path = Path(sys.executable)
|
43 |
-
if python_path.name == "python" or python_path.name.startswith("python"):
|
44 |
-
# Standard virtual environment structure
|
45 |
-
venv_lib = python_path.parent.parent / "lib"
|
46 |
-
for item in venv_lib.iterdir():
|
47 |
-
if item.name.startswith("python"):
|
48 |
-
site_packages = item / "site-packages"
|
49 |
-
if site_packages.exists():
|
50 |
-
return site_packages
|
51 |
-
|
52 |
-
# Last resort: use current directory
|
53 |
-
return Path()
|
54 |
-
|
55 |
-
|
56 |
-
def apply_patch_file(patch_file: Path, target_dir: Path) -> bool:
|
57 |
-
"""
|
58 |
-
Apply a single patch file to the target directory.
|
59 |
-
|
60 |
-
Args:
|
61 |
-
patch_file: Path to the .patch file
|
62 |
-
target_dir: Target directory (usually site-packages)
|
63 |
-
|
64 |
-
Returns:
|
65 |
-
True if patch was applied successfully, False otherwise
|
66 |
-
"""
|
67 |
-
try:
|
68 |
-
logger.info(f"Applying patch: {patch_file.name}")
|
69 |
-
|
70 |
-
# Check if patch is already applied
|
71 |
-
if is_patch_already_applied(patch_file, target_dir):
|
72 |
-
logger.info(f"Patch {patch_file.name} already applied")
|
73 |
-
return True
|
74 |
-
|
75 |
-
# Clean any duplicate validation code before applying
|
76 |
-
if "model2vec.patch" in patch_file.name:
|
77 |
-
clean_duplicate_validation_code(target_dir)
|
78 |
-
|
79 |
-
# Use patch command with the following options:
|
80 |
-
# -p1: strip 1 leading directory from paths
|
81 |
-
# -d: change to directory before applying
|
82 |
-
# -f: force (don't ask questions)
|
83 |
-
# -N: don't reverse patches that appear to be already applied
|
84 |
-
result = subprocess.run( # noqa: S603
|
85 |
-
["patch", "-p1", "-d", str(target_dir), "-f", "-N"], # noqa: S607
|
86 |
-
input=patch_file.read_text(),
|
87 |
-
text=True,
|
88 |
-
capture_output=True,
|
89 |
-
check=False, # Don't raise exception on non-zero exit
|
90 |
-
)
|
91 |
-
|
92 |
-
if result.returncode == 0:
|
93 |
-
logger.info(f"Successfully applied patch: {patch_file.name}")
|
94 |
-
return True
|
95 |
-
if "already applied" in result.stderr.lower() or "reversed" in result.stderr.lower():
|
96 |
-
logger.info(f"Patch {patch_file.name} already applied")
|
97 |
-
return True
|
98 |
-
logger.warning(f"Failed to apply patch {patch_file.name}: {result.stderr}")
|
99 |
-
return False
|
100 |
-
|
101 |
-
except FileNotFoundError:
|
102 |
-
logger.exception("'patch' command not found. Please install patch utility.")
|
103 |
-
return False
|
104 |
-
except Exception:
|
105 |
-
logger.exception(f"Error applying patch {patch_file.name}")
|
106 |
-
return False
|
107 |
-
|
108 |
-
|
109 |
-
def apply_all_patches() -> int:
|
110 |
-
"""
|
111 |
-
Apply all patches from the patches directory.
|
112 |
-
|
113 |
-
Returns:
|
114 |
-
Number of patches successfully applied
|
115 |
-
"""
|
116 |
-
patches_dir = find_patches_directory()
|
117 |
-
|
118 |
-
if not patches_dir.exists():
|
119 |
-
logger.warning(f"Patches directory not found: {patches_dir}")
|
120 |
-
return 0
|
121 |
-
|
122 |
-
# Find all .patch files
|
123 |
-
patch_files = list(patches_dir.glob("*.patch"))
|
124 |
-
|
125 |
-
if not patch_files:
|
126 |
-
logger.info("No patch files found")
|
127 |
-
return 0
|
128 |
-
|
129 |
-
# Get target directory (site-packages)
|
130 |
-
target_dir = get_site_packages_path()
|
131 |
-
logger.info(f"Applying patches to: {target_dir}")
|
132 |
-
|
133 |
-
# Clean any existing duplicates first
|
134 |
-
clean_duplicate_validation_code(target_dir)
|
135 |
-
|
136 |
-
success_count = 0
|
137 |
-
|
138 |
-
# Sort patch files for consistent ordering
|
139 |
-
for patch_file in sorted(patch_files):
|
140 |
-
if apply_patch_file(patch_file, target_dir):
|
141 |
-
success_count += 1
|
142 |
-
|
143 |
-
logger.info(f"Applied {success_count}/{len(patch_files)} patches successfully")
|
144 |
-
return success_count
|
145 |
-
|
146 |
-
|
147 |
-
def is_patch_already_applied(patch_file: Path, target_dir: Path) -> bool:
|
148 |
-
"""
|
149 |
-
Check if a patch has already been applied by looking for specific markers.
|
150 |
-
|
151 |
-
Args:
|
152 |
-
patch_file: Path to the .patch file
|
153 |
-
target_dir: Target directory (usually site-packages)
|
154 |
-
|
155 |
-
Returns:
|
156 |
-
True if patch appears to be already applied, False otherwise
|
157 |
-
"""
|
158 |
-
try:
|
159 |
-
# For model2vec.patch, check if the validation code is already present
|
160 |
-
if "model2vec.patch" in patch_file.name:
|
161 |
-
inference_file = target_dir / "model2vec" / "distill" / "inference.py"
|
162 |
-
if inference_file.exists():
|
163 |
-
inference_content = inference_file.read_text()
|
164 |
-
# Check for the specific validation code we're adding
|
165 |
-
if (
|
166 |
-
"Token-vector mismatch:" in inference_content
|
167 |
-
and "Truncating to prevent failure" in inference_content
|
168 |
-
):
|
169 |
-
# Also make sure it's in the right place (before return statement, not after)
|
170 |
-
lines = inference_content.split("\n")
|
171 |
-
for i, line in enumerate(lines):
|
172 |
-
if "return out_tokens, out_weights" in line:
|
173 |
-
# Check if validation code appears before this return
|
174 |
-
preceding_lines = lines[max(0, i - 10) : i]
|
175 |
-
if any("Token-vector mismatch:" in pline for pline in preceding_lines):
|
176 |
-
return True
|
177 |
-
break
|
178 |
-
|
179 |
-
# For tokenlearn.patch, check if the indexing fix is already present
|
180 |
-
if "tokenlearn.patch" in patch_file.name:
|
181 |
-
pretrain_file = target_dir / "tokenlearn" / "pretrain.py"
|
182 |
-
if pretrain_file.exists():
|
183 |
-
pretrain_content = pretrain_file.read_text()
|
184 |
-
# Check for the specific fix we're adding
|
185 |
-
if (
|
186 |
-
"Fix for index out of bounds issue" in pretrain_content
|
187 |
-
and "torch.clamp(input_ids, 0, self.w.shape[0] - 1)" in pretrain_content
|
188 |
-
):
|
189 |
-
return True
|
190 |
-
|
191 |
-
return False
|
192 |
-
|
193 |
-
except Exception as e:
|
194 |
-
logger.warning(f"Error checking if patch {patch_file.name} is applied: {e}")
|
195 |
-
return False
|
196 |
-
|
197 |
-
|
198 |
-
def clean_duplicate_validation_code(target_dir: Path) -> bool:
|
199 |
-
"""
|
200 |
-
Clean up duplicate validation code that might have been added by multiple patch applications.
|
201 |
-
|
202 |
-
Args:
|
203 |
-
target_dir: Target directory (usually site-packages)
|
204 |
-
|
205 |
-
Returns:
|
206 |
-
True if cleanup was successful, False otherwise
|
207 |
-
"""
|
208 |
-
try:
|
209 |
-
inference_file = target_dir / "model2vec" / "distill" / "inference.py"
|
210 |
-
if not inference_file.exists():
|
211 |
-
return True
|
212 |
-
|
213 |
-
content = inference_file.read_text()
|
214 |
-
lines = content.split("\n")
|
215 |
-
|
216 |
-
# Find all instances of the validation code
|
217 |
-
validation_indices = []
|
218 |
-
for i, line in enumerate(lines):
|
219 |
-
if "Token-vector mismatch:" in line:
|
220 |
-
validation_indices.append(i)
|
221 |
-
|
222 |
-
if len(validation_indices) <= 1:
|
223 |
-
return True # No duplicates or no validation code
|
224 |
-
|
225 |
-
# Keep only the validation code that appears before a return statement
|
226 |
-
lines_to_keep = []
|
227 |
-
skip_until = -1
|
228 |
-
|
229 |
-
for i, line in enumerate(lines):
|
230 |
-
if i <= skip_until:
|
231 |
-
continue
|
232 |
-
|
233 |
-
# If this is validation code
|
234 |
-
if "Token-vector mismatch:" in line:
|
235 |
-
# Look ahead to see if there's a return statement nearby
|
236 |
-
has_return_after = False
|
237 |
-
for j in range(i, min(len(lines), i + 20)):
|
238 |
-
if "return out_tokens, out_weights" in lines[j]:
|
239 |
-
has_return_after = True
|
240 |
-
break
|
241 |
-
|
242 |
-
# Keep this validation block only if it's followed by a return
|
243 |
-
if has_return_after:
|
244 |
-
lines_to_keep.append(line)
|
245 |
-
else:
|
246 |
-
# Skip this validation block (it's a duplicate)
|
247 |
-
# Find the end of this validation block
|
248 |
-
for j in range(i + 1, len(lines)):
|
249 |
-
if lines[j].strip() == "" or not lines[j].startswith(" "):
|
250 |
-
skip_until = j - 1
|
251 |
-
break
|
252 |
-
else:
|
253 |
-
lines_to_keep.append(line)
|
254 |
-
|
255 |
-
# Write back the cleaned content
|
256 |
-
cleaned_content = "\n".join(lines_to_keep)
|
257 |
-
inference_file.write_text(cleaned_content)
|
258 |
-
logger.info("Cleaned duplicate validation code from inference.py")
|
259 |
-
return True
|
260 |
-
|
261 |
-
except Exception as e:
|
262 |
-
logger.warning(f"Error cleaning duplicate validation code: {e}")
|
263 |
-
return False
|
264 |
-
|
265 |
-
|
266 |
-
def main() -> None:
|
267 |
-
"""Main function for standalone execution."""
|
268 |
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
269 |
-
|
270 |
-
print("Applying all patches...")
|
271 |
-
success_count = apply_all_patches()
|
272 |
-
print(f"Done. Applied {success_count} patches.")
|
273 |
-
|
274 |
-
|
275 |
-
if __name__ == "__main__":
|
276 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uv.lock
CHANGED
@@ -774,24 +774,31 @@ dependencies = [
|
|
774 |
{ name = "flash-attn" },
|
775 |
{ name = "hatchling" },
|
776 |
{ name = "iso639" },
|
|
|
|
|
777 |
{ name = "kaleido" },
|
778 |
{ name = "lightning" },
|
779 |
{ name = "matplotlib" },
|
780 |
-
{ name = "
|
781 |
{ name = "mteb" },
|
782 |
{ name = "numpy" },
|
783 |
{ name = "plotly" },
|
784 |
{ name = "psutil" },
|
785 |
{ name = "pydantic" },
|
786 |
{ name = "requests" },
|
|
|
|
|
787 |
{ name = "scikit-learn" },
|
788 |
{ name = "seaborn" },
|
789 |
{ name = "sentence-transformers" },
|
790 |
{ name = "setuptools" },
|
|
|
791 |
{ name = "smart-open", extra = ["s3"] },
|
792 |
{ name = "statsmodels" },
|
793 |
-
{ name = "
|
794 |
{ name = "torch" },
|
|
|
|
|
795 |
{ name = "typer" },
|
796 |
]
|
797 |
|
@@ -813,24 +820,31 @@ requires-dist = [
|
|
813 |
{ name = "flash-attn", specifier = ">=2.7.4.post1" },
|
814 |
{ name = "hatchling", specifier = ">=1.27.0" },
|
815 |
{ name = "iso639", specifier = ">=0.1.4" },
|
|
|
|
|
816 |
{ name = "kaleido", specifier = "==1.0.0rc13" },
|
817 |
{ name = "lightning", specifier = ">=2.5.1.post0" },
|
818 |
{ name = "matplotlib", specifier = ">=3.10.3" },
|
819 |
-
{ name = "
|
820 |
{ name = "mteb", specifier = ">=1.14.15" },
|
821 |
{ name = "numpy", specifier = ">=1.26.4" },
|
822 |
{ name = "plotly", specifier = ">=6.1.1" },
|
823 |
{ name = "psutil", specifier = ">=7.0.0" },
|
824 |
{ name = "pydantic", specifier = ">=2.11.5" },
|
825 |
{ name = "requests", specifier = ">=2.32.3" },
|
|
|
|
|
826 |
{ name = "scikit-learn", specifier = ">=1.6.1" },
|
827 |
{ name = "seaborn", specifier = ">=0.13.2" },
|
828 |
{ name = "sentence-transformers", specifier = ">=4.1.0" },
|
829 |
{ name = "setuptools", specifier = ">=80.8.0" },
|
|
|
830 |
{ name = "smart-open", extras = ["s3"], specifier = ">=7.1.0" },
|
831 |
{ name = "statsmodels", specifier = ">=0.14.4" },
|
832 |
-
{ name = "
|
833 |
{ name = "torch", specifier = ">=2.7.0" },
|
|
|
|
|
834 |
{ name = "typer", specifier = ">=0.16.0" },
|
835 |
]
|
836 |
|
@@ -1187,38 +1201,6 @@ wheels = [
|
|
1187 |
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
|
1188 |
]
|
1189 |
|
1190 |
-
[[package]]
|
1191 |
-
name = "model2vec"
|
1192 |
-
version = "0.5.0"
|
1193 |
-
source = { registry = "https://pypi.org/simple" }
|
1194 |
-
dependencies = [
|
1195 |
-
{ name = "jinja2" },
|
1196 |
-
{ name = "joblib" },
|
1197 |
-
{ name = "numpy" },
|
1198 |
-
{ name = "rich" },
|
1199 |
-
{ name = "safetensors" },
|
1200 |
-
{ name = "setuptools" },
|
1201 |
-
{ name = "tokenizers" },
|
1202 |
-
{ name = "tqdm" },
|
1203 |
-
]
|
1204 |
-
sdist = { url = "https://files.pythonhosted.org/packages/93/18/c546916657e47e52b6e25b231803903bcf4e7ef2497fe41e9869236d7dee/model2vec-0.5.0.tar.gz", hash = "sha256:0771fd99d5c58fac631a2faa233759a8cec7a3be6e9aeeeeeca2d5e7048d1c7b", size = 2665840 }
|
1205 |
-
wheels = [
|
1206 |
-
{ url = "https://files.pythonhosted.org/packages/66/ab/5263bc4605e9960fece76b710c01fef33859dc6ae72832d5987db75eed63/model2vec-0.5.0-py3-none-any.whl", hash = "sha256:12f14a18556975c037961a836a702388876bfec1ff76176f056884d219735271", size = 44578 },
|
1207 |
-
]
|
1208 |
-
|
1209 |
-
[package.optional-dependencies]
|
1210 |
-
distill = [
|
1211 |
-
{ name = "scikit-learn" },
|
1212 |
-
{ name = "torch" },
|
1213 |
-
{ name = "transformers" },
|
1214 |
-
]
|
1215 |
-
train = [
|
1216 |
-
{ name = "lightning" },
|
1217 |
-
{ name = "scikit-learn" },
|
1218 |
-
{ name = "skops" },
|
1219 |
-
{ name = "torch" },
|
1220 |
-
]
|
1221 |
-
|
1222 |
[[package]]
|
1223 |
name = "more-itertools"
|
1224 |
version = "10.7.0"
|
@@ -2492,22 +2474,6 @@ wheels = [
|
|
2492 |
{ url = "https://files.pythonhosted.org/packages/e6/b6/072a8e053ae600dcc2ac0da81a23548e3b523301a442a6ca900e92ac35be/tokenizers-0.21.1-cp39-abi3-win_amd64.whl", hash = "sha256:0f0dcbcc9f6e13e675a66d7a5f2f225a736745ce484c1a4e07476a89ccdad382", size = 2435481 },
|
2493 |
]
|
2494 |
|
2495 |
-
[[package]]
|
2496 |
-
name = "tokenlearn"
|
2497 |
-
version = "0.2.0"
|
2498 |
-
source = { registry = "https://pypi.org/simple" }
|
2499 |
-
dependencies = [
|
2500 |
-
{ name = "datasets" },
|
2501 |
-
{ name = "model2vec", extra = ["distill"] },
|
2502 |
-
{ name = "more-itertools" },
|
2503 |
-
{ name = "sentence-transformers" },
|
2504 |
-
{ name = "torch" },
|
2505 |
-
]
|
2506 |
-
sdist = { url = "https://files.pythonhosted.org/packages/58/b6/f9587ea271a9a7464cd25025b65f471d49bbceb48cc90742a89ac085edfd/tokenlearn-0.2.0.tar.gz", hash = "sha256:7a8faa0f51a510d185a40bef197a88116464adb8ce85ffd12c1d6905369c2375", size = 149042 }
|
2507 |
-
wheels = [
|
2508 |
-
{ url = "https://files.pythonhosted.org/packages/40/3d/1c2b2e80ffd929bb8e7930d6a48e3b4252676cdc6c0c38f13a6f0f374b9c/tokenlearn-0.2.0-py3-none-any.whl", hash = "sha256:7a05e2800420eb2914c30e7377adeb14822c63585a0b9ed018bc82735dae1f29", size = 11970 },
|
2509 |
-
]
|
2510 |
-
|
2511 |
[[package]]
|
2512 |
name = "torch"
|
2513 |
version = "2.7.0"
|
@@ -2580,7 +2546,7 @@ wheels = [
|
|
2580 |
|
2581 |
[[package]]
|
2582 |
name = "transformers"
|
2583 |
-
version = "4.52.
|
2584 |
source = { registry = "https://pypi.org/simple" }
|
2585 |
dependencies = [
|
2586 |
{ name = "filelock" },
|
@@ -2594,9 +2560,9 @@ dependencies = [
|
|
2594 |
{ name = "tokenizers" },
|
2595 |
{ name = "tqdm" },
|
2596 |
]
|
2597 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
2598 |
wheels = [
|
2599 |
-
{ url = "https://files.pythonhosted.org/packages/
|
2600 |
]
|
2601 |
|
2602 |
[[package]]
|
|
|
774 |
{ name = "flash-attn" },
|
775 |
{ name = "hatchling" },
|
776 |
{ name = "iso639" },
|
777 |
+
{ name = "jinja2" },
|
778 |
+
{ name = "joblib" },
|
779 |
{ name = "kaleido" },
|
780 |
{ name = "lightning" },
|
781 |
{ name = "matplotlib" },
|
782 |
+
{ name = "more-itertools" },
|
783 |
{ name = "mteb" },
|
784 |
{ name = "numpy" },
|
785 |
{ name = "plotly" },
|
786 |
{ name = "psutil" },
|
787 |
{ name = "pydantic" },
|
788 |
{ name = "requests" },
|
789 |
+
{ name = "rich" },
|
790 |
+
{ name = "safetensors" },
|
791 |
{ name = "scikit-learn" },
|
792 |
{ name = "seaborn" },
|
793 |
{ name = "sentence-transformers" },
|
794 |
{ name = "setuptools" },
|
795 |
+
{ name = "skops" },
|
796 |
{ name = "smart-open", extra = ["s3"] },
|
797 |
{ name = "statsmodels" },
|
798 |
+
{ name = "tokenizers" },
|
799 |
{ name = "torch" },
|
800 |
+
{ name = "tqdm" },
|
801 |
+
{ name = "transformers" },
|
802 |
{ name = "typer" },
|
803 |
]
|
804 |
|
|
|
820 |
{ name = "flash-attn", specifier = ">=2.7.4.post1" },
|
821 |
{ name = "hatchling", specifier = ">=1.27.0" },
|
822 |
{ name = "iso639", specifier = ">=0.1.4" },
|
823 |
+
{ name = "jinja2", specifier = ">=3.0.0" },
|
824 |
+
{ name = "joblib", specifier = ">=1.0.0" },
|
825 |
{ name = "kaleido", specifier = "==1.0.0rc13" },
|
826 |
{ name = "lightning", specifier = ">=2.5.1.post0" },
|
827 |
{ name = "matplotlib", specifier = ">=3.10.3" },
|
828 |
+
{ name = "more-itertools", specifier = ">=10.5.0" },
|
829 |
{ name = "mteb", specifier = ">=1.14.15" },
|
830 |
{ name = "numpy", specifier = ">=1.26.4" },
|
831 |
{ name = "plotly", specifier = ">=6.1.1" },
|
832 |
{ name = "psutil", specifier = ">=7.0.0" },
|
833 |
{ name = "pydantic", specifier = ">=2.11.5" },
|
834 |
{ name = "requests", specifier = ">=2.32.3" },
|
835 |
+
{ name = "rich", specifier = ">=10.0.0" },
|
836 |
+
{ name = "safetensors", specifier = ">=0.3.0" },
|
837 |
{ name = "scikit-learn", specifier = ">=1.6.1" },
|
838 |
{ name = "seaborn", specifier = ">=0.13.2" },
|
839 |
{ name = "sentence-transformers", specifier = ">=4.1.0" },
|
840 |
{ name = "setuptools", specifier = ">=80.8.0" },
|
841 |
+
{ name = "skops", specifier = ">=0.11.0" },
|
842 |
{ name = "smart-open", extras = ["s3"], specifier = ">=7.1.0" },
|
843 |
{ name = "statsmodels", specifier = ">=0.14.4" },
|
844 |
+
{ name = "tokenizers", specifier = ">=0.20" },
|
845 |
{ name = "torch", specifier = ">=2.7.0" },
|
846 |
+
{ name = "tqdm", specifier = ">=4.65.0" },
|
847 |
+
{ name = "transformers", specifier = "<=4.52.1" },
|
848 |
{ name = "typer", specifier = ">=0.16.0" },
|
849 |
]
|
850 |
|
|
|
1201 |
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
|
1202 |
]
|
1203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1204 |
[[package]]
|
1205 |
name = "more-itertools"
|
1206 |
version = "10.7.0"
|
|
|
2474 |
{ url = "https://files.pythonhosted.org/packages/e6/b6/072a8e053ae600dcc2ac0da81a23548e3b523301a442a6ca900e92ac35be/tokenizers-0.21.1-cp39-abi3-win_amd64.whl", hash = "sha256:0f0dcbcc9f6e13e675a66d7a5f2f225a736745ce484c1a4e07476a89ccdad382", size = 2435481 },
|
2475 |
]
|
2476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2477 |
[[package]]
|
2478 |
name = "torch"
|
2479 |
version = "2.7.0"
|
|
|
2546 |
|
2547 |
[[package]]
|
2548 |
name = "transformers"
|
2549 |
+
version = "4.52.1"
|
2550 |
source = { registry = "https://pypi.org/simple" }
|
2551 |
dependencies = [
|
2552 |
{ name = "filelock" },
|
|
|
2560 |
{ name = "tokenizers" },
|
2561 |
{ name = "tqdm" },
|
2562 |
]
|
2563 |
+
sdist = { url = "https://files.pythonhosted.org/packages/4a/de/f3f3a0649dc522aeff55a5739e06e132c875c53701307a2ddd7ce7528ec5/transformers-4.52.1.tar.gz", hash = "sha256:c380d583ed9c7ebe3e30ca5e55ec1249db39eb9ee277f8e74dab1abc6a03c938", size = 8944009 }
|
2564 |
wheels = [
|
2565 |
+
{ url = "https://files.pythonhosted.org/packages/b8/1e/2b00e5021c3545d4a0ae32f3d332ae29e62a6259092f1468976e7b9d4adb/transformers-4.52.1-py3-none-any.whl", hash = "sha256:604b2bb357c480dc5883b7944e8562c967f6b06f63dfb6a1c4665d13d067148f", size = 10459023 },
|
2566 |
]
|
2567 |
|
2568 |
[[package]]
|