mjbuehler commited on
Commit
b182650
·
verified ·
1 Parent(s): 0b7012d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +94 -0
README.md CHANGED
@@ -137,6 +137,100 @@ A similar mechanism can be employed to generate 3D models:
137
 
138
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/623ce1c6b66fedf374859fe7/6ZsvCZ3x3TGvugly44MMI.png)
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  ## Citation
141
 
142
  Please cite as:
 
137
 
138
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/623ce1c6b66fedf374859fe7/6ZsvCZ3x3TGvugly44MMI.png)
139
 
140
+ ## Fine-tuning
141
+
142
+
143
+ Load base model
144
+
145
+ ```python
146
+ model_id = "microsoft/Phi-3-vision-128k-instruct"
147
+
148
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto")
149
+
150
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
151
+ ```
152
+
153
+ Define FT_repo_id to push on HF hub/save model:
154
+ ```
155
+ FT_repo_id='xxxxx/' #<repo_ID>
156
+ ```
157
+
158
+ ```
159
+ from datasets import load_dataset
160
+
161
+ train_dataset = load_dataset("lamm-mit/Cephalo-Wikipedia-Materials", split="train")
162
+ ```
163
+
164
+ ```python
165
+ import random
166
+
167
+ class MyDataCollator:
168
+ def __init__(self, processor):
169
+ self.processor = processor
170
+
171
+ def __call__(self, examples):
172
+ texts = []
173
+ images = []
174
+ for example in examples:
175
+ image = example["image"]
176
+ question = example["query"]
177
+ answer = example["answer"]
178
+ messages = [ {
179
+ "role": "user", "content": '<|image_1|>\n'+question},
180
+ {"role": "assistant", "content": f"{answer}"}, ]
181
+
182
+ text = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
183
+
184
+ images.append(image)
185
+
186
+ batch = processor(text=text, images=[image], return_tensors="pt", padding=True
187
+
188
+ labels = batch["input_ids"].clone()
189
+ labels[labels <0] = -100
190
+
191
+ batch["labels"] = labels
192
+
193
+ return batch
194
+
195
+ data_collator = MyDataCollator(processor)
196
+ ```
197
+ Then set up trainer, and train:
198
+ ```python
199
+ from transformers import TrainingArguments, Trainer
200
+
201
+ optim = "paged_adamw_8bit"
202
+
203
+ training_args = TrainingArguments(
204
+ num_train_epochs=2,
205
+ per_device_train_batch_size=1,
206
+ #per_device_eval_batch_size=4,
207
+ gradient_accumulation_steps=4,
208
+ warmup_steps=250,
209
+ learning_rate=1e-5,
210
+ weight_decay=0.01,
211
+ logging_steps=25,
212
+ output_dir="output_training",
213
+ optim=optim,
214
+ save_strategy="steps",
215
+ save_steps=1000,
216
+ save_total_limit=16,
217
+ #fp16=True,
218
+ bf16=True,
219
+ push_to_hub_model_id=FT_repo_id,
220
+ remove_unused_columns=False,
221
+ report_to="none",
222
+ )
223
+
224
+ trainer = Trainer(
225
+ model=model,
226
+ args=training_args,
227
+ data_collator=data_collator,
228
+ train_dataset=train_dataset,
229
+ )
230
+
231
+ trainer.train()
232
+ ```
233
+
234
  ## Citation
235
 
236
  Please cite as: