Aakash-Tripathi commited on
Commit
a38637c
·
verified ·
1 Parent(s): 3c6b89e

Upload configuration_mirai.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_mirai.py +130 -0
configuration_mirai.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration class for MIRAI model."""
2
+
3
+ from transformers import PretrainedConfig
4
+ from typing import List, Dict, Any, Optional
5
+
6
+
7
+ class MiraiConfig(PretrainedConfig):
8
+ """
9
+ Configuration class for MIRAI breast cancer risk prediction model.
10
+
11
+ Args:
12
+ num_classes: Number of prediction classes (default: 5 for 5-year predictions)
13
+ img_size: Input image size [height, width] (default: [1664, 2048])
14
+ num_chan: Number of image channels (default: 3)
15
+ num_images: Number of mammogram views (default: 4)
16
+ multi_image: Whether to use multiple images (default: True)
17
+ encoder_config: Configuration for the ResNet encoder
18
+ transformer_config: Configuration for the transformer module
19
+ risk_factors: Configuration for clinical risk factors
20
+ preprocessing: Image preprocessing parameters
21
+ **kwargs: Additional configuration parameters
22
+ """
23
+
24
+ model_type = "mirai"
25
+
26
+ def __init__(
27
+ self,
28
+ num_classes: int = 5,
29
+ img_size: List[int] = None,
30
+ num_chan: int = 3,
31
+ num_images: int = 4,
32
+ multi_image: bool = True,
33
+ encoder_config: Dict[str, Any] = None,
34
+ transformer_config: Dict[str, Any] = None,
35
+ risk_factors: Dict[str, Any] = None,
36
+ preprocessing: Dict[str, Any] = None,
37
+ model_version: str = "1.0.0",
38
+ **kwargs
39
+ ):
40
+ super().__init__(**kwargs)
41
+
42
+ # Model architecture
43
+ self.num_classes = num_classes
44
+ self.img_size = img_size or [1664, 2048]
45
+ self.num_chan = num_chan
46
+ self.num_images = num_images
47
+ self.multi_image = multi_image
48
+ self.model_version = model_version
49
+
50
+ # Encoder configuration
51
+ self.encoder_config = encoder_config or {
52
+ "architecture": "resnet",
53
+ "block_layout": [
54
+ ["BasicBlock", 2],
55
+ ["BasicBlock", 2],
56
+ ["BasicBlock", 2],
57
+ ["BasicBlock", 2]
58
+ ],
59
+ "pool_name": "GlobalMaxPool",
60
+ "img_only_dim": 2048,
61
+ "dropout": 0.25,
62
+ "pretrained_on_imagenet": False
63
+ }
64
+
65
+ # Transformer configuration
66
+ self.transformer_config = transformer_config or {
67
+ "hidden_dim": 256,
68
+ "num_heads": 8,
69
+ "num_layers": 6,
70
+ "dropout": 0.1,
71
+ "max_seq_length": 4
72
+ }
73
+
74
+ # Risk factors configuration
75
+ self.risk_factors = risk_factors or {
76
+ "use_risk_factors": True,
77
+ "num_risk_factors": 34,
78
+ "risk_factor_keys": [
79
+ "density", "binary_family_history", "binary_biopsy_benign", "binary_biopsy_LCIS",
80
+ "binary_biopsy_atypical_hyperplasia", "age", "menarche_age", "menopause_age",
81
+ "first_pregnancy_age", "prior_hist", "race", "parous", "menopausal_status",
82
+ "weight", "height", "ovarian_cancer", "ovarian_cancer_age", "ashkenazi",
83
+ "brca", "mom_bc_cancer_history", "m_aunt_bc_cancer_history",
84
+ "p_aunt_bc_cancer_history", "m_grandmother_bc_cancer_history",
85
+ "p_grantmother_bc_cancer_history", "sister_bc_cancer_history",
86
+ "mom_oc_cancer_history", "m_aunt_oc_cancer_history",
87
+ "p_aunt_oc_cancer_history", "m_grandmother_oc_cancer_history",
88
+ "p_grantmother_oc_cancer_history", "sister_oc_cancer_history",
89
+ "hrt_type", "hrt_duration", "hrt_years_ago_stopped"
90
+ ]
91
+ }
92
+
93
+ # Preprocessing configuration
94
+ self.preprocessing = preprocessing or {
95
+ "img_mean": 7047.99,
96
+ "img_std": 12005.5,
97
+ "normalize_method": "imagenet",
98
+ "imagenet_mean": [0.485, 0.456, 0.406],
99
+ "imagenet_std": [0.229, 0.224, 0.225]
100
+ }
101
+
102
+ @property
103
+ def use_risk_factors(self) -> bool:
104
+ """Whether the model uses clinical risk factors."""
105
+ return self.risk_factors.get("use_risk_factors", True)
106
+
107
+ @property
108
+ def num_risk_factors(self) -> int:
109
+ """Number of clinical risk factors."""
110
+ return self.risk_factors.get("num_risk_factors", 34)
111
+
112
+ @property
113
+ def risk_factor_keys(self) -> List[str]:
114
+ """List of risk factor keys."""
115
+ return self.risk_factors.get("risk_factor_keys", [])
116
+
117
+ @property
118
+ def img_height(self) -> int:
119
+ """Image height."""
120
+ return self.img_size[0]
121
+
122
+ @property
123
+ def img_width(self) -> int:
124
+ """Image width."""
125
+ return self.img_size[1]
126
+
127
+ def to_dict(self) -> Dict[str, Any]:
128
+ """Convert configuration to dictionary."""
129
+ output = super().to_dict()
130
+ return output