Gabriele Campanella
commited on
Commit
·
f1a9699
1
Parent(s):
efb420c
added download instructions
Browse files
README.md
CHANGED
@@ -14,3 +14,57 @@ tags:
|
|
14 |
|
15 |
ViT-large (300M parameters) trained on a diverse neuropathology dataset.
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
ViT-large (300M parameters) trained on a diverse neuropathology dataset.
|
16 |
|
17 |
+
|
18 |
+
## Model Usage
|
19 |
+
|
20 |
+
To get started, first clone the repository with this command:
|
21 |
+
```bash
|
22 |
+
git clone --no-checkout https://huggingface.co/MountSinaiCompPath/neuroFM_HE20x && cd neuroFM_HE20x && git sparse-checkout init --no-cone && git sparse-checkout set '/*' '!*.bin' && git checkout
|
23 |
+
```
|
24 |
+
|
25 |
+
Now you can use the following code:
|
26 |
+
```python
|
27 |
+
from PIL import Image
|
28 |
+
import numpy as np
|
29 |
+
import vision_transformer
|
30 |
+
import torch
|
31 |
+
import torch.nn as nn
|
32 |
+
import torchvision.transforms as transforms
|
33 |
+
from huggingface_hub import PyTorchModelHubMixin
|
34 |
+
|
35 |
+
class neuroFM_HE20x(nn.Module, PyTorchModelHubMixin):
|
36 |
+
def __init__(self):
|
37 |
+
super().__init__()
|
38 |
+
vit_kwargs = dict(
|
39 |
+
img_size=224,
|
40 |
+
patch_size=14,
|
41 |
+
init_values=1.0e-05,
|
42 |
+
ffn_layer='swiglufused',
|
43 |
+
block_chunks=4,
|
44 |
+
qkv_bias=True,
|
45 |
+
proj_bias=True,
|
46 |
+
ffn_bias=True,
|
47 |
+
)
|
48 |
+
self.encoder = vision_transformer.__dict__['vit_large'](**vit_kwargs)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
return self.encoder(x)
|
52 |
+
|
53 |
+
# Download model
|
54 |
+
model = neuroFM_HE20x.from_pretrained("MountSinaiCompPath/neuroFM_HE20x")
|
55 |
+
|
56 |
+
# Set up transform
|
57 |
+
transform = transforms.Compose([
|
58 |
+
transforms.ToTensor(),
|
59 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
60 |
+
])
|
61 |
+
|
62 |
+
# Image
|
63 |
+
img = np.random.randint(0, 256, size=224*224*3).reshape(224,224,3).astype(np.uint8)
|
64 |
+
img = Image.fromarray(img)
|
65 |
+
img = transform(img).unsqueeze(0)
|
66 |
+
|
67 |
+
# Inference
|
68 |
+
with torch.no_grad():
|
69 |
+
h = model(img)
|
70 |
+
```
|