---
base_model: google/gemma-3-270m-it
tags:
- medical
- healthcare
- chain-of-thought
- cot
- reasoning
- sft
- lora
- unsloth
- gemma
- gemma-3
- long-context
- transformers
license: apache-2.0
language:
- en
pipeline_tag: text-generation
datasets:
- FreedomIntelligence/medical-o1-reasoning-SFT
---
# Medical-Diagnosis-COT-Gemma3-270M
**Alpha AI (www.alphaai.biz)** fine-tuned Gemma-3 270M for **medical question answering with explicit chain-of-thought (CoT)**. The model emits reasoning inside ` ... ` followed by a final answer, making it well-suited for research on verifiable medical reasoning and for internal tooling where transparent intermediate steps are desired.
> ⚠️**Not for clinical use.** This model is a research system, not a medical device. Do not use it for diagnosis or treatment decisions.
---
## TL;DR
* **Base**: `google/gemma-3-270m-it` (Gemma 3, 270M params). Access requires accepting Google’s Gemma license on Hugging Face.
* **Data**: `FreedomIntelligence/medical-o1-reasoning-SFT` (`en` split; columns: `Question`, `Complex_CoT`, `Response`).
* **Training**: SFT with LoRA via Unsloth; assistant-only loss; sequences templated in Gemma-3 chat format.
* **Objective**: Produce `…` (reasoning) + final answer.
---
## Intended Use & Limitations
**Intended use**
* Research on medical reasoning, CoT interpretability, prompt engineering, dataset curation.
* Internal assistants that require visible reasoning traces (to be reviewed by humans).
**Out of scope / limitations**
* Not a substitute for clinician judgment; may hallucinate facts.
* No guarantee of guideline adherence (e.g., UpToDate/NICE/ACOG).
* Biases from synthetic/derived training data will propagate.
Dataset license is **Apache-2.0**; the base model is covered by **Google's Gemma license**, review these before use.
---
## Model Details
* **Model family**: Gemma 3 (270M).
* **Context window**: Gemma 3 supports up to **128K tokens** (practical context used during SFT was lower due to GPU limits).
* **Architecture**: decoder-only causal LM (Gemma family).
* **Fine-tuning**: Parameter-Efficient Fine-Tuning (LoRA) using Unsloth.
---
## Data
**Source**: `FreedomIntelligence/medical-o1-reasoning-SFT` + Human annotated and filtered medical diagnosis data
* English subset (`en`): \~19.7k rows with fields **`Question`**, **`Complex_CoT`**, **`Response`**. Total dataset ≈ 90,120 rows across splits/languages.
* Built for advanced medical reasoning; constructed with GPT-4o and a verifier on **verifiable medical problems**. See dataset card and paper.
* Human annotated data ~ 10k rows with similar fields.
**Preprocessing**
* Each row rendered to a chat conversation under the **Gemma-3** template.
* Assistant output concatenates `{Complex_CoT}\n{Response}`.
* Training objective masked to **assistant** tokens only (instruction masking).
---
## Prompt & Output Format
**Training prompt style (system + user + assistant):**
```text
System:
Below is an instruction that describes a task,
paired with an input that provides further context.
Write a response that appropriately completes the request.
Before answering, think carefully about the question and create a step-by-step
chain of thoughts to ensure a logical and accurate response.
### Instruction:
You are a medical expert with advanced knowledge in clinical reasoning,
diagnostics, and treatment planning.
Please answer the following medical question.
User:
### Question:
{question}
### Response:
Assistant:
{complex_cot}
{final_answer}
```
**Generation tip (serve-time):**
If you don’t want to expose CoT to end-users, post-process generated text to **strip the ` ... ` block** and only show the final answer.
Python snippet to strip CoT:
```python
import re
def strip_think(text: str) -> str:
return re.sub(r".*?\s*", "", text, flags=re.S|re.I).strip()
```
---
## Quick Start
### Transformers (merged weights)
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
repo = "alphaaico/Medical-Diagnosis-COT-Gemma3-270M"
tok = AutoTokenizer.from_pretrained(repo)
mdl = AutoModelForCausalLM.from_pretrained(repo, device_map="auto")
prompt = """### Question:
A 65-year-old with exertional dyspnea and orthopnea presents with bilateral
pitting edema and raised JVP. What initial pharmacologic therapy is indicated?
### Response:
"""
from transformers import TextStreamer
streamer = TextStreamer(tok, skip_prompt=True, skip_special_tokens=True)
inputs = tok.apply_chat_template(
[{"role":"system","content":"You are a careful medical reasoner."},
{"role":"user","content":prompt}],
tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(mdl.device)
out = mdl.generate(**inputs, max_new_tokens=256, temperature=1.0, top_p=0.95, top_k=64, streamer=streamer)
```
### PEFT LoRA (if this repo hosts LoRA adapters)
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
base = "google/gemma-3-270m-it" # requires accepting Gemma license on Hugging Face
repo = "alphaaico/Medical-Diagnosis-COT-Gemma3-270M"
tok = AutoTokenizer.from_pretrained(base)
base_mdl = AutoModelForCausalLM.from_pretrained(base, device_map="auto")
mdl = PeftModel.from_pretrained(base_mdl, repo) # merges LoRA at runtime
mdl.eval()
```
> If you see a 403/404 on the base model, make sure you’ve accepted Google’s Gemma license in your Hugging Face account.
---
## Training Procedure
* **Trainer**: TRL `SFTTrainer` (supervised finetuning).
* **Library**: Unsloth for fast loading + LoRA. Example reference repo for Gemma-3-270M is available on Hugging Face.
**Key hyperparameters (demo run that produced loss ≈ 3.3):**
* Steps: `max_steps=500` , ideall go beyond 1500 for better results.
* Optimizer: `adamw_8bit`, `weight_decay=0.01`, `lr_scheduler_type=linear`
* LR: `5e-5` (for longer runs, `2e-5` is often more stable)
* Batch: `per_device_train_batch_size=8`, `gradient_accumulation_steps=1`
* Warmup: `warmup_steps=5`
* Seed: `3407`
* Loss masking: **assistant-only** (user/system tokens ignored)
> Note: For very long CoT sequences on small GPUs (e.g., T4 16 GB), consider `per_device_train_batch_size=1` + gradient accumulation and a larger `max_seq_length`. Gemma 3 supports up to **128K** context, but practical fine-tuning length depends on memory. For smaller queries, everything fits well for this to be used on edge devices.
---
## Evaluation
* **Training loss**: \~**3.3** at step 500 (demo run). Further decreased to \~**2.3** with parameter tuning.
* **Format compliance**: qualitatively verified to produce `…` followed by an answer.
* **No formal clinical benchmarks** reported yet. Contributions welcome via the Discussions tab.
---
## Safety, Ethics, and Responsible Use
* Do not rely on outputs for patient care. Validate with authoritative sources and licensed professionals.
* Be mindful of dataset artifacts and synthetic reasoning patterns.
* Consider stripping CoT in user-facing apps to avoid over-trust in intermediate narratives.
---
## License & Attribution
* **Base model**: Google **Gemma 3** — access controlled under Google’s Gemma license on Hugging Face.
* **Dataset**: `FreedomIntelligence/medical-o1-reasoning-SFT` — **Apache-2.0**.
* **This fine-tune**: Derivative work under the base model’s license terms. Review and comply with both licenses before distribution or commercial use.
**Please cite the dataset if you use this model:**
```
@misc{chen2024huatuogpto1medicalcomplexreasoning,
title={HuatuoGPT-o1, Towards Medical Complex Reasoning with LLMs},
author={Junying Chen and Zhenyang Cai and Ke Ji and Xidong Wang and Wanlong Liu and Rongsheng Wang and Jianye Hou and Benyou Wang},
year={2024},
eprint={2412.18925},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2412.18925}
}
```
---
## Acknowledgements
* **Google** for releasing Gemma 3.
* **FreedomIntelligence** for `medical-o1-reasoning-SFT`.
* **Unsloth** for streamlined fine-tuning utilities for Gemma 3.
---
## Contact
Questions, issues, or contributions: open a Discussion on this repo.
For enterprise collaboration with **Alpha AI**, reach out via the organization profile on Hugging Face or find us on www.alphaai.biz.
---
### Appendix: Inference without CoT (server-side filter)
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
repo = "alphaaico/Medical-Diagnosis-COT-Gemma3-270M"
tok = AutoTokenizer.from_pretrained(repo)
mdl = AutoModelForCausalLM.from_pretrained(repo, device_map="auto")
def strip_think(txt): return re.sub(r".*?\s*", "", txt, flags=re.S|re.I).strip()
def ask(question):
user = f"### Question:\n\n{question}\n\n### Response:"
msgs = [
{"role":"system","content":"You are a medical expert. Think step-by-step, then answer succinctly."},
{"role":"user","content":user},
]
prompt = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
out = mdl.generate(**tok(prompt, return_tensors="pt").to(mdl.device), max_new_tokens=512, temperature=1.0, top_p=0.95, top_k=64)
text = tok.decode(out[0], skip_special_tokens=True)
return strip_think(text)
print(ask("Elderly patient with new ankle swelling on thiazolidinedione—likely cause?"))
```