Update README.md
Browse files
README.md
CHANGED
@@ -20,13 +20,13 @@ base_model:
|
|
20 |
|
21 |
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
22 |
|
23 |
-
- **Developed by:**
|
24 |
- **Funded by [optional]:** [More Information Needed]
|
25 |
- **Shared by [optional]:** [More Information Needed]
|
26 |
- **Model type:** [More Information Needed]
|
27 |
- **Language(s) (NLP):** [More Information Needed]
|
28 |
- **License:** [More Information Needed]
|
29 |
-
- **Finetuned from model [optional]:**
|
30 |
|
31 |
### Model Sources [optional]
|
32 |
|
@@ -86,17 +86,65 @@ Use the code below to get started with the model.
|
|
86 |
|
87 |
### Training Procedure
|
88 |
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
#### Preprocessing [optional]
|
92 |
|
93 |
-
|
|
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
#### Training Hyperparameters
|
97 |
|
98 |
-
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
#### Speeds, Sizes, Times [optional]
|
101 |
|
102 |
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
@@ -105,7 +153,8 @@ Use the code below to get started with the model.
|
|
105 |
|
106 |
## Evaluation
|
107 |
|
108 |
-
|
|
|
109 |
|
110 |
### Testing Data, Factors & Metrics
|
111 |
|
|
|
20 |
|
21 |
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
22 |
|
23 |
+
- **Developed by:** Arash Nicoomanesh
|
24 |
- **Funded by [optional]:** [More Information Needed]
|
25 |
- **Shared by [optional]:** [More Information Needed]
|
26 |
- **Model type:** [More Information Needed]
|
27 |
- **Language(s) (NLP):** [More Information Needed]
|
28 |
- **License:** [More Information Needed]
|
29 |
+
- **Finetuned from model [optional]:** google/gemma-2b-it
|
30 |
|
31 |
### Model Sources [optional]
|
32 |
|
|
|
86 |
|
87 |
### Training Procedure
|
88 |
|
89 |
+
model = Gemma2ForCausalLM.from_pretrained( # Changed here
|
90 |
+
base_model,
|
91 |
+
quantization_config=bnb_config,
|
92 |
+
device_map="auto",
|
93 |
+
attn_implementation=attn_implementation
|
94 |
+
)
|
95 |
+
tokenizer = GemmaTokenizerFast.from_pretrained(base_model, padding_side="right",
|
96 |
+
truncation_side="right", trust_remote_code=True)
|
97 |
|
98 |
#### Preprocessing [optional]
|
99 |
|
100 |
+
dataset = load_dataset(dataset_name, split="all", cache_dir="./cache")
|
101 |
+
dataset = dataset.shuffle(seed=42).select(range(3000)) # Use 3k samples for a better demo
|
102 |
|
103 |
+
# Define a cleaning function to remove unwanted artifacts
|
104 |
+
def clean_text(text):
|
105 |
+
# Remove URLs and any "Chat Doctor" or similar phrases
|
106 |
+
text = re.sub(r'\b(?:www\.[^\s]+|http\S+)', '', text) # Remove URLs
|
107 |
+
text = re.sub(r'\b(?:Chat Doctor(?:.com)?(?:.in)?|www\.(?:google|yahoo)\S*)', '', text) # Remove site names
|
108 |
+
text = re.sub(r'\s+', ' ', text) # Collapse multiple spaces
|
109 |
+
return text.strip()
|
110 |
|
|
|
111 |
|
112 |
+
#### Training Hyperparameters
|
113 |
|
114 |
+
training_args = TrainingArguments(
|
115 |
+
output_dir=new_model,
|
116 |
+
per_device_train_batch_size=1,
|
117 |
+
per_device_eval_batch_size=1,
|
118 |
+
gradient_accumulation_steps=2,
|
119 |
+
optim="paged_adamw_32bit",
|
120 |
+
num_train_epochs=1,
|
121 |
+
eval_strategy="steps",
|
122 |
+
eval_steps=200,
|
123 |
+
save_steps=500, # Keep save_steps as 500
|
124 |
+
logging_steps=1,
|
125 |
+
warmup_steps=10,
|
126 |
+
logging_strategy="steps",
|
127 |
+
learning_rate=2e-4,
|
128 |
+
fp16=True,
|
129 |
+
bf16=False,
|
130 |
+
group_by_length=True,
|
131 |
+
report_to="wandb",
|
132 |
+
load_best_model_at_end=False # Disable loading best model at the end
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
# Trainer with early stopping callback
|
137 |
+
trainer = SFTTrainer(
|
138 |
+
model=model,
|
139 |
+
train_dataset=dataset["train"],
|
140 |
+
eval_dataset=dataset["test"],
|
141 |
+
peft_config=peft_config,
|
142 |
+
max_seq_length=512,
|
143 |
+
dataset_text_field="text", # Specify the text field in your dataset
|
144 |
+
tokenizer=tokenizer,
|
145 |
+
args=training_args,
|
146 |
+
packing=False,
|
147 |
+
)
|
148 |
#### Speeds, Sizes, Times [optional]
|
149 |
|
150 |
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
|
|
153 |
|
154 |
## Evaluation
|
155 |
|
156 |
+
View run noble-hill-29 at: https://wandb.ai/anicomanesh/Fine-tune%20Gemma-2-2b-it%20on%20Medical%20Dataset/runs/06xd9vvz
|
157 |
+
wandb: ⭐️ View project at: https://wandb.ai/anicomanesh/Fine-tune%20Gemma-2-2b-it%20on%20Medical%20Dat
|
158 |
|
159 |
### Testing Data, Factors & Metrics
|
160 |
|