add train.py
Browse files
train.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import random
|
3 |
+
import argparse
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import wandb
|
9 |
+
from tqdm import tqdm
|
10 |
+
from PIL import Image
|
11 |
+
from datasets import load_dataset
|
12 |
+
from transformers import (
|
13 |
+
Qwen2_5_VLForConditionalGeneration,
|
14 |
+
AutoProcessor,
|
15 |
+
BitsAndBytesConfig,
|
16 |
+
)
|
17 |
+
from qwen_vl_utils import process_vision_info
|
18 |
+
from peft import LoraConfig, get_peft_model
|
19 |
+
from trl import SFTConfig, SFTTrainer
|
20 |
+
|
21 |
+
|
22 |
+
def extract_question(raw_text: str) -> str:
|
23 |
+
pattern = r"<\|start_header_id\|>user<\|end_header_id\|>\s*(.*?)\s*<\|eot_id\|>"
|
24 |
+
m = re.search(pattern, raw_text, re.DOTALL)
|
25 |
+
return m.group(1).strip() if m else raw_text.strip()
|
26 |
+
|
27 |
+
def format_data_spacethinker(sample):
|
28 |
+
system_message = {
|
29 |
+
"role": "system",
|
30 |
+
"content": [
|
31 |
+
{
|
32 |
+
"type": "text",
|
33 |
+
"text": (
|
34 |
+
"You are VL-Thinking U+1F914, a helpful assistant with excellent reasoning ability.\n"
|
35 |
+
"A user asks you a question, and you should try to solve it."
|
36 |
+
"You should first think about the reasoning process in the mind and then provides the user with the answer.\n"
|
37 |
+
"The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>."
|
38 |
+
)
|
39 |
+
}
|
40 |
+
]
|
41 |
+
}
|
42 |
+
formatted = [system_message]
|
43 |
+
|
44 |
+
user_msg = {"role": "user", "content": []}
|
45 |
+
question = extract_question(sample.get("input", ""))
|
46 |
+
if question:
|
47 |
+
user_msg["content"].append({"type": "text", "text": question})
|
48 |
+
images = sample.get("images") or []
|
49 |
+
if images:
|
50 |
+
user_msg["content"].append({"type": "image", "image": images[0]})
|
51 |
+
formatted.append(user_msg)
|
52 |
+
|
53 |
+
if sample.get("output"):
|
54 |
+
formatted.append({
|
55 |
+
"role": "assistant",
|
56 |
+
"content": [{"type": "text", "text": sample["output"]}]
|
57 |
+
})
|
58 |
+
return formatted
|
59 |
+
|
60 |
+
|
61 |
+
def collate_fn(examples, processor):
|
62 |
+
# examples: list of formatted samples (list of message dicts)
|
63 |
+
texts = [processor.apply_chat_template(sample, tokenize=False) for sample in examples]
|
64 |
+
image_batches = [process_vision_info(sample)[0] for sample in examples]
|
65 |
+
batch = processor(text=texts, images=image_batches, return_tensors="pt", padding=True)
|
66 |
+
batch = {k: v.cpu() for k, v in batch.items()}
|
67 |
+
|
68 |
+
labels = batch["input_ids"].clone()
|
69 |
+
labels[labels == processor.tokenizer.pad_token_id] = -100
|
70 |
+
|
71 |
+
image_token_ids = (
|
72 |
+
[151652, 151653, 151655]
|
73 |
+
if hasattr(processor, "image_processor")
|
74 |
+
else [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]
|
75 |
+
)
|
76 |
+
for tid in image_token_ids:
|
77 |
+
labels[labels == tid] = -100
|
78 |
+
|
79 |
+
batch["labels"] = labels
|
80 |
+
return batch
|
81 |
+
|
82 |
+
|
83 |
+
@dataclass
|
84 |
+
class TrainingConfig:
|
85 |
+
model_id: str = "UCSC-VLAA/VLAA-Thinker-Qwen2.5VL-3B"
|
86 |
+
lora_r: int = 128
|
87 |
+
lora_alpha: int = 256
|
88 |
+
lora_dropout: float = 0.05
|
89 |
+
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "o_proj"])
|
90 |
+
num_train_epochs: int = 3
|
91 |
+
train_batch_size: int = 1
|
92 |
+
eval_batch_size: int = 1
|
93 |
+
gradient_accumulation_steps: int = 8
|
94 |
+
learning_rate: float = 2e-5
|
95 |
+
warmup_ratio: float = 0.03
|
96 |
+
output_dir: str = "spaceom"
|
97 |
+
wandb_project: str = "spaceom"
|
98 |
+
wandb_run_name: str = "spaceom"
|
99 |
+
|
100 |
+
|
101 |
+
def parse_args() -> TrainingConfig:
|
102 |
+
default_cfg = TrainingConfig()
|
103 |
+
parser = argparse.ArgumentParser(description="Train a VL Spacethinker model with LoRA")
|
104 |
+
parser.add_argument("--model_id", default=default_cfg.model_id)
|
105 |
+
parser.add_argument("--lora_r", type=int, default=default_cfg.lora_r)
|
106 |
+
parser.add_argument("--lora_alpha", type=int, default=default_cfg.lora_alpha)
|
107 |
+
parser.add_argument("--lora_dropout", type=float, default=default_cfg.lora_dropout)
|
108 |
+
parser.add_argument(
|
109 |
+
"--target_modules",
|
110 |
+
default=','.join(default_cfg.target_modules),
|
111 |
+
help="Comma-separated list of target modules for LoRA"
|
112 |
+
)
|
113 |
+
parser.add_argument("--num_train_epochs", type=int, default=default_cfg.num_train_epochs)
|
114 |
+
parser.add_argument("--train_batch_size", type=int, default=default_cfg.train_batch_size)
|
115 |
+
parser.add_argument("--eval_batch_size", type=int, default=default_cfg.eval_batch_size)
|
116 |
+
parser.add_argument(
|
117 |
+
"--gradient_accumulation_steps", type=int, default=default_cfg.gradient_accumulation_steps
|
118 |
+
)
|
119 |
+
parser.add_argument("--learning_rate", type=float, default=default_cfg.learning_rate)
|
120 |
+
parser.add_argument("--warmup_ratio", type=float, default=default_cfg.warmup_ratio)
|
121 |
+
parser.add_argument("--output_dir", default=default_cfg.output_dir)
|
122 |
+
parser.add_argument("--wandb_project", default=default_cfg.wandb_project)
|
123 |
+
parser.add_argument("--wandb_run_name", default=default_cfg.wandb_run_name)
|
124 |
+
|
125 |
+
args = parser.parse_args()
|
126 |
+
return TrainingConfig(
|
127 |
+
model_id=args.model_id,
|
128 |
+
lora_r=args.lora_r,
|
129 |
+
lora_alpha=args.lora_alpha,
|
130 |
+
lora_dropout=args.lora_dropout,
|
131 |
+
target_modules=args.target_modules.split(","),
|
132 |
+
num_train_epochs=args.num_train_epochs,
|
133 |
+
train_batch_size=args.train_batch_size,
|
134 |
+
eval_batch_size=args.eval_batch_size,
|
135 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
136 |
+
learning_rate=args.learning_rate,
|
137 |
+
warmup_ratio=args.warmup_ratio,
|
138 |
+
output_dir=args.output_dir,
|
139 |
+
wandb_project=args.wandb_project,
|
140 |
+
wandb_run_name=args.wandb_run_name,
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
def prepare_datasets(cfg: TrainingConfig):
|
145 |
+
print(f"Loading dataset: SpaceThinker")
|
146 |
+
raw_train_spacethinker = load_dataset("remyxai/SpaceThinker", split="train")
|
147 |
+
raw_eval_spacethinker = load_dataset("remyxai/SpaceThinker", split="test")
|
148 |
+
|
149 |
+
print(f"Loading dataset: SpaceOm")
|
150 |
+
raw_train_spaceom = load_dataset("remyxai/SpaceOm", split="train")
|
151 |
+
raw_eval_spaceom = load_dataset("remyxai/SpaceOm", split="test")
|
152 |
+
|
153 |
+
print(f"Loading dataset: Robo2VLM")
|
154 |
+
raw_train_robo2vlm = load_dataset("remyxai/Robo2VLM-Reasoning", split="train")
|
155 |
+
raw_eval_robo2vlm = load_dataset("remyxai/Robo2VLM-Reasoning", split="test")
|
156 |
+
|
157 |
+
print("Formatting train samples…")
|
158 |
+
train_ds_spacethinker = [format_data_spacethinker(s) for s in tqdm(raw_train_spacethinker, desc="Train")]
|
159 |
+
train_ds_spaceom = [format_data_spacethinker(s) for s in tqdm(raw_train_spaceom, desc="Train")]
|
160 |
+
train_ds_robo2vlm = [format_data_spacethinker(s) for s in tqdm(raw_train_robo2vlm, desc="Train")]
|
161 |
+
print("Formatting eval samples…")
|
162 |
+
eval_ds_spacethinker = [format_data_spacethinker(s) for s in tqdm(raw_eval_spacethinker, desc="Eval")]
|
163 |
+
eval_ds_spaceom = [format_data_spacethinker(s) for s in tqdm(raw_eval_spaceom, desc="Eval")]
|
164 |
+
eval_ds_robo2vlm = [format_data_spacethinker(s) for s in tqdm(raw_eval_robo2vlm, desc="Eval")]
|
165 |
+
|
166 |
+
train_ds = train_ds_spacethinker + train_ds_spaceom + train_ds_robo2vlm
|
167 |
+
eval_ds = eval_ds_spacethinker + eval_ds_spaceom + eval_ds_robo2vlm
|
168 |
+
random.shuffle(train_ds)
|
169 |
+
random.shuffle(eval_ds)
|
170 |
+
|
171 |
+
return train_ds, eval_ds
|
172 |
+
|
173 |
+
|
174 |
+
def prepare_model_and_optimizer(cfg: TrainingConfig):
|
175 |
+
bnb = BitsAndBytesConfig(
|
176 |
+
load_in_4bit=True,
|
177 |
+
bnb_4bit_use_double_quant=True,
|
178 |
+
bnb_4bit_quant_type="nf4",
|
179 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
180 |
+
)
|
181 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
182 |
+
cfg.model_id,
|
183 |
+
device_map="auto",
|
184 |
+
torch_dtype=torch.bfloat16,
|
185 |
+
quantization_config=bnb
|
186 |
+
)
|
187 |
+
processor = AutoProcessor.from_pretrained(cfg.model_id)
|
188 |
+
|
189 |
+
peft_cfg = LoraConfig(
|
190 |
+
r=cfg.lora_r,
|
191 |
+
lora_alpha=cfg.lora_alpha,
|
192 |
+
lora_dropout=cfg.lora_dropout,
|
193 |
+
bias="none",
|
194 |
+
target_modules=cfg.target_modules,
|
195 |
+
task_type="CAUSAL_LM",
|
196 |
+
)
|
197 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
198 |
+
peft_model = get_peft_model(model, peft_cfg).to(device)
|
199 |
+
peft_model.print_trainable_parameters()
|
200 |
+
return peft_model, processor, peft_cfg
|
201 |
+
|
202 |
+
|
203 |
+
def main():
|
204 |
+
cfg = parse_args()
|
205 |
+
train_ds, eval_ds = prepare_datasets(cfg)
|
206 |
+
model, processor, peft_cfg = prepare_model_and_optimizer(cfg)
|
207 |
+
|
208 |
+
sft_args = SFTConfig(
|
209 |
+
output_dir=cfg.output_dir,
|
210 |
+
num_train_epochs=cfg.num_train_epochs,
|
211 |
+
per_device_train_batch_size=cfg.train_batch_size,
|
212 |
+
per_device_eval_batch_size=cfg.eval_batch_size,
|
213 |
+
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
214 |
+
gradient_checkpointing=True,
|
215 |
+
optim="adamw_torch_fused",
|
216 |
+
learning_rate=cfg.learning_rate,
|
217 |
+
lr_scheduler_type="constant",
|
218 |
+
logging_steps=10,
|
219 |
+
eval_steps=10,
|
220 |
+
eval_strategy="steps",
|
221 |
+
save_strategy="steps",
|
222 |
+
save_steps=20,
|
223 |
+
metric_for_best_model="eval_loss",
|
224 |
+
greater_is_better=False,
|
225 |
+
load_best_model_at_end=True,
|
226 |
+
bf16=True,
|
227 |
+
tf32=True,
|
228 |
+
max_grad_norm=0.3,
|
229 |
+
warmup_ratio=cfg.warmup_ratio,
|
230 |
+
gradient_checkpointing_kwargs={"use_reentrant": False},
|
231 |
+
push_to_hub=True,
|
232 |
+
report_to="wandb",
|
233 |
+
dataset_kwargs={"skip_prepare_dataset": True},
|
234 |
+
)
|
235 |
+
sft_args.remove_unused_columns = False
|
236 |
+
|
237 |
+
wandb.init(
|
238 |
+
project=cfg.wandb_project,
|
239 |
+
name=cfg.wandb_run_name,
|
240 |
+
config=sft_args,
|
241 |
+
)
|
242 |
+
|
243 |
+
trainer = SFTTrainer(
|
244 |
+
model=model,
|
245 |
+
args=sft_args,
|
246 |
+
train_dataset=train_ds,
|
247 |
+
eval_dataset=eval_ds,
|
248 |
+
data_collator=lambda ex: collate_fn(ex, processor),
|
249 |
+
peft_config=peft_cfg,
|
250 |
+
tokenizer=processor.tokenizer,
|
251 |
+
)
|
252 |
+
|
253 |
+
trainer.train()
|
254 |
+
trainer.save_model(cfg.output_dir)
|
255 |
+
|
256 |
+
|
257 |
+
if __name__ == "__main__":
|
258 |
+
main()
|