Add dataset format checking script, training configuration, and dataset files for ZamAI-Mistral-7B-Pashto
Browse files- check_dataset_format.py +136 -0
- config.json +23 -0
- prepare_for_autotrain.py +148 -0
- requirements.txt +4 -0
- tokenizer_config.json +43 -0
- train_model.py +273 -0
- training_config.yaml +93 -0
- zamai_pashto_dataset.json +42 -0
check_dataset_format.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""
|
5 |
+
Script to check if a dataset is properly formatted for AutoTrain
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
from datasets import load_dataset
|
10 |
+
from collections import Counter
|
11 |
+
|
12 |
+
def parse_args():
|
13 |
+
parser = argparse.ArgumentParser(description="Check dataset format for AutoTrain compatibility")
|
14 |
+
parser.add_argument(
|
15 |
+
"--dataset_path",
|
16 |
+
type=str,
|
17 |
+
required=True,
|
18 |
+
help="Path to dataset (local or HF Hub)"
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"--split",
|
22 |
+
type=str,
|
23 |
+
default="train",
|
24 |
+
help="Split to check (train/test/validation)"
|
25 |
+
)
|
26 |
+
return parser.parse_args()
|
27 |
+
|
28 |
+
def check_dataset_format(dataset):
|
29 |
+
"""Check if dataset is properly formatted"""
|
30 |
+
print("\n=== Dataset Format Check ===")
|
31 |
+
|
32 |
+
# Check columns
|
33 |
+
columns = dataset.column_names
|
34 |
+
print(f"Columns found: {columns}")
|
35 |
+
|
36 |
+
# Check for expected columns
|
37 |
+
text_format = "text" in columns
|
38 |
+
instruction_format = "instruction" in columns and "response" in columns
|
39 |
+
|
40 |
+
if text_format:
|
41 |
+
print("✅ Dataset is in TEXT format (single 'text' column)")
|
42 |
+
elif instruction_format:
|
43 |
+
print("✅ Dataset is in INSTRUCTION-RESPONSE format")
|
44 |
+
else:
|
45 |
+
print("❌ Dataset format not recognized. Expected either:")
|
46 |
+
print(" - 'text' column for text generation")
|
47 |
+
print(" - 'instruction' and 'response' columns for instruction tuning")
|
48 |
+
return False
|
49 |
+
|
50 |
+
# Check dataset size
|
51 |
+
print(f"\nDataset size: {len(dataset)} examples")
|
52 |
+
|
53 |
+
# Sample some examples
|
54 |
+
print("\n=== Sample Examples ===")
|
55 |
+
for i in range(min(3, len(dataset))):
|
56 |
+
print(f"\nExample {i+1}:")
|
57 |
+
example = dataset[i]
|
58 |
+
|
59 |
+
if text_format:
|
60 |
+
text = example['text']
|
61 |
+
print(f"Text length: {len(text)} characters")
|
62 |
+
print(f"Preview: {text[:200]}...")
|
63 |
+
elif instruction_format:
|
64 |
+
instruction = example['instruction']
|
65 |
+
response = example['response']
|
66 |
+
print(f"Instruction: {instruction[:100]}...")
|
67 |
+
print(f"Response: {response[:100]}...")
|
68 |
+
|
69 |
+
# Check for empty values
|
70 |
+
print("\n=== Data Quality Check ===")
|
71 |
+
empty_count = 0
|
72 |
+
short_count = 0
|
73 |
+
|
74 |
+
for example in dataset:
|
75 |
+
if text_format:
|
76 |
+
if not example['text'] or len(example['text'].strip()) == 0:
|
77 |
+
empty_count += 1
|
78 |
+
elif len(example['text']) < 10:
|
79 |
+
short_count += 1
|
80 |
+
elif instruction_format:
|
81 |
+
if not example['instruction'] or not example['response']:
|
82 |
+
empty_count += 1
|
83 |
+
elif len(example['instruction']) < 5 or len(example['response']) < 5:
|
84 |
+
short_count += 1
|
85 |
+
|
86 |
+
print(f"Empty examples: {empty_count}")
|
87 |
+
print(f"Very short examples: {short_count}")
|
88 |
+
|
89 |
+
if empty_count > 0:
|
90 |
+
print("⚠️ Warning: Found empty examples. Consider removing them.")
|
91 |
+
if short_count > len(dataset) * 0.1:
|
92 |
+
print("⚠️ Warning: Many short examples found. This might affect training quality.")
|
93 |
+
|
94 |
+
# Language detection (simple check for Pashto characters)
|
95 |
+
print("\n=== Language Check ===")
|
96 |
+
pashto_count = 0
|
97 |
+
sample_size = min(100, len(dataset))
|
98 |
+
|
99 |
+
for i in range(sample_size):
|
100 |
+
example = dataset[i]
|
101 |
+
text = example.get('text', '') or example.get('instruction', '') + example.get('response', '')
|
102 |
+
# Check for Pashto/Arabic script characters
|
103 |
+
if any('\u0600' <= c <= '\u06FF' for c in text):
|
104 |
+
pashto_count += 1
|
105 |
+
|
106 |
+
pashto_percentage = (pashto_count / sample_size) * 100
|
107 |
+
print(f"Examples with Pashto/Arabic script: {pashto_percentage:.1f}%")
|
108 |
+
|
109 |
+
if pashto_percentage < 50:
|
110 |
+
print("⚠️ Warning: Less than 50% of samples contain Pashto script.")
|
111 |
+
|
112 |
+
return True
|
113 |
+
|
114 |
+
def main():
|
115 |
+
args = parse_args()
|
116 |
+
|
117 |
+
print(f"Loading dataset from: {args.dataset_path}")
|
118 |
+
|
119 |
+
try:
|
120 |
+
# Load dataset
|
121 |
+
dataset = load_dataset(args.dataset_path, split=args.split)
|
122 |
+
|
123 |
+
# Check format
|
124 |
+
is_valid = check_dataset_format(dataset)
|
125 |
+
|
126 |
+
if is_valid:
|
127 |
+
print("\n✅ Dataset is ready for AutoTrain!")
|
128 |
+
else:
|
129 |
+
print("\n❌ Dataset needs formatting adjustments.")
|
130 |
+
|
131 |
+
except Exception as e:
|
132 |
+
print(f"Error loading dataset: {e}")
|
133 |
+
print("\nMake sure the dataset path is correct and accessible.")
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
main()
|
config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "mistral",
|
3 |
+
"architectures": ["MistralForCausalLM"],
|
4 |
+
"hidden_size": 4096,
|
5 |
+
"intermediate_size": 14336,
|
6 |
+
"num_hidden_layers": 32,
|
7 |
+
"num_attention_heads": 32,
|
8 |
+
"num_key_value_heads": 8,
|
9 |
+
"hidden_act": "silu",
|
10 |
+
"max_position_embeddings": 32768,
|
11 |
+
"initializer_range": 0.02,
|
12 |
+
"rms_norm_eps": 1e-05,
|
13 |
+
"use_cache": true,
|
14 |
+
"pad_token_id": null,
|
15 |
+
"bos_token_id": 1,
|
16 |
+
"eos_token_id": 2,
|
17 |
+
"tie_word_embeddings": false,
|
18 |
+
"rope_theta": 10000.0,
|
19 |
+
"sliding_window": 4096,
|
20 |
+
"attention_dropout": 0.0,
|
21 |
+
"torch_dtype": "float16",
|
22 |
+
"transformers_version": "4.34.0"
|
23 |
+
}
|
prepare_for_autotrain.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""
|
5 |
+
Script to prepare Pashto dataset for AutoTrain fine-tuning
|
6 |
+
Converts data to the format expected by AutoTrain and uploads to HF Hub
|
7 |
+
"""
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import json
|
11 |
+
import os
|
12 |
+
from datasets import Dataset, DatasetDict
|
13 |
+
from huggingface_hub import HfApi, create_repo
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser(description="Prepare dataset for AutoTrain")
|
17 |
+
parser.add_argument(
|
18 |
+
"--input_file",
|
19 |
+
type=str,
|
20 |
+
default="zamai_pashto_dataset.json",
|
21 |
+
help="Input JSON file with the dataset"
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--format_type",
|
25 |
+
type=str,
|
26 |
+
choices=["text", "instruction-response"],
|
27 |
+
default="instruction-response",
|
28 |
+
help="Format type for the dataset"
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--train_split",
|
32 |
+
type=float,
|
33 |
+
default=0.9,
|
34 |
+
help="Proportion of data for training (rest goes to test)"
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--push_to_hub",
|
38 |
+
action="store_true",
|
39 |
+
help="Push dataset to Hugging Face Hub"
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--hub_dataset_name",
|
43 |
+
type=str,
|
44 |
+
default="ZamAI_Pashto_Training",
|
45 |
+
help="Name for the dataset on HF Hub"
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--hf_token",
|
49 |
+
type=str,
|
50 |
+
help="Hugging Face API token"
|
51 |
+
)
|
52 |
+
return parser.parse_args()
|
53 |
+
|
54 |
+
def format_for_text_generation(data, format_type):
|
55 |
+
"""Format data for text generation training"""
|
56 |
+
formatted_data = []
|
57 |
+
|
58 |
+
for item in data:
|
59 |
+
if format_type == "text":
|
60 |
+
# Simple text format - concatenate input and output
|
61 |
+
if 'input' in item and 'output' in item:
|
62 |
+
text = f"{item['input']}\n\n{item['output']}"
|
63 |
+
formatted_data.append({"text": text})
|
64 |
+
|
65 |
+
elif format_type == "instruction-response":
|
66 |
+
# Instruction-response format
|
67 |
+
if 'input' in item and 'output' in item:
|
68 |
+
formatted_data.append({
|
69 |
+
"instruction": item['input'],
|
70 |
+
"response": item['output']
|
71 |
+
})
|
72 |
+
|
73 |
+
return formatted_data
|
74 |
+
|
75 |
+
def main():
|
76 |
+
args = parse_args()
|
77 |
+
|
78 |
+
# Load the dataset
|
79 |
+
print(f"Loading dataset from {args.input_file}")
|
80 |
+
with open(args.input_file, 'r', encoding='utf-8') as f:
|
81 |
+
data = json.load(f)
|
82 |
+
|
83 |
+
print(f"Loaded {len(data)} examples")
|
84 |
+
|
85 |
+
# Format the data
|
86 |
+
print(f"Formatting data as {args.format_type}")
|
87 |
+
formatted_data = format_for_text_generation(data, args.format_type)
|
88 |
+
|
89 |
+
# Split into train/test
|
90 |
+
split_idx = int(len(formatted_data) * args.train_split)
|
91 |
+
train_data = formatted_data[:split_idx]
|
92 |
+
test_data = formatted_data[split_idx:]
|
93 |
+
|
94 |
+
print(f"Train examples: {len(train_data)}")
|
95 |
+
print(f"Test examples: {len(test_data)}")
|
96 |
+
|
97 |
+
# Create datasets
|
98 |
+
train_dataset = Dataset.from_list(train_data)
|
99 |
+
test_dataset = Dataset.from_list(test_data)
|
100 |
+
|
101 |
+
# Create DatasetDict
|
102 |
+
dataset_dict = DatasetDict({
|
103 |
+
"train": train_dataset,
|
104 |
+
"test": test_dataset
|
105 |
+
})
|
106 |
+
|
107 |
+
# Save locally
|
108 |
+
local_path = f"{args.hub_dataset_name}_local"
|
109 |
+
dataset_dict.save_to_disk(local_path)
|
110 |
+
print(f"Dataset saved locally to {local_path}")
|
111 |
+
|
112 |
+
# Push to Hub if requested
|
113 |
+
if args.push_to_hub:
|
114 |
+
if not args.hf_token:
|
115 |
+
print("Error: HF token required for pushing to Hub")
|
116 |
+
return
|
117 |
+
|
118 |
+
print(f"Pushing dataset to Hub as tasal9/{args.hub_dataset_name}")
|
119 |
+
|
120 |
+
# Create repo if needed
|
121 |
+
api = HfApi(token=args.hf_token)
|
122 |
+
repo_id = f"tasal9/{args.hub_dataset_name}"
|
123 |
+
|
124 |
+
try:
|
125 |
+
create_repo(
|
126 |
+
repo_id=repo_id,
|
127 |
+
token=args.hf_token,
|
128 |
+
repo_type="dataset",
|
129 |
+
exist_ok=True
|
130 |
+
)
|
131 |
+
except Exception as e:
|
132 |
+
print(f"Note: {e}")
|
133 |
+
|
134 |
+
# Push dataset
|
135 |
+
dataset_dict.push_to_hub(
|
136 |
+
repo_id,
|
137 |
+
token=args.hf_token,
|
138 |
+
commit_message="Upload Pashto training dataset for AutoTrain"
|
139 |
+
)
|
140 |
+
|
141 |
+
print(f"Dataset uploaded to https://huggingface.co/datasets/{repo_id}")
|
142 |
+
|
143 |
+
# Print sample
|
144 |
+
print("\nSample from training data:")
|
145 |
+
print(train_dataset[0])
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
main()
|
requirements.txt
CHANGED
@@ -43,3 +43,7 @@ ipykernel>=6.25.0
|
|
43 |
black>=23.9.0
|
44 |
isort>=5.12.0
|
45 |
pylint>=2.17.0
|
|
|
|
|
|
|
|
|
|
43 |
black>=23.9.0
|
44 |
isort>=5.12.0
|
45 |
pylint>=2.17.0
|
46 |
+
|
47 |
+
# Additional training dependencies
|
48 |
+
wandb>=0.15.0
|
49 |
+
pyyaml>=6.0
|
tokenizer_config.json
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"0": {
|
6 |
+
"content": "<unk>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"1": {
|
14 |
+
"content": "<s>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"2": {
|
22 |
+
"content": "</s>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"additional_special_tokens": [],
|
31 |
+
"bos_token": "<s>",
|
32 |
+
"clean_up_tokenization_spaces": false,
|
33 |
+
"eos_token": "</s>",
|
34 |
+
"legacy": true,
|
35 |
+
"model_max_length": 32768,
|
36 |
+
"pad_token": null,
|
37 |
+
"sp_model_kwargs": {},
|
38 |
+
"spaces_between_special_tokens": false,
|
39 |
+
"tokenizer_class": "LlamaTokenizer",
|
40 |
+
"unk_token": "<unk>",
|
41 |
+
"use_default_system_prompt": true,
|
42 |
+
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
|
43 |
+
}
|
train_model.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""
|
5 |
+
Comprehensive training script for ZamAI-Mistral-7B-Pashto
|
6 |
+
Supports both local training and AutoTrain
|
7 |
+
"""
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import os
|
11 |
+
import yaml
|
12 |
+
import torch
|
13 |
+
from transformers import (
|
14 |
+
AutoModelForCausalLM,
|
15 |
+
AutoTokenizer,
|
16 |
+
TrainingArguments,
|
17 |
+
Trainer,
|
18 |
+
DataCollatorForLanguageModeling,
|
19 |
+
BitsAndBytesConfig
|
20 |
+
)
|
21 |
+
from datasets import load_dataset
|
22 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
23 |
+
from peft.utils import prepare_model_for_kbit_training
|
24 |
+
try:
|
25 |
+
import wandb
|
26 |
+
except ImportError:
|
27 |
+
wandb = None
|
28 |
+
|
29 |
+
def parse_args():
|
30 |
+
parser = argparse.ArgumentParser(description="Train ZamAI-Mistral-7B-Pashto")
|
31 |
+
parser.add_argument(
|
32 |
+
"--config",
|
33 |
+
type=str,
|
34 |
+
default="training_config.yaml",
|
35 |
+
help="Path to training configuration file"
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--local",
|
39 |
+
action="store_true",
|
40 |
+
help="Run training locally instead of using AutoTrain"
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--dataset_path",
|
44 |
+
type=str,
|
45 |
+
help="Override dataset path from config"
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--output_dir",
|
49 |
+
type=str,
|
50 |
+
help="Override output directory from config"
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--hf_token",
|
54 |
+
type=str,
|
55 |
+
help="Hugging Face API token"
|
56 |
+
)
|
57 |
+
return parser.parse_args()
|
58 |
+
|
59 |
+
def load_config(config_path):
|
60 |
+
"""Load training configuration from YAML file"""
|
61 |
+
with open(config_path, 'r') as f:
|
62 |
+
config = yaml.safe_load(f)
|
63 |
+
return config
|
64 |
+
|
65 |
+
def prepare_model_and_tokenizer(config):
|
66 |
+
"""Prepare model and tokenizer with quantization and LoRA"""
|
67 |
+
|
68 |
+
# Quantization config
|
69 |
+
bnb_config = BitsAndBytesConfig(
|
70 |
+
load_in_4bit=config['model'].get('load_in_4bit', True),
|
71 |
+
bnb_4bit_use_double_quant=True,
|
72 |
+
bnb_4bit_quant_type="nf4",
|
73 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
74 |
+
)
|
75 |
+
|
76 |
+
# Load model
|
77 |
+
model = AutoModelForCausalLM.from_pretrained(
|
78 |
+
config['model']['name'],
|
79 |
+
quantization_config=bnb_config,
|
80 |
+
device_map="auto",
|
81 |
+
trust_remote_code=True,
|
82 |
+
use_flash_attention_2=config['model'].get('use_flash_attention_2', True)
|
83 |
+
)
|
84 |
+
|
85 |
+
# Load tokenizer
|
86 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
87 |
+
config['model']['name'],
|
88 |
+
trust_remote_code=True
|
89 |
+
)
|
90 |
+
tokenizer.pad_token = tokenizer.eos_token
|
91 |
+
tokenizer.padding_side = "right"
|
92 |
+
|
93 |
+
# Prepare model for k-bit training
|
94 |
+
model = prepare_model_for_kbit_training(model)
|
95 |
+
|
96 |
+
# LoRA configuration
|
97 |
+
lora_config = LoraConfig(
|
98 |
+
r=config['lora']['r'],
|
99 |
+
lora_alpha=config['lora']['lora_alpha'],
|
100 |
+
target_modules=config['lora']['target_modules'],
|
101 |
+
lora_dropout=config['lora']['lora_dropout'],
|
102 |
+
bias=config['lora']['bias'],
|
103 |
+
task_type=config['lora']['task_type']
|
104 |
+
)
|
105 |
+
|
106 |
+
# Apply LoRA
|
107 |
+
model = get_peft_model(model, lora_config)
|
108 |
+
model.print_trainable_parameters()
|
109 |
+
|
110 |
+
return model, tokenizer
|
111 |
+
|
112 |
+
def prepare_dataset(config, tokenizer):
|
113 |
+
"""Load and prepare dataset for training"""
|
114 |
+
|
115 |
+
# Load train and validation splits separately
|
116 |
+
train_dataset = load_dataset(
|
117 |
+
config['dataset']['name'],
|
118 |
+
split=config['dataset']['train_split']
|
119 |
+
)
|
120 |
+
validation_dataset = load_dataset(
|
121 |
+
config['dataset']['name'],
|
122 |
+
split=config['dataset']['validation_split']
|
123 |
+
)
|
124 |
+
|
125 |
+
# Tokenization function
|
126 |
+
def tokenize_function(examples):
|
127 |
+
# Handle different dataset formats
|
128 |
+
if config['dataset']['text_column'] in examples:
|
129 |
+
texts = examples[config['dataset']['text_column']]
|
130 |
+
else:
|
131 |
+
# For instruction-response format
|
132 |
+
texts = [
|
133 |
+
f"[INST] {inst} [/INST] {resp}"
|
134 |
+
for inst, resp in zip(examples['instruction'], examples['response'])
|
135 |
+
]
|
136 |
+
|
137 |
+
return tokenizer(
|
138 |
+
texts,
|
139 |
+
truncation=True,
|
140 |
+
padding="max_length",
|
141 |
+
max_length=config['dataset']['max_seq_length']
|
142 |
+
)
|
143 |
+
|
144 |
+
# Tokenize datasets
|
145 |
+
tokenized_train = train_dataset.map(
|
146 |
+
tokenize_function,
|
147 |
+
batched=True,
|
148 |
+
remove_columns=train_dataset.column_names
|
149 |
+
)
|
150 |
+
tokenized_validation = validation_dataset.map(
|
151 |
+
tokenize_function,
|
152 |
+
batched=True,
|
153 |
+
remove_columns=validation_dataset.column_names
|
154 |
+
)
|
155 |
+
|
156 |
+
# Return datasets directly
|
157 |
+
return tokenized_train, tokenized_validation
|
158 |
+
|
159 |
+
def main():
|
160 |
+
args = parse_args()
|
161 |
+
|
162 |
+
# Load configuration
|
163 |
+
config = load_config(args.config)
|
164 |
+
|
165 |
+
# Override config with command line arguments
|
166 |
+
if args.dataset_path:
|
167 |
+
config['dataset']['name'] = args.dataset_path
|
168 |
+
if args.output_dir:
|
169 |
+
config['output']['output_dir'] = args.output_dir
|
170 |
+
if args.hf_token:
|
171 |
+
config['hub']['hub_token'] = args.hf_token
|
172 |
+
|
173 |
+
# Set up wandb if configured
|
174 |
+
if 'wandb' in config['advanced'].get('report_to', []):
|
175 |
+
wandb.init(
|
176 |
+
project="zamai-mistral-pashto",
|
177 |
+
name=f"mistral-7b-pashto-{config['lora']['r']}r",
|
178 |
+
config=config
|
179 |
+
)
|
180 |
+
|
181 |
+
print("Loading model and tokenizer...")
|
182 |
+
model, tokenizer = prepare_model_and_tokenizer(config)
|
183 |
+
|
184 |
+
print("Preparing dataset...")
|
185 |
+
train_dataset, validation_dataset = prepare_dataset(config, tokenizer)
|
186 |
+
|
187 |
+
# Data collator
|
188 |
+
data_collator = DataCollatorForLanguageModeling(
|
189 |
+
tokenizer=tokenizer,
|
190 |
+
mlm=False
|
191 |
+
)
|
192 |
+
|
193 |
+
# Training arguments
|
194 |
+
training_args = TrainingArguments(
|
195 |
+
output_dir=config['output']['output_dir'],
|
196 |
+
num_train_epochs=config['training']['num_train_epochs'],
|
197 |
+
per_device_train_batch_size=config['training']['per_device_train_batch_size'],
|
198 |
+
per_device_eval_batch_size=config['training']['per_device_eval_batch_size'],
|
199 |
+
gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
|
200 |
+
gradient_checkpointing=config['training']['gradient_checkpointing'],
|
201 |
+
learning_rate=config['training']['learning_rate'],
|
202 |
+
lr_scheduler_type=config['training']['lr_scheduler_type'],
|
203 |
+
warmup_ratio=config['training']['warmup_ratio'],
|
204 |
+
weight_decay=config['training']['weight_decay'],
|
205 |
+
max_grad_norm=config['training']['max_grad_norm'],
|
206 |
+
optim=config['optimization']['optim'],
|
207 |
+
fp16=config['optimization']['fp16'],
|
208 |
+
bf16=config['optimization']['bf16'],
|
209 |
+
tf32=config['optimization']['tf32'],
|
210 |
+
logging_dir=config['output']['logging_dir'],
|
211 |
+
logging_steps=config['output']['logging_steps'],
|
212 |
+
save_steps=config['output']['save_steps'],
|
213 |
+
save_total_limit=config['output']['save_total_limit'],
|
214 |
+
evaluation_strategy=config['output']['evaluation_strategy'],
|
215 |
+
eval_steps=config['output']['eval_steps'],
|
216 |
+
save_strategy=config['output']['save_strategy'],
|
217 |
+
load_best_model_at_end=config['output']['load_best_model_at_end'],
|
218 |
+
metric_for_best_model=config['output']['metric_for_best_model'],
|
219 |
+
greater_is_better=config['output']['greater_is_better'],
|
220 |
+
push_to_hub=config['hub']['push_to_hub'],
|
221 |
+
hub_model_id=config['hub']['hub_model_id'],
|
222 |
+
hub_strategy=config['hub']['hub_strategy'],
|
223 |
+
hub_token=config['hub']['hub_token'],
|
224 |
+
seed=config['advanced']['seed'],
|
225 |
+
data_seed=config['advanced']['data_seed'],
|
226 |
+
report_to=config['advanced']['report_to'],
|
227 |
+
remove_unused_columns=config['advanced']['remove_unused_columns']
|
228 |
+
)
|
229 |
+
|
230 |
+
# Initialize trainer
|
231 |
+
trainer = Trainer(
|
232 |
+
model=model,
|
233 |
+
args=training_args,
|
234 |
+
train_dataset=train_dataset,
|
235 |
+
eval_dataset=validation_dataset,
|
236 |
+
data_collator=data_collator
|
237 |
+
)
|
238 |
+
|
239 |
+
# Train
|
240 |
+
print("Starting training...")
|
241 |
+
trainer.train()
|
242 |
+
|
243 |
+
# Save the final model
|
244 |
+
print("Saving final model...")
|
245 |
+
trainer.save_model()
|
246 |
+
|
247 |
+
# Push to hub if configured
|
248 |
+
if config['hub']['push_to_hub']:
|
249 |
+
print("Pushing to Hugging Face Hub...")
|
250 |
+
trainer.push_to_hub()
|
251 |
+
|
252 |
+
print("Training completed successfully!")
|
253 |
+
|
254 |
+
# Test the model
|
255 |
+
print("\nTesting the model with a sample prompt...")
|
256 |
+
test_prompt = "ستاسو نوم څه دی؟"
|
257 |
+
inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)
|
258 |
+
|
259 |
+
with torch.no_grad():
|
260 |
+
outputs = model.generate(
|
261 |
+
**inputs,
|
262 |
+
max_new_tokens=config['inference']['max_new_tokens'],
|
263 |
+
temperature=config['inference']['temperature'],
|
264 |
+
top_p=config['inference']['top_p'],
|
265 |
+
do_sample=config['inference']['do_sample']
|
266 |
+
)
|
267 |
+
|
268 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
269 |
+
print(f"Prompt: {test_prompt}")
|
270 |
+
print(f"Response: {response}")
|
271 |
+
|
272 |
+
if __name__ == "__main__":
|
273 |
+
main()
|
training_config.yaml
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ZamAI-Mistral-7B-Pashto Training Configuration
|
2 |
+
|
3 |
+
# Model Configuration
|
4 |
+
model:
|
5 |
+
name: "mistralai/Mistral-7B-Instruct-v0.1"
|
6 |
+
type: "causal-lm"
|
7 |
+
use_flash_attention_2: true
|
8 |
+
load_in_8bit: false
|
9 |
+
load_in_4bit: true # Use 4-bit quantization for efficiency
|
10 |
+
|
11 |
+
# Dataset Configuration
|
12 |
+
dataset:
|
13 |
+
name: "tasal9/ZamAI_Pashto_Training"
|
14 |
+
text_column: "text" # or "instruction" for instruction-response format
|
15 |
+
train_split: "train"
|
16 |
+
validation_split: "test"
|
17 |
+
max_seq_length: 2048
|
18 |
+
|
19 |
+
# Training Parameters
|
20 |
+
training:
|
21 |
+
num_train_epochs: 3
|
22 |
+
per_device_train_batch_size: 4
|
23 |
+
per_device_eval_batch_size: 4
|
24 |
+
gradient_accumulation_steps: 4
|
25 |
+
gradient_checkpointing: true
|
26 |
+
learning_rate: 2e-4
|
27 |
+
lr_scheduler_type: "cosine"
|
28 |
+
warmup_ratio: 0.03
|
29 |
+
weight_decay: 0.001
|
30 |
+
max_grad_norm: 0.3
|
31 |
+
|
32 |
+
# LoRA Configuration
|
33 |
+
lora:
|
34 |
+
r: 16 # LoRA attention dimension
|
35 |
+
lora_alpha: 32 # LoRA scaling parameter
|
36 |
+
lora_dropout: 0.05 # LoRA dropout
|
37 |
+
bias: "none"
|
38 |
+
task_type: "CAUSAL_LM"
|
39 |
+
target_modules:
|
40 |
+
- "q_proj"
|
41 |
+
- "k_proj"
|
42 |
+
- "v_proj"
|
43 |
+
- "o_proj"
|
44 |
+
- "gate_proj"
|
45 |
+
- "up_proj"
|
46 |
+
- "down_proj"
|
47 |
+
- "lm_head"
|
48 |
+
|
49 |
+
# Optimization
|
50 |
+
optimization:
|
51 |
+
optim: "paged_adamw_32bit"
|
52 |
+
fp16: false
|
53 |
+
bf16: true # Use bfloat16 for better stability
|
54 |
+
tf32: true
|
55 |
+
|
56 |
+
# Logging and Saving
|
57 |
+
output:
|
58 |
+
output_dir: "./results"
|
59 |
+
logging_dir: "./logs"
|
60 |
+
logging_steps: 10
|
61 |
+
save_steps: 500
|
62 |
+
save_total_limit: 3
|
63 |
+
evaluation_strategy: "steps"
|
64 |
+
eval_steps: 500
|
65 |
+
save_strategy: "steps"
|
66 |
+
load_best_model_at_end: true
|
67 |
+
metric_for_best_model: "eval_loss"
|
68 |
+
greater_is_better: false
|
69 |
+
|
70 |
+
# Hub Configuration
|
71 |
+
hub:
|
72 |
+
push_to_hub: true
|
73 |
+
hub_model_id: "tasal9/ZamAI-Mistral-7B-Pashto"
|
74 |
+
hub_strategy: "every_save"
|
75 |
+
hub_token: "YOUR_HF_TOKEN" # Replace with actual token
|
76 |
+
|
77 |
+
# Advanced Settings
|
78 |
+
advanced:
|
79 |
+
seed: 42
|
80 |
+
data_seed: 42
|
81 |
+
dataloader_num_workers: 4
|
82 |
+
remove_unused_columns: false
|
83 |
+
label_names: ["labels"]
|
84 |
+
report_to: ["tensorboard", "wandb"] # Optional: remove wandb if not using
|
85 |
+
|
86 |
+
# Inference Configuration (for testing)
|
87 |
+
inference:
|
88 |
+
max_new_tokens: 256
|
89 |
+
temperature: 0.7
|
90 |
+
top_p: 0.9
|
91 |
+
top_k: 50
|
92 |
+
do_sample: true
|
93 |
+
repetition_penalty: 1.1
|
zamai_pashto_dataset.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"input": "ستاسو نوم څه دی؟",
|
4 |
+
"output": "زما نوم ZamAI دی. زه د مصنوعی ذکاوت پر بنسټ جوړ شوی یو ژبنی مودل یم چې په پښتو ژبه کې مرسته کولی شم."
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"input": "د افغانستان د کلتور په اړه ماته معلومات راکړئ.",
|
8 |
+
"output": "د افغانستان کلتور ډیر غني او بډای دی. افغانستان د بیلابیلو قومونو، ژبو او دودونو هیواد دی. اصلي قومونه پښتانه، تاجک، هزاره، ازبک او نور دي. دا هیواد ډیرې دودیزې موسیقۍ، خواړه، لباسونه او هنرونه لري."
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"input": "په پښتو کې سلام څنګه کوو؟",
|
12 |
+
"output": "په پښتو کې د سلام کولو څو لارې دي:\n- السلام علیکم (رسمي)\n- سلام ورور/خور (غیر رسمي)\n- ستړی مه شې (د ورځې په هر وخت)\n- په خیر راغلاست (د هرکلي لپاره)"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"input": "د کابل په اړه معلومات راکړئ.",
|
16 |
+
"output": "کابل د افغانستان پلازمېنه او تر ټولو لوی ښار دی. دا ښار د هندوکش غرونو په منځ کې موقعیت لري او نږدې ۴.۵ میلیونه وګړي لري. کابل یو تاریخي ښار دی چې له ۳۵۰۰ کلونو څخه زیات تاریخ لري."
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"input": "پښتو ژبه څومره خلک خبرې کوي؟",
|
20 |
+
"output": "پښتو د نړۍ په کچه نږدې ۵۰-۶۰ میلیونه خلک خبرې کوي. دا د افغانستان دوه رسمي ژبو څخه یوه ده او په پاکستان کې هم په پراخه کچه ویل کیږي. پښتو د هند-اروپایي ژبو له کورنۍ څخه ده."
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"input": "د پښتنو دودونه څه دي؟",
|
24 |
+
"output": "پښتانه ډیر دودونه لري چې د پښتونولي په نوم پیژندل کیږي:\n- میلمه پالنه: میلمانه ته د عزت ورکول\n- ننواتې: د بخښنې غوښتل او ورکول\n- بدل: د ښې یا بدې بدله اخیستل\n- غیرت: د کورنۍ او قوم ساتنه\n- جرګه: د ستونزو د حل لپاره ټولنیز غونډې"
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"input": "څنګه کولی شم پښتو زده کړم؟",
|
28 |
+
"output": "د پښتو زده کړې لپاره:\n۱. د پښتو الفبا زده کړئ (۴۴ توري)\n۲. بنسټیزې جملې او کلمې زده کړئ\n۳. پښتو موسیقي واورئ او فلمونه وګورئ\n۴. له پښتو ویونکو سره خبرې وکړئ\n۵. پښتو کتابونه او شعرونه ولولئ\n۶. آنلاین کورسونه یا اپلیکیشنونه وکاروئ"
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"input": "د افغانستان ملي سرود په کومه ژبه دی؟",
|
32 |
+
"output": "د افغانستان ملي سرود په پښتو ژبه دی. دا سرود په ۲۰۰۶ کال کې رسمي شو او د افغانستان د بیلابیلو قومونو نومونه پکې یاد شوي دي. سرود د 'دا وطن افغانستان دی' په جمله پیل کیږي."
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"input": "پښتو شاعري څنګه ده؟",
|
36 |
+
"output": "پښتو شاعري ډیره بډایه ده:\n- لنډۍ: دوه کرښیزه لنډ شعرونه\n- غزل: د مینې او عرفان شعرونه\n- رباعي: څلور کرښیز شعرونه\n- قصیده: اوږد شعرونه\nمشهور شاعران: خوشحال خان خټک، رحمان بابا، حمزه شینواری"
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"input": "د پښتو ژبې تاریخ څه دی؟",
|
40 |
+
"output": "پښتو یوه لرغونې ژبه ده چې له ۲۵۰۰ کلونو څخه زیات تاریخ لري. د پښتو لومړنی لیکلی اثر د امیر کروړ تذکره ده (۸ پیړۍ). پښتو د ۱۹۳۶ کال راهیسې د افغانستان رسمي ژبه ده. دا ژبه عربي الفبا کاروي او ۴۴ توري لري."
|
41 |
+
}
|
42 |
+
]
|