udman99 commited on
Commit
ea40a1d
·
verified ·
1 Parent(s): be646b9

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example.png filter=lfs diff=lfs merge=lfs -text
37
+ results.png filter=lfs diff=lfs merge=lfs -text
38
+ Screenshot[[:space:]]2024-01-21[[:space:]]at[[:space:]]11.56.17.png filter=lfs diff=lfs merge=lfs -text
39
+ T1.png filter=lfs diff=lfs merge=lfs -text
40
+ T2.png filter=lfs diff=lfs merge=lfs -text
41
+ t4.png filter=lfs diff=lfs merge=lfs -text
42
+ example_input.jpg filter=lfs diff=lfs merge=lfs -text
MyConfig.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class RMBGConfig(PretrainedConfig):
5
+ model_type = "SegformerForSemanticSegmentation"
6
+ def __init__(
7
+ self,
8
+ in_ch=3,
9
+ out_ch=1,
10
+ **kwargs):
11
+ self.in_ch = in_ch
12
+ self.out_ch = out_ch
13
+ super().__init__(**kwargs)
MyPipe.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ import torch.nn.functional as F
3
+ from torchvision.transforms.functional import normalize
4
+ import numpy as np
5
+ from transformers import Pipeline
6
+ from transformers.image_utils import load_image
7
+ from skimage import io
8
+ from PIL import Image
9
+
10
+ class RMBGPipe(Pipeline):
11
+ def __init__(self,**kwargs):
12
+ Pipeline.__init__(self,**kwargs)
13
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
+ self.model.to(self.device)
15
+ self.model.eval()
16
+
17
+ def _sanitize_parameters(self, **kwargs):
18
+ # parse parameters
19
+ preprocess_kwargs = {}
20
+ postprocess_kwargs = {}
21
+ if "model_input_size" in kwargs :
22
+ preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
23
+ if "return_mask" in kwargs:
24
+ postprocess_kwargs["return_mask"] = kwargs["return_mask"]
25
+ return preprocess_kwargs, {}, postprocess_kwargs
26
+
27
+ def preprocess(self,input_image,model_input_size: list=[1024,1024]):
28
+ # preprocess the input
29
+ orig_im = load_image(input_image)
30
+ orig_im = np.array(orig_im)
31
+ orig_im_size = orig_im.shape[0:2]
32
+ preprocessed_image = self.preprocess_image(orig_im, model_input_size).to(self.device)
33
+ inputs = {
34
+ "preprocessed_image":preprocessed_image,
35
+ "orig_im_size":orig_im_size,
36
+ "input_image" : input_image
37
+ }
38
+ return inputs
39
+
40
+ def _forward(self,inputs):
41
+ result = self.model(inputs.pop("preprocessed_image"))
42
+ inputs["result"] = result
43
+ return inputs
44
+
45
+ def postprocess(self,inputs,return_mask:bool=False ):
46
+ result = inputs.pop("result")
47
+ orig_im_size = inputs.pop("orig_im_size")
48
+ input_image = inputs.pop("input_image")
49
+ result_image = self.postprocess_image(result[0][0], orig_im_size)
50
+ pil_im = Image.fromarray(result_image)
51
+ if return_mask ==True :
52
+ return pil_im
53
+ input_image = load_image(input_image)
54
+ no_bg_image = input_image.copy()
55
+ no_bg_image.putalpha(pil_im)
56
+ return no_bg_image
57
+
58
+ # utilities functions
59
+ def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
60
+ # same as utilities.py with minor modification
61
+ if len(im.shape) < 3:
62
+ im = im[:, :, np.newaxis]
63
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
64
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
65
+ image = torch.divide(im_tensor,255.0)
66
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
67
+ return image
68
+
69
+ def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
70
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
71
+ ma = torch.max(result)
72
+ mi = torch.min(result)
73
+ result = (result-mi)/(ma-mi)
74
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
75
+ im_array = np.squeeze(im_array)
76
+ return im_array
README.md ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: bria-rmbg-1.4
4
+ license_link: https://bria.ai/bria-huggingface-model-license-agreement/
5
+ pipeline_tag: image-segmentation
6
+ tags:
7
+ - remove background
8
+ - background
9
+ - background-removal
10
+ - Pytorch
11
+ - vision
12
+ - legal liability
13
+ - transformers
14
+ - transformers.js
15
+
16
+ extra_gated_description: RMBG v1.4 is available as a source-available model for non-commercial use
17
+ extra_gated_heading: "Fill in this form to get instant access"
18
+ extra_gated_fields:
19
+ Name: text
20
+ Company/Org name: text
21
+ Org Type (Early/Growth Startup, Enterprise, Academy): text
22
+ Role: text
23
+ Country: text
24
+ Email: text
25
+ By submitting this form, I agree to BRIA’s Privacy policy and Terms & conditions, see links below: checkbox
26
+ ---
27
+
28
+ # BRIA Background Removal v1.4 Model Card
29
+
30
+ RMBG v1.4 is our state-of-the-art background removal model, designed to effectively separate foreground from background in a range of
31
+ categories and image types. This model has been trained on a carefully selected dataset, which includes:
32
+ general stock images, e-commerce, gaming, and advertising content, making it suitable for commercial use cases powering enterprise content creation at scale.
33
+ The accuracy, efficiency, and versatility currently rival leading source-available models.
34
+ It is ideal where content safety, legally licensed datasets, and bias mitigation are paramount.
35
+
36
+ Developed by BRIA AI, RMBG v1.4 is available as a source-available model for non-commercial use.
37
+
38
+
39
+ To purchase a commercial license, simply click [Here](https://go.bria.ai/3D5EGp0).
40
+
41
+
42
+ [CLICK HERE FOR A DEMO](https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4)
43
+
44
+ **NOTE** New RMBG version available! Check out [RMBG-2.0](https://huggingface.co/briaai/RMBG-2.0)
45
+
46
+ Join our [Discord community](https://discord.gg/Nxe9YW9zHS) for more information, tutorials, tools, and to connect with other users!
47
+
48
+
49
+ ![examples](t4.png)
50
+
51
+
52
+ ### Model Description
53
+
54
+ - **Developed by:** [BRIA AI](https://bria.ai/)
55
+ - **Model type:** Background Removal
56
+ - **License:** [bria-rmbg-1.4](https://bria.ai/bria-huggingface-model-license-agreement/)
57
+ - The model is released under a Creative Commons license for non-commercial use.
58
+ - Commercial use is subject to a commercial agreement with BRIA. To purchase a commercial license simply click [Here](https://go.bria.ai/3B4Asxv).
59
+
60
+ - **Model Description:** BRIA RMBG 1.4 is a saliency segmentation model trained exclusively on a professional-grade dataset.
61
+ - **BRIA:** Resources for more information: [BRIA AI](https://bria.ai/)
62
+
63
+
64
+
65
+ ## Training data
66
+ Bria-RMBG model was trained with over 12,000 high-quality, high-resolution, manually labeled (pixel-wise accuracy), fully licensed images.
67
+ Our benchmark included balanced gender, balanced ethnicity, and people with different types of disabilities.
68
+ For clarity, we provide our data distribution according to different categories, demonstrating our model’s versatility.
69
+
70
+ ### Distribution of images:
71
+
72
+ | Category | Distribution |
73
+ | -----------------------------------| -----------------------------------:|
74
+ | Objects only | 45.11% |
75
+ | People with objects/animals | 25.24% |
76
+ | People only | 17.35% |
77
+ | people/objects/animals with text | 8.52% |
78
+ | Text only | 2.52% |
79
+ | Animals only | 1.89% |
80
+
81
+ | Category | Distribution |
82
+ | -----------------------------------| -----------------------------------------:|
83
+ | Photorealistic | 87.70% |
84
+ | Non-Photorealistic | 12.30% |
85
+
86
+
87
+ | Category | Distribution |
88
+ | -----------------------------------| -----------------------------------:|
89
+ | Non Solid Background | 52.05% |
90
+ | Solid Background | 47.95%
91
+
92
+
93
+ | Category | Distribution |
94
+ | -----------------------------------| -----------------------------------:|
95
+ | Single main foreground object | 51.42% |
96
+ | Multiple objects in the foreground | 48.58% |
97
+
98
+
99
+ ## Qualitative Evaluation
100
+
101
+ ![examples](results.png)
102
+
103
+
104
+ ## Architecture
105
+
106
+ RMBG v1.4 is developed on the [IS-Net](https://github.com/xuebinqin/DIS) enhanced with our unique training scheme and proprietary dataset.
107
+ These modifications significantly improve the model’s accuracy and effectiveness in diverse image-processing scenarios.
108
+
109
+ ## Installation
110
+ ```bash
111
+ pip install -qr https://huggingface.co/briaai/RMBG-1.4/resolve/main/requirements.txt
112
+ ```
113
+
114
+ ## Usage
115
+
116
+ Either load the pipeline
117
+ ```python
118
+ from transformers import pipeline
119
+ image_path = "https://farm5.staticflickr.com/4007/4322154488_997e69e4cf_z.jpg"
120
+ pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
121
+ pillow_mask = pipe(image_path, return_mask = True) # outputs a pillow mask
122
+ pillow_image = pipe(image_path) # applies mask on input and returns a pillow image
123
+ ```
124
+
125
+ Or load the model
126
+ ```python
127
+ from transformers import AutoModelForImageSegmentation
128
+ from torchvision.transforms.functional import normalize
129
+ model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4",trust_remote_code=True)
130
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
131
+ if len(im.shape) < 3:
132
+ im = im[:, :, np.newaxis]
133
+ # orig_im_size=im.shape[0:2]
134
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
135
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
136
+ image = torch.divide(im_tensor,255.0)
137
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
138
+ return image
139
+
140
+ def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray:
141
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
142
+ ma = torch.max(result)
143
+ mi = torch.min(result)
144
+ result = (result-mi)/(ma-mi)
145
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
146
+ im_array = np.squeeze(im_array)
147
+ return im_array
148
+
149
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
150
+ model.to(device)
151
+
152
+ # prepare input
153
+ image_path = "https://farm5.staticflickr.com/4007/4322154488_997e69e4cf_z.jpg"
154
+ orig_im = io.imread(image_path)
155
+ orig_im_size = orig_im.shape[0:2]
156
+ image = preprocess_image(orig_im, model_input_size).to(device)
157
+
158
+ # inference
159
+ result=model(image)
160
+
161
+ # post process
162
+ result_image = postprocess_image(result[0][0], orig_im_size)
163
+
164
+ # save result
165
+ pil_mask_im = Image.fromarray(result_image)
166
+ orig_image = Image.open(image_path)
167
+ no_bg_image = orig_image.copy()
168
+ no_bg_image.putalpha(pil_mask_im)
169
+ ```
170
+
briarmbg.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel
5
+ from .MyConfig import RMBGConfig
6
+
7
+ class REBNCONV(nn.Module):
8
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
9
+ super(REBNCONV,self).__init__()
10
+
11
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
12
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
13
+ self.relu_s1 = nn.ReLU(inplace=True)
14
+
15
+ def forward(self,x):
16
+
17
+ hx = x
18
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
19
+
20
+ return xout
21
+
22
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
23
+ def _upsample_like(src,tar):
24
+
25
+ src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
26
+
27
+ return src
28
+
29
+
30
+ ### RSU-7 ###
31
+ class RSU7(nn.Module):
32
+
33
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
34
+ super(RSU7,self).__init__()
35
+
36
+ self.in_ch = in_ch
37
+ self.mid_ch = mid_ch
38
+ self.out_ch = out_ch
39
+
40
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
41
+
42
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
43
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
44
+
45
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
46
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
47
+
48
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
49
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
+
51
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
52
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
53
+
54
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
55
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
56
+
57
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
58
+
59
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
60
+
61
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
62
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
63
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
64
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
65
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
66
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
67
+
68
+ def forward(self,x):
69
+ b, c, h, w = x.shape
70
+
71
+ hx = x
72
+ hxin = self.rebnconvin(hx)
73
+
74
+ hx1 = self.rebnconv1(hxin)
75
+ hx = self.pool1(hx1)
76
+
77
+ hx2 = self.rebnconv2(hx)
78
+ hx = self.pool2(hx2)
79
+
80
+ hx3 = self.rebnconv3(hx)
81
+ hx = self.pool3(hx3)
82
+
83
+ hx4 = self.rebnconv4(hx)
84
+ hx = self.pool4(hx4)
85
+
86
+ hx5 = self.rebnconv5(hx)
87
+ hx = self.pool5(hx5)
88
+
89
+ hx6 = self.rebnconv6(hx)
90
+
91
+ hx7 = self.rebnconv7(hx6)
92
+
93
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
94
+ hx6dup = _upsample_like(hx6d,hx5)
95
+
96
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
97
+ hx5dup = _upsample_like(hx5d,hx4)
98
+
99
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
100
+ hx4dup = _upsample_like(hx4d,hx3)
101
+
102
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
103
+ hx3dup = _upsample_like(hx3d,hx2)
104
+
105
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
106
+ hx2dup = _upsample_like(hx2d,hx1)
107
+
108
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
109
+
110
+ return hx1d + hxin
111
+
112
+
113
+ ### RSU-6 ###
114
+ class RSU6(nn.Module):
115
+
116
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
117
+ super(RSU6,self).__init__()
118
+
119
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
120
+
121
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
122
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
123
+
124
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
125
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
126
+
127
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
128
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
129
+
130
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
131
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
132
+
133
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
134
+
135
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
136
+
137
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
138
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
139
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
140
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
141
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
142
+
143
+ def forward(self,x):
144
+
145
+ hx = x
146
+
147
+ hxin = self.rebnconvin(hx)
148
+
149
+ hx1 = self.rebnconv1(hxin)
150
+ hx = self.pool1(hx1)
151
+
152
+ hx2 = self.rebnconv2(hx)
153
+ hx = self.pool2(hx2)
154
+
155
+ hx3 = self.rebnconv3(hx)
156
+ hx = self.pool3(hx3)
157
+
158
+ hx4 = self.rebnconv4(hx)
159
+ hx = self.pool4(hx4)
160
+
161
+ hx5 = self.rebnconv5(hx)
162
+
163
+ hx6 = self.rebnconv6(hx5)
164
+
165
+
166
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
167
+ hx5dup = _upsample_like(hx5d,hx4)
168
+
169
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
170
+ hx4dup = _upsample_like(hx4d,hx3)
171
+
172
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
173
+ hx3dup = _upsample_like(hx3d,hx2)
174
+
175
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
176
+ hx2dup = _upsample_like(hx2d,hx1)
177
+
178
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
179
+
180
+ return hx1d + hxin
181
+
182
+ ### RSU-5 ###
183
+ class RSU5(nn.Module):
184
+
185
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
186
+ super(RSU5,self).__init__()
187
+
188
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
189
+
190
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
191
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
192
+
193
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
194
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
195
+
196
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
197
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
198
+
199
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
200
+
201
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
202
+
203
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
204
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
205
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
206
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
207
+
208
+ def forward(self,x):
209
+
210
+ hx = x
211
+
212
+ hxin = self.rebnconvin(hx)
213
+
214
+ hx1 = self.rebnconv1(hxin)
215
+ hx = self.pool1(hx1)
216
+
217
+ hx2 = self.rebnconv2(hx)
218
+ hx = self.pool2(hx2)
219
+
220
+ hx3 = self.rebnconv3(hx)
221
+ hx = self.pool3(hx3)
222
+
223
+ hx4 = self.rebnconv4(hx)
224
+
225
+ hx5 = self.rebnconv5(hx4)
226
+
227
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
228
+ hx4dup = _upsample_like(hx4d,hx3)
229
+
230
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
231
+ hx3dup = _upsample_like(hx3d,hx2)
232
+
233
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
234
+ hx2dup = _upsample_like(hx2d,hx1)
235
+
236
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
237
+
238
+ return hx1d + hxin
239
+
240
+ ### RSU-4 ###
241
+ class RSU4(nn.Module):
242
+
243
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
244
+ super(RSU4,self).__init__()
245
+
246
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
247
+
248
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
249
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
250
+
251
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
252
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
253
+
254
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
255
+
256
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
257
+
258
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
259
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
260
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
261
+
262
+ def forward(self,x):
263
+
264
+ hx = x
265
+
266
+ hxin = self.rebnconvin(hx)
267
+
268
+ hx1 = self.rebnconv1(hxin)
269
+ hx = self.pool1(hx1)
270
+
271
+ hx2 = self.rebnconv2(hx)
272
+ hx = self.pool2(hx2)
273
+
274
+ hx3 = self.rebnconv3(hx)
275
+
276
+ hx4 = self.rebnconv4(hx3)
277
+
278
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
279
+ hx3dup = _upsample_like(hx3d,hx2)
280
+
281
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
282
+ hx2dup = _upsample_like(hx2d,hx1)
283
+
284
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
285
+
286
+ return hx1d + hxin
287
+
288
+ ### RSU-4F ###
289
+ class RSU4F(nn.Module):
290
+
291
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
292
+ super(RSU4F,self).__init__()
293
+
294
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
295
+
296
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
297
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
298
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
299
+
300
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
301
+
302
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
303
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
304
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
305
+
306
+ def forward(self,x):
307
+
308
+ hx = x
309
+
310
+ hxin = self.rebnconvin(hx)
311
+
312
+ hx1 = self.rebnconv1(hxin)
313
+ hx2 = self.rebnconv2(hx1)
314
+ hx3 = self.rebnconv3(hx2)
315
+
316
+ hx4 = self.rebnconv4(hx3)
317
+
318
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
319
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
320
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
321
+
322
+ return hx1d + hxin
323
+
324
+
325
+ class myrebnconv(nn.Module):
326
+ def __init__(self, in_ch=3,
327
+ out_ch=1,
328
+ kernel_size=3,
329
+ stride=1,
330
+ padding=1,
331
+ dilation=1,
332
+ groups=1):
333
+ super(myrebnconv,self).__init__()
334
+
335
+ self.conv = nn.Conv2d(in_ch,
336
+ out_ch,
337
+ kernel_size=kernel_size,
338
+ stride=stride,
339
+ padding=padding,
340
+ dilation=dilation,
341
+ groups=groups)
342
+ self.bn = nn.BatchNorm2d(out_ch)
343
+ self.rl = nn.ReLU(inplace=True)
344
+
345
+ def forward(self,x):
346
+ return self.rl(self.bn(self.conv(x)))
347
+
348
+
349
+ class BriaRMBG(PreTrainedModel):
350
+ config_class = RMBGConfig
351
+ def __init__(self,config:RMBGConfig = RMBGConfig()):
352
+ super().__init__(config)
353
+ in_ch = config.in_ch # 3
354
+ out_ch = config.out_ch # 1
355
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
356
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
357
+
358
+ self.stage1 = RSU7(64,32,64)
359
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
360
+
361
+ self.stage2 = RSU6(64,32,128)
362
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
363
+
364
+ self.stage3 = RSU5(128,64,256)
365
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
366
+
367
+ self.stage4 = RSU4(256,128,512)
368
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
369
+
370
+ self.stage5 = RSU4F(512,256,512)
371
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
372
+
373
+ self.stage6 = RSU4F(512,256,512)
374
+
375
+ # decoder
376
+ self.stage5d = RSU4F(1024,256,512)
377
+ self.stage4d = RSU4(1024,128,256)
378
+ self.stage3d = RSU5(512,64,128)
379
+ self.stage2d = RSU6(256,32,64)
380
+ self.stage1d = RSU7(128,16,64)
381
+
382
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
383
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
384
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
385
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
386
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
387
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
388
+
389
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
390
+
391
+ def forward(self,x):
392
+
393
+ hx = x
394
+
395
+ hxin = self.conv_in(hx)
396
+ #hx = self.pool_in(hxin)
397
+
398
+ #stage 1
399
+ hx1 = self.stage1(hxin)
400
+ hx = self.pool12(hx1)
401
+
402
+ #stage 2
403
+ hx2 = self.stage2(hx)
404
+ hx = self.pool23(hx2)
405
+
406
+ #stage 3
407
+ hx3 = self.stage3(hx)
408
+ hx = self.pool34(hx3)
409
+
410
+ #stage 4
411
+ hx4 = self.stage4(hx)
412
+ hx = self.pool45(hx4)
413
+
414
+ #stage 5
415
+ hx5 = self.stage5(hx)
416
+ hx = self.pool56(hx5)
417
+
418
+ #stage 6
419
+ hx6 = self.stage6(hx)
420
+ hx6up = _upsample_like(hx6,hx5)
421
+
422
+ #-------------------- decoder --------------------
423
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
424
+ hx5dup = _upsample_like(hx5d,hx4)
425
+
426
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
427
+ hx4dup = _upsample_like(hx4d,hx3)
428
+
429
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
430
+ hx3dup = _upsample_like(hx3d,hx2)
431
+
432
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
433
+ hx2dup = _upsample_like(hx2d,hx1)
434
+
435
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
436
+
437
+
438
+ #side output
439
+ d1 = self.side1(hx1d)
440
+ d1 = _upsample_like(d1,x)
441
+
442
+ d2 = self.side2(hx2d)
443
+ d2 = _upsample_like(d2,x)
444
+
445
+ d3 = self.side3(hx3d)
446
+ d3 = _upsample_like(d3,x)
447
+
448
+ d4 = self.side4(hx4d)
449
+ d4 = _upsample_like(d4,x)
450
+
451
+ d5 = self.side5(hx5d)
452
+ d5 = _upsample_like(d5,x)
453
+
454
+ d6 = self.side6(hx6)
455
+ d6 = _upsample_like(d6,x)
456
+
457
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
458
+
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "briaai/RMBG-1.4",
3
+ "architectures": [
4
+ "BriaRMBG"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "MyConfig.RMBGConfig",
8
+ "AutoModelForImageSegmentation": "briarmbg.BriaRMBG"
9
+ },
10
+ "custom_pipelines": {
11
+ "image-segmentation": {
12
+ "impl": "MyPipe.RMBGPipe",
13
+ "pt": [
14
+ "AutoModelForImageSegmentation"
15
+ ],
16
+ "tf": [],
17
+ "type": "image"
18
+ }
19
+ },
20
+ "in_ch": 3,
21
+ "model_type": "SegformerForSemanticSegmentation",
22
+ "out_ch": 1,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.38.0.dev0"
25
+ }
example_inference.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from skimage import io
2
+ import torch, os
3
+ from PIL import Image
4
+ from briarmbg import BriaRMBG
5
+ from utilities import preprocess_image, postprocess_image
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ def example_inference():
9
+
10
+ im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
11
+
12
+ net = BriaRMBG()
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
15
+ net.to(device)
16
+ net.eval()
17
+
18
+ # prepare input
19
+ model_input_size = [1024,1024]
20
+ orig_im = io.imread(im_path)
21
+ orig_im_size = orig_im.shape[0:2]
22
+ image = preprocess_image(orig_im, model_input_size).to(device)
23
+
24
+ # inference
25
+ result=net(image)
26
+
27
+ # post process
28
+ result_image = postprocess_image(result[0][0], orig_im_size)
29
+
30
+ # save result
31
+ pil_mask_im = Image.fromarray(result_image)
32
+ orig_image = Image.open(im_path)
33
+ no_bg_image = orig_image.copy()
34
+ no_bg_image.putalpha(pil_mask_im)
35
+ no_bg_image.save("example_image_no_bg.png")
36
+
37
+
38
+ if __name__ == "__main__":
39
+ example_inference()
example_input.jpg ADDED

Git LFS Details

  • SHA256: 1e9cff13a43d13ec0d0d733a55234e862a35c282cdbfa197c85223a937f28a56
  • Pointer size: 131 Bytes
  • Size of remote file: 327 kB
handler.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ class EndpointHandler():
5
+ def __init__(self, path=""):
6
+ # Initialize the image segmentation pipeline
7
+ self.pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
8
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
9
+ # Extract the image path from the input data
10
+ image_path = data.get("image_path", "")
11
+
12
+ # Perform image segmentation
13
+ pillow_mask = self.pipe(image_path, return_mask=True) # outputs a pillow mask
14
+ pillow_image = self.pipe(image_path) # outputs the segmented image
15
+
16
+ # Save the segmented image at the root folder
17
+ output_image_path = "segmented_image.png"
18
+ pillow_image.save(output_image_path)
19
+
20
+ # Return the result as a list of dictionaries
21
+ return [{"image_path": output_image_path, "mask": pillow_mask}]
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:893c16c340b1ddafc93e78457a4d94190da9b7179149f8574284c83caebf5e8c
3
+ size 176718373
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46ef7fe46f2ae284d8f1aaa24bfa5fca5ef25a34e2c7caa890a0029eb100e87f
3
+ size 176381984
onnx/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cafcf770b06757c4eaced21b1a88e57fd2b66de01b8045f35f01535ba742e0f
3
+ size 176153355
onnx/model_fp16.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fdfdb41866d872e0acf4a010c35c1a8547bf0eebe0d1544406bbf1c824cb59d
3
+ size 88217533
onnx/model_quantized.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6648479275dfd0ede0f3a8abc20aa5c437b394681b05e5af6d268250aaf40f3
3
+ size 44403226
onnx/quantize_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "per_channel": false,
3
+ "reduce_range": false,
4
+ "per_model_config": {
5
+ "model": {
6
+ "op_types": [
7
+ "Concat",
8
+ "MaxPool",
9
+ "Resize",
10
+ "Conv",
11
+ "Unsqueeze",
12
+ "Cast",
13
+ "Shape",
14
+ "Relu",
15
+ "Sigmoid",
16
+ "Gather",
17
+ "Constant",
18
+ "Slice",
19
+ "Add"
20
+ ],
21
+ "weight_type": "QUInt8"
22
+ }
23
+ }
24
+ }
preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "do_pad": false,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "feature_extractor_type": "ImageFeatureExtractor",
12
+ "image_std": [
13
+ 1,
14
+ 1,
15
+ 1
16
+ ],
17
+ "resample": 2,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "width": 1024,
21
+ "height": 1024
22
+ }
23
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59569acdb281ac9fc9f78f9d33b6f9f17f68e25086b74f9025c35bb5f2848967
3
+ size 176574018
requirements.txt ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ accelerate==1.7.0
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.11.15
5
+ aiosignal==1.3.2
6
+ alabaster==1.0.0
7
+ albucore==0.0.24
8
+ albumentations==2.0.7
9
+ ale-py==0.11.0
10
+ altair==5.5.0
11
+ annotated-types==0.7.0
12
+ antlr4-python3-runtime==4.9.3
13
+ anyio==4.9.0
14
+ argon2-cffi==23.1.0
15
+ argon2-cffi-bindings==21.2.0
16
+ array_record==0.7.2
17
+ arviz==0.21.0
18
+ astropy==7.1.0
19
+ astropy-iers-data==0.2025.5.19.0.38.36
20
+ astunparse==1.6.3
21
+ atpublic==5.1
22
+ attrs==25.3.0
23
+ audioread==3.0.1
24
+ autograd==1.8.0
25
+ babel==2.17.0
26
+ backcall==0.2.0
27
+ backports.tarfile==1.2.0
28
+ beautifulsoup4==4.13.4
29
+ betterproto==2.0.0b6
30
+ bigframes==2.4.0
31
+ bigquery-magics==0.9.0
32
+ bleach==6.2.0
33
+ blinker==1.9.0
34
+ blis==1.3.0
35
+ blobfile==3.0.0
36
+ blosc2==3.3.3
37
+ bokeh==3.7.3
38
+ Bottleneck==1.4.2
39
+ bqplot==0.12.45
40
+ branca==0.8.1
41
+ build==1.2.2.post1
42
+ CacheControl==0.14.3
43
+ cachetools==5.5.2
44
+ catalogue==2.0.10
45
+ certifi==2025.4.26
46
+ cffi==1.17.1
47
+ chardet==5.2.0
48
+ charset-normalizer==3.4.2
49
+ chex==0.1.89
50
+ clarabel==0.10.0
51
+ click==8.2.1
52
+ cloudpathlib==0.21.1
53
+ cloudpickle==3.1.1
54
+ cmake==3.31.6
55
+ cmdstanpy==1.2.5
56
+ colorcet==3.1.0
57
+ colorlover==0.3.0
58
+ colour==0.1.5
59
+ community==1.0.0b1
60
+ confection==0.1.5
61
+ cons==0.4.6
62
+ contourpy==1.3.2
63
+ cramjam==2.10.0
64
+ cryptography==43.0.3
65
+ cuda-python==12.6.2.post1
66
+ cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-25.2.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
67
+ cudf-polars-cu12==25.2.2
68
+ cufflinks==0.17.3
69
+ cuml-cu12==25.2.1
70
+ cupy-cuda12x==13.3.0
71
+ curl_cffi==0.11.1
72
+ cuvs-cu12==25.2.1
73
+ cvxopt==1.3.2
74
+ cvxpy==1.6.5
75
+ cycler==0.12.1
76
+ cyipopt==1.5.0
77
+ cymem==2.0.11
78
+ Cython==3.0.12
79
+ dask==2024.12.1
80
+ dask-cuda==25.2.0
81
+ dask-cudf-cu12==25.2.2
82
+ dask-expr==1.1.21
83
+ dataproc-spark-connect==0.7.4
84
+ datascience==0.17.6
85
+ datasets==2.14.4
86
+ db-dtypes==1.4.3
87
+ dbus-python==1.2.18
88
+ debugpy==1.8.0
89
+ decorator==4.4.2
90
+ defusedxml==0.7.1
91
+ diffusers==0.33.1
92
+ dill==0.3.7
93
+ distributed==2024.12.1
94
+ distributed-ucxx-cu12==0.42.0
95
+ distro==1.9.0
96
+ dlib==19.24.6
97
+ dm-tree==0.1.9
98
+ docker-pycreds==0.4.0
99
+ docstring_parser==0.16
100
+ docutils==0.21.2
101
+ dopamine_rl==4.1.2
102
+ duckdb==1.2.2
103
+ earthengine-api==1.5.15
104
+ easydict==1.13
105
+ editdistance==0.8.1
106
+ eerepr==0.1.2
107
+ einops==0.8.1
108
+ en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl#sha256=1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85
109
+ entrypoints==0.4
110
+ et_xmlfile==2.0.0
111
+ etils==1.12.2
112
+ etuples==0.3.9
113
+ Farama-Notifications==0.0.4
114
+ fastai==2.7.19
115
+ fastcore==1.7.29
116
+ fastdownload==0.0.7
117
+ fastjsonschema==2.21.1
118
+ fastprogress==1.0.3
119
+ fastrlock==0.8.3
120
+ filelock==3.18.0
121
+ firebase-admin==6.8.0
122
+ Flask==3.1.1
123
+ flatbuffers==25.2.10
124
+ flax==0.10.6
125
+ folium==0.19.6
126
+ fonttools==4.58.0
127
+ frozendict==2.4.6
128
+ frozenlist==1.6.0
129
+ fsspec==2025.3.2
130
+ future==1.0.0
131
+ gast==0.6.0
132
+ gcsfs==2025.3.2
133
+ GDAL==3.8.4
134
+ gdown==5.2.0
135
+ geemap==0.35.3
136
+ geocoder==1.38.1
137
+ geographiclib==2.0
138
+ geopandas==1.0.1
139
+ geopy==2.4.1
140
+ gin-config==0.5.0
141
+ gitdb==4.0.12
142
+ GitPython==3.1.44
143
+ glob2==0.7
144
+ google==2.0.3
145
+ google-ai-generativelanguage==0.6.15
146
+ google-api-core==2.24.2
147
+ google-api-python-client==2.169.0
148
+ google-auth==2.38.0
149
+ google-auth-httplib2==0.2.0
150
+ google-auth-oauthlib==1.2.2
151
+ google-cloud-aiplatform==1.93.1
152
+ google-cloud-bigquery==3.33.0
153
+ google-cloud-bigquery-connection==1.18.2
154
+ google-cloud-bigquery-storage==2.31.0
155
+ google-cloud-core==2.4.3
156
+ google-cloud-dataproc==5.18.1
157
+ google-cloud-datastore==2.21.0
158
+ google-cloud-firestore==2.20.2
159
+ google-cloud-functions==1.20.3
160
+ google-cloud-iam==2.19.0
161
+ google-cloud-language==2.17.1
162
+ google-cloud-resource-manager==1.14.2
163
+ google-cloud-spanner==3.54.0
164
+ google-cloud-storage==2.19.0
165
+ google-cloud-translate==3.20.2
166
+ google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz
167
+ google-crc32c==1.7.1
168
+ google-genai==1.16.1
169
+ google-generativeai==0.8.5
170
+ google-pasta==0.2.0
171
+ google-resumable-media==2.7.2
172
+ googleapis-common-protos==1.70.0
173
+ googledrivedownloader==1.1.0
174
+ graphviz==0.20.3
175
+ greenlet==3.2.2
176
+ grpc-google-iam-v1==0.14.2
177
+ grpc-interceptor==0.15.4
178
+ grpcio==1.71.0
179
+ grpcio-status==1.71.0
180
+ grpclib==0.4.8
181
+ gspread==6.2.1
182
+ gspread-dataframe==4.0.0
183
+ gym==0.25.2
184
+ gym-notices==0.0.8
185
+ gymnasium==1.1.1
186
+ h11==0.16.0
187
+ h2==4.2.0
188
+ h5netcdf==1.6.1
189
+ h5py==3.13.0
190
+ hdbscan==0.8.40
191
+ hf_transfer==0.1.9
192
+ highspy==1.10.0
193
+ holidays==0.73
194
+ holoviews==1.20.2
195
+ hpack==4.1.0
196
+ html5lib==1.1
197
+ httpcore==1.0.9
198
+ httpimport==1.4.1
199
+ httplib2==0.22.0
200
+ httpx==0.28.1
201
+ huggingface-hub==0.31.4
202
+ humanize==4.12.3
203
+ hyperframe==6.1.0
204
+ hyperopt==0.2.7
205
+ ibis-framework==9.5.0
206
+ idna==3.10
207
+ imageio==2.37.0
208
+ imageio-ffmpeg==0.6.0
209
+ imagesize==1.4.1
210
+ imbalanced-learn==0.13.0
211
+ immutabledict==4.2.1
212
+ importlib_metadata==8.7.0
213
+ importlib_resources==6.5.2
214
+ imutils==0.5.4
215
+ inflect==7.5.0
216
+ iniconfig==2.1.0
217
+ intel-cmplr-lib-ur==2025.1.1
218
+ intel-openmp==2025.1.1
219
+ ipyevents==2.0.2
220
+ ipyfilechooser==0.6.0
221
+ ipykernel==6.17.1
222
+ ipyleaflet==0.19.2
223
+ ipyparallel==8.8.0
224
+ ipython==7.34.0
225
+ ipython-genutils==0.2.0
226
+ ipython-sql==0.5.0
227
+ ipytree==0.2.2
228
+ ipywidgets==7.7.1
229
+ itsdangerous==2.2.0
230
+ jaraco.classes==3.4.0
231
+ jaraco.context==6.0.1
232
+ jaraco.functools==4.1.0
233
+ jax==0.5.2
234
+ jax-cuda12-pjrt==0.5.1
235
+ jax-cuda12-plugin==0.5.1
236
+ jaxlib==0.5.1
237
+ jeepney==0.9.0
238
+ jieba==0.42.1
239
+ Jinja2==3.1.6
240
+ jiter==0.10.0
241
+ joblib==1.5.0
242
+ jsonpatch==1.33
243
+ jsonpickle==4.1.0
244
+ jsonpointer==3.0.0
245
+ jsonschema==4.23.0
246
+ jsonschema-specifications==2025.4.1
247
+ jupyter-client==6.1.12
248
+ jupyter-console==6.1.0
249
+ jupyter-leaflet==0.19.2
250
+ jupyter-server==1.16.0
251
+ jupyter_core==5.7.2
252
+ jupyter_kernel_gateway @ git+https://github.com/googlecolab/kernel_gateway@b134e9945df25c2dcb98ade9129399be10788671
253
+ jupyterlab_pygments==0.3.0
254
+ jupyterlab_widgets==3.0.15
255
+ kaggle==1.7.4.5
256
+ kagglehub==0.3.12
257
+ keras==3.8.0
258
+ keras-hub==0.18.1
259
+ keras-nlp==0.18.1
260
+ keyring==25.6.0
261
+ keyrings.google-artifactregistry-auth==1.1.2
262
+ kiwisolver==1.4.8
263
+ langchain==0.3.25
264
+ langchain-core==0.3.60
265
+ langchain-text-splitters==0.3.8
266
+ langcodes==3.5.0
267
+ langsmith==0.3.42
268
+ language_data==1.3.0
269
+ launchpadlib==1.10.16
270
+ lazr.restfulclient==0.14.4
271
+ lazr.uri==1.0.6
272
+ lazy_loader==0.4
273
+ libclang==18.1.1
274
+ libcudf-cu12 @ https://pypi.nvidia.com/libcudf-cu12/libcudf_cu12-25.2.1-py3-none-manylinux_2_28_x86_64.whl
275
+ libcugraph-cu12==25.2.0
276
+ libcuml-cu12==25.2.1
277
+ libcuvs-cu12==25.2.1
278
+ libkvikio-cu12==25.2.1
279
+ libpysal==4.13.0
280
+ libraft-cu12==25.2.0
281
+ librosa==0.11.0
282
+ libucx-cu12==1.18.1
283
+ libucxx-cu12==0.42.0
284
+ lightgbm @ file:///tmp/lightgbm/LightGBM/dist/lightgbm-4.5.0-py3-none-linux_x86_64.whl
285
+ linkify-it-py==2.0.3
286
+ llvmlite==0.43.0
287
+ locket==1.0.0
288
+ logical-unification==0.4.6
289
+ lxml==5.4.0
290
+ Mako==1.1.3
291
+ marisa-trie==1.2.1
292
+ Markdown==3.8
293
+ markdown-it-py==3.0.0
294
+ MarkupSafe==3.0.2
295
+ matplotlib==3.10.0
296
+ matplotlib-inline==0.1.7
297
+ matplotlib-venn==1.1.2
298
+ mdit-py-plugins==0.4.2
299
+ mdurl==0.1.2
300
+ miniKanren==1.0.3
301
+ missingno==0.5.2
302
+ mistune==3.1.3
303
+ mizani==0.13.5
304
+ mkl==2025.0.1
305
+ ml-dtypes==0.4.1
306
+ mlxtend==0.23.4
307
+ more-itertools==10.7.0
308
+ moviepy==1.0.3
309
+ mpmath==1.3.0
310
+ msgpack==1.1.0
311
+ multidict==6.4.4
312
+ multipledispatch==1.0.0
313
+ multiprocess==0.70.15
314
+ multitasking==0.0.11
315
+ murmurhash==1.0.12
316
+ music21==9.3.0
317
+ namex==0.0.9
318
+ narwhals==1.40.0
319
+ natsort==8.4.0
320
+ nbclassic==1.3.1
321
+ nbclient==0.10.2
322
+ nbconvert==7.16.6
323
+ nbformat==5.10.4
324
+ ndindex==1.10.0
325
+ nest-asyncio==1.6.0
326
+ networkx==3.4.2
327
+ nibabel==5.3.2
328
+ nltk==3.9.1
329
+ notebook==6.5.7
330
+ notebook_shim==0.2.4
331
+ numba==0.60.0
332
+ numba-cuda==0.2.0
333
+ numexpr==2.10.2
334
+ numpy==2.0.2
335
+ nvidia-cublas-cu12==12.5.3.2
336
+ nvidia-cuda-cupti-cu12==12.5.82
337
+ nvidia-cuda-nvcc-cu12==12.5.82
338
+ nvidia-cuda-nvrtc-cu12==12.5.82
339
+ nvidia-cuda-runtime-cu12==12.5.82
340
+ nvidia-cudnn-cu12==9.3.0.75
341
+ nvidia-cufft-cu12==11.2.3.61
342
+ nvidia-curand-cu12==10.3.6.82
343
+ nvidia-cusolver-cu12==11.6.3.83
344
+ nvidia-cusparse-cu12==12.5.1.3
345
+ nvidia-cusparselt-cu12==0.6.2
346
+ nvidia-ml-py==12.575.51
347
+ nvidia-nccl-cu12==2.21.5
348
+ nvidia-nvcomp-cu12==4.2.0.11
349
+ nvidia-nvjitlink-cu12==12.5.82
350
+ nvidia-nvtx-cu12==12.4.127
351
+ nvtx==0.2.11
352
+ nx-cugraph-cu12 @ https://pypi.nvidia.com/nx-cugraph-cu12/nx_cugraph_cu12-25.2.0-py3-none-any.whl
353
+ oauth2client==4.1.3
354
+ oauthlib==3.2.2
355
+ omegaconf==2.3.0
356
+ openai==1.81.0
357
+ opencv-contrib-python==4.11.0.86
358
+ opencv-python==4.11.0.86
359
+ opencv-python-headless==4.11.0.86
360
+ openpyxl==3.1.5
361
+ opt_einsum==3.4.0
362
+ optax==0.2.4
363
+ optree==0.15.0
364
+ orbax-checkpoint==0.11.13
365
+ orjson==3.10.18
366
+ osqp==1.0.4
367
+ packaging==24.2
368
+ pandas==2.2.2
369
+ pandas-datareader==0.10.0
370
+ pandas-gbq==0.29.0
371
+ pandas-stubs==2.2.2.240909
372
+ pandocfilters==1.5.1
373
+ panel==1.7.0
374
+ param==2.2.0
375
+ parso==0.8.4
376
+ parsy==2.1
377
+ partd==1.4.2
378
+ pathlib==1.0.1
379
+ patsy==1.0.1
380
+ peewee==3.18.1
381
+ peft==0.15.2
382
+ pexpect==4.9.0
383
+ pickleshare==0.7.5
384
+ pillow==11.2.1
385
+ platformdirs==4.3.8
386
+ plotly==5.24.1
387
+ plotnine==0.14.5
388
+ pluggy==1.6.0
389
+ ply==3.11
390
+ polars==1.21.0
391
+ pooch==1.8.2
392
+ portpicker==1.5.2
393
+ preshed==3.0.9
394
+ prettytable==3.16.0
395
+ proglog==0.1.12
396
+ progressbar2==4.5.0
397
+ prometheus_client==0.22.0
398
+ promise==2.3
399
+ prompt_toolkit==3.0.51
400
+ propcache==0.3.1
401
+ prophet==1.1.6
402
+ proto-plus==1.26.1
403
+ protobuf==5.29.4
404
+ psutil==5.9.5
405
+ psycopg2==2.9.10
406
+ ptyprocess==0.7.0
407
+ py-cpuinfo==9.0.0
408
+ py4j==0.10.9.7
409
+ pyarrow==18.1.0
410
+ pyasn1==0.6.1
411
+ pyasn1_modules==0.4.2
412
+ pycairo==1.28.0
413
+ pycocotools==2.0.8
414
+ pycparser==2.22
415
+ pycryptodomex==3.23.0
416
+ pydantic==2.11.4
417
+ pydantic_core==2.33.2
418
+ pydata-google-auth==1.9.1
419
+ pydot==3.0.4
420
+ pydotplus==2.0.2
421
+ PyDrive==1.3.1
422
+ PyDrive2==1.21.3
423
+ pyerfa==2.0.1.5
424
+ pygame==2.6.1
425
+ pygit2==1.18.0
426
+ Pygments==2.19.1
427
+ PyGObject==3.42.0
428
+ PyJWT==2.10.1
429
+ pylibcudf-cu12 @ https://pypi.nvidia.com/pylibcudf-cu12/pylibcudf_cu12-25.2.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
430
+ pylibcugraph-cu12==25.2.0
431
+ pylibraft-cu12==25.2.0
432
+ pymc==5.22.0
433
+ pymystem3==0.2.0
434
+ pynndescent==0.5.13
435
+ pynvjitlink-cu12==0.6.0
436
+ pynvml==12.0.0
437
+ pyogrio==0.11.0
438
+ pyomo==6.9.2
439
+ PyOpenGL==3.1.9
440
+ pyOpenSSL==24.2.1
441
+ pyparsing==3.2.3
442
+ pyperclip==1.9.0
443
+ pyproj==3.7.1
444
+ pyproject_hooks==1.2.0
445
+ pyshp==2.3.1
446
+ PySocks==1.7.1
447
+ pyspark==3.5.1
448
+ pytensor==2.30.3
449
+ pytest==8.3.5
450
+ python-apt==0.0.0
451
+ python-box==7.3.2
452
+ python-dateutil==2.9.0.post0
453
+ python-louvain==0.16
454
+ python-slugify==8.0.4
455
+ python-snappy==0.7.3
456
+ python-utils==3.9.1
457
+ pytz==2025.2
458
+ pyviz_comms==3.0.4
459
+ PyWavelets==1.8.0
460
+ PyYAML==6.0.2
461
+ pyzmq==24.0.1
462
+ raft-dask-cu12==25.2.0
463
+ rapids-dask-dependency==25.2.0
464
+ ratelim==0.1.6
465
+ referencing==0.36.2
466
+ regex==2024.11.6
467
+ requests==2.32.3
468
+ requests-oauthlib==2.0.0
469
+ requests-toolbelt==1.0.0
470
+ requirements-parser==0.9.0
471
+ rich==13.9.4
472
+ rmm-cu12==25.2.0
473
+ roman-numerals-py==3.1.0
474
+ rpds-py==0.25.1
475
+ rpy2==3.5.17
476
+ rsa==4.9.1
477
+ safetensors==0.5.3
478
+ scikit-image==0.25.2
479
+ scikit-learn==1.6.1
480
+ scipy==1.15.3
481
+ scooby==0.10.1
482
+ scs==3.2.7.post2
483
+ seaborn==0.13.2
484
+ SecretStorage==3.3.3
485
+ Send2Trash==1.8.3
486
+ sentence-transformers==4.1.0
487
+ sentencepiece==0.2.0
488
+ sentry-sdk==2.29.1
489
+ setproctitle==1.3.6
490
+ shap==0.47.2
491
+ shapely==2.1.1
492
+ shellingham==1.5.4
493
+ simple-parsing==0.1.7
494
+ simplejson==3.20.1
495
+ simsimd==6.2.1
496
+ six==1.17.0
497
+ sklearn-compat==0.1.3
498
+ sklearn-pandas==2.2.0
499
+ slicer==0.0.8
500
+ smart-open==7.1.0
501
+ smmap==5.0.2
502
+ sniffio==1.3.1
503
+ snowballstemmer==3.0.1
504
+ sortedcontainers==2.4.0
505
+ soundfile==0.13.1
506
+ soupsieve==2.7
507
+ soxr==0.5.0.post1
508
+ spacy==3.8.6
509
+ spacy-legacy==3.0.12
510
+ spacy-loggers==1.0.5
511
+ spanner-graph-notebook==1.1.6
512
+ Sphinx==8.2.3
513
+ sphinxcontrib-applehelp==2.0.0
514
+ sphinxcontrib-devhelp==2.0.0
515
+ sphinxcontrib-htmlhelp==2.1.0
516
+ sphinxcontrib-jsmath==1.0.1
517
+ sphinxcontrib-qthelp==2.0.0
518
+ sphinxcontrib-serializinghtml==2.0.0
519
+ SQLAlchemy==2.0.41
520
+ sqlglot==25.20.2
521
+ sqlparse==0.5.3
522
+ srsly==2.5.1
523
+ stanio==0.5.1
524
+ statsmodels==0.14.4
525
+ stringzilla==3.12.5
526
+ stumpy==1.13.0
527
+ sympy==1.13.1
528
+ tables==3.10.2
529
+ tabulate==0.9.0
530
+ tbb==2022.1.0
531
+ tblib==3.1.0
532
+ tcmlib==1.3.0
533
+ tenacity==9.1.2
534
+ tensorboard==2.18.0
535
+ tensorboard-data-server==0.7.2
536
+ tensorflow==2.18.0
537
+ tensorflow-datasets==4.9.8
538
+ tensorflow-hub==0.16.1
539
+ tensorflow-io-gcs-filesystem==0.37.1
540
+ tensorflow-metadata==1.17.1
541
+ tensorflow-probability==0.25.0
542
+ tensorflow-text==2.18.1
543
+ tensorflow_decision_forests==1.11.0
544
+ tensorstore==0.1.74
545
+ termcolor==3.1.0
546
+ terminado==0.18.1
547
+ text-unidecode==1.3
548
+ textblob==0.19.0
549
+ tf-slim==1.1.0
550
+ tf_keras==2.18.0
551
+ thinc==8.3.6
552
+ threadpoolctl==3.6.0
553
+ tifffile==2025.5.21
554
+ tiktoken==0.9.0
555
+ timm==1.0.15
556
+ tinycss2==1.4.0
557
+ tokenizers==0.21.1
558
+ toml==0.10.2
559
+ toolz==0.12.1
560
+ torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp311-cp311-linux_x86_64.whl
561
+ torchao==0.10.0
562
+ torchaudio @ https://download.pytorch.org/whl/cu124/torchaudio-2.6.0%2Bcu124-cp311-cp311-linux_x86_64.whl
563
+ torchdata==0.11.0
564
+ torchsummary==1.5.1
565
+ torchtune==0.6.1
566
+ torchvision @ https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp311-cp311-linux_x86_64.whl
567
+ tornado==6.4.2
568
+ tqdm==4.67.1
569
+ traitlets==5.7.1
570
+ traittypes==0.2.1
571
+ transformers==4.52.2
572
+ treelite==4.4.1
573
+ treescope==0.1.9
574
+ triton==3.2.0
575
+ tsfresh==0.21.0
576
+ tweepy==4.15.0
577
+ typeguard==4.4.2
578
+ typer==0.15.3
579
+ types-pytz==2025.2.0.20250516
580
+ types-setuptools==80.8.0.20250521
581
+ typing-inspection==0.4.1
582
+ typing_extensions==4.13.2
583
+ tzdata==2025.2
584
+ tzlocal==5.3.1
585
+ uc-micro-py==1.0.3
586
+ ucx-py-cu12==0.42.0
587
+ ucxx-cu12==0.42.0
588
+ umap-learn==0.5.7
589
+ umf==0.10.0
590
+ uritemplate==4.1.1
591
+ urllib3==2.4.0
592
+ vega-datasets==0.9.0
593
+ wadllib==1.3.6
594
+ wandb==0.19.11
595
+ wasabi==1.1.3
596
+ wcwidth==0.2.13
597
+ weasel==0.4.1
598
+ webcolors==24.11.1
599
+ webencodings==0.5.1
600
+ websocket-client==1.8.0
601
+ websockets==15.0.1
602
+ Werkzeug==3.1.3
603
+ widgetsnbextension==3.6.10
604
+ wordcloud==1.9.4
605
+ wrapt==1.17.2
606
+ wurlitzer==3.1.1
607
+ xarray==2025.3.1
608
+ xarray-einstats==0.8.0
609
+ xgboost==2.1.4
610
+ xlrd==2.0.1
611
+ xxhash==3.5.0
612
+ xyzservices==2025.4.0
613
+ yarl==1.20.0
614
+ ydf==0.12.0
615
+ yellowbrick==1.5
616
+ yfinance==0.2.61
617
+ zict==3.0.0
618
+ zipp==3.21.0
619
+ zstandard==0.23.0
results.png ADDED

Git LFS Details

  • SHA256: 2b7f08fc4c09db56b516186c0629f72523a5cbe328beaedda8b36349af4b04bc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
t4.png ADDED

Git LFS Details

  • SHA256: 43a9453f567d9bff7fe4481205575bbf302499379047ee6073247315452ba8fb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.16 MB
utilities.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchvision.transforms.functional import normalize
4
+ import numpy as np
5
+
6
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
7
+ if len(im.shape) < 3:
8
+ im = im[:, :, np.newaxis]
9
+ # orig_im_size=im.shape[0:2]
10
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
11
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
12
+ image = torch.divide(im_tensor,255.0)
13
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
14
+ return image
15
+
16
+
17
+ def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray:
18
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
19
+ ma = torch.max(result)
20
+ mi = torch.min(result)
21
+ result = (result-mi)/(ma-mi)
22
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
23
+ im_array = np.squeeze(im_array)
24
+ return im_array
25
+