Add pipeline tag, library name, data access, data generation, training pipeline and evaluation (#1)
Browse files- Add pipeline tag, library name, data access, data generation, training pipeline and evaluation (8c0a68b68ea71ac4343a7f1dc34bd19916beebdc)
Co-authored-by: Niels Rogge <[email protected]>
README.md
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
3 |
---
|
4 |
|
5 |
# MedReason: Eliciting Factual Medical Reasoning Steps in LLMs via Knowledge Graphs
|
@@ -8,16 +10,37 @@ license: apache-2.0
|
|
8 |
📃 <a href="https://arxiv.org/abs/2504.00993" target="_blank">Paper</a> |🤗 <a href="https://huggingface.co/UCSC-VLAA/MedReason-8B" target="_blank">MedReason-8B</a> | 📚 <a href="https://huggingface.co/datasets/UCSC-VLAA/MedReason" target="_blank">MedReason Data</a>
|
9 |
</p>
|
10 |
|
11 |
-
|
12 |
## ⚡Introduction
|
13 |
|
|
|
|
|
14 |
**MedReason** is a large-scale high-quality medical reasoning dataset designed to enable faithful and explainable medical problem-solving in large language models (LLMs).
|
15 |
|
16 |
- We utilize a structured medical knowledge graph (KG) to convert clinical QA pairs into logical chains of reasoning, or “thinking paths”.
|
17 |
-
- Our pipeline generates detailed reasoning for various medical questions from 7 medical datasets, resulting in a dataset of **32,682** question-answer pairs, each with detailed, step-by-step explanations.
|
18 |
- By finetuning with proposed [MedReason dataset](https://huggingface.co/datasets/UCSC-VLAA/MedReason), our best model [MedReason-8B](https://huggingface.co/UCSC-VLAA/MedReason-8B), achieves *state-of-the-art* performance.
|
19 |
|
20 |
-
We open-sourced our
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
## 👨⚕️ Model
|
23 |
|
@@ -29,7 +52,7 @@ We open-sourced our model here.
|
|
29 |
| MedReason-Llama | [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) | [Link](https://huggingface.co/UCSC-VLAA/MedReason-Llama) |
|
30 |
| MedReason-Mistral | [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) | [Link](https://huggingface.co/UCSC-VLAA/MedReason-Mistral) |
|
31 |
|
32 |
-
- **Deploy**: we provide a example code for direct inference with MedReason-8B.
|
33 |
|
34 |
Also, MedReason-8B can be deployed with tools like [vllm](https://github.com/vllm-project/vllm) or [Sglang](https://github.com/sgl-project/sglang), we provide code for model deployment using Sglang in `./src/evaluation/eval.py`
|
35 |
|
@@ -49,6 +72,87 @@ outputs = model.generate(**inputs, max_new_tokens=2048)
|
|
49 |
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
50 |
```
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
## 🙏🏼 Acknowledgement
|
53 |
|
54 |
We gratefully acknowledge the inspiring work of [HuatuoGPT-o1](https://github.com/FreedomIntelligence/HuatuoGPT-o1), which laid important groundwork for this research. We also thank the developers of the excellent tools [curator](https://github.com/bespokelabsai/curator/), [trl](https://github.com/huggingface/trl), and [sglang](https://github.com/sgl-project/sglang) for making this work possible.
|
@@ -57,13 +161,12 @@ We gratefully acknowledge the inspiring work of [HuatuoGPT-o1](https://github.co
|
|
57 |
|
58 |
```
|
59 |
@misc{wu2025medreasonelicitingfactualmedical,
|
60 |
-
title={MedReason: Eliciting Factual Medical Reasoning Steps in LLMs via Knowledge Graphs},
|
61 |
author={Juncheng Wu and Wenlong Deng and Xingxuan Li and Sheng Liu and Taomian Mi and Yifan Peng and Ziyang Xu and Yi Liu and Hyunjin Cho and Chang-In Choi and Yihan Cao and Hui Ren and Xiang Li and Xiaoxiao Li and Yuyin Zhou},
|
62 |
year={2025},
|
63 |
eprint={2504.00993},
|
64 |
archivePrefix={arXiv},
|
65 |
primaryClass={cs.CL},
|
66 |
-
url={https://arxiv.org/abs/2504.00993},
|
67 |
}
|
68 |
-
```
|
69 |
-
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
library_name: transformers
|
4 |
+
pipeline_tag: question-answering
|
5 |
---
|
6 |
|
7 |
# MedReason: Eliciting Factual Medical Reasoning Steps in LLMs via Knowledge Graphs
|
|
|
10 |
📃 <a href="https://arxiv.org/abs/2504.00993" target="_blank">Paper</a> |🤗 <a href="https://huggingface.co/UCSC-VLAA/MedReason-8B" target="_blank">MedReason-8B</a> | 📚 <a href="https://huggingface.co/datasets/UCSC-VLAA/MedReason" target="_blank">MedReason Data</a>
|
11 |
</p>
|
12 |
|
|
|
13 |
## ⚡Introduction
|
14 |
|
15 |
+
<img src="./assets/main.png" alt="main" style="zoom: 33%;" />
|
16 |
+
|
17 |
**MedReason** is a large-scale high-quality medical reasoning dataset designed to enable faithful and explainable medical problem-solving in large language models (LLMs).
|
18 |
|
19 |
- We utilize a structured medical knowledge graph (KG) to convert clinical QA pairs into logical chains of reasoning, or “thinking paths”.
|
20 |
+
- Our pipeline generates detailed reasoning for various medical questions from 7 medical datasets, resulting in a dataset of **32,682** question-answer pairs, each with detailed, step-by-step explanations.
|
21 |
- By finetuning with proposed [MedReason dataset](https://huggingface.co/datasets/UCSC-VLAA/MedReason), our best model [MedReason-8B](https://huggingface.co/UCSC-VLAA/MedReason-8B), achieves *state-of-the-art* performance.
|
22 |
|
23 |
+
We open-sourced our models, data, and code here.
|
24 |
+
|
25 |
+
## 📚 Data
|
26 |
+
|
27 |
+
- **Data Access**
|
28 |
+
|
29 |
+
| Data | Description | Link |
|
30 |
+
| --------- | --------------------------------- | ----------------------------------------------------------- |
|
31 |
+
| MedReason | Our quality filtered data for SFT | [Link](https://huggingface.co/datasets/UCSC-VLAA/MedReason) |
|
32 |
+
|
33 |
+
- **Data Generation**
|
34 |
+
|
35 |
+
We provide the code for generating Chain-of-Thought reasoning based on medical QA pairs and knowledge-graph (KG) in `./src/data_generation`
|
36 |
+
|
37 |
+
1. Set the file path of each datasets in `./configs/dataset_configs.yml`
|
38 |
+
2. Fill your Azure endpoint and API key in `./src/data_generation/utils.py`
|
39 |
+
3. Run the following script
|
40 |
+
|
41 |
+
```bash
|
42 |
+
python ./src/data_generation/Generate_Reasoning.py --dataset medqa --sample <number_of_samples> --start_idx 0 --batch_size 1&
|
43 |
+
```
|
44 |
|
45 |
## 👨⚕️ Model
|
46 |
|
|
|
52 |
| MedReason-Llama | [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) | [Link](https://huggingface.co/UCSC-VLAA/MedReason-Llama) |
|
53 |
| MedReason-Mistral | [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) | [Link](https://huggingface.co/UCSC-VLAA/MedReason-Mistral) |
|
54 |
|
55 |
+
- **Deploy**: we provide a example code for direct inference with MedReason-8B.
|
56 |
|
57 |
Also, MedReason-8B can be deployed with tools like [vllm](https://github.com/vllm-project/vllm) or [Sglang](https://github.com/sgl-project/sglang), we provide code for model deployment using Sglang in `./src/evaluation/eval.py`
|
58 |
|
|
|
72 |
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
73 |
```
|
74 |
|
75 |
+
## 🚀 Training Piepline
|
76 |
+
|
77 |
+
Simply Supervised-Finetuning (SFT) using MedReason data improves the LLM’s medical reasoning capability.
|
78 |
+
|
79 |
+
Fine-tune the model on 8-GPU:
|
80 |
+
|
81 |
+
```bash
|
82 |
+
# based on Huatuo-o1-8B
|
83 |
+
accelerate launch --config_file ./configs/deepspeed_zero3.yaml \
|
84 |
+
--num_processes 8 \
|
85 |
+
--num_machines 1 \
|
86 |
+
--machine_rank 0 \
|
87 |
+
--deepspeed_multinode_launcher standard ./src/model_training/SFT.py \
|
88 |
+
--model_path FreedomIntelligence/HuatuoGPT-o1-8B \
|
89 |
+
--data_path /path/to/your/data \
|
90 |
+
--n_epochs 3 \
|
91 |
+
--experiment_name huatuo_o1_medreason_8B \
|
92 |
+
--base_model Llama
|
93 |
+
|
94 |
+
# based on DeepSeek-distilled-Llama-8B
|
95 |
+
accelerate launch --config_file ./configs/deepspeed_zero3.yaml \
|
96 |
+
--num_processes 8 \
|
97 |
+
--num_machines 1 \
|
98 |
+
--machine_rank 0 \
|
99 |
+
--deepspeed_multinode_launcher standard ./src/model_training/SFT.py \
|
100 |
+
--model_path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
|
101 |
+
--data_path /path/to/your/data\
|
102 |
+
--n_epochs 3 \
|
103 |
+
--experiment_name distilled_llama_medreason_8B \
|
104 |
+
--base_model Llama
|
105 |
+
```
|
106 |
+
|
107 |
+
## 🧐 Evaluation
|
108 |
+
|
109 |
+
- **Qualitative Results:**
|
110 |
+
|
111 |
+
Case Study on Medbullets Benchmark. **MedReason-8B** generates accurate reasoning with reliable knowledge.
|
112 |
+
|
113 |
+
<img src="./assets/case_v6.png" alt="case_v6" style="zoom: 40%;" />
|
114 |
+
|
115 |
+
- **Performance on medical benchmarks**:
|
116 |
+
|
117 |
+
Results of instruction-tuned LLMs fine-tuned with MedReason data:
|
118 |
+
|
119 |
+
<img src="./assets/tab1.png" alt="tab1" style="zoom:50%;" />
|
120 |
+
|
121 |
+
Performance of MedReason-8B on challenging and common medical QA benchmarks:
|
122 |
+
|
123 |
+
<img src="./assets/tab3.png" alt="tab3" style="zoom:50%;" />
|
124 |
+
|
125 |
+
- **Run evaluation**:
|
126 |
+
|
127 |
+
1. You first need to install [Sglang](https://github.com/sgl-project/sglang). After installation, deploy the model you want to test using Sglang with the following command:
|
128 |
+
|
129 |
+
```bash
|
130 |
+
# deploy on 8 GPUs
|
131 |
+
log_num=0
|
132 |
+
model_name=UCSC-VLAA/MedReason-8B
|
133 |
+
port=28${log_num}35
|
134 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m sglang.launch_server --model-path $model_name --port $port --mem-fraction-static 0.8 --dp 8 --tp 1 > sglang${log_num}.log 2>&1 &
|
135 |
+
```
|
136 |
+
|
137 |
+
2. Wait for the model to be deployed. After deployment, you can run the following code for evaluation. We use prompts that allow the model to respond freely. We find that the extracted results are consistently reliable and broadly cover the intended scope. You can also set the `--strict_prompt` option to use stricter prompts for more precise answer extraction.
|
138 |
+
|
139 |
+
```bash
|
140 |
+
log_num=0
|
141 |
+
task_floder=MedReason-8B-results
|
142 |
+
model_name=UCSC-VLAA/MedReason-8B
|
143 |
+
port=28${log_num}35
|
144 |
+
|
145 |
+
eval_file=./eval_data/medbullets_op4.jsonl
|
146 |
+
python ./src/evaluation/eval.py --model_name $model_name --eval_file $eval_file --port $port --strict_prompt --batch_size 1000 --max_new_tokens 2000 --task_floder $task_floder
|
147 |
+
```
|
148 |
+
|
149 |
+
3. After completing the evaluation, run the following code to stop the Sglang service and release GPU memory.
|
150 |
+
|
151 |
+
```bash
|
152 |
+
pkill -f sglang
|
153 |
+
pkill -f multiprocessing.spawn
|
154 |
+
```
|
155 |
+
|
156 |
## 🙏🏼 Acknowledgement
|
157 |
|
158 |
We gratefully acknowledge the inspiring work of [HuatuoGPT-o1](https://github.com/FreedomIntelligence/HuatuoGPT-o1), which laid important groundwork for this research. We also thank the developers of the excellent tools [curator](https://github.com/bespokelabsai/curator/), [trl](https://github.com/huggingface/trl), and [sglang](https://github.com/sgl-project/sglang) for making this work possible.
|
|
|
161 |
|
162 |
```
|
163 |
@misc{wu2025medreasonelicitingfactualmedical,
|
164 |
+
title={MedReason: Eliciting Factual Medical Reasoning Steps in LLMs via Knowledge Graphs},
|
165 |
author={Juncheng Wu and Wenlong Deng and Xingxuan Li and Sheng Liu and Taomian Mi and Yifan Peng and Ziyang Xu and Yi Liu and Hyunjin Cho and Chang-In Choi and Yihan Cao and Hui Ren and Xiang Li and Xiaoxiao Li and Yuyin Zhou},
|
166 |
year={2025},
|
167 |
eprint={2504.00993},
|
168 |
archivePrefix={arXiv},
|
169 |
primaryClass={cs.CL},
|
170 |
+
url={https://arxiv.org/abs/2504.00993},
|
171 |
}
|
172 |
+
```
|
|