tasal9 commited on
Commit
7714053
·
1 Parent(s): d0bd165

Add dataset format checking script, training configuration, and dataset files for ZamAI-Mistral-7B-Pashto

Browse files
check_dataset_format.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Script to check if a dataset is properly formatted for AutoTrain
6
+ """
7
+
8
+ import argparse
9
+ from datasets import load_dataset
10
+ from collections import Counter
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser(description="Check dataset format for AutoTrain compatibility")
14
+ parser.add_argument(
15
+ "--dataset_path",
16
+ type=str,
17
+ required=True,
18
+ help="Path to dataset (local or HF Hub)"
19
+ )
20
+ parser.add_argument(
21
+ "--split",
22
+ type=str,
23
+ default="train",
24
+ help="Split to check (train/test/validation)"
25
+ )
26
+ return parser.parse_args()
27
+
28
+ def check_dataset_format(dataset):
29
+ """Check if dataset is properly formatted"""
30
+ print("\n=== Dataset Format Check ===")
31
+
32
+ # Check columns
33
+ columns = dataset.column_names
34
+ print(f"Columns found: {columns}")
35
+
36
+ # Check for expected columns
37
+ text_format = "text" in columns
38
+ instruction_format = "instruction" in columns and "response" in columns
39
+
40
+ if text_format:
41
+ print("✅ Dataset is in TEXT format (single 'text' column)")
42
+ elif instruction_format:
43
+ print("✅ Dataset is in INSTRUCTION-RESPONSE format")
44
+ else:
45
+ print("❌ Dataset format not recognized. Expected either:")
46
+ print(" - 'text' column for text generation")
47
+ print(" - 'instruction' and 'response' columns for instruction tuning")
48
+ return False
49
+
50
+ # Check dataset size
51
+ print(f"\nDataset size: {len(dataset)} examples")
52
+
53
+ # Sample some examples
54
+ print("\n=== Sample Examples ===")
55
+ for i in range(min(3, len(dataset))):
56
+ print(f"\nExample {i+1}:")
57
+ example = dataset[i]
58
+
59
+ if text_format:
60
+ text = example['text']
61
+ print(f"Text length: {len(text)} characters")
62
+ print(f"Preview: {text[:200]}...")
63
+ elif instruction_format:
64
+ instruction = example['instruction']
65
+ response = example['response']
66
+ print(f"Instruction: {instruction[:100]}...")
67
+ print(f"Response: {response[:100]}...")
68
+
69
+ # Check for empty values
70
+ print("\n=== Data Quality Check ===")
71
+ empty_count = 0
72
+ short_count = 0
73
+
74
+ for example in dataset:
75
+ if text_format:
76
+ if not example['text'] or len(example['text'].strip()) == 0:
77
+ empty_count += 1
78
+ elif len(example['text']) < 10:
79
+ short_count += 1
80
+ elif instruction_format:
81
+ if not example['instruction'] or not example['response']:
82
+ empty_count += 1
83
+ elif len(example['instruction']) < 5 or len(example['response']) < 5:
84
+ short_count += 1
85
+
86
+ print(f"Empty examples: {empty_count}")
87
+ print(f"Very short examples: {short_count}")
88
+
89
+ if empty_count > 0:
90
+ print("⚠️ Warning: Found empty examples. Consider removing them.")
91
+ if short_count > len(dataset) * 0.1:
92
+ print("⚠️ Warning: Many short examples found. This might affect training quality.")
93
+
94
+ # Language detection (simple check for Pashto characters)
95
+ print("\n=== Language Check ===")
96
+ pashto_count = 0
97
+ sample_size = min(100, len(dataset))
98
+
99
+ for i in range(sample_size):
100
+ example = dataset[i]
101
+ text = example.get('text', '') or example.get('instruction', '') + example.get('response', '')
102
+ # Check for Pashto/Arabic script characters
103
+ if any('\u0600' <= c <= '\u06FF' for c in text):
104
+ pashto_count += 1
105
+
106
+ pashto_percentage = (pashto_count / sample_size) * 100
107
+ print(f"Examples with Pashto/Arabic script: {pashto_percentage:.1f}%")
108
+
109
+ if pashto_percentage < 50:
110
+ print("⚠️ Warning: Less than 50% of samples contain Pashto script.")
111
+
112
+ return True
113
+
114
+ def main():
115
+ args = parse_args()
116
+
117
+ print(f"Loading dataset from: {args.dataset_path}")
118
+
119
+ try:
120
+ # Load dataset
121
+ dataset = load_dataset(args.dataset_path, split=args.split)
122
+
123
+ # Check format
124
+ is_valid = check_dataset_format(dataset)
125
+
126
+ if is_valid:
127
+ print("\n✅ Dataset is ready for AutoTrain!")
128
+ else:
129
+ print("\n❌ Dataset needs formatting adjustments.")
130
+
131
+ except Exception as e:
132
+ print(f"Error loading dataset: {e}")
133
+ print("\nMake sure the dataset path is correct and accessible.")
134
+
135
+ if __name__ == "__main__":
136
+ main()
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "mistral",
3
+ "architectures": ["MistralForCausalLM"],
4
+ "hidden_size": 4096,
5
+ "intermediate_size": 14336,
6
+ "num_hidden_layers": 32,
7
+ "num_attention_heads": 32,
8
+ "num_key_value_heads": 8,
9
+ "hidden_act": "silu",
10
+ "max_position_embeddings": 32768,
11
+ "initializer_range": 0.02,
12
+ "rms_norm_eps": 1e-05,
13
+ "use_cache": true,
14
+ "pad_token_id": null,
15
+ "bos_token_id": 1,
16
+ "eos_token_id": 2,
17
+ "tie_word_embeddings": false,
18
+ "rope_theta": 10000.0,
19
+ "sliding_window": 4096,
20
+ "attention_dropout": 0.0,
21
+ "torch_dtype": "float16",
22
+ "transformers_version": "4.34.0"
23
+ }
prepare_for_autotrain.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Script to prepare Pashto dataset for AutoTrain fine-tuning
6
+ Converts data to the format expected by AutoTrain and uploads to HF Hub
7
+ """
8
+
9
+ import argparse
10
+ import json
11
+ import os
12
+ from datasets import Dataset, DatasetDict
13
+ from huggingface_hub import HfApi, create_repo
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(description="Prepare dataset for AutoTrain")
17
+ parser.add_argument(
18
+ "--input_file",
19
+ type=str,
20
+ default="zamai_pashto_dataset.json",
21
+ help="Input JSON file with the dataset"
22
+ )
23
+ parser.add_argument(
24
+ "--format_type",
25
+ type=str,
26
+ choices=["text", "instruction-response"],
27
+ default="instruction-response",
28
+ help="Format type for the dataset"
29
+ )
30
+ parser.add_argument(
31
+ "--train_split",
32
+ type=float,
33
+ default=0.9,
34
+ help="Proportion of data for training (rest goes to test)"
35
+ )
36
+ parser.add_argument(
37
+ "--push_to_hub",
38
+ action="store_true",
39
+ help="Push dataset to Hugging Face Hub"
40
+ )
41
+ parser.add_argument(
42
+ "--hub_dataset_name",
43
+ type=str,
44
+ default="ZamAI_Pashto_Training",
45
+ help="Name for the dataset on HF Hub"
46
+ )
47
+ parser.add_argument(
48
+ "--hf_token",
49
+ type=str,
50
+ help="Hugging Face API token"
51
+ )
52
+ return parser.parse_args()
53
+
54
+ def format_for_text_generation(data, format_type):
55
+ """Format data for text generation training"""
56
+ formatted_data = []
57
+
58
+ for item in data:
59
+ if format_type == "text":
60
+ # Simple text format - concatenate input and output
61
+ if 'input' in item and 'output' in item:
62
+ text = f"{item['input']}\n\n{item['output']}"
63
+ formatted_data.append({"text": text})
64
+
65
+ elif format_type == "instruction-response":
66
+ # Instruction-response format
67
+ if 'input' in item and 'output' in item:
68
+ formatted_data.append({
69
+ "instruction": item['input'],
70
+ "response": item['output']
71
+ })
72
+
73
+ return formatted_data
74
+
75
+ def main():
76
+ args = parse_args()
77
+
78
+ # Load the dataset
79
+ print(f"Loading dataset from {args.input_file}")
80
+ with open(args.input_file, 'r', encoding='utf-8') as f:
81
+ data = json.load(f)
82
+
83
+ print(f"Loaded {len(data)} examples")
84
+
85
+ # Format the data
86
+ print(f"Formatting data as {args.format_type}")
87
+ formatted_data = format_for_text_generation(data, args.format_type)
88
+
89
+ # Split into train/test
90
+ split_idx = int(len(formatted_data) * args.train_split)
91
+ train_data = formatted_data[:split_idx]
92
+ test_data = formatted_data[split_idx:]
93
+
94
+ print(f"Train examples: {len(train_data)}")
95
+ print(f"Test examples: {len(test_data)}")
96
+
97
+ # Create datasets
98
+ train_dataset = Dataset.from_list(train_data)
99
+ test_dataset = Dataset.from_list(test_data)
100
+
101
+ # Create DatasetDict
102
+ dataset_dict = DatasetDict({
103
+ "train": train_dataset,
104
+ "test": test_dataset
105
+ })
106
+
107
+ # Save locally
108
+ local_path = f"{args.hub_dataset_name}_local"
109
+ dataset_dict.save_to_disk(local_path)
110
+ print(f"Dataset saved locally to {local_path}")
111
+
112
+ # Push to Hub if requested
113
+ if args.push_to_hub:
114
+ if not args.hf_token:
115
+ print("Error: HF token required for pushing to Hub")
116
+ return
117
+
118
+ print(f"Pushing dataset to Hub as tasal9/{args.hub_dataset_name}")
119
+
120
+ # Create repo if needed
121
+ api = HfApi(token=args.hf_token)
122
+ repo_id = f"tasal9/{args.hub_dataset_name}"
123
+
124
+ try:
125
+ create_repo(
126
+ repo_id=repo_id,
127
+ token=args.hf_token,
128
+ repo_type="dataset",
129
+ exist_ok=True
130
+ )
131
+ except Exception as e:
132
+ print(f"Note: {e}")
133
+
134
+ # Push dataset
135
+ dataset_dict.push_to_hub(
136
+ repo_id,
137
+ token=args.hf_token,
138
+ commit_message="Upload Pashto training dataset for AutoTrain"
139
+ )
140
+
141
+ print(f"Dataset uploaded to https://huggingface.co/datasets/{repo_id}")
142
+
143
+ # Print sample
144
+ print("\nSample from training data:")
145
+ print(train_dataset[0])
146
+
147
+ if __name__ == "__main__":
148
+ main()
requirements.txt CHANGED
@@ -43,3 +43,7 @@ ipykernel>=6.25.0
43
  black>=23.9.0
44
  isort>=5.12.0
45
  pylint>=2.17.0
 
 
 
 
 
43
  black>=23.9.0
44
  isort>=5.12.0
45
  pylint>=2.17.0
46
+
47
+ # Additional training dependencies
48
+ wandb>=0.15.0
49
+ pyyaml>=6.0
tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "additional_special_tokens": [],
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "legacy": true,
35
+ "model_max_length": 32768,
36
+ "pad_token": null,
37
+ "sp_model_kwargs": {},
38
+ "spaces_between_special_tokens": false,
39
+ "tokenizer_class": "LlamaTokenizer",
40
+ "unk_token": "<unk>",
41
+ "use_default_system_prompt": true,
42
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
43
+ }
train_model.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Comprehensive training script for ZamAI-Mistral-7B-Pashto
6
+ Supports both local training and AutoTrain
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+ import yaml
12
+ import torch
13
+ from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoTokenizer,
16
+ TrainingArguments,
17
+ Trainer,
18
+ DataCollatorForLanguageModeling,
19
+ BitsAndBytesConfig
20
+ )
21
+ from datasets import load_dataset
22
+ from peft import LoraConfig, get_peft_model, TaskType
23
+ from peft.utils import prepare_model_for_kbit_training
24
+ try:
25
+ import wandb
26
+ except ImportError:
27
+ wandb = None
28
+
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser(description="Train ZamAI-Mistral-7B-Pashto")
31
+ parser.add_argument(
32
+ "--config",
33
+ type=str,
34
+ default="training_config.yaml",
35
+ help="Path to training configuration file"
36
+ )
37
+ parser.add_argument(
38
+ "--local",
39
+ action="store_true",
40
+ help="Run training locally instead of using AutoTrain"
41
+ )
42
+ parser.add_argument(
43
+ "--dataset_path",
44
+ type=str,
45
+ help="Override dataset path from config"
46
+ )
47
+ parser.add_argument(
48
+ "--output_dir",
49
+ type=str,
50
+ help="Override output directory from config"
51
+ )
52
+ parser.add_argument(
53
+ "--hf_token",
54
+ type=str,
55
+ help="Hugging Face API token"
56
+ )
57
+ return parser.parse_args()
58
+
59
+ def load_config(config_path):
60
+ """Load training configuration from YAML file"""
61
+ with open(config_path, 'r') as f:
62
+ config = yaml.safe_load(f)
63
+ return config
64
+
65
+ def prepare_model_and_tokenizer(config):
66
+ """Prepare model and tokenizer with quantization and LoRA"""
67
+
68
+ # Quantization config
69
+ bnb_config = BitsAndBytesConfig(
70
+ load_in_4bit=config['model'].get('load_in_4bit', True),
71
+ bnb_4bit_use_double_quant=True,
72
+ bnb_4bit_quant_type="nf4",
73
+ bnb_4bit_compute_dtype=torch.bfloat16
74
+ )
75
+
76
+ # Load model
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ config['model']['name'],
79
+ quantization_config=bnb_config,
80
+ device_map="auto",
81
+ trust_remote_code=True,
82
+ use_flash_attention_2=config['model'].get('use_flash_attention_2', True)
83
+ )
84
+
85
+ # Load tokenizer
86
+ tokenizer = AutoTokenizer.from_pretrained(
87
+ config['model']['name'],
88
+ trust_remote_code=True
89
+ )
90
+ tokenizer.pad_token = tokenizer.eos_token
91
+ tokenizer.padding_side = "right"
92
+
93
+ # Prepare model for k-bit training
94
+ model = prepare_model_for_kbit_training(model)
95
+
96
+ # LoRA configuration
97
+ lora_config = LoraConfig(
98
+ r=config['lora']['r'],
99
+ lora_alpha=config['lora']['lora_alpha'],
100
+ target_modules=config['lora']['target_modules'],
101
+ lora_dropout=config['lora']['lora_dropout'],
102
+ bias=config['lora']['bias'],
103
+ task_type=config['lora']['task_type']
104
+ )
105
+
106
+ # Apply LoRA
107
+ model = get_peft_model(model, lora_config)
108
+ model.print_trainable_parameters()
109
+
110
+ return model, tokenizer
111
+
112
+ def prepare_dataset(config, tokenizer):
113
+ """Load and prepare dataset for training"""
114
+
115
+ # Load train and validation splits separately
116
+ train_dataset = load_dataset(
117
+ config['dataset']['name'],
118
+ split=config['dataset']['train_split']
119
+ )
120
+ validation_dataset = load_dataset(
121
+ config['dataset']['name'],
122
+ split=config['dataset']['validation_split']
123
+ )
124
+
125
+ # Tokenization function
126
+ def tokenize_function(examples):
127
+ # Handle different dataset formats
128
+ if config['dataset']['text_column'] in examples:
129
+ texts = examples[config['dataset']['text_column']]
130
+ else:
131
+ # For instruction-response format
132
+ texts = [
133
+ f"[INST] {inst} [/INST] {resp}"
134
+ for inst, resp in zip(examples['instruction'], examples['response'])
135
+ ]
136
+
137
+ return tokenizer(
138
+ texts,
139
+ truncation=True,
140
+ padding="max_length",
141
+ max_length=config['dataset']['max_seq_length']
142
+ )
143
+
144
+ # Tokenize datasets
145
+ tokenized_train = train_dataset.map(
146
+ tokenize_function,
147
+ batched=True,
148
+ remove_columns=train_dataset.column_names
149
+ )
150
+ tokenized_validation = validation_dataset.map(
151
+ tokenize_function,
152
+ batched=True,
153
+ remove_columns=validation_dataset.column_names
154
+ )
155
+
156
+ # Return datasets directly
157
+ return tokenized_train, tokenized_validation
158
+
159
+ def main():
160
+ args = parse_args()
161
+
162
+ # Load configuration
163
+ config = load_config(args.config)
164
+
165
+ # Override config with command line arguments
166
+ if args.dataset_path:
167
+ config['dataset']['name'] = args.dataset_path
168
+ if args.output_dir:
169
+ config['output']['output_dir'] = args.output_dir
170
+ if args.hf_token:
171
+ config['hub']['hub_token'] = args.hf_token
172
+
173
+ # Set up wandb if configured
174
+ if 'wandb' in config['advanced'].get('report_to', []):
175
+ wandb.init(
176
+ project="zamai-mistral-pashto",
177
+ name=f"mistral-7b-pashto-{config['lora']['r']}r",
178
+ config=config
179
+ )
180
+
181
+ print("Loading model and tokenizer...")
182
+ model, tokenizer = prepare_model_and_tokenizer(config)
183
+
184
+ print("Preparing dataset...")
185
+ train_dataset, validation_dataset = prepare_dataset(config, tokenizer)
186
+
187
+ # Data collator
188
+ data_collator = DataCollatorForLanguageModeling(
189
+ tokenizer=tokenizer,
190
+ mlm=False
191
+ )
192
+
193
+ # Training arguments
194
+ training_args = TrainingArguments(
195
+ output_dir=config['output']['output_dir'],
196
+ num_train_epochs=config['training']['num_train_epochs'],
197
+ per_device_train_batch_size=config['training']['per_device_train_batch_size'],
198
+ per_device_eval_batch_size=config['training']['per_device_eval_batch_size'],
199
+ gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
200
+ gradient_checkpointing=config['training']['gradient_checkpointing'],
201
+ learning_rate=config['training']['learning_rate'],
202
+ lr_scheduler_type=config['training']['lr_scheduler_type'],
203
+ warmup_ratio=config['training']['warmup_ratio'],
204
+ weight_decay=config['training']['weight_decay'],
205
+ max_grad_norm=config['training']['max_grad_norm'],
206
+ optim=config['optimization']['optim'],
207
+ fp16=config['optimization']['fp16'],
208
+ bf16=config['optimization']['bf16'],
209
+ tf32=config['optimization']['tf32'],
210
+ logging_dir=config['output']['logging_dir'],
211
+ logging_steps=config['output']['logging_steps'],
212
+ save_steps=config['output']['save_steps'],
213
+ save_total_limit=config['output']['save_total_limit'],
214
+ evaluation_strategy=config['output']['evaluation_strategy'],
215
+ eval_steps=config['output']['eval_steps'],
216
+ save_strategy=config['output']['save_strategy'],
217
+ load_best_model_at_end=config['output']['load_best_model_at_end'],
218
+ metric_for_best_model=config['output']['metric_for_best_model'],
219
+ greater_is_better=config['output']['greater_is_better'],
220
+ push_to_hub=config['hub']['push_to_hub'],
221
+ hub_model_id=config['hub']['hub_model_id'],
222
+ hub_strategy=config['hub']['hub_strategy'],
223
+ hub_token=config['hub']['hub_token'],
224
+ seed=config['advanced']['seed'],
225
+ data_seed=config['advanced']['data_seed'],
226
+ report_to=config['advanced']['report_to'],
227
+ remove_unused_columns=config['advanced']['remove_unused_columns']
228
+ )
229
+
230
+ # Initialize trainer
231
+ trainer = Trainer(
232
+ model=model,
233
+ args=training_args,
234
+ train_dataset=train_dataset,
235
+ eval_dataset=validation_dataset,
236
+ data_collator=data_collator
237
+ )
238
+
239
+ # Train
240
+ print("Starting training...")
241
+ trainer.train()
242
+
243
+ # Save the final model
244
+ print("Saving final model...")
245
+ trainer.save_model()
246
+
247
+ # Push to hub if configured
248
+ if config['hub']['push_to_hub']:
249
+ print("Pushing to Hugging Face Hub...")
250
+ trainer.push_to_hub()
251
+
252
+ print("Training completed successfully!")
253
+
254
+ # Test the model
255
+ print("\nTesting the model with a sample prompt...")
256
+ test_prompt = "ستاسو نوم څه دی؟"
257
+ inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)
258
+
259
+ with torch.no_grad():
260
+ outputs = model.generate(
261
+ **inputs,
262
+ max_new_tokens=config['inference']['max_new_tokens'],
263
+ temperature=config['inference']['temperature'],
264
+ top_p=config['inference']['top_p'],
265
+ do_sample=config['inference']['do_sample']
266
+ )
267
+
268
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
269
+ print(f"Prompt: {test_prompt}")
270
+ print(f"Response: {response}")
271
+
272
+ if __name__ == "__main__":
273
+ main()
training_config.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ZamAI-Mistral-7B-Pashto Training Configuration
2
+
3
+ # Model Configuration
4
+ model:
5
+ name: "mistralai/Mistral-7B-Instruct-v0.1"
6
+ type: "causal-lm"
7
+ use_flash_attention_2: true
8
+ load_in_8bit: false
9
+ load_in_4bit: true # Use 4-bit quantization for efficiency
10
+
11
+ # Dataset Configuration
12
+ dataset:
13
+ name: "tasal9/ZamAI_Pashto_Training"
14
+ text_column: "text" # or "instruction" for instruction-response format
15
+ train_split: "train"
16
+ validation_split: "test"
17
+ max_seq_length: 2048
18
+
19
+ # Training Parameters
20
+ training:
21
+ num_train_epochs: 3
22
+ per_device_train_batch_size: 4
23
+ per_device_eval_batch_size: 4
24
+ gradient_accumulation_steps: 4
25
+ gradient_checkpointing: true
26
+ learning_rate: 2e-4
27
+ lr_scheduler_type: "cosine"
28
+ warmup_ratio: 0.03
29
+ weight_decay: 0.001
30
+ max_grad_norm: 0.3
31
+
32
+ # LoRA Configuration
33
+ lora:
34
+ r: 16 # LoRA attention dimension
35
+ lora_alpha: 32 # LoRA scaling parameter
36
+ lora_dropout: 0.05 # LoRA dropout
37
+ bias: "none"
38
+ task_type: "CAUSAL_LM"
39
+ target_modules:
40
+ - "q_proj"
41
+ - "k_proj"
42
+ - "v_proj"
43
+ - "o_proj"
44
+ - "gate_proj"
45
+ - "up_proj"
46
+ - "down_proj"
47
+ - "lm_head"
48
+
49
+ # Optimization
50
+ optimization:
51
+ optim: "paged_adamw_32bit"
52
+ fp16: false
53
+ bf16: true # Use bfloat16 for better stability
54
+ tf32: true
55
+
56
+ # Logging and Saving
57
+ output:
58
+ output_dir: "./results"
59
+ logging_dir: "./logs"
60
+ logging_steps: 10
61
+ save_steps: 500
62
+ save_total_limit: 3
63
+ evaluation_strategy: "steps"
64
+ eval_steps: 500
65
+ save_strategy: "steps"
66
+ load_best_model_at_end: true
67
+ metric_for_best_model: "eval_loss"
68
+ greater_is_better: false
69
+
70
+ # Hub Configuration
71
+ hub:
72
+ push_to_hub: true
73
+ hub_model_id: "tasal9/ZamAI-Mistral-7B-Pashto"
74
+ hub_strategy: "every_save"
75
+ hub_token: "YOUR_HF_TOKEN" # Replace with actual token
76
+
77
+ # Advanced Settings
78
+ advanced:
79
+ seed: 42
80
+ data_seed: 42
81
+ dataloader_num_workers: 4
82
+ remove_unused_columns: false
83
+ label_names: ["labels"]
84
+ report_to: ["tensorboard", "wandb"] # Optional: remove wandb if not using
85
+
86
+ # Inference Configuration (for testing)
87
+ inference:
88
+ max_new_tokens: 256
89
+ temperature: 0.7
90
+ top_p: 0.9
91
+ top_k: 50
92
+ do_sample: true
93
+ repetition_penalty: 1.1
zamai_pashto_dataset.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "input": "ستاسو نوم څه دی؟",
4
+ "output": "زما نوم ZamAI دی. زه د مصنوعی ذکاوت پر بنسټ جوړ شوی یو ژبنی مودل یم چې په پښتو ژبه کې مرسته کولی شم."
5
+ },
6
+ {
7
+ "input": "د افغانستان د کلتور په اړه ماته معلومات راکړئ.",
8
+ "output": "د افغانستان کلتور ډیر غني او بډای دی. افغانستان د بیلابیلو قومونو، ژبو او دودونو هیواد دی. اصلي قومونه پښتانه، تاجک، هزاره، ازبک او نور دي. دا هیواد ډیرې دودیزې موسیقۍ، خواړه، لباسونه او هنرونه لري."
9
+ },
10
+ {
11
+ "input": "په پښتو کې سلام څنګه کوو؟",
12
+ "output": "په پښتو کې د سلام کولو څو لارې دي:\n- السلام علیکم (رسمي)\n- سلام ورور/خور (غیر رسمي)\n- ستړی مه شې (د ورځې په هر وخت)\n- په خیر راغلاست (د هرکلي لپاره)"
13
+ },
14
+ {
15
+ "input": "د کابل په اړه معلومات راکړئ.",
16
+ "output": "کابل د افغانستان پلازمېنه او تر ټولو لوی ښار دی. دا ښار د هندوکش غرونو په منځ کې موقعیت لري او نږدې ۴.۵ میلیونه وګړي لري. کابل یو تاریخي ښار دی چې له ۳۵۰۰ کلونو څخه زیات تاریخ لري."
17
+ },
18
+ {
19
+ "input": "پښتو ژبه څومره خلک خبرې کوي؟",
20
+ "output": "پښتو د نړۍ په کچه نږدې ۵۰-۶۰ میلیونه خلک خبرې کوي. دا د افغانستان دوه رسمي ژبو څخه یوه ده او په پاکستان کې هم په پراخه کچه ویل کیږي. پښتو د هند-اروپایي ژبو له کورنۍ څخه ده."
21
+ },
22
+ {
23
+ "input": "د پښتنو دودونه څه دي؟",
24
+ "output": "پښتانه ډیر دودونه لري چې د پښتونولي په نوم پیژندل کیږي:\n- میلمه پالنه: میلمانه ته د عزت ورکول\n- ننواتې: د بخښنې غوښتل او ورکول\n- بدل: د ښې یا بدې بدله اخیستل\n- غیرت: د کورنۍ او قوم ساتنه\n- جرګه: د ستونزو د حل لپاره ټولنیز غونډې"
25
+ },
26
+ {
27
+ "input": "څنګه کولی شم پښتو زده کړم؟",
28
+ "output": "د پښتو زده کړې لپاره:\n۱. د پښتو الفبا زده کړئ (۴۴ توري)\n۲. بنسټیزې جملې او کلمې زده کړئ\n۳. پښتو موسیقي واورئ او فلمونه وګورئ\n۴. له پښتو ویونکو سره خبرې وکړئ\n۵. پښتو کتابونه او شعرونه ولولئ\n۶. آنلاین کورسونه یا اپلیکیشنونه وکاروئ"
29
+ },
30
+ {
31
+ "input": "د افغانستان ملي سرود په کومه ژبه دی؟",
32
+ "output": "د افغانستان ملي سرود په پښتو ژبه دی. دا سرود په ۲۰۰۶ کال کې رسمي شو او د افغانستان د بیلابیلو قومونو نومونه پکې یاد شوي دي. سرود د 'دا وطن افغانستان دی' په جمله پیل کیږي."
33
+ },
34
+ {
35
+ "input": "پښتو شاعري څنګه ده؟",
36
+ "output": "پښتو شاعري ډیره بډایه ده:\n- لنډۍ: دوه کرښیزه لنډ شعرونه\n- غزل: د مینې او عرفان شعرونه\n- رباعي: څلور کرښیز شعرونه\n- قصیده: اوږد شعرونه\nمشهور شاعران: خوشحال خان خټک، رحمان بابا، حمزه شینواری"
37
+ },
38
+ {
39
+ "input": "د پښتو ژبې تاریخ څه دی؟",
40
+ "output": "پښتو یوه لرغونې ژبه ده چې له ۲۵۰۰ کلونو څخه زیات تاریخ لري. د پښتو لومړنی لیکلی اثر د امیر کروړ تذکره ده (۸ پیړۍ). پښتو د ۱۹۳۶ کال راهیسې د افغانستان رسمي ژبه ده. دا ژبه عربي الفبا کاروي او ۴۴ توري لري."
41
+ }
42
+ ]