Spaces:
Sleeping
Sleeping
Kieran Fraser
commited on
Commit
·
d2635ec
1
Parent(s):
787bffa
First commit.
Browse filesSigned-off-by: Kieran Fraser <[email protected]>
- .gitignore +260 -0
- app.py +430 -0
- carbon_colors.py +173 -0
- carbon_theme.py +102 -0
- requirements.txt +9 -0
.gitignore
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
!.vscode/*.code-snippets
|
132 |
+
!.vscode/extensions.json
|
133 |
+
!.vscode/launch.json
|
134 |
+
!.vscode/settings.json
|
135 |
+
!.vscode/tasks.json
|
136 |
+
*$py.class
|
137 |
+
*.code-workspace
|
138 |
+
*.cover
|
139 |
+
*.egg
|
140 |
+
*.egg-info/
|
141 |
+
*.iws
|
142 |
+
*.log
|
143 |
+
*.manifest
|
144 |
+
*.mo
|
145 |
+
*.pot
|
146 |
+
*.py,cover
|
147 |
+
*.py[cod]
|
148 |
+
*.sage.py
|
149 |
+
*.so
|
150 |
+
*.spec
|
151 |
+
*.vsix
|
152 |
+
.Python
|
153 |
+
.cache
|
154 |
+
.coverage
|
155 |
+
.coverage.*
|
156 |
+
.dmypy.json
|
157 |
+
.eggs/
|
158 |
+
.env
|
159 |
+
.history
|
160 |
+
.history/
|
161 |
+
.hypothesis/
|
162 |
+
.idea/$CACHE_FILE$
|
163 |
+
.idea/**/aws.xml
|
164 |
+
.idea/**/azureSettings.xml
|
165 |
+
.idea/**/contentModel.xml
|
166 |
+
.idea/**/dataSources.ids
|
167 |
+
.idea/**/dataSources.local.xml
|
168 |
+
.idea/**/dataSources/
|
169 |
+
.idea/**/dbnavigator.xml
|
170 |
+
.idea/**/dictionaries
|
171 |
+
.idea/**/dynamic.xml
|
172 |
+
.idea/**/gradle.xml
|
173 |
+
.idea/**/libraries
|
174 |
+
.idea/**/markdown-navigator-enh.xml
|
175 |
+
.idea/**/markdown-navigator.xml
|
176 |
+
.idea/**/markdown-navigator/
|
177 |
+
.idea/**/mongoSettings.xml
|
178 |
+
.idea/**/shelf
|
179 |
+
.idea/**/sonarIssues.xml
|
180 |
+
.idea/**/sonarlint/
|
181 |
+
.idea/**/sqlDataSources.xml
|
182 |
+
.idea/**/tasks.xml
|
183 |
+
.idea/**/uiDesigner.xml
|
184 |
+
.idea/**/usage.statistics.xml
|
185 |
+
.idea/**/workspace.xml
|
186 |
+
.idea/caches/build_file_checksums.ser
|
187 |
+
.idea/codestream.xml
|
188 |
+
.idea/httpRequests
|
189 |
+
.idea/replstate.xml
|
190 |
+
.idea/sonarlint/
|
191 |
+
.idea_modules/
|
192 |
+
.installed.cfg
|
193 |
+
.ionide
|
194 |
+
.ipynb_checkpoints
|
195 |
+
.mypy_cache/
|
196 |
+
.nox/
|
197 |
+
.pdm.toml
|
198 |
+
.pybuilder/
|
199 |
+
.pyre/
|
200 |
+
.pytest_cache/
|
201 |
+
.pytype/
|
202 |
+
.ropeproject
|
203 |
+
.scrapy
|
204 |
+
.spyderproject
|
205 |
+
.spyproject
|
206 |
+
.tox/
|
207 |
+
.venv
|
208 |
+
.vscode/*
|
209 |
+
.vscode/*.code-snippets
|
210 |
+
.webassets-cache
|
211 |
+
/site
|
212 |
+
ENV/
|
213 |
+
MANIFEST
|
214 |
+
__pycache__/
|
215 |
+
__pypackages__/
|
216 |
+
atlassian-ide-plugin.xml
|
217 |
+
build/
|
218 |
+
celerybeat-schedule
|
219 |
+
celerybeat.pid
|
220 |
+
cmake-build-*/
|
221 |
+
com_crashlytics_export_strings.xml
|
222 |
+
cover/
|
223 |
+
coverage.xml
|
224 |
+
crashlytics-build.properties
|
225 |
+
crashlytics.properties
|
226 |
+
cython_debug/
|
227 |
+
db.sqlite3
|
228 |
+
db.sqlite3-journal
|
229 |
+
develop-eggs/
|
230 |
+
dist/
|
231 |
+
dmypy.json
|
232 |
+
docs/_build/
|
233 |
+
downloads/
|
234 |
+
eggs/
|
235 |
+
env.bak/
|
236 |
+
env/
|
237 |
+
fabric.properties
|
238 |
+
htmlcov/
|
239 |
+
instance/
|
240 |
+
ipython_config.py
|
241 |
+
lib/
|
242 |
+
lib64/
|
243 |
+
local_settings.py
|
244 |
+
nosetests.xml
|
245 |
+
out/
|
246 |
+
parts/
|
247 |
+
pip-delete-this-directory.txt
|
248 |
+
pip-log.txt
|
249 |
+
profile_default/
|
250 |
+
sdist/
|
251 |
+
share/python-wheels/
|
252 |
+
target/
|
253 |
+
var/
|
254 |
+
venv.bak/
|
255 |
+
venv/
|
256 |
+
wheels/
|
257 |
+
Pipfile
|
258 |
+
.vscode
|
259 |
+
Pipfile.lock
|
260 |
+
Data - DELETE AT THE END OF THE PROJECT
|
app.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
ART-JATIC Gradio Example App
|
3 |
+
|
4 |
+
To run:
|
5 |
+
- clone the repository
|
6 |
+
- execute: gradio examples/gradio_app.py or python examples/gradio_app.py
|
7 |
+
- navigate to local URL e.g. http://127.0.0.1:7860
|
8 |
+
'''
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
import pandas as pd
|
13 |
+
from carbon_theme import Carbon
|
14 |
+
|
15 |
+
import gradio as gr
|
16 |
+
import os
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
|
19 |
+
css = """
|
20 |
+
.input-image { margin: auto !important }
|
21 |
+
.plot-padding { padding: 20px; }
|
22 |
+
"""
|
23 |
+
|
24 |
+
def extract_predictions(predictions_, conf_thresh):
|
25 |
+
coco_labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
|
26 |
+
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
27 |
+
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
28 |
+
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
29 |
+
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
30 |
+
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
31 |
+
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
|
32 |
+
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
33 |
+
'teddy bear', 'hair drier', 'toothbrush']
|
34 |
+
# Get the predicted class
|
35 |
+
predictions_class = [coco_labels[i] for i in list(predictions_["labels"])]
|
36 |
+
# print("\npredicted classes:", predictions_class)
|
37 |
+
if len(predictions_class) < 1:
|
38 |
+
return [], [], []
|
39 |
+
# Get the predicted bounding boxes
|
40 |
+
predictions_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(predictions_["boxes"])]
|
41 |
+
|
42 |
+
# Get the predicted prediction score
|
43 |
+
predictions_score = list(predictions_["scores"])
|
44 |
+
# print("predicted score:", predictions_score)
|
45 |
+
|
46 |
+
# Get a list of index with score greater than threshold
|
47 |
+
threshold = conf_thresh
|
48 |
+
predictions_t = [predictions_score.index(x) for x in predictions_score if x > threshold]
|
49 |
+
if len(predictions_t) > 0:
|
50 |
+
predictions_t = predictions_t # [-1] #indices where score over threshold
|
51 |
+
else:
|
52 |
+
# no predictions esxceeding threshold
|
53 |
+
return [], [], []
|
54 |
+
# predictions in score order
|
55 |
+
predictions_boxes = [predictions_boxes[i] for i in predictions_t]
|
56 |
+
predictions_class = [predictions_class[i] for i in predictions_t]
|
57 |
+
predictions_scores = [predictions_score[i] for i in predictions_t]
|
58 |
+
return predictions_class, predictions_boxes, predictions_scores
|
59 |
+
|
60 |
+
def plot_image_with_boxes(img, boxes, pred_cls, title):
|
61 |
+
import cv2
|
62 |
+
text_size = 1
|
63 |
+
text_th = 2
|
64 |
+
rect_th = 1
|
65 |
+
|
66 |
+
sections = []
|
67 |
+
for i in range(len(boxes)):
|
68 |
+
cv2.rectangle(img, (int(boxes[i][0][0]), int(boxes[i][0][1])), (int(boxes[i][1][0]), int(boxes[i][1][1])),
|
69 |
+
color=(0, 255, 0), thickness=rect_th)
|
70 |
+
# Write the prediction class
|
71 |
+
cv2.putText(img, pred_cls[i], (int(boxes[i][0][0]), int(boxes[i][0][1])), cv2.FONT_HERSHEY_SIMPLEX, text_size,
|
72 |
+
(0, 255, 0), thickness=text_th)
|
73 |
+
sections.append( ((int(boxes[i][0][0]),
|
74 |
+
int(boxes[i][0][1]),
|
75 |
+
int(boxes[i][1][0]),
|
76 |
+
int(boxes[i][1][1])), (pred_cls[i])) )
|
77 |
+
|
78 |
+
|
79 |
+
return img.astype(np.uint8)
|
80 |
+
|
81 |
+
def filter_boxes(predictions, conf_thresh):
|
82 |
+
dictionary = {}
|
83 |
+
|
84 |
+
boxes_list = []
|
85 |
+
scores_list = []
|
86 |
+
labels_list = []
|
87 |
+
|
88 |
+
for i in range(len(predictions[0]["boxes"])):
|
89 |
+
score = predictions[0]["scores"][i]
|
90 |
+
if score >= conf_thresh:
|
91 |
+
boxes_list.append(predictions[0]["boxes"][i])
|
92 |
+
scores_list.append(predictions[0]["scores"][[i]])
|
93 |
+
labels_list.append(predictions[0]["labels"][[i]])
|
94 |
+
|
95 |
+
dictionary["boxes"] = np.vstack(boxes_list)
|
96 |
+
dictionary["scores"] = np.hstack(scores_list)
|
97 |
+
dictionary["labels"] = np.hstack(labels_list)
|
98 |
+
|
99 |
+
y = [dictionary]
|
100 |
+
|
101 |
+
return y
|
102 |
+
|
103 |
+
def basic_cifar10_model(overfit=False):
|
104 |
+
'''
|
105 |
+
Load an example CIFAR10 model
|
106 |
+
'''
|
107 |
+
from art.estimators.classification.pytorch import PyTorchClassifier
|
108 |
+
|
109 |
+
labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
110 |
+
path = './'
|
111 |
+
class Model(torch.nn.Module):
|
112 |
+
"""
|
113 |
+
Create model for pytorch.
|
114 |
+
Here the model does not use maxpooling. Needed for certification tests.
|
115 |
+
"""
|
116 |
+
|
117 |
+
def __init__(self):
|
118 |
+
super(Model, self).__init__()
|
119 |
+
|
120 |
+
self.conv = torch.nn.Conv2d(
|
121 |
+
in_channels=3, out_channels=16, kernel_size=(4, 4), dilation=(1, 1), padding=(0, 0), stride=(3, 3)
|
122 |
+
)
|
123 |
+
|
124 |
+
self.fullyconnected = torch.nn.Linear(in_features=1600, out_features=10)
|
125 |
+
|
126 |
+
self.relu = torch.nn.ReLU()
|
127 |
+
|
128 |
+
w_conv2d = np.load(
|
129 |
+
os.path.join(
|
130 |
+
os.path.dirname(path),
|
131 |
+
"utils/resources/models",
|
132 |
+
"W_CONV2D_NO_MPOOL_CIFAR10.npy",
|
133 |
+
)
|
134 |
+
)
|
135 |
+
b_conv2d = np.load(
|
136 |
+
os.path.join(
|
137 |
+
os.path.dirname(path),
|
138 |
+
"utils/resources/models",
|
139 |
+
"B_CONV2D_NO_MPOOL_CIFAR10.npy",
|
140 |
+
)
|
141 |
+
)
|
142 |
+
w_dense = np.load(
|
143 |
+
os.path.join(
|
144 |
+
os.path.dirname(path),
|
145 |
+
"utils/resources/models",
|
146 |
+
"W_DENSE_NO_MPOOL_CIFAR10.npy",
|
147 |
+
)
|
148 |
+
)
|
149 |
+
b_dense = np.load(
|
150 |
+
os.path.join(
|
151 |
+
os.path.dirname(path),
|
152 |
+
"utils/resources/models",
|
153 |
+
"B_DENSE_NO_MPOOL_CIFAR10.npy",
|
154 |
+
)
|
155 |
+
)
|
156 |
+
|
157 |
+
self.conv.weight = torch.nn.Parameter(torch.Tensor(w_conv2d))
|
158 |
+
self.conv.bias = torch.nn.Parameter(torch.Tensor(b_conv2d))
|
159 |
+
self.fullyconnected.weight = torch.nn.Parameter(torch.Tensor(w_dense))
|
160 |
+
self.fullyconnected.bias = torch.nn.Parameter(torch.Tensor(b_dense))
|
161 |
+
|
162 |
+
# pylint: disable=W0221
|
163 |
+
# disable pylint because of API requirements for function
|
164 |
+
def forward(self, x):
|
165 |
+
"""
|
166 |
+
Forward function to evaluate the model
|
167 |
+
:param x: Input to the model
|
168 |
+
:return: Prediction of the model
|
169 |
+
"""
|
170 |
+
x = self.conv(x)
|
171 |
+
x = self.relu(x)
|
172 |
+
x = x.reshape(-1, 1600)
|
173 |
+
x = self.fullyconnected(x)
|
174 |
+
return x
|
175 |
+
|
176 |
+
# Define the network
|
177 |
+
model = Model()
|
178 |
+
# Define a loss function and optimizer
|
179 |
+
if overfit:
|
180 |
+
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
|
181 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0)
|
182 |
+
else:
|
183 |
+
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
|
184 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
185 |
+
|
186 |
+
# Get classifier
|
187 |
+
jptc = PyTorchClassifier(
|
188 |
+
model=model, loss=loss_fn, optimizer=optimizer, input_shape=(3, 32, 32), nb_classes=10, clip_values=(0, 1), labels=labels
|
189 |
+
)
|
190 |
+
return jptc
|
191 |
+
|
192 |
+
def det_evasion_evaluate(*args):
|
193 |
+
'''
|
194 |
+
Run a detection task evaluation
|
195 |
+
'''
|
196 |
+
|
197 |
+
def clf_evasion_evaluate(*args):
|
198 |
+
'''
|
199 |
+
Run a classification task evaluation
|
200 |
+
'''
|
201 |
+
|
202 |
+
def show_model_params(model_type):
|
203 |
+
'''
|
204 |
+
Show model parameters based on selected model type
|
205 |
+
'''
|
206 |
+
if model_type!="Example CIFAR10" and model_type!="Example XView" and model_type!="CIFAR10 Overfit":
|
207 |
+
return gr.Column(visible=True)
|
208 |
+
return gr.Column(visible=False)
|
209 |
+
|
210 |
+
def show_dataset_params(dataset_type):
|
211 |
+
'''
|
212 |
+
Show dataset parameters based on dataset type
|
213 |
+
'''
|
214 |
+
if dataset_type=="Example CIFAR10":
|
215 |
+
return [gr.Column(visible=False), gr.Row(visible=False), gr.Row(visible=False)]
|
216 |
+
elif dataset_type=="local":
|
217 |
+
return [gr.Column(visible=True), gr.Row(visible=True), gr.Row(visible=False)]
|
218 |
+
return [gr.Column(visible=True), gr.Row(visible=False), gr.Row(visible=True)]
|
219 |
+
|
220 |
+
def pgd_show_label_output(dataset_type):
|
221 |
+
'''
|
222 |
+
Show PGD output component based on dataset type
|
223 |
+
'''
|
224 |
+
if dataset_type=="local":
|
225 |
+
return [gr.Label(visible=True), gr.Label(visible=True), gr.Number(visible=False), gr.Number(visible=False), gr.Number(visible=True)]
|
226 |
+
return [gr.Label(visible=False), gr.Label(visible=False), gr.Number(visible=True), gr.Number(visible=True), gr.Number(visible=True)]
|
227 |
+
|
228 |
+
def pgd_update_epsilon(clip_values):
|
229 |
+
'''
|
230 |
+
Update max value of PGD epsilon slider based on model clip values
|
231 |
+
'''
|
232 |
+
if clip_values == 255:
|
233 |
+
return gr.Slider(minimum=0.0001, maximum=255, label="Epslion", value=55)
|
234 |
+
return gr.Slider(minimum=0.0001, maximum=1, label="Epslion", value=0.05)
|
235 |
+
|
236 |
+
def patch_show_label_output(dataset_type):
|
237 |
+
'''
|
238 |
+
Show adversarial patch output components based on dataset type
|
239 |
+
'''
|
240 |
+
if dataset_type=="local":
|
241 |
+
return [gr.Label(visible=True), gr.Label(visible=True), gr.Number(visible=False), gr.Number(visible=False), gr.Number(visible=True)]
|
242 |
+
return [gr.Label(visible=False), gr.Label(visible=False), gr.Number(visible=True), gr.Number(visible=True), gr.Number(visible=True)]
|
243 |
+
|
244 |
+
# e.g. To use a local alternative theme: carbon_theme = Carbon()
|
245 |
+
carbon_theme = Carbon()
|
246 |
+
with gr.Blocks(css=css, theme=carbon_theme) as demo:
|
247 |
+
import art
|
248 |
+
text = art.__version__
|
249 |
+
gr.Markdown(f"<h1>ART (v{text}) Gradio Example</h1>")
|
250 |
+
|
251 |
+
with gr.Tab("Info"):
|
252 |
+
gr.Markdown('This is step 1. Using the tabs, select a task for evaluation.')
|
253 |
+
|
254 |
+
with gr.Tab("Classification", elem_classes="task-tab"):
|
255 |
+
gr.Markdown("Classifying images with a set of categories.")
|
256 |
+
|
257 |
+
# Model and Dataset Selection
|
258 |
+
with gr.Row():
|
259 |
+
# Model and Dataset type e.g. Torchvision, HuggingFace, local etc.
|
260 |
+
with gr.Column():
|
261 |
+
model_type = gr.Radio(label="Model type", choices=["Example CIFAR10", "Huggingface", "torchvision"],
|
262 |
+
value="Example CIFAR10")
|
263 |
+
dataset_type = gr.Radio(label="Dataset", choices=["Example CIFAR10", "Huggingface", "local"],
|
264 |
+
value="Example CIFAR10")
|
265 |
+
# Model parameters e.g. RESNET, VIT, input dimensions, clipping values etc.
|
266 |
+
with gr.Column(visible=False) as model_params:
|
267 |
+
model_path = gr.Textbox(placeholder="URL", label="Model path")
|
268 |
+
with gr.Row():
|
269 |
+
with gr.Column():
|
270 |
+
model_channels = gr.Textbox(placeholder="Integer, 3 for RGB images", label="Input Channels", value=3)
|
271 |
+
with gr.Column():
|
272 |
+
model_width = gr.Textbox(placeholder="Integer", label="Input Width", value=640)
|
273 |
+
with gr.Row():
|
274 |
+
with gr.Column():
|
275 |
+
model_height = gr.Textbox(placeholder="Integer", label="Input Height", value=480)
|
276 |
+
with gr.Column():
|
277 |
+
model_clip = gr.Radio(choices=[1, 255], label="Pixel clip", value=1)
|
278 |
+
# Dataset parameters e.g. Torchvision, HuggingFace, local etc.
|
279 |
+
with gr.Column(visible=False) as dataset_params:
|
280 |
+
with gr.Row() as local_image:
|
281 |
+
image = gr.Image(sources=['upload'], type="pil", height=150, width=150, elem_classes="input-image")
|
282 |
+
with gr.Row() as hosted_image:
|
283 |
+
dataset_path = gr.Textbox(placeholder="URL", label="Dataset path")
|
284 |
+
dataset_split = gr.Textbox(placeholder="test", label="Dataset split")
|
285 |
+
|
286 |
+
model_type.change(show_model_params, model_type, model_params)
|
287 |
+
dataset_type.change(show_dataset_params, dataset_type, [dataset_params, local_image, hosted_image])
|
288 |
+
|
289 |
+
# Attack Selection
|
290 |
+
with gr.Row():
|
291 |
+
|
292 |
+
with gr.Tab("Info"):
|
293 |
+
gr.Markdown("This is step 2. Select the type of attack for evaluation.")
|
294 |
+
|
295 |
+
with gr.Tab("White Box"):
|
296 |
+
gr.Markdown("White box attacks assume the attacker has __full access__ to the model.")
|
297 |
+
|
298 |
+
with gr.Tab("Info"):
|
299 |
+
gr.Markdown("This is step 3. Select the type of white-box attack to evaluate.")
|
300 |
+
|
301 |
+
with gr.Tab("Evasion"):
|
302 |
+
gr.Markdown("Evasion attacks are deployed to cause a model to incorrectly classify or detect items/objects in an image.")
|
303 |
+
|
304 |
+
with gr.Tab("Info"):
|
305 |
+
gr.Markdown("This is step 4. Select the type of Evasion attack to evaluate.")
|
306 |
+
|
307 |
+
with gr.Tab("Projected Gradient Descent"):
|
308 |
+
gr.Markdown("This attack uses PGD to identify adversarial examples.")
|
309 |
+
|
310 |
+
|
311 |
+
with gr.Row():
|
312 |
+
|
313 |
+
with gr.Column():
|
314 |
+
attack = gr.Textbox(visible=True, value="PGD", label="Attack", interactive=False)
|
315 |
+
max_iter = gr.Slider(minimum=1, maximum=5000, label="Max iterations", value=1000)
|
316 |
+
eps = gr.Slider(minimum=0.0001, maximum=1, label="Epslion", value=0.05)
|
317 |
+
eps_steps = gr.Slider(minimum=0.001, maximum=1000, label="Epsilon steps", value=0.1)
|
318 |
+
targeted = gr.Textbox(placeholder="Target label (integer)", label="Target")
|
319 |
+
eval_btn_pgd = gr.Button("Evaluate")
|
320 |
+
model_clip.change(pgd_update_epsilon, model_clip, eps)
|
321 |
+
|
322 |
+
# Evaluation Output. Visualisations of success/failures of running evaluation attacks.
|
323 |
+
with gr.Column():
|
324 |
+
with gr.Row():
|
325 |
+
with gr.Column():
|
326 |
+
original_gallery = gr.Gallery(label="Original", preview=True, show_download_button=True)
|
327 |
+
benign_output = gr.Label(num_top_classes=3, visible=False)
|
328 |
+
clean_accuracy = gr.Number(label="Clean Accuracy", precision=2)
|
329 |
+
quality_plot = gr.LinePlot(label="Gradient Quality", x='iteration', y='value', color='metric',
|
330 |
+
x_title='Iteration', y_title='Avg in Gradients (%)',
|
331 |
+
caption="""Illustrates the average percent of zero, infinity
|
332 |
+
or NaN gradients identified in images
|
333 |
+
across all batches.""", elem_classes="plot-padding", visible=False)
|
334 |
+
|
335 |
+
with gr.Column():
|
336 |
+
adversarial_gallery = gr.Gallery(label="Adversarial", preview=True, show_download_button=True)
|
337 |
+
adversarial_output = gr.Label(num_top_classes=3, visible=False)
|
338 |
+
robust_accuracy = gr.Number(label="Robust Accuracy", precision=2)
|
339 |
+
perturbation_added = gr.Number(label="Perturbation Added", precision=2)
|
340 |
+
|
341 |
+
dataset_type.change(pgd_show_label_output, dataset_type, [benign_output, adversarial_output,
|
342 |
+
clean_accuracy, robust_accuracy, perturbation_added])
|
343 |
+
eval_btn_pgd.click(clf_evasion_evaluate, inputs=[attack, model_type, model_path, model_channels, model_height, model_width,
|
344 |
+
model_clip, max_iter, eps, eps_steps, targeted,
|
345 |
+
dataset_type, dataset_path, dataset_split, image],
|
346 |
+
outputs=[original_gallery, benign_output, adversarial_gallery, adversarial_output, clean_accuracy,
|
347 |
+
robust_accuracy, perturbation_added, quality_plot], api_name='patch')
|
348 |
+
|
349 |
+
with gr.Row():
|
350 |
+
clear_btn = gr.ClearButton([image, targeted, original_gallery, benign_output, clean_accuracy,
|
351 |
+
adversarial_gallery, adversarial_output, robust_accuracy, perturbation_added])
|
352 |
+
|
353 |
+
|
354 |
+
|
355 |
+
with gr.Tab("Adversarial Patch"):
|
356 |
+
gr.Markdown("This attack crafts an adversarial patch that facilitates evasion.")
|
357 |
+
|
358 |
+
with gr.Row():
|
359 |
+
|
360 |
+
with gr.Column():
|
361 |
+
attack = gr.Textbox(visible=True, value="Adversarial Patch", label="Attack", interactive=False)
|
362 |
+
max_iter = gr.Slider(minimum=1, maximum=5000, label="Max iterations", value=100)
|
363 |
+
x_location = gr.Slider(minimum=1, maximum=640, label="Location (x)", value=18)
|
364 |
+
y_location = gr.Slider(minimum=1, maximum=480, label="Location (y)", value=18)
|
365 |
+
patch_height = gr.Slider(minimum=1, maximum=640, label="Patch height", value=18)
|
366 |
+
patch_width = gr.Slider(minimum=1, maximum=480, label="Patch width", value=18)
|
367 |
+
targeted = gr.Textbox(placeholder="Target label (integer)", label="Target")
|
368 |
+
eval_btn_patch = gr.Button("Evaluate")
|
369 |
+
model_clip.change()
|
370 |
+
|
371 |
+
# Evaluation Output. Visualisations of success/failures of running evaluation attacks.
|
372 |
+
with gr.Column():
|
373 |
+
with gr.Row():
|
374 |
+
with gr.Column():
|
375 |
+
original_gallery = gr.Gallery(label="Original", preview=True, show_download_button=True)
|
376 |
+
benign_output = gr.Label(num_top_classes=3, visible=False)
|
377 |
+
clean_accuracy = gr.Number(label="Clean Accuracy", precision=2)
|
378 |
+
|
379 |
+
with gr.Column():
|
380 |
+
adversarial_gallery = gr.Gallery(label="Adversarial", preview=True, show_download_button=True)
|
381 |
+
adversarial_output = gr.Label(num_top_classes=3, visible=False)
|
382 |
+
robust_accuracy = gr.Number(label="Robust Accuracy", precision=2)
|
383 |
+
patch_image = gr.Image(label="Adversarial Patch")
|
384 |
+
|
385 |
+
dataset_type.change(patch_show_label_output, dataset_type, [benign_output, adversarial_output,
|
386 |
+
clean_accuracy, robust_accuracy, patch_image])
|
387 |
+
eval_btn_patch.click(clf_evasion_evaluate, inputs=[attack, model_type, model_path, model_channels, model_height, model_width,
|
388 |
+
model_clip, max_iter, x_location, y_location, patch_height, patch_width, targeted,
|
389 |
+
dataset_type, dataset_path, dataset_split, image],
|
390 |
+
outputs=[original_gallery, benign_output, adversarial_gallery, adversarial_output, clean_accuracy,
|
391 |
+
robust_accuracy, patch_image])
|
392 |
+
|
393 |
+
with gr.Row():
|
394 |
+
clear_btn = gr.ClearButton([image, targeted, original_gallery, benign_output, clean_accuracy,
|
395 |
+
adversarial_gallery, adversarial_output, robust_accuracy, patch_image])
|
396 |
+
|
397 |
+
with gr.Tab("Poisoning"):
|
398 |
+
gr.Markdown("Coming soon.")
|
399 |
+
|
400 |
+
with gr.Tab("Black Box"):
|
401 |
+
gr.Markdown("Black box attacks assume the attacker __does not__ have full access to the model but can query it for predictions.")
|
402 |
+
|
403 |
+
with gr.Tab("Info"):
|
404 |
+
gr.Markdown("This is step 3. Select the type of black-box attack to evaluate.")
|
405 |
+
|
406 |
+
with gr.Tab("Evasion"):
|
407 |
+
|
408 |
+
gr.Markdown("Evasion attacks are deployed to cause a model to incorrectly classify or detect items/objects in an image.")
|
409 |
+
|
410 |
+
with gr.Tab("Info"):
|
411 |
+
gr.Markdown("This is step 4. Select the type of Evasion attack to evaluate.")
|
412 |
+
|
413 |
+
with gr.Tab("HopSkipJump"):
|
414 |
+
gr.Markdown("Coming soon.")
|
415 |
+
|
416 |
+
with gr.Tab("Square Attack"):
|
417 |
+
gr.Markdown("Coming soon.")
|
418 |
+
|
419 |
+
with gr.Tab("AutoAttack"):
|
420 |
+
gr.Markdown("Coming soon.")
|
421 |
+
|
422 |
+
|
423 |
+
if __name__ == "__main__":
|
424 |
+
|
425 |
+
# during development, set debug=True
|
426 |
+
demo.launch(show_api=False, debug=True, share=False,
|
427 |
+
server_name="0.0.0.0",
|
428 |
+
server_port=7777,
|
429 |
+
ssl_verify=False,
|
430 |
+
max_threads=20)
|
carbon_colors.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
|
4 |
+
class Color:
|
5 |
+
all = []
|
6 |
+
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
c50: str,
|
10 |
+
c100: str,
|
11 |
+
c200: str,
|
12 |
+
c300: str,
|
13 |
+
c400: str,
|
14 |
+
c500: str,
|
15 |
+
c600: str,
|
16 |
+
c700: str,
|
17 |
+
c800: str,
|
18 |
+
c900: str,
|
19 |
+
c950: str,
|
20 |
+
name: str | None = None,
|
21 |
+
):
|
22 |
+
self.c50 = c50
|
23 |
+
self.c100 = c100
|
24 |
+
self.c200 = c200
|
25 |
+
self.c300 = c300
|
26 |
+
self.c400 = c400
|
27 |
+
self.c500 = c500
|
28 |
+
self.c600 = c600
|
29 |
+
self.c700 = c700
|
30 |
+
self.c800 = c800
|
31 |
+
self.c900 = c900
|
32 |
+
self.c950 = c950
|
33 |
+
self.name = name
|
34 |
+
Color.all.append(self)
|
35 |
+
|
36 |
+
def expand(self) -> list[str]:
|
37 |
+
return [
|
38 |
+
self.c50,
|
39 |
+
self.c100,
|
40 |
+
self.c200,
|
41 |
+
self.c300,
|
42 |
+
self.c400,
|
43 |
+
self.c500,
|
44 |
+
self.c600,
|
45 |
+
self.c700,
|
46 |
+
self.c800,
|
47 |
+
self.c900,
|
48 |
+
self.c950,
|
49 |
+
]
|
50 |
+
|
51 |
+
|
52 |
+
black = Color(
|
53 |
+
name="black",
|
54 |
+
c50="#000000",
|
55 |
+
c100="#000000",
|
56 |
+
c200="#000000",
|
57 |
+
c300="#000000",
|
58 |
+
c400="#000000",
|
59 |
+
c500="#000000",
|
60 |
+
c600="#000000",
|
61 |
+
c700="#000000",
|
62 |
+
c800="#000000",
|
63 |
+
c900="#000000",
|
64 |
+
c950="#000000",
|
65 |
+
)
|
66 |
+
|
67 |
+
blackHover = Color(
|
68 |
+
name="blackHover",
|
69 |
+
c50="#212121",
|
70 |
+
c100="#212121",
|
71 |
+
c200="#212121",
|
72 |
+
c300="#212121",
|
73 |
+
c400="#212121",
|
74 |
+
c500="#212121",
|
75 |
+
c600="#212121",
|
76 |
+
c700="#212121",
|
77 |
+
c800="#212121",
|
78 |
+
c900="#212121",
|
79 |
+
c950="#212121",
|
80 |
+
)
|
81 |
+
|
82 |
+
white = Color(
|
83 |
+
name="white",
|
84 |
+
c50="#ffffff",
|
85 |
+
c100="#ffffff",
|
86 |
+
c200="#ffffff",
|
87 |
+
c300="#ffffff",
|
88 |
+
c400="#ffffff",
|
89 |
+
c500="#ffffff",
|
90 |
+
c600="#ffffff",
|
91 |
+
c700="#ffffff",
|
92 |
+
c800="#ffffff",
|
93 |
+
c900="#ffffff",
|
94 |
+
c950="#ffffff",
|
95 |
+
)
|
96 |
+
|
97 |
+
whiteHover = Color(
|
98 |
+
name="whiteHover",
|
99 |
+
c50="#e8e8e8",
|
100 |
+
c100="#e8e8e8",
|
101 |
+
c200="#e8e8e8",
|
102 |
+
c300="#e8e8e8",
|
103 |
+
c400="#e8e8e8",
|
104 |
+
c500="#e8e8e8",
|
105 |
+
c600="#e8e8e8",
|
106 |
+
c700="#e8e8e8",
|
107 |
+
c800="#e8e8e8",
|
108 |
+
c900="#e8e8e8",
|
109 |
+
c950="#e8e8e8",
|
110 |
+
)
|
111 |
+
|
112 |
+
red = Color(
|
113 |
+
name="red",
|
114 |
+
c50="#fff1f1",
|
115 |
+
c100="#ffd7d9",
|
116 |
+
c200="#ffb3b8",
|
117 |
+
c300="#ff8389",
|
118 |
+
c400="#fa4d56",
|
119 |
+
c500="#da1e28",
|
120 |
+
c600="#a2191f",
|
121 |
+
c700="#750e13",
|
122 |
+
c800="#520408",
|
123 |
+
c900="#2d0709",
|
124 |
+
c950="#2d0709",
|
125 |
+
)
|
126 |
+
|
127 |
+
redHover = Color(
|
128 |
+
name="redHover",
|
129 |
+
c50="#540d11",
|
130 |
+
c100="#66050a",
|
131 |
+
c200="#921118",
|
132 |
+
c300="#c21e25",
|
133 |
+
c400="#b81922",
|
134 |
+
c500="#ee0713",
|
135 |
+
c600="#ff6168",
|
136 |
+
c700="#ff99a0",
|
137 |
+
c800="#ffc2c5",
|
138 |
+
c900="#ffe0e0",
|
139 |
+
c950="#ffe0e0",
|
140 |
+
)
|
141 |
+
|
142 |
+
blue = Color(
|
143 |
+
name="blue",
|
144 |
+
c50="#edf5ff",
|
145 |
+
c100="#d0e2ff",
|
146 |
+
c200="#a6c8ff",
|
147 |
+
c300="#78a9ff",
|
148 |
+
c400="#4589ff",
|
149 |
+
c500="#0f62fe",
|
150 |
+
c600="#0043ce",
|
151 |
+
c700="#002d9c",
|
152 |
+
c800="#001d6c",
|
153 |
+
c900="#001141",
|
154 |
+
c950="#001141",
|
155 |
+
)
|
156 |
+
|
157 |
+
blueHover = Color(
|
158 |
+
name="blueHover",
|
159 |
+
|
160 |
+
c50="#001f75",
|
161 |
+
c100="#00258a",
|
162 |
+
c200="#0039c7",
|
163 |
+
c300="#0053ff",
|
164 |
+
c400="#0050e6",
|
165 |
+
c500="#1f70ff",
|
166 |
+
c600="#5c97ff",
|
167 |
+
c700="#8ab6ff",
|
168 |
+
c800="#b8d3ff",
|
169 |
+
c900="#dbebff",
|
170 |
+
c950="#dbebff",
|
171 |
+
)
|
172 |
+
|
173 |
+
|
carbon_theme.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Iterable
|
4 |
+
|
5 |
+
from gradio.themes.base import Base
|
6 |
+
from gradio.themes.utils import colors, fonts, sizes
|
7 |
+
import carbon_colors
|
8 |
+
|
9 |
+
|
10 |
+
class Carbon(Base):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
*,
|
14 |
+
primary_hue: carbon_colors.Color | str = carbon_colors.white,
|
15 |
+
secondary_hue: carbon_colors.Color | str = carbon_colors.red,
|
16 |
+
neutral_hue: carbon_colors.Color | str = carbon_colors.blue,
|
17 |
+
spacing_size: sizes.Size | str = sizes.spacing_lg,
|
18 |
+
radius_size: sizes.Size | str = sizes.radius_none,
|
19 |
+
text_size: sizes.Size | str = sizes.text_md,
|
20 |
+
font: fonts.Font
|
21 |
+
| str
|
22 |
+
| Iterable[fonts.Font | str] = (
|
23 |
+
fonts.GoogleFont("IBM Plex Mono"),
|
24 |
+
fonts.GoogleFont("IBM Plex Sans"),
|
25 |
+
fonts.GoogleFont("IBM Plex Serif"),
|
26 |
+
),
|
27 |
+
font_mono: fonts.Font
|
28 |
+
| str
|
29 |
+
| Iterable[fonts.Font | str] = (
|
30 |
+
fonts.GoogleFont("IBM Plex Mono"),
|
31 |
+
),
|
32 |
+
):
|
33 |
+
super().__init__(
|
34 |
+
primary_hue=primary_hue,
|
35 |
+
secondary_hue=secondary_hue,
|
36 |
+
neutral_hue=neutral_hue,
|
37 |
+
spacing_size=spacing_size,
|
38 |
+
radius_size=radius_size,
|
39 |
+
text_size=text_size,
|
40 |
+
font=font,
|
41 |
+
font_mono=font_mono,
|
42 |
+
)
|
43 |
+
self.name = "carbon"
|
44 |
+
super().set(
|
45 |
+
# Colors
|
46 |
+
slider_color="*neutral_900",
|
47 |
+
slider_color_dark="*neutral_500",
|
48 |
+
body_text_color="*neutral_900",
|
49 |
+
block_label_text_color="*body_text_color",
|
50 |
+
block_title_text_color="*body_text_color",
|
51 |
+
body_text_color_subdued="*neutral_700",
|
52 |
+
background_fill_primary_dark="*neutral_900",
|
53 |
+
background_fill_secondary_dark="*neutral_800",
|
54 |
+
block_background_fill_dark="*neutral_800",
|
55 |
+
input_background_fill_dark="*neutral_700",
|
56 |
+
# Button Colors
|
57 |
+
button_primary_background_fill=carbon_colors.blue.c500,
|
58 |
+
button_primary_background_fill_hover="*neutral_300",
|
59 |
+
button_primary_text_color="white",
|
60 |
+
button_primary_background_fill_dark="*neutral_600",
|
61 |
+
button_primary_background_fill_hover_dark="*neutral_600",
|
62 |
+
button_primary_text_color_dark="white",
|
63 |
+
button_secondary_background_fill="*button_primary_background_fill",
|
64 |
+
button_secondary_background_fill_hover="*button_primary_background_fill_hover",
|
65 |
+
button_secondary_text_color="*button_primary_text_color",
|
66 |
+
button_cancel_background_fill="*button_primary_background_fill",
|
67 |
+
button_cancel_background_fill_hover="*button_primary_background_fill_hover",
|
68 |
+
button_cancel_text_color="*button_primary_text_color",
|
69 |
+
checkbox_background_color=carbon_colors.black.c50,
|
70 |
+
checkbox_label_background_fill="*button_primary_background_fill",
|
71 |
+
checkbox_label_background_fill_hover="*button_primary_background_fill_hover",
|
72 |
+
checkbox_label_text_color="*button_primary_text_color",
|
73 |
+
checkbox_background_color_selected=carbon_colors.black.c50,
|
74 |
+
checkbox_border_width="1px",
|
75 |
+
checkbox_border_width_dark="1px",
|
76 |
+
checkbox_border_color=carbon_colors.white.c50,
|
77 |
+
checkbox_border_color_dark=carbon_colors.white.c50,
|
78 |
+
|
79 |
+
checkbox_border_color_focus=carbon_colors.blue.c900,
|
80 |
+
checkbox_border_color_focus_dark=carbon_colors.blue.c900,
|
81 |
+
checkbox_border_color_selected=carbon_colors.white.c50,
|
82 |
+
checkbox_border_color_selected_dark=carbon_colors.white.c50,
|
83 |
+
|
84 |
+
checkbox_background_color_hover=carbon_colors.black.c50,
|
85 |
+
checkbox_background_color_hover_dark=carbon_colors.black.c50,
|
86 |
+
checkbox_background_color_dark=carbon_colors.black.c50,
|
87 |
+
checkbox_background_color_selected_dark=carbon_colors.black.c50,
|
88 |
+
# Padding
|
89 |
+
checkbox_label_padding="16px",
|
90 |
+
button_large_padding="*spacing_lg",
|
91 |
+
button_small_padding="*spacing_sm",
|
92 |
+
# Borders
|
93 |
+
block_border_width="0px",
|
94 |
+
block_border_width_dark="1px",
|
95 |
+
shadow_drop_lg="0 1px 4px 0 rgb(0 0 0 / 0.1)",
|
96 |
+
block_shadow="*shadow_drop_lg",
|
97 |
+
block_shadow_dark="none",
|
98 |
+
# Block Labels
|
99 |
+
block_title_text_weight="600",
|
100 |
+
block_label_text_weight="600",
|
101 |
+
block_label_text_size="*text_md",
|
102 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
adversarial-robustness-toolbox
|
3 |
+
pandas
|
4 |
+
jupyter
|
5 |
+
|
6 |
+
torch==1.13.1
|
7 |
+
tensorflow==2.10.1; sys_platform != "darwin"
|
8 |
+
tensorflow-macos; sys_platform == "darwin"
|
9 |
+
tensorflow-metal; sys_platform == "darwin"
|