Margerie commited on
Commit
db48033
β€’
1 Parent(s): c641d0e
Files changed (1) hide show
  1. 1_πŸ“_form.py +366 -372
1_πŸ“_form.py CHANGED
@@ -1,372 +1,366 @@
1
- # from yaml import load
2
- from persist import persist, load_widget_state
3
- import streamlit as st
4
- from io import StringIO
5
- import tempfile
6
- from pathlib import Path
7
- import requests
8
- from huggingface_hub import hf_hub_download, upload_file
9
- import pandas as pd
10
- from huggingface_hub import create_repo
11
- import os
12
- from middleMan import parse_into_jinja_markdown as pj
13
-
14
- import requests
15
-
16
- @st.cache
17
- def get_icd():
18
- # Get ICD10 list
19
- token_endpoint = 'https://icdaccessmanagement.who.int/connect/token'
20
- client_id = '3bc9c811-7f2e-4dab-a2dc-940e47a38fef_a6108252-4503-4ff7-90ab-300fd27392aa'
21
- client_secret = 'xPj7mleWf1Bilu9f7P10UQmBPvL5F6Wgd8/rJhO1T04='
22
- scope = 'icdapi_access'
23
- grant_type = 'client_credentials'
24
- # set data to post
25
- payload = {'client_id': client_id,
26
- 'client_secret': client_secret,
27
- 'scope': scope,
28
- 'grant_type': grant_type}
29
- # make request
30
- r = requests.post(token_endpoint, data=payload, verify=False).json()
31
- token = r['access_token']
32
- # access ICD API
33
- uri = 'https://id.who.int/icd/release/10/2019/C00-C75'
34
- # HTTP header fields to set
35
- headers = {'Authorization': 'Bearer '+token,
36
- 'Accept': 'application/json',
37
- 'Accept-Language': 'en',
38
- 'API-Version': 'v2'}
39
- # make request
40
- r = requests.get(uri, headers=headers, verify=False)
41
- print("icd",r.json())
42
- icd_map =[]
43
- for child in r.json()['child']:
44
- r_child = requests.get(child, headers=headers,verify=False).json()
45
- icd_map.append(r_child["code"]+" "+r_child["title"]["@value"])
46
- return icd_map
47
-
48
- @st.cache
49
- def get_treatment_mod():
50
- url = "https://clinicaltables.nlm.nih.gov/loinc_answers?loinc_num=21964-2"
51
- r = requests.get(url).json()
52
- treatment_mod = [treatment['DisplayText'] for treatment in r]
53
- return treatment_mod
54
-
55
-
56
- @st.cache
57
- def get_cached_data():
58
- languages_df = pd.read_html("https://hf.co/languages")[0]
59
- languages_map = pd.Series(languages_df["Language"].values, index=languages_df["ISO code"]).to_dict()
60
-
61
- license_df = pd.read_html("https://huggingface.co/docs/hub/repositories-licenses")[0]
62
- license_map = pd.Series(
63
- license_df["License identifier (to use in repo card)"].values, index=license_df.Fullname
64
- ).to_dict()
65
-
66
- available_metrics = [x['id'] for x in requests.get('https://huggingface.co/api/metrics').json()]
67
-
68
- r = requests.get('https://huggingface.co/api/models-tags-by-type')
69
- tags_data = r.json()
70
- libraries = [x['id'] for x in tags_data['library']]
71
- tasks = [x['id'] for x in tags_data['pipeline_tag']]
72
-
73
- icd_map = get_icd()
74
- treatment_mod = get_treatment_mod()
75
- return languages_map, license_map, available_metrics, libraries, tasks, icd_map, treatment_mod
76
-
77
-
78
- def card_upload(card_info,repo_id,token):
79
- #commit_message=None,
80
- repo_type = "space"
81
- commit_description=None,
82
- revision=None,
83
- create_pr=None
84
- with tempfile.TemporaryDirectory() as tmpdir:
85
- tmp_path = Path(tmpdir) / "README.md"
86
- tmp_path.write_text(str(card_info))
87
- url = upload_file(
88
- path_or_fileobj=str(tmp_path),
89
- path_in_repo="README.md",
90
- repo_id=repo_id,
91
- token=token,
92
- repo_type=repo_type,
93
- identical_ok=True,
94
- revision=revision,
95
- )
96
- return url
97
-
98
- def validate(self, repo_type="model"):
99
- """Validates card against Hugging Face Hub's model card validation logic.
100
- Using this function requires access to the internet, so it is only called
101
- internally by `modelcards.ModelCard.push_to_hub`.
102
- Args:
103
- repo_type (`str`, *optional*):
104
- The type of Hugging Face repo to push to. Defaults to None, which will use
105
- use "model". Other options are "dataset" and "space".
106
- """
107
- if repo_type is None:
108
- repo_type = "model"
109
-
110
- # TODO - compare against repo types constant in huggingface_hub if we move this object there.
111
- if repo_type not in ["model", "space", "dataset"]:
112
- raise RuntimeError(
113
- "Provided repo_type '{repo_type}' should be one of ['model', 'space',"
114
- " 'dataset']."
115
- )
116
-
117
- body = {
118
- "repoType": repo_type,
119
- "content": str(self),
120
- }
121
- headers = {"Accept": "text/plain"}
122
-
123
- try:
124
- r = requests.post(
125
- "https://huggingface.co/api/validate-yaml", body, headers=headers
126
- )
127
- r.raise_for_status()
128
- except requests.exceptions.HTTPError as exc:
129
- if r.status_code == 400:
130
- raise RuntimeError(r.text)
131
- else:
132
- raise exc
133
-
134
-
135
- ## Save uploaded [markdown] file to directory to be used by jinja parser function
136
- def save_uploadedfile(uploadedfile):
137
- with open(os.path.join("temp_uploaded_filed_Dir",uploadedfile.name),"wb") as f:
138
- f.write(uploadedfile.getbuffer())
139
- st.success("Saved File:{} to temp_uploaded_filed_Dir".format(uploadedfile.name))
140
- return uploadedfile.name
141
-
142
-
143
- def main_page():
144
-
145
-
146
- if "model_name" not in st.session_state:
147
- # Initialize session state.
148
- st.session_state.update({
149
- "input_model_name": "",
150
- "languages": [],
151
- "license": "",
152
- "library_name": "",
153
- "datasets": "",
154
- "metrics": [],
155
- "task": "",
156
- "tags": "",
157
- "model_description": "Some cool model...",
158
- "the_authors":"",
159
- "Shared_by":"",
160
- "Model_details_text": "",
161
- "Model_developers": "",
162
- "blog_url":"",
163
- "Parent_Model_url":"",
164
- "Parent_Model_name":"",
165
-
166
- "Model_how_to": "",
167
-
168
- "Model_uses": "",
169
- "Direct_Use": "",
170
- "Downstream_Use":"",
171
- "Out-of-Scope_Use":"",
172
-
173
- "Model_Limits_n_Risks": "",
174
- "Recommendations":"",
175
-
176
- "training_Data": "",
177
- "model_preprocessing":"",
178
- "Speeds_Sizes_Times":"",
179
-
180
-
181
-
182
- "Model_Eval": "",
183
- "Testing_Data":"",
184
- "Factors":"",
185
- "Metrics":"",
186
- "Model_Results":"",
187
-
188
- "Model_c02_emitted": "",
189
- "Model_hardware":"",
190
- "hours_used":"",
191
- "Model_cloud_provider":"",
192
- "Model_cloud_region":"",
193
-
194
- "Model_cite": "",
195
- "paper_url": "",
196
- "github_url": "",
197
- "bibtex_citation": "",
198
- "APA_citation":"",
199
-
200
- "Model_examin":"",
201
- "Model_card_contact":"",
202
- "Model_card_authors":"",
203
- "Glossary":"",
204
- "More_info":"",
205
-
206
- "Model_specs":"",
207
- "compute_infrastructure":"",
208
- "technical_specs_software":"",
209
-
210
- "check_box": bool,
211
- "markdown_upload":" ",
212
- "legal_view":bool,
213
- "researcher_view":bool,
214
- "beginner_technical_view":bool,
215
- "markdown_state":"",
216
- })
217
- ## getting cache for each warnings
218
- languages_map, license_map, available_metrics, libraries, tasks, icd_map, treatment_mod = get_cached_data()
219
-
220
- ## form UI setting
221
- st.header("Model basic information (Dose prediction)")
222
-
223
- warning_placeholder = st.empty()
224
-
225
- st.text_input("Model Name", key=persist("model_name"))
226
- st.number_input("Version",key=persist("version"),step=0.1)
227
- st.text("Intended use:")
228
- left, right = st.columns([4,2])
229
- left.multiselect("Treatment site ICD10",list(icd_map), help="Reference ICD10 WHO: https://icd.who.int/icdapi")
230
- right.multiselect("Treatment modality", list(treatment_mod), help="Reference LOINC Modality Radiation treatment: https://loinc.org/21964-2" )
231
- left, right = st.columns(2)
232
- nlines = left.number_input("Number of prescription levels", 0, 20, 1)
233
- # cols = st.columns(ncol)
234
- for i in range(nlines):
235
- right.number_input(f"Prescription [Gy] # {i}", key=i)
236
- st.text_area("Additional information", placeholder = "Bilateral cases only", help="E.g. Bilateral cases only", key=persist('additional_information'))
237
- st.text_area("Motivation for development", key=persist('motivation'))
238
- st.text_area("Class", placeholder="RULE 11, FROM MDCG 2021-24", key=persist('class'))
239
- st.date_input("Creation date", key=persist('creation_date'))
240
- st.text_area("Type of architecture",value="UNet", key=persist('architecture'))
241
-
242
- st.text("Developed by:")
243
- left, middle, right = st.columns(3)
244
- left.text_input("Name", key=persist('dev_name'))
245
- middle.text_input("Institution", placeholder = "University/clinic/company", key=persist('dev_institution'))
246
- right.text_input("Email", key=persist('dev_email'))
247
-
248
- st.text_area("Funded by", key=persist('fund'))
249
- st.text_area("Shared by", key=persist('shared'))
250
- st.selectbox("License", [""] + list(license_map.values()), help="The license associated with this model.", key=persist("license"))
251
- st.text_area("Fine tuned from model", key=persist('fine_tuned_from'))
252
- st.text_input("Related Research Paper", help="Research paper related to this model.", key=persist("paper_url"))
253
- st.text_input("Related GitHub Repository", help="Link to a GitHub repository used in the development of this model", key=persist("github_url"))
254
- st.text_area("Bibtex Citation", help="Bibtex citations for related work", key=persist("bibtex_citations"))
255
- # st.selectbox("Library Name", [""] + libraries, help="The name of the library this model came from (Ex. pytorch, timm, spacy, keras, etc.). This is usually automatically detected in model repos, so it is not required.", key=persist('library_name'))
256
- # st.text_input("Parent Model (URL)", help="If this model has another model as its base, please provide the URL link to the parent model", key=persist("Parent_Model_name"))
257
- # st.text_input("Datasets (comma separated)", help="The dataset(s) used to train this model. Use dataset id from https://hf.co/datasets.", key=persist("datasets"))
258
- # st.multiselect("Metrics", available_metrics, help="Metrics used in the training/evaluation of this model. Use metric id from https://hf.co/metrics.", key=persist("metrics"))
259
- # st.selectbox("Task", [""] + tasks, help="What task does this model aim to solve?", key=persist('task'))
260
- # st.text_input("Tags (comma separated)", help="Additional tags to add which will be filterable on https://hf.co/models. (Ex. image-classification, vision, resnet)", key=persist("tags"))
261
- # st.text_input("Author(s) (comma separated)", help="The authors who developed this model. If you trained this model, the author is you.", key=persist("the_authors"))
262
- # s
263
- # st.text_input("Carbon Emitted:", help="You can estimate carbon emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700)", key=persist("Model_c02_emitted"))
264
-
265
- # st.header("Technical specifications")
266
- # st.header("Training data, methodology, and results")
267
- # st.header("Evaluation data, methodology, and results / commissioning")
268
- # st.header("Ethical use considerations")
269
-
270
- # warnings setting
271
- languages=st.session_state.languages or None
272
- license=st.session_state.license or None
273
- task = st.session_state.task or None
274
- markdown_upload = st.session_state.markdown_upload
275
- #uploaded_model_card = st.session_state.uploaded_model
276
- # Handle any warnings...
277
- do_warn = False
278
- warning_msg = "Warning: The following fields are required but have not been filled in: "
279
- if not languages:
280
- warning_msg += "\n- Languages"
281
- do_warn = True
282
- if not license:
283
- warning_msg += "\n- License"
284
- do_warn = True
285
- if not task or not markdown_upload:
286
- warning_msg += "\n- Please choose a task or upload a model card"
287
- do_warn = True
288
- if do_warn:
289
- warning_placeholder.error(warning_msg)
290
-
291
- with st.sidebar:
292
-
293
- ######################################################
294
- ### Uploading a model card from local drive
295
- ######################################################
296
- st.markdown("## Upload Model Card")
297
-
298
- st.markdown("#### Model Card must be in markdown (.md) format.")
299
-
300
- # Read a single file
301
- uploaded_file = st.file_uploader("Choose a file", type = ['md'], help = 'Please choose a markdown (.md) file type to upload')
302
- if uploaded_file is not None:
303
-
304
- file_details = {"FileName":uploaded_file.name,"FileType":uploaded_file.type}
305
- name_of_uploaded_file = save_uploadedfile(uploaded_file)
306
-
307
- st.session_state.markdown_upload = name_of_uploaded_file ## uploaded model card
308
-
309
- elif st.session_state.task =='fill-mask' or 'translation' or 'token-classification' or ' sentence-similarity' or 'summarization' or 'question-answering' or 'text2text-generation' or 'text-classification' or 'text-generation' or 'conversational':
310
- #st.session_state.markdown_upload = open(
311
- # "language_model_template1.md", "r+"
312
- #).read()
313
- st.session_state.markdown_upload = "language_model_template1.md" ## language model template
314
-
315
- elif st.session_state.task:
316
-
317
- st.session_state.markdown_upload = "current_card.md" ## default non language model template
318
-
319
- #########################################
320
- ### Uploading model card to HUB
321
- #########################################
322
- out_markdown =open( st.session_state.markdown_upload, "r+"
323
- ).read()
324
- print_out_final = f"{out_markdown}"
325
- st.markdown("## Export Loaded Model Card to Hub")
326
- with st.form("Upload to πŸ€— Hub"):
327
- st.markdown("Use a token with write access from [here](https://hf.co/settings/tokens)")
328
- token = st.text_input("Token", type='password')
329
- repo_id = st.text_input("Repo ID")
330
- submit = st.form_submit_button('Upload to πŸ€— Hub', help='The current model card will be uploaded to a branch in the supplied repo ')
331
-
332
- if submit:
333
- if len(repo_id.split('/')) == 2:
334
- repo_url = create_repo(repo_id, exist_ok=True, token=token)
335
- new_url = card_upload(pj(),repo_id, token=token)
336
- st.success(f"Pushed the card to the repo [here]({new_url})!") # note: was repo_url
337
- else:
338
- st.error("Repo ID invalid. It should be username/repo-name. For example: nateraw/food")
339
-
340
-
341
- #########################################
342
- ### Download model card
343
- #########################################
344
-
345
-
346
- st.markdown("## Download current Model Card")
347
-
348
- if st.session_state.model_name is None or st.session_state.model_name== ' ':
349
- downloaded_file_name = 'current_model_card.md'
350
- else:
351
- downloaded_file_name = st.session_state.model_name+'_'+'model_card.md'
352
- download_status = st.download_button(label = 'Download Model Card', data = pj(), file_name = downloaded_file_name, help = "The current model card will be downloaded as a markdown (.md) file")
353
- if download_status == True:
354
- st.success("Your current model card, successfully downloaded πŸ€—")
355
-
356
-
357
- def page_switcher(page):
358
- st.session_state.runpage = page
359
-
360
- def main():
361
-
362
- st.header("About Model Cards")
363
- st.markdown(Path('about.md').read_text(), unsafe_allow_html=True)
364
- btn = st.button('Create a Model Card πŸ“',on_click=page_switcher,args=(main_page,))
365
- if btn:
366
- st.experimental_rerun() # rerun is needed to clear the page
367
-
368
- if __name__ == '__main__':
369
- load_widget_state()
370
- if 'runpage' not in st.session_state :
371
- st.session_state.runpage = main
372
- st.session_state.runpage()
 
1
+ # from yaml import load
2
+ from persist import persist, load_widget_state
3
+ import streamlit as st
4
+ from io import StringIO
5
+ import tempfile
6
+ from pathlib import Path
7
+ import requests
8
+ from huggingface_hub import hf_hub_download, upload_file
9
+ import pandas as pd
10
+ from huggingface_hub import create_repo
11
+ import os
12
+ from middleMan import parse_into_jinja_markdown as pj
13
+
14
+ import requests
15
+
16
+ @st.cache
17
+ def get_icd():
18
+ # Get ICD10 list
19
+ token_endpoint = 'https://icdaccessmanagement.who.int/connect/token'
20
+ client_id = '3bc9c811-7f2e-4dab-a2dc-940e47a38fef_a6108252-4503-4ff7-90ab-300fd27392aa'
21
+ client_secret = 'xPj7mleWf1Bilu9f7P10UQmBPvL5F6Wgd8/rJhO1T04='
22
+ scope = 'icdapi_access'
23
+ grant_type = 'client_credentials'
24
+ # set data to post
25
+ payload = {'client_id': client_id,
26
+ 'client_secret': client_secret,
27
+ 'scope': scope,
28
+ 'grant_type': grant_type}
29
+ # make request
30
+ r = requests.post(token_endpoint, data=payload, verify=False).json()
31
+ token = r['access_token']
32
+ # access ICD API
33
+ uri = 'https://id.who.int/icd/release/10/2019/C00-C75'
34
+ # HTTP header fields to set
35
+ headers = {'Authorization': 'Bearer '+token,
36
+ 'Accept': 'application/json',
37
+ 'Accept-Language': 'en',
38
+ 'API-Version': 'v2'}
39
+ # make request
40
+ r = requests.get(uri, headers=headers, verify=False)
41
+ print("icd",r.json())
42
+ icd_map =[]
43
+ for child in r.json()['child']:
44
+ r_child = requests.get(child, headers=headers,verify=False).json()
45
+ icd_map.append(r_child["code"]+" "+r_child["title"]["@value"])
46
+ return icd_map
47
+
48
+ @st.cache
49
+ def get_treatment_mod():
50
+ url = "https://clinicaltables.nlm.nih.gov/loinc_answers?loinc_num=21964-2"
51
+ r = requests.get(url).json()
52
+ treatment_mod = [treatment['DisplayText'] for treatment in r]
53
+ return treatment_mod
54
+
55
+
56
+ @st.cache
57
+ def get_cached_data():
58
+ languages_df = pd.read_html("https://hf.co/languages")[0]
59
+ languages_map = pd.Series(languages_df["Language"].values, index=languages_df["ISO code"]).to_dict()
60
+
61
+ license_df = pd.read_html("https://huggingface.co/docs/hub/repositories-licenses")[0]
62
+ license_map = pd.Series(
63
+ license_df["License identifier (to use in repo card)"].values, index=license_df.Fullname
64
+ ).to_dict()
65
+
66
+ available_metrics = [x['id'] for x in requests.get('https://huggingface.co/api/metrics').json()]
67
+
68
+ r = requests.get('https://huggingface.co/api/models-tags-by-type')
69
+ tags_data = r.json()
70
+ libraries = [x['id'] for x in tags_data['library']]
71
+ tasks = [x['id'] for x in tags_data['pipeline_tag']]
72
+
73
+ icd_map = get_icd()
74
+ treatment_mod = get_treatment_mod()
75
+ return languages_map, license_map, available_metrics, libraries, tasks, icd_map, treatment_mod
76
+
77
+
78
+ def card_upload(card_info,repo_id,token):
79
+ #commit_message=None,
80
+ repo_type = "space"
81
+ commit_description=None,
82
+ revision=None,
83
+ create_pr=None
84
+ with tempfile.TemporaryDirectory() as tmpdir:
85
+ tmp_path = Path(tmpdir) / "README.md"
86
+ tmp_path.write_text(str(card_info))
87
+ url = upload_file(
88
+ path_or_fileobj=str(tmp_path),
89
+ path_in_repo="README.md",
90
+ repo_id=repo_id,
91
+ token=token,
92
+ repo_type=repo_type,
93
+ # identical_ok=True,
94
+ revision=revision,
95
+ )
96
+ return url
97
+
98
+ def validate(self, repo_type="model"):
99
+ """Validates card against Hugging Face Hub's model card validation logic.
100
+ Using this function requires access to the internet, so it is only called
101
+ internally by `modelcards.ModelCard.push_to_hub`.
102
+ Args:
103
+ repo_type (`str`, *optional*):
104
+ The type of Hugging Face repo to push to. Defaults to None, which will use
105
+ use "model". Other options are "dataset" and "space".
106
+ """
107
+ if repo_type is None:
108
+ repo_type = "model"
109
+
110
+ # TODO - compare against repo types constant in huggingface_hub if we move this object there.
111
+ if repo_type not in ["model", "space", "dataset"]:
112
+ raise RuntimeError(
113
+ "Provided repo_type '{repo_type}' should be one of ['model', 'space',"
114
+ " 'dataset']."
115
+ )
116
+
117
+ body = {
118
+ "repoType": repo_type,
119
+ "content": str(self),
120
+ }
121
+ headers = {"Accept": "text/plain"}
122
+
123
+ try:
124
+ r = requests.post(
125
+ "https://huggingface.co/api/validate-yaml", body, headers=headers
126
+ )
127
+ r.raise_for_status()
128
+ except requests.exceptions.HTTPError as exc:
129
+ if r.status_code == 400:
130
+ raise RuntimeError(r.text)
131
+ else:
132
+ raise exc
133
+
134
+
135
+ ## Save uploaded [markdown] file to directory to be used by jinja parser function
136
+ def save_uploadedfile(uploadedfile):
137
+ with open(os.path.join("temp_uploaded_filed_Dir",uploadedfile.name),"wb") as f:
138
+ f.write(uploadedfile.getbuffer())
139
+ st.success("Saved File:{} to temp_uploaded_filed_Dir".format(uploadedfile.name))
140
+ return uploadedfile.name
141
+
142
+
143
+ def main_page():
144
+
145
+
146
+ if "model_name" not in st.session_state:
147
+ # Initialize session state.
148
+ st.session_state.update({
149
+ "input_model_name": "",
150
+ "license": "",
151
+ "library_name": "",
152
+ "datasets": "",
153
+ "metrics": [],
154
+ "task": "",
155
+ "tags": "",
156
+ "model_description": "Some cool model...",
157
+ "the_authors":"",
158
+ "Shared_by":"",
159
+ "Model_details_text": "",
160
+ "Model_developers": "",
161
+ "blog_url":"",
162
+ "Parent_Model_url":"",
163
+ "Parent_Model_name":"",
164
+
165
+ "Model_how_to": "",
166
+
167
+ "Model_uses": "",
168
+ "Direct_Use": "",
169
+ "Downstream_Use":"",
170
+ "Out-of-Scope_Use":"",
171
+
172
+ "Model_Limits_n_Risks": "",
173
+ "Recommendations":"",
174
+
175
+ "training_Data": "",
176
+ "model_preprocessing":"",
177
+ "Speeds_Sizes_Times":"",
178
+
179
+
180
+
181
+ "Model_Eval": "",
182
+ "Testing_Data":"",
183
+ "Factors":"",
184
+ "Metrics":"",
185
+ "Model_Results":"",
186
+
187
+ "Model_c02_emitted": "",
188
+ "Model_hardware":"",
189
+ "hours_used":"",
190
+ "Model_cloud_provider":"",
191
+ "Model_cloud_region":"",
192
+
193
+ "Model_cite": "",
194
+ "paper_url": "",
195
+ "github_url": "",
196
+ "bibtex_citation": "",
197
+ "APA_citation":"",
198
+
199
+ "Model_examin":"",
200
+ "Model_card_contact":"",
201
+ "Model_card_authors":"",
202
+ "Glossary":"",
203
+ "More_info":"",
204
+
205
+ "Model_specs":"",
206
+ "compute_infrastructure":"",
207
+ "technical_specs_software":"",
208
+
209
+ "check_box": bool,
210
+ "markdown_upload":" ",
211
+ "legal_view":bool,
212
+ "researcher_view":bool,
213
+ "beginner_technical_view":bool,
214
+ "markdown_state":"",
215
+ })
216
+ ## getting cache for each warnings
217
+ languages_map, license_map, available_metrics, libraries, tasks, icd_map, treatment_mod = get_cached_data()
218
+
219
+ ## form UI setting
220
+ st.header("Model basic information (Dose prediction)")
221
+
222
+ warning_placeholder = st.empty()
223
+
224
+ st.text_input("Model Name", key=persist("model_name"))
225
+ st.number_input("Version",key=persist("version"),step=0.1)
226
+ st.text("Intended use:")
227
+ left, right = st.columns([4,2])
228
+ left.multiselect("Treatment site ICD10",list(icd_map), help="Reference ICD10 WHO: https://icd.who.int/icdapi")
229
+ right.multiselect("Treatment modality", list(treatment_mod), help="Reference LOINC Modality Radiation treatment: https://loinc.org/21964-2" )
230
+ left, right = st.columns(2)
231
+ nlines = int(left.number_input("Number of prescription levels", 0, 20, 1))
232
+ # cols = st.columns(ncol)
233
+ for i in range(nlines):
234
+ right.number_input(f"Prescription [Gy] # {i}", key=i)
235
+ st.text_area("Additional information", placeholder = "Bilateral cases only", help="E.g. Bilateral cases only", key=persist('additional_information'))
236
+ st.text_area("Motivation for development", key=persist('motivation'))
237
+ st.text_area("Class", placeholder="RULE 11, FROM MDCG 2021-24", key=persist('class'))
238
+ st.date_input("Creation date", key=persist('creation_date'))
239
+ st.text_area("Type of architecture",value="UNet", key=persist('architecture'))
240
+
241
+ st.text("Developed by:")
242
+ left, middle, right = st.columns(3)
243
+ left.text_input("Name", key=persist('dev_name'))
244
+ middle.text_input("Institution", placeholder = "University/clinic/company", key=persist('dev_institution'))
245
+ right.text_input("Email", key=persist('dev_email'))
246
+
247
+ st.text_area("Funded by", key=persist('fund'))
248
+ st.text_area("Shared by", key=persist('shared'))
249
+ st.selectbox("License", [""] + list(license_map.values()), help="The license associated with this model.", key=persist("license"))
250
+ st.text_area("Fine tuned from model", key=persist('fine_tuned_from'))
251
+ st.text_input("Related Research Paper", help="Research paper related to this model.", key=persist("paper_url"))
252
+ st.text_input("Related GitHub Repository", help="Link to a GitHub repository used in the development of this model", key=persist("github_url"))
253
+ st.text_area("Bibtex Citation", help="Bibtex citations for related work", key=persist("bibtex_citations"))
254
+ # st.selectbox("Library Name", [""] + libraries, help="The name of the library this model came from (Ex. pytorch, timm, spacy, keras, etc.). This is usually automatically detected in model repos, so it is not required.", key=persist('library_name'))
255
+ # st.text_input("Parent Model (URL)", help="If this model has another model as its base, please provide the URL link to the parent model", key=persist("Parent_Model_name"))
256
+ # st.text_input("Datasets (comma separated)", help="The dataset(s) used to train this model. Use dataset id from https://hf.co/datasets.", key=persist("datasets"))
257
+ # st.multiselect("Metrics", available_metrics, help="Metrics used in the training/evaluation of this model. Use metric id from https://hf.co/metrics.", key=persist("metrics"))
258
+ # st.selectbox("Task", [""] + tasks, help="What task does this model aim to solve?", key=persist('task'))
259
+ # st.text_input("Tags (comma separated)", help="Additional tags to add which will be filterable on https://hf.co/models. (Ex. image-classification, vision, resnet)", key=persist("tags"))
260
+ # st.text_input("Author(s) (comma separated)", help="The authors who developed this model. If you trained this model, the author is you.", key=persist("the_authors"))
261
+ # s
262
+ # st.text_input("Carbon Emitted:", help="You can estimate carbon emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700)", key=persist("Model_c02_emitted"))
263
+
264
+ # st.header("Technical specifications")
265
+ # st.header("Training data, methodology, and results")
266
+ # st.header("Evaluation data, methodology, and results / commissioning")
267
+ # st.header("Ethical use considerations")
268
+
269
+ # warnings setting
270
+ # languages=st.session_state.languages or None
271
+ license=st.session_state.license or None
272
+ task = st.session_state.task or None
273
+ markdown_upload = st.session_state.markdown_upload
274
+ #uploaded_model_card = st.session_state.uploaded_model
275
+ # Handle any warnings...
276
+ do_warn = False
277
+ warning_msg = "Warning: The following fields are required but have not been filled in: "
278
+ if not license:
279
+ warning_msg += "\n- License"
280
+ do_warn = True
281
+ if do_warn:
282
+ warning_placeholder.error(warning_msg)
283
+
284
+ with st.sidebar:
285
+
286
+ ######################################################
287
+ ### Uploading a model card from local drive
288
+ ######################################################
289
+ st.markdown("## Upload Model Card")
290
+
291
+ st.markdown("#### Model Card must be in markdown (.md) format.")
292
+
293
+ # Read a single file
294
+ uploaded_file = st.file_uploader("Choose a file", type = ['md'], help = 'Please choose a markdown (.md) file type to upload')
295
+ if uploaded_file is not None:
296
+
297
+ file_details = {"FileName":uploaded_file.name,"FileType":uploaded_file.type}
298
+ name_of_uploaded_file = save_uploadedfile(uploaded_file)
299
+
300
+ st.session_state.markdown_upload = name_of_uploaded_file ## uploaded model card
301
+
302
+ elif st.session_state.task =='fill-mask' or 'translation' or 'token-classification' or ' sentence-similarity' or 'summarization' or 'question-answering' or 'text2text-generation' or 'text-classification' or 'text-generation' or 'conversational':
303
+ print("YO",st.session_state.task)
304
+ st.session_state.markdown_upload = "language_model_template1.md" ## language model template
305
+
306
+ elif st.session_state.task:
307
+
308
+ st.session_state.markdown_upload = "current_card.md" ## default non language model template
309
+ print("st.session_state.markdown_upload",st.session_state.markdown_upload)
310
+ #########################################
311
+ ### Uploading model card to HUB
312
+ #########################################
313
+ out_markdown =open( st.session_state.markdown_upload, "r+"
314
+ ).read()
315
+ print_out_final = f"{out_markdown}"
316
+ st.markdown("## Export Loaded Model Card to Hub")
317
+ with st.form("Upload to πŸ€— Hub"):
318
+ st.markdown("Use a token with write access from [here](https://hf.co/settings/tokens)")
319
+ token = st.text_input("Token", type='password')
320
+ repo_id = st.text_input("Repo ID")
321
+ submit = st.form_submit_button('Upload to πŸ€— Hub', help='The current model card will be uploaded to a branch in the supplied repo ')
322
+
323
+ if submit:
324
+ if len(repo_id.split('/')) == 2:
325
+ repo_url = "repo"#create_repo(repo_id, exist_ok=True, token=token)
326
+ print("repo_url",repo_url)
327
+ card_info = pj()
328
+ print(card_info)
329
+ new_url = card_upload(card_info,repo_id, token=token)
330
+ st.success(f"Pushed the card to the repo [here]({new_url})!") # note: was repo_url
331
+ else:
332
+ st.error("Repo ID invalid. It should be username/repo-name. For example: nateraw/food")
333
+
334
+
335
+ #########################################
336
+ ### Download model card
337
+ #########################################
338
+
339
+
340
+ st.markdown("## Download current Model Card")
341
+
342
+ if st.session_state.model_name is None or st.session_state.model_name== ' ':
343
+ downloaded_file_name = 'current_model_card.md'
344
+ else:
345
+ downloaded_file_name = st.session_state.model_name+'_'+'model_card.md'
346
+ download_status = st.download_button(label = 'Download Model Card', data = pj(), file_name = downloaded_file_name, help = "The current model card will be downloaded as a markdown (.md) file")
347
+ if download_status == True:
348
+ st.success("Your current model card, successfully downloaded πŸ€—")
349
+
350
+
351
+ def page_switcher(page):
352
+ st.session_state.runpage = page
353
+
354
+ def main():
355
+
356
+ st.header("About Model Cards")
357
+ st.markdown(Path('about.md').read_text(), unsafe_allow_html=True)
358
+ btn = st.button('Create a Model Card πŸ“',on_click=page_switcher,args=(main_page,))
359
+ if btn:
360
+ st.experimental_rerun() # rerun is needed to clear the page
361
+
362
+ if __name__ == '__main__':
363
+ load_widget_state()
364
+ if 'runpage' not in st.session_state :
365
+ st.session_state.runpage = main
366
+ st.session_state.runpage()