salma-remyx commited on
Commit
76d0523
·
verified ·
1 Parent(s): 5f6674d

add train.py

Browse files
Files changed (1) hide show
  1. train.py +258 -0
train.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import random
3
+ import argparse
4
+ from dataclasses import dataclass, field
5
+ from typing import List
6
+
7
+ import torch
8
+ import wandb
9
+ from tqdm import tqdm
10
+ from PIL import Image
11
+ from datasets import load_dataset
12
+ from transformers import (
13
+ Qwen2_5_VLForConditionalGeneration,
14
+ AutoProcessor,
15
+ BitsAndBytesConfig,
16
+ )
17
+ from qwen_vl_utils import process_vision_info
18
+ from peft import LoraConfig, get_peft_model
19
+ from trl import SFTConfig, SFTTrainer
20
+
21
+
22
+ def extract_question(raw_text: str) -> str:
23
+ pattern = r"<\|start_header_id\|>user<\|end_header_id\|>\s*(.*?)\s*<\|eot_id\|>"
24
+ m = re.search(pattern, raw_text, re.DOTALL)
25
+ return m.group(1).strip() if m else raw_text.strip()
26
+
27
+ def format_data_spacethinker(sample):
28
+ system_message = {
29
+ "role": "system",
30
+ "content": [
31
+ {
32
+ "type": "text",
33
+ "text": (
34
+ "You are VL-Thinking U+1F914, a helpful assistant with excellent reasoning ability.\n"
35
+ "A user asks you a question, and you should try to solve it."
36
+ "You should first think about the reasoning process in the mind and then provides the user with the answer.\n"
37
+ "The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>."
38
+ )
39
+ }
40
+ ]
41
+ }
42
+ formatted = [system_message]
43
+
44
+ user_msg = {"role": "user", "content": []}
45
+ question = extract_question(sample.get("input", ""))
46
+ if question:
47
+ user_msg["content"].append({"type": "text", "text": question})
48
+ images = sample.get("images") or []
49
+ if images:
50
+ user_msg["content"].append({"type": "image", "image": images[0]})
51
+ formatted.append(user_msg)
52
+
53
+ if sample.get("output"):
54
+ formatted.append({
55
+ "role": "assistant",
56
+ "content": [{"type": "text", "text": sample["output"]}]
57
+ })
58
+ return formatted
59
+
60
+
61
+ def collate_fn(examples, processor):
62
+ # examples: list of formatted samples (list of message dicts)
63
+ texts = [processor.apply_chat_template(sample, tokenize=False) for sample in examples]
64
+ image_batches = [process_vision_info(sample)[0] for sample in examples]
65
+ batch = processor(text=texts, images=image_batches, return_tensors="pt", padding=True)
66
+ batch = {k: v.cpu() for k, v in batch.items()}
67
+
68
+ labels = batch["input_ids"].clone()
69
+ labels[labels == processor.tokenizer.pad_token_id] = -100
70
+
71
+ image_token_ids = (
72
+ [151652, 151653, 151655]
73
+ if hasattr(processor, "image_processor")
74
+ else [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]
75
+ )
76
+ for tid in image_token_ids:
77
+ labels[labels == tid] = -100
78
+
79
+ batch["labels"] = labels
80
+ return batch
81
+
82
+
83
+ @dataclass
84
+ class TrainingConfig:
85
+ model_id: str = "UCSC-VLAA/VLAA-Thinker-Qwen2.5VL-3B"
86
+ lora_r: int = 128
87
+ lora_alpha: int = 256
88
+ lora_dropout: float = 0.05
89
+ target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "o_proj"])
90
+ num_train_epochs: int = 3
91
+ train_batch_size: int = 1
92
+ eval_batch_size: int = 1
93
+ gradient_accumulation_steps: int = 8
94
+ learning_rate: float = 2e-5
95
+ warmup_ratio: float = 0.03
96
+ output_dir: str = "spaceom"
97
+ wandb_project: str = "spaceom"
98
+ wandb_run_name: str = "spaceom"
99
+
100
+
101
+ def parse_args() -> TrainingConfig:
102
+ default_cfg = TrainingConfig()
103
+ parser = argparse.ArgumentParser(description="Train a VL Spacethinker model with LoRA")
104
+ parser.add_argument("--model_id", default=default_cfg.model_id)
105
+ parser.add_argument("--lora_r", type=int, default=default_cfg.lora_r)
106
+ parser.add_argument("--lora_alpha", type=int, default=default_cfg.lora_alpha)
107
+ parser.add_argument("--lora_dropout", type=float, default=default_cfg.lora_dropout)
108
+ parser.add_argument(
109
+ "--target_modules",
110
+ default=','.join(default_cfg.target_modules),
111
+ help="Comma-separated list of target modules for LoRA"
112
+ )
113
+ parser.add_argument("--num_train_epochs", type=int, default=default_cfg.num_train_epochs)
114
+ parser.add_argument("--train_batch_size", type=int, default=default_cfg.train_batch_size)
115
+ parser.add_argument("--eval_batch_size", type=int, default=default_cfg.eval_batch_size)
116
+ parser.add_argument(
117
+ "--gradient_accumulation_steps", type=int, default=default_cfg.gradient_accumulation_steps
118
+ )
119
+ parser.add_argument("--learning_rate", type=float, default=default_cfg.learning_rate)
120
+ parser.add_argument("--warmup_ratio", type=float, default=default_cfg.warmup_ratio)
121
+ parser.add_argument("--output_dir", default=default_cfg.output_dir)
122
+ parser.add_argument("--wandb_project", default=default_cfg.wandb_project)
123
+ parser.add_argument("--wandb_run_name", default=default_cfg.wandb_run_name)
124
+
125
+ args = parser.parse_args()
126
+ return TrainingConfig(
127
+ model_id=args.model_id,
128
+ lora_r=args.lora_r,
129
+ lora_alpha=args.lora_alpha,
130
+ lora_dropout=args.lora_dropout,
131
+ target_modules=args.target_modules.split(","),
132
+ num_train_epochs=args.num_train_epochs,
133
+ train_batch_size=args.train_batch_size,
134
+ eval_batch_size=args.eval_batch_size,
135
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
136
+ learning_rate=args.learning_rate,
137
+ warmup_ratio=args.warmup_ratio,
138
+ output_dir=args.output_dir,
139
+ wandb_project=args.wandb_project,
140
+ wandb_run_name=args.wandb_run_name,
141
+ )
142
+
143
+
144
+ def prepare_datasets(cfg: TrainingConfig):
145
+ print(f"Loading dataset: SpaceThinker")
146
+ raw_train_spacethinker = load_dataset("remyxai/SpaceThinker", split="train")
147
+ raw_eval_spacethinker = load_dataset("remyxai/SpaceThinker", split="test")
148
+
149
+ print(f"Loading dataset: SpaceOm")
150
+ raw_train_spaceom = load_dataset("remyxai/SpaceOm", split="train")
151
+ raw_eval_spaceom = load_dataset("remyxai/SpaceOm", split="test")
152
+
153
+ print(f"Loading dataset: Robo2VLM")
154
+ raw_train_robo2vlm = load_dataset("remyxai/Robo2VLM-Reasoning", split="train")
155
+ raw_eval_robo2vlm = load_dataset("remyxai/Robo2VLM-Reasoning", split="test")
156
+
157
+ print("Formatting train samples…")
158
+ train_ds_spacethinker = [format_data_spacethinker(s) for s in tqdm(raw_train_spacethinker, desc="Train")]
159
+ train_ds_spaceom = [format_data_spacethinker(s) for s in tqdm(raw_train_spaceom, desc="Train")]
160
+ train_ds_robo2vlm = [format_data_spacethinker(s) for s in tqdm(raw_train_robo2vlm, desc="Train")]
161
+ print("Formatting eval samples…")
162
+ eval_ds_spacethinker = [format_data_spacethinker(s) for s in tqdm(raw_eval_spacethinker, desc="Eval")]
163
+ eval_ds_spaceom = [format_data_spacethinker(s) for s in tqdm(raw_eval_spaceom, desc="Eval")]
164
+ eval_ds_robo2vlm = [format_data_spacethinker(s) for s in tqdm(raw_eval_robo2vlm, desc="Eval")]
165
+
166
+ train_ds = train_ds_spacethinker + train_ds_spaceom + train_ds_robo2vlm
167
+ eval_ds = eval_ds_spacethinker + eval_ds_spaceom + eval_ds_robo2vlm
168
+ random.shuffle(train_ds)
169
+ random.shuffle(eval_ds)
170
+
171
+ return train_ds, eval_ds
172
+
173
+
174
+ def prepare_model_and_optimizer(cfg: TrainingConfig):
175
+ bnb = BitsAndBytesConfig(
176
+ load_in_4bit=True,
177
+ bnb_4bit_use_double_quant=True,
178
+ bnb_4bit_quant_type="nf4",
179
+ bnb_4bit_compute_dtype=torch.bfloat16
180
+ )
181
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
182
+ cfg.model_id,
183
+ device_map="auto",
184
+ torch_dtype=torch.bfloat16,
185
+ quantization_config=bnb
186
+ )
187
+ processor = AutoProcessor.from_pretrained(cfg.model_id)
188
+
189
+ peft_cfg = LoraConfig(
190
+ r=cfg.lora_r,
191
+ lora_alpha=cfg.lora_alpha,
192
+ lora_dropout=cfg.lora_dropout,
193
+ bias="none",
194
+ target_modules=cfg.target_modules,
195
+ task_type="CAUSAL_LM",
196
+ )
197
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
198
+ peft_model = get_peft_model(model, peft_cfg).to(device)
199
+ peft_model.print_trainable_parameters()
200
+ return peft_model, processor, peft_cfg
201
+
202
+
203
+ def main():
204
+ cfg = parse_args()
205
+ train_ds, eval_ds = prepare_datasets(cfg)
206
+ model, processor, peft_cfg = prepare_model_and_optimizer(cfg)
207
+
208
+ sft_args = SFTConfig(
209
+ output_dir=cfg.output_dir,
210
+ num_train_epochs=cfg.num_train_epochs,
211
+ per_device_train_batch_size=cfg.train_batch_size,
212
+ per_device_eval_batch_size=cfg.eval_batch_size,
213
+ gradient_accumulation_steps=cfg.gradient_accumulation_steps,
214
+ gradient_checkpointing=True,
215
+ optim="adamw_torch_fused",
216
+ learning_rate=cfg.learning_rate,
217
+ lr_scheduler_type="constant",
218
+ logging_steps=10,
219
+ eval_steps=10,
220
+ eval_strategy="steps",
221
+ save_strategy="steps",
222
+ save_steps=20,
223
+ metric_for_best_model="eval_loss",
224
+ greater_is_better=False,
225
+ load_best_model_at_end=True,
226
+ bf16=True,
227
+ tf32=True,
228
+ max_grad_norm=0.3,
229
+ warmup_ratio=cfg.warmup_ratio,
230
+ gradient_checkpointing_kwargs={"use_reentrant": False},
231
+ push_to_hub=True,
232
+ report_to="wandb",
233
+ dataset_kwargs={"skip_prepare_dataset": True},
234
+ )
235
+ sft_args.remove_unused_columns = False
236
+
237
+ wandb.init(
238
+ project=cfg.wandb_project,
239
+ name=cfg.wandb_run_name,
240
+ config=sft_args,
241
+ )
242
+
243
+ trainer = SFTTrainer(
244
+ model=model,
245
+ args=sft_args,
246
+ train_dataset=train_ds,
247
+ eval_dataset=eval_ds,
248
+ data_collator=lambda ex: collate_fn(ex, processor),
249
+ peft_config=peft_cfg,
250
+ tokenizer=processor.tokenizer,
251
+ )
252
+
253
+ trainer.train()
254
+ trainer.save_model(cfg.output_dir)
255
+
256
+
257
+ if __name__ == "__main__":
258
+ main()