update model and InfiMed.py
Browse files- InfiMed.py +5 -2
- README.md +89 -3
- config.json +9 -4
- model-00001-of-00002.safetensors +1 -1
- model-00002-of-00002.safetensors +1 -1
InfiMed.py
CHANGED
@@ -181,14 +181,16 @@ class InfiMed(PreTrainedModel):
|
|
181 |
if vision_model is not None:
|
182 |
self.vision_model = vision_model
|
183 |
else:
|
184 |
-
self.vision_model = SiglipVisionModel.from_pretrained(config.vision_config._name_or_path, hidden_act = "gelu")
|
|
|
185 |
|
186 |
if language_model is not None:
|
187 |
self.language_model = language_model
|
188 |
self.config.llm_config = language_model.config
|
189 |
else:
|
190 |
if config.llm_config.architectures[0] == 'Qwen3ForCausalLM':
|
191 |
-
self.language_model = Qwen3ForCausalLM.from_pretrained(config.llm_config._name_or_path, pad_token_id = 151670, bos_token_id = 128245, eos_token_id = 151645, tie_word_embeddings = False)
|
|
|
192 |
else:
|
193 |
raise NotImplementedError(
|
194 |
f'{config.llm_config.architectures[0]} is not implemented.')
|
@@ -520,3 +522,4 @@ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX
|
|
520 |
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
|
521 |
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
522 |
return input_ids, labels
|
|
|
|
181 |
if vision_model is not None:
|
182 |
self.vision_model = vision_model
|
183 |
else:
|
184 |
+
# self.vision_model = SiglipVisionModel.from_pretrained(config.vision_config._name_or_path, hidden_act = "gelu")
|
185 |
+
self.vision_model = SiglipVisionModel(config.vision_config)
|
186 |
|
187 |
if language_model is not None:
|
188 |
self.language_model = language_model
|
189 |
self.config.llm_config = language_model.config
|
190 |
else:
|
191 |
if config.llm_config.architectures[0] == 'Qwen3ForCausalLM':
|
192 |
+
# self.language_model = Qwen3ForCausalLM.from_pretrained(config.llm_config._name_or_path, pad_token_id = 151670, bos_token_id = 128245, eos_token_id = 151645, tie_word_embeddings = False)
|
193 |
+
self.language_model = Qwen3ForCausalLM(config.llm_config)
|
194 |
else:
|
195 |
raise NotImplementedError(
|
196 |
f'{config.llm_config.architectures[0]} is not implemented.')
|
|
|
522 |
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
|
523 |
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
524 |
return input_ids, labels
|
525 |
+
|
README.md
CHANGED
@@ -1,3 +1,89 @@
|
|
1 |
-
---
|
2 |
-
license:
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
- zh
|
6 |
+
base_model:
|
7 |
+
- google/siglip-so400m-patch14-384
|
8 |
+
- Qwen/Qwen3-4B
|
9 |
+
---
|
10 |
+
## Introduction
|
11 |
+
|
12 |
+
InfiMed-4B is a medical Multimodal Large Language Model (MLLM) developed by the InfiXAI team. Our model outperforms HuatuoGPT-V-7B and MedGemma-4B-IT. The goal of InfiMed-4B is to develop a high-performance medical MLLM that ensures accessibility and affordability for a broad audience. Welcome to explore its capabilities and feel free to contact us for any questions or opportunities.
|
13 |
+
|
14 |
+
## Model Card
|
15 |
+
|
16 |
+
### Model Architecture:
|
17 |
+
|
18 |
+
| Architecture | ViT | LLM | Adapter | Resolution |
|
19 |
+
| --- | --- | --- | --- | --- | --- |
|
20 |
+
| 🤗InfiMed-4B | [🤗siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) |[🤗Qwen3-4B](https://huggingface.co/Qwen/Qwen3-4B) | 2-layer MLP | 384x384xN |
|
21 |
+
|
22 |
+
## Evaluation
|
23 |
+
|
24 |
+
InfiMed-4B not only outperforms HuatuoGPT-V-7B and MedGemma-4B-IT, but is also competitive compared to recently released SoTA models.
|
25 |
+
|
26 |
+
|
27 |
+
### Detail Evaluations:
|
28 |
+
|
29 |
+
| Model | Size | MMMU-Med | VQA-RAD | SLAKE | PathVQA | PMC-VQA | OMVQA | MedXVQA | Avg. |
|
30 |
+
|---------------------|------|----------|---------|-------|---------|---------|-------|---------|-------|
|
31 |
+
| **Proprietary Models** | | | | | | | | | |
|
32 |
+
| GPT-5 | | 83.6 | 67.8 | 78.1 | 52.8 | 60.0 | 76.4 | 71.0 | 70.0 |
|
33 |
+
| GPT-5-mini | | 80.5 | 66.3 | 76.1 | 52.4 | 57.6 | 70.9 | 60.1 | 66.3 |
|
34 |
+
| GPT-5-nano | | 74.1 | 55.4 | 69.3 | 45.4 | 51.3 | 66.5 | 45.1 | 58.2 |
|
35 |
+
| GPT-4.1 | | 75.2 | 65.0 | 72.2 | 55.5 | 55.2 | 75.5 | 45.2 | 63.4 |
|
36 |
+
| Claude Sonnet 4 | | 74.6 | 67.6 | 70.6 | 54.2 | 54.4 | 65.5 | 43.3 | 61.5 |
|
37 |
+
| Gemini-2.5-Flash | | 76.9 | 68.5 | 75.8 | 55.4 | 55.4 | 71.0 | 52.8 | 65.1 |
|
38 |
+
| **General Open-source Models** | | | | | | | | | |
|
39 |
+
| Qwen2.5VL-3B | 3B | 51.3 | 56.8 | 63.2 | 37.1 | 50.6 | 64.5 | 20.7 | 49.2 |
|
40 |
+
| Qwen2.5VL-7B | 7B | 50.6 | 64.5 | 67.2 | 44.1 | 51.9 | 63.6 | 22.3 | 52.0 |
|
41 |
+
| InternVL3-8B | 8B | 59.2 | 65.4 | 72.8 | 48.6 | 53.8 | 79.1 | 22.4 | 57.3 |
|
42 |
+
| **Medical Open-source Models** | | | | | | | | | |
|
43 |
+
| MedGemma-4B-IT | 4B | 43.7 | 72.5 | 76.4 | 48.8 | 49.9 | 69.8 | 22.3 | 54.8 |
|
44 |
+
| LLaVA-Med-7B | 7B | 29.3 | 53.7 | 48.0 | 38.8 | 30.5 | 44.3 | 20.3 | 37.8 |
|
45 |
+
| HuatuoGPT-V-7B | 7B | 47.3 | 67.0 | 67.8 | 48.0 | 53.3 | 74.2 | 21.6 | 54.2 |
|
46 |
+
| Lingshu-7B | 7B | 54.0 | 67.9 | 83.1 | 61.9 | 56.3 | 82.9 | 26.7 | 61.8 |
|
47 |
+
| BioMediX2-8B | 8B | 39.8 | 49.2 | 57.7 | 37.0 | 43.5 | 63.3 | 21.8 | 44.6 |
|
48 |
+
| Infi-Med-1.7B | 1.7B | 34.7 | 56.3 | 75.3 | 60.7 | 48.1 | 58.9 | 21.8 | 50.8 |
|
49 |
+
| Infi-Med-4B | 4B | 43.3 | 57.9 | 77.7 | 63.4 | 56.6 | 76.8 | 21.9 | 56.4 |
|
50 |
+
|
51 |
+
|
52 |
+
### Code:
|
53 |
+
|
54 |
+
```Python
|
55 |
+
from InfiMed import InfiMed
|
56 |
+
from PIL import Image
|
57 |
+
import torch
|
58 |
+
|
59 |
+
# Define the path to the pretrained checkpoint
|
60 |
+
pretrained_model_path = "."
|
61 |
+
|
62 |
+
# Load the model from the pretrained checkpoint
|
63 |
+
model = InfiMed.from_pretrained(pretrained_model_path, device_map="auto", torch_dtype=torch.bfloat16)
|
64 |
+
|
65 |
+
image_path = "" # Replace with the path to your image file
|
66 |
+
image = Image.open(image_path).convert("RGB") # Ensure the image is in RGB format
|
67 |
+
|
68 |
+
# Prepare input messages
|
69 |
+
messages = {
|
70 |
+
"prompt": "What modality is used to take this image?",
|
71 |
+
"image": image # No image for this example, set to None
|
72 |
+
}
|
73 |
+
|
74 |
+
# Generate output
|
75 |
+
output_text = model.generate_output(messages)
|
76 |
+
|
77 |
+
# Print the result
|
78 |
+
print("Model Response:", output_text)
|
79 |
+
|
80 |
+
```
|
81 |
+
<br>
|
82 |
+
|
83 |
+
## Acknowledge
|
84 |
+
|
85 |
+
Our model is built upon numerous outstanding open-source projects, and we are grateful for their contributions. We extend special thanks to the google team and Qwen team for their great base models.
|
86 |
+
|
87 |
+
## License
|
88 |
+
|
89 |
+
This project is licensed under [Apache License 2.0](LICENSE).
|
config.json
CHANGED
@@ -3,10 +3,13 @@
|
|
3 |
"InfiMed"
|
4 |
],
|
5 |
"llm_config": {
|
6 |
-
"_name_or_path": "
|
7 |
"architectures": [
|
8 |
"Qwen3ForCausalLM"
|
9 |
],
|
|
|
|
|
|
|
10 |
"attention_bias": false,
|
11 |
"attention_dropout": 0.0,
|
12 |
"bos_token_id": 151643,
|
@@ -30,7 +33,8 @@
|
|
30 |
"torch_dtype": "bfloat16",
|
31 |
"use_cache": true,
|
32 |
"use_sliding_window": false,
|
33 |
-
"vocab_size": 151936
|
|
|
34 |
},
|
35 |
"load_precision": "bf16",
|
36 |
"max_length": 32,
|
@@ -47,10 +51,11 @@
|
|
47 |
],
|
48 |
"transformers_version": "4.52.4",
|
49 |
"vision_config": {
|
50 |
-
"_name_or_path": "
|
51 |
"architectures": [
|
52 |
"SiglipModel"
|
53 |
],
|
|
|
54 |
"attention_dropout": 0.0,
|
55 |
"hidden_act": "gelu_pytorch_tanh",
|
56 |
"hidden_size": 1152,
|
@@ -67,4 +72,4 @@
|
|
67 |
},
|
68 |
"wandb_entity": null,
|
69 |
"wandb_project": "mmpretrain"
|
70 |
-
}
|
|
|
3 |
"InfiMed"
|
4 |
],
|
5 |
"llm_config": {
|
6 |
+
"_name_or_path": ".",
|
7 |
"architectures": [
|
8 |
"Qwen3ForCausalLM"
|
9 |
],
|
10 |
+
"pad_token_id": 151670,
|
11 |
+
"bos_token_id": 128245,
|
12 |
+
"eos_token_id": 151645,
|
13 |
"attention_bias": false,
|
14 |
"attention_dropout": 0.0,
|
15 |
"bos_token_id": 151643,
|
|
|
33 |
"torch_dtype": "bfloat16",
|
34 |
"use_cache": true,
|
35 |
"use_sliding_window": false,
|
36 |
+
"vocab_size": 151936,
|
37 |
+
"tie_word_embeddings": false
|
38 |
},
|
39 |
"load_precision": "bf16",
|
40 |
"max_length": 32,
|
|
|
51 |
],
|
52 |
"transformers_version": "4.52.4",
|
53 |
"vision_config": {
|
54 |
+
"_name_or_path": ".",
|
55 |
"architectures": [
|
56 |
"SiglipModel"
|
57 |
],
|
58 |
+
"hidden_act": "gelu",
|
59 |
"attention_dropout": 0.0,
|
60 |
"hidden_act": "gelu_pytorch_tanh",
|
61 |
"hidden_size": 1152,
|
|
|
72 |
},
|
73 |
"wandb_entity": null,
|
74 |
"wandb_project": "mmpretrain"
|
75 |
+
}
|
model-00001-of-00002.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4966471968
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:69b2009323a7164d5b24aa2829f27dd67a64c588444a37220a976e17bc2ab9a3
|
3 |
size 4966471968
|
model-00002-of-00002.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4731957576
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cbc5e171fac4c829b458fd9eb6e6233444209a40bb3f5d5963aa481b4880dadc
|
3 |
size 4731957576
|