Feature Extraction
Transformers
PyTorch
bbsnet
custom_code
thinh-huynh-re commited on
Commit
7f682aa
·
1 Parent(s): b207ec3

Upload model

Browse files
Files changed (4) hide show
  1. config.json +12 -0
  2. configuration_bbsnet.py +47 -0
  3. modeling_bbsnet.py +48 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BBSNetModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_bbsnet.BBSNetConfig",
7
+ "AutoModel": "modeling_bbsnet.BBSNetModel"
8
+ },
9
+ "model_type": "bbsnet",
10
+ "torch_dtype": "float32",
11
+ "transformers_version": "4.26.1"
12
+ }
configuration_bbsnet.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+ """
6
+ The configuration of a model is an object that
7
+ will contain all the necessary information to build the model.
8
+
9
+ The three important things to remember when writing you own configuration are the following:
10
+
11
+ - you have to inherit from PretrainedConfig,
12
+ - the __init__ of your PretrainedConfig must accept any kwargs,
13
+ - those kwargs need to be passed to the superclass __init__.
14
+ """
15
+
16
+
17
+ class BBSNetConfig(PretrainedConfig):
18
+
19
+ """
20
+ Defining a model_type for your configuration is not mandatory,
21
+ unless you want to register your model with the auto classes."""
22
+
23
+ model_type = "bbsnet"
24
+
25
+ def __init__(self, **kwargs):
26
+ super().__init__(**kwargs)
27
+
28
+
29
+ if __name__ == "__main__":
30
+ """
31
+ With this done, you can easily create and save your configuration like
32
+ you would do with any other model config of the library.
33
+ Here is how we can create a resnet50d config and save it:
34
+ """
35
+ bbsnet_config = BBSNetConfig()
36
+ bbsnet_config.save_pretrained("custom-bbsnet")
37
+
38
+ """
39
+ This will save a file named config.json inside the folder custom-resnet.
40
+ You can then reload your config with the from_pretrained method:
41
+ """
42
+ bbsnet_config = BBSNetConfig.from_pretrained("custom-bbsnet")
43
+
44
+ """
45
+ You can also use any other method of the PretrainedConfig class,
46
+ like push_to_hub() to directly upload your config to the Hub.
47
+ """
modeling_bbsnet.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ from torch import Tensor, nn
4
+ from transformers import PreTrainedModel
5
+
6
+ from models.BBSNet_model import BBSNet
7
+
8
+ from .configuration_bbsnet import BBSNetConfig
9
+
10
+
11
+ class BBSNetModel(PreTrainedModel):
12
+ """
13
+ The line that sets the config_class is not mandatory,
14
+ unless you want to register your model with the auto classes
15
+ """
16
+
17
+ config_class = BBSNetConfig
18
+
19
+ def __init__(self, config: BBSNetConfig):
20
+ super().__init__(config)
21
+ self.model = BBSNet()
22
+ self.loss = nn.BCEWithLogitsLoss()
23
+
24
+ """
25
+ You can have your model return anything you want,
26
+ but returning a dictionary with the loss included when labels are passed,
27
+ will make your model directly usable inside the Trainer class.
28
+ Using another output format is fine as long as you are planning on
29
+ using your own training loop or another library for training.
30
+ """
31
+
32
+ def forward(
33
+ self, rgbs: Tensor, depths: Tensor, gts: Optional[Tensor] = None
34
+ ) -> Dict[str, Tensor]:
35
+ _, logits = self.model(rgbs, depths)
36
+ if gts is not None:
37
+ loss = self.loss(logits, gts)
38
+ return {"loss": loss, "logits": logits}
39
+ return {"logits": logits}
40
+
41
+
42
+ if __name__ == "__main__":
43
+ resnet50d_config = ResnetConfig.from_pretrained("custom-resnet")
44
+ resnet50d = ResnetModelForImageClassification(resnet50d_config)
45
+
46
+ # Load pretrained weights from timm
47
+ pretrained_model: nn.Module = timm.create_model("resnet50d", pretrained=True)
48
+ resnet50d.model.load_state_dict(pretrained_model.state_dict())
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:483479ab2cc3cd42ff1be16e1ba76ac09afcb0c010487221b81c96aa22f75106
3
+ size 199976498