Arnic commited on
Commit
e7802d6
·
verified ·
1 Parent(s): 19f23cc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +56 -7
README.md CHANGED
@@ -20,13 +20,13 @@ base_model:
20
 
21
  This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
22
 
23
- - **Developed by:** [More Information Needed]
24
  - **Funded by [optional]:** [More Information Needed]
25
  - **Shared by [optional]:** [More Information Needed]
26
  - **Model type:** [More Information Needed]
27
  - **Language(s) (NLP):** [More Information Needed]
28
  - **License:** [More Information Needed]
29
- - **Finetuned from model [optional]:** [More Information Needed]
30
 
31
  ### Model Sources [optional]
32
 
@@ -86,17 +86,65 @@ Use the code below to get started with the model.
86
 
87
  ### Training Procedure
88
 
89
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
 
 
 
 
 
 
 
90
 
91
  #### Preprocessing [optional]
92
 
93
- [More Information Needed]
 
94
 
 
 
 
 
 
 
 
95
 
96
- #### Training Hyperparameters
97
 
98
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  #### Speeds, Sizes, Times [optional]
101
 
102
  <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
@@ -105,7 +153,8 @@ Use the code below to get started with the model.
105
 
106
  ## Evaluation
107
 
108
- <!-- This section describes the evaluation protocols and provides the results. -->
 
109
 
110
  ### Testing Data, Factors & Metrics
111
 
 
20
 
21
  This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
22
 
23
+ - **Developed by:** Arash Nicoomanesh
24
  - **Funded by [optional]:** [More Information Needed]
25
  - **Shared by [optional]:** [More Information Needed]
26
  - **Model type:** [More Information Needed]
27
  - **Language(s) (NLP):** [More Information Needed]
28
  - **License:** [More Information Needed]
29
+ - **Finetuned from model [optional]:** google/gemma-2b-it
30
 
31
  ### Model Sources [optional]
32
 
 
86
 
87
  ### Training Procedure
88
 
89
+ model = Gemma2ForCausalLM.from_pretrained( # Changed here
90
+ base_model,
91
+ quantization_config=bnb_config,
92
+ device_map="auto",
93
+ attn_implementation=attn_implementation
94
+ )
95
+ tokenizer = GemmaTokenizerFast.from_pretrained(base_model, padding_side="right",
96
+ truncation_side="right", trust_remote_code=True)
97
 
98
  #### Preprocessing [optional]
99
 
100
+ dataset = load_dataset(dataset_name, split="all", cache_dir="./cache")
101
+ dataset = dataset.shuffle(seed=42).select(range(3000)) # Use 3k samples for a better demo
102
 
103
+ # Define a cleaning function to remove unwanted artifacts
104
+ def clean_text(text):
105
+ # Remove URLs and any "Chat Doctor" or similar phrases
106
+ text = re.sub(r'\b(?:www\.[^\s]+|http\S+)', '', text) # Remove URLs
107
+ text = re.sub(r'\b(?:Chat Doctor(?:.com)?(?:.in)?|www\.(?:google|yahoo)\S*)', '', text) # Remove site names
108
+ text = re.sub(r'\s+', ' ', text) # Collapse multiple spaces
109
+ return text.strip()
110
 
 
111
 
112
+ #### Training Hyperparameters
113
 
114
+ training_args = TrainingArguments(
115
+ output_dir=new_model,
116
+ per_device_train_batch_size=1,
117
+ per_device_eval_batch_size=1,
118
+ gradient_accumulation_steps=2,
119
+ optim="paged_adamw_32bit",
120
+ num_train_epochs=1,
121
+ eval_strategy="steps",
122
+ eval_steps=200,
123
+ save_steps=500, # Keep save_steps as 500
124
+ logging_steps=1,
125
+ warmup_steps=10,
126
+ logging_strategy="steps",
127
+ learning_rate=2e-4,
128
+ fp16=True,
129
+ bf16=False,
130
+ group_by_length=True,
131
+ report_to="wandb",
132
+ load_best_model_at_end=False # Disable loading best model at the end
133
+ )
134
+
135
+
136
+ # Trainer with early stopping callback
137
+ trainer = SFTTrainer(
138
+ model=model,
139
+ train_dataset=dataset["train"],
140
+ eval_dataset=dataset["test"],
141
+ peft_config=peft_config,
142
+ max_seq_length=512,
143
+ dataset_text_field="text", # Specify the text field in your dataset
144
+ tokenizer=tokenizer,
145
+ args=training_args,
146
+ packing=False,
147
+ )
148
  #### Speeds, Sizes, Times [optional]
149
 
150
  <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
 
153
 
154
  ## Evaluation
155
 
156
+ View run noble-hill-29 at: https://wandb.ai/anicomanesh/Fine-tune%20Gemma-2-2b-it%20on%20Medical%20Dataset/runs/06xd9vvz
157
+ wandb: ⭐️ View project at: https://wandb.ai/anicomanesh/Fine-tune%20Gemma-2-2b-it%20on%20Medical%20Dat
158
 
159
  ### Testing Data, Factors & Metrics
160