quantumiracle-git commited on
Commit
a31d8d1
·
1 Parent(s): dfc0cb3

Update hfserver.py

Browse files
Files changed (1) hide show
  1. hfserver.py +304 -101
hfserver.py CHANGED
@@ -5,14 +5,61 @@ import datetime
5
  import io
6
  import json
7
  import os
 
8
  from abc import ABC, abstractmethod
9
  from typing import TYPE_CHECKING, Any, List, Optional
10
 
11
  import gradio as gr
12
  from gradio import encryptor, utils
 
13
 
14
  if TYPE_CHECKING:
15
- from gradio.components import Component
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  class FlaggingCallback(ABC):
@@ -21,7 +68,7 @@ class FlaggingCallback(ABC):
21
  """
22
 
23
  @abstractmethod
24
- def setup(self, components: List[Component], flagging_dir: str):
25
  """
26
  This method should be overridden and ensure that everything is set up correctly for flag().
27
  This method gets called once at the beginning of the Interface.launch() method.
@@ -54,13 +101,24 @@ class FlaggingCallback(ABC):
54
  pass
55
 
56
 
 
57
  class SimpleCSVLogger(FlaggingCallback):
58
  """
59
- A simple example implementation of the FlaggingCallback abstract class
60
- provided for illustrative purposes.
 
 
 
 
 
 
 
61
  """
62
 
63
- def setup(self, components: List[Component], flagging_dir: str):
 
 
 
64
  self.components = components
65
  self.flagging_dir = flagging_dir
66
  os.makedirs(flagging_dir, exist_ok=True)
@@ -77,33 +135,46 @@ class SimpleCSVLogger(FlaggingCallback):
77
 
78
  csv_data = []
79
  for component, sample in zip(self.components, flag_data):
 
 
 
80
  csv_data.append(
81
- component.save_flagged(
82
- flagging_dir,
83
- component.label,
84
  sample,
 
85
  None,
86
  )
87
  )
88
 
89
  with open(log_filepath, "a", newline="") as csvfile:
90
- writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
91
- writer.writerow(csv_data)
92
 
93
  with open(log_filepath, "r") as csvfile:
94
  line_count = len([None for row in csv.reader(csvfile)]) - 1
95
  return line_count
96
 
97
 
 
98
  class CSVLogger(FlaggingCallback):
99
  """
100
- The default implementation of the FlaggingCallback abstract class.
101
- Logs the input and output data to a CSV file. Supports encryption.
 
 
 
 
 
 
 
102
  """
103
 
 
 
 
104
  def setup(
105
  self,
106
- components: List[Component],
107
  flagging_dir: str,
108
  encryption_key: Optional[str] = None,
109
  ):
@@ -125,22 +196,33 @@ class CSVLogger(FlaggingCallback):
125
 
126
  if flag_index is None:
127
  csv_data = []
128
- for component, sample in zip(self.components, flag_data):
129
- csv_data.append(
130
- component.save_flagged(
131
- flagging_dir,
132
- component.label,
133
- sample,
134
- self.encryption_key,
135
- )
136
- if sample is not None
137
- else ""
138
  )
 
 
 
 
 
 
 
 
 
 
 
 
139
  csv_data.append(flag_option if flag_option is not None else "")
140
  csv_data.append(username if username is not None else "")
141
  csv_data.append(str(datetime.datetime.now()))
142
  if is_new:
143
- headers = [component.label for component in self.components] + [
 
 
 
144
  "flag",
145
  "username",
146
  "timestamp",
@@ -153,14 +235,14 @@ class CSVLogger(FlaggingCallback):
153
  flag_col_index = header.index("flag")
154
  content[flag_index][flag_col_index] = flag_option
155
  output = io.StringIO()
156
- writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
157
- writer.writerows(content)
158
  return output.getvalue()
159
 
160
  if self.encryption_key:
161
  output = io.StringIO()
162
  if not is_new:
163
- with open(log_filepath, "rb") as csvfile:
164
  encrypted_csv = csvfile.read()
165
  decrypted_csv = encryptor.decrypt(
166
  self.encryption_key, encrypted_csv
@@ -169,70 +251,70 @@ class CSVLogger(FlaggingCallback):
169
  if flag_index is not None:
170
  file_content = replace_flag_at_index(file_content)
171
  output.write(file_content)
172
- writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
173
  if flag_index is None:
174
  if is_new:
175
- writer.writerow(headers)
176
- writer.writerow(csv_data)
177
- with open(log_filepath, "wb") as csvfile:
178
  csvfile.write(
179
  encryptor.encrypt(self.encryption_key, output.getvalue().encode())
180
  )
181
  else:
182
  if flag_index is None:
183
- with open(log_filepath, "a", newline="") as csvfile:
184
- writer = csv.writer(
185
- csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'"
186
- )
187
  if is_new:
188
- writer.writerow(headers)
189
- writer.writerow(csv_data)
190
  else:
191
- with open(log_filepath) as csvfile:
192
  file_content = csvfile.read()
193
  file_content = replace_flag_at_index(file_content)
194
  with open(
195
- log_filepath, "w", newline=""
196
  ) as csvfile: # newline parameter needed for Windows
197
- csvfile.write(file_content)
198
- with open(log_filepath, "r") as csvfile:
199
  line_count = len([None for row in csv.reader(csvfile)]) - 1
200
  return line_count
201
 
202
 
 
203
  class HuggingFaceDatasetSaver(FlaggingCallback):
204
  """
205
- A FlaggingCallback that saves flagged data to a HuggingFace dataset.
 
 
 
 
 
 
 
 
 
206
  """
207
 
208
  def __init__(
209
  self,
210
- hf_foken: str,
211
  dataset_name: str,
212
  organization: Optional[str] = None,
213
  private: bool = False,
214
- verbose: bool = True,
215
  ):
216
  """
217
- Params:
218
- hf_token (str): The token to use to access the huggingface API.
219
- dataset_name (str): The name of the dataset to save the data to, e.g.
220
- "image-classifier-1"
221
- organization (str): The name of the organization to which to attach
222
- the datasets. If None, the dataset attaches to the user only.
223
- private (bool): If the dataset does not already exist, whether it
224
- should be created as a private dataset or public. Private datasets
225
- may require paid huggingface.co accounts
226
- verbose (bool): Whether to print out the status of the dataset
227
- creation.
228
  """
229
- self.hf_foken = hf_foken
230
  self.dataset_name = dataset_name
231
  self.organization_name = organization
232
  self.dataset_private = private
233
- self.verbose = verbose
234
 
235
- def setup(self, components: List[Component], flagging_dir: str):
236
  """
237
  Params:
238
  flagging_dir (str): local directory where the dataset is cloned,
@@ -246,9 +328,8 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
246
  "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
247
  )
248
  path_to_dataset_repo = huggingface_hub.create_repo(
249
- # name=self.dataset_name, https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py
250
- repo_id=self.dataset_name,
251
- token=self.hf_foken,
252
  private=self.dataset_private,
253
  repo_type="dataset",
254
  exist_ok=True,
@@ -260,9 +341,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
260
  self.repo = huggingface_hub.Repository(
261
  local_dir=self.dataset_dir,
262
  clone_from=path_to_dataset_repo,
263
- use_auth_token=self.hf_foken,
264
  )
265
- self.repo.git_pull()
266
 
267
  # Should filename be user-specified?
268
  self.log_file = os.path.join(self.dataset_dir, "data.csv")
@@ -275,68 +356,190 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
275
  flag_index: Optional[int] = None,
276
  username: Optional[str] = None,
277
  ) -> int:
 
 
278
  is_new = not os.path.exists(self.log_file)
279
- infos = {"flagged": {"features": {}}}
280
 
281
- with open(self.log_file, "a", newline="") as csvfile:
282
  writer = csv.writer(csvfile)
283
 
284
  # File previews for certain input and output types
285
- file_preview_types = {
286
- gr.inputs.Audio: "Audio",
287
- gr.outputs.Audio: "Audio",
288
- gr.inputs.Image: "Image",
289
- gr.outputs.Image: "Image",
290
- }
291
 
292
  # Generate the headers and dataset_infos
293
  if is_new:
294
- headers = []
295
-
296
- for component, sample in zip(self.components, flag_data):
297
- headers.append(component.label)
298
- headers.append(component.label)
299
- infos["flagged"]["features"][component.label] = {
300
- "dtype": "string",
301
- "_type": "Value",
302
- }
303
- if isinstance(component, tuple(file_preview_types)):
304
- headers.append(component.label + " file")
305
- for _component, _type in file_preview_types.items():
306
- if isinstance(component, _component):
307
- infos["flagged"]["features"][
308
- component.label + " file"
309
- ] = {"_type": _type}
310
- break
311
-
312
- headers.append("flag")
313
- infos["flagged"]["features"]["flag"] = {
314
- "dtype": "string",
315
- "_type": "Value",
316
- }
317
-
318
- writer.writerow(headers)
319
 
320
  # Generate the row corresponding to the flagged sample
321
  csv_data = []
322
  for component, sample in zip(self.components, flag_data):
323
- filepath = component.save_flagged(
324
- self.dataset_dir, component.label, sample, None
 
325
  )
 
326
  csv_data.append(filepath)
327
  if isinstance(component, tuple(file_preview_types)):
328
  csv_data.append(
329
  "{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
330
  )
331
  csv_data.append(flag_option if flag_option is not None else "")
332
- writer.writerow(csv_data)
333
 
334
  if is_new:
335
  json.dump(infos, open(self.infos_file, "w"))
336
 
337
- with open(self.log_file, "r") as csvfile:
338
  line_count = len([None for row in csv.reader(csvfile)]) - 1
339
 
340
  self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
341
 
342
- return line_count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import io
6
  import json
7
  import os
8
+ import uuid
9
  from abc import ABC, abstractmethod
10
  from typing import TYPE_CHECKING, Any, List, Optional
11
 
12
  import gradio as gr
13
  from gradio import encryptor, utils
14
+ from gradio.documentation import document, set_documentation_group
15
 
16
  if TYPE_CHECKING:
17
+ from gradio.components import IOComponent
18
+
19
+ set_documentation_group("flagging")
20
+
21
+
22
+ def _get_dataset_features_info(is_new, components):
23
+ """
24
+ Takes in a list of components and returns a dataset features info
25
+ Parameters:
26
+ is_new: boolean, whether the dataset is new or not
27
+ components: list of components
28
+ Returns:
29
+ infos: a dictionary of the dataset features
30
+ file_preview_types: dictionary mapping of gradio components to appropriate string.
31
+ header: list of header strings
32
+ """
33
+ infos = {"flagged": {"features": {}}}
34
+ # File previews for certain input and output types
35
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
36
+ headers = []
37
+
38
+ # Generate the headers and dataset_infos
39
+ if is_new:
40
+
41
+ for component in components:
42
+ headers.append(component.label)
43
+ infos["flagged"]["features"][component.label] = {
44
+ "dtype": "string",
45
+ "_type": "Value",
46
+ }
47
+ if isinstance(component, tuple(file_preview_types)):
48
+ headers.append(component.label + " file")
49
+ for _component, _type in file_preview_types.items():
50
+ if isinstance(component, _component):
51
+ infos["flagged"]["features"][component.label + " file"] = {
52
+ "_type": _type
53
+ }
54
+ break
55
+
56
+ headers.append("flag")
57
+ infos["flagged"]["features"]["flag"] = {
58
+ "dtype": "string",
59
+ "_type": "Value",
60
+ }
61
+
62
+ return infos, file_preview_types, headers
63
 
64
 
65
  class FlaggingCallback(ABC):
 
68
  """
69
 
70
  @abstractmethod
71
+ def setup(self, components: List[IOComponent], flagging_dir: str):
72
  """
73
  This method should be overridden and ensure that everything is set up correctly for flag().
74
  This method gets called once at the beginning of the Interface.launch() method.
 
101
  pass
102
 
103
 
104
+ @document()
105
  class SimpleCSVLogger(FlaggingCallback):
106
  """
107
+ A simplified implementation of the FlaggingCallback abstract class
108
+ provided for illustrative purposes. Each flagged sample (both the input and output data)
109
+ is logged to a CSV file on the machine running the gradio app.
110
+ Example:
111
+ import gradio as gr
112
+ def image_classifier(inp):
113
+ return {'cat': 0.3, 'dog': 0.7}
114
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
115
+ flagging_callback=SimpleCSVLogger())
116
  """
117
 
118
+ def __init__(self):
119
+ pass
120
+
121
+ def setup(self, components: List[IOComponent], flagging_dir: str):
122
  self.components = components
123
  self.flagging_dir = flagging_dir
124
  os.makedirs(flagging_dir, exist_ok=True)
 
135
 
136
  csv_data = []
137
  for component, sample in zip(self.components, flag_data):
138
+ save_dir = os.path.join(
139
+ flagging_dir, utils.strip_invalid_filename_characters(component.label)
140
+ )
141
  csv_data.append(
142
+ component.deserialize(
 
 
143
  sample,
144
+ save_dir,
145
  None,
146
  )
147
  )
148
 
149
  with open(log_filepath, "a", newline="") as csvfile:
150
+ writer = csv.writer(csvfile)
151
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
152
 
153
  with open(log_filepath, "r") as csvfile:
154
  line_count = len([None for row in csv.reader(csvfile)]) - 1
155
  return line_count
156
 
157
 
158
+ @document()
159
  class CSVLogger(FlaggingCallback):
160
  """
161
+ The default implementation of the FlaggingCallback abstract class. Each flagged
162
+ sample (both the input and output data) is logged to a CSV file with headers on the machine running the gradio app.
163
+ Example:
164
+ import gradio as gr
165
+ def image_classifier(inp):
166
+ return {'cat': 0.3, 'dog': 0.7}
167
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
168
+ flagging_callback=CSVLogger())
169
+ Guides: using_flagging
170
  """
171
 
172
+ def __init__(self):
173
+ pass
174
+
175
  def setup(
176
  self,
177
+ components: List[IOComponent],
178
  flagging_dir: str,
179
  encryption_key: Optional[str] = None,
180
  ):
 
196
 
197
  if flag_index is None:
198
  csv_data = []
199
+ for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
200
+ save_dir = os.path.join(
201
+ flagging_dir,
202
+ utils.strip_invalid_filename_characters(
203
+ component.label or f"component {idx}"
204
+ ),
 
 
 
 
205
  )
206
+ if utils.is_update(sample):
207
+ csv_data.append(str(sample))
208
+ else:
209
+ csv_data.append(
210
+ component.deserialize(
211
+ sample,
212
+ save_dir=save_dir,
213
+ encryption_key=self.encryption_key,
214
+ )
215
+ if sample is not None
216
+ else ""
217
+ )
218
  csv_data.append(flag_option if flag_option is not None else "")
219
  csv_data.append(username if username is not None else "")
220
  csv_data.append(str(datetime.datetime.now()))
221
  if is_new:
222
+ headers = [
223
+ component.label or f"component {idx}"
224
+ for idx, component in enumerate(self.components)
225
+ ] + [
226
  "flag",
227
  "username",
228
  "timestamp",
 
235
  flag_col_index = header.index("flag")
236
  content[flag_index][flag_col_index] = flag_option
237
  output = io.StringIO()
238
+ writer = csv.writer(output)
239
+ writer.writerows(utils.sanitize_list_for_csv(content))
240
  return output.getvalue()
241
 
242
  if self.encryption_key:
243
  output = io.StringIO()
244
  if not is_new:
245
+ with open(log_filepath, "rb", encoding="utf-8") as csvfile:
246
  encrypted_csv = csvfile.read()
247
  decrypted_csv = encryptor.decrypt(
248
  self.encryption_key, encrypted_csv
 
251
  if flag_index is not None:
252
  file_content = replace_flag_at_index(file_content)
253
  output.write(file_content)
254
+ writer = csv.writer(output)
255
  if flag_index is None:
256
  if is_new:
257
+ writer.writerow(utils.sanitize_list_for_csv(headers))
258
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
259
+ with open(log_filepath, "wb", encoding="utf-8") as csvfile:
260
  csvfile.write(
261
  encryptor.encrypt(self.encryption_key, output.getvalue().encode())
262
  )
263
  else:
264
  if flag_index is None:
265
+ with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
266
+ writer = csv.writer(csvfile)
 
 
267
  if is_new:
268
+ writer.writerow(utils.sanitize_list_for_csv(headers))
269
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
270
  else:
271
+ with open(log_filepath, encoding="utf-8") as csvfile:
272
  file_content = csvfile.read()
273
  file_content = replace_flag_at_index(file_content)
274
  with open(
275
+ log_filepath, "w", newline="", encoding="utf-8"
276
  ) as csvfile: # newline parameter needed for Windows
277
+ csvfile.write(utils.sanitize_list_for_csv(file_content))
278
+ with open(log_filepath, "r", encoding="utf-8") as csvfile:
279
  line_count = len([None for row in csv.reader(csvfile)]) - 1
280
  return line_count
281
 
282
 
283
+ @document()
284
  class HuggingFaceDatasetSaver(FlaggingCallback):
285
  """
286
+ A callback that saves each flagged sample (both the input and output data)
287
+ to a HuggingFace dataset.
288
+ Example:
289
+ import gradio as gr
290
+ hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
291
+ def image_classifier(inp):
292
+ return {'cat': 0.3, 'dog': 0.7}
293
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
294
+ allow_flagging="manual", flagging_callback=hf_writer)
295
+ Guides: using_flagging
296
  """
297
 
298
  def __init__(
299
  self,
300
+ hf_token: str,
301
  dataset_name: str,
302
  organization: Optional[str] = None,
303
  private: bool = False,
 
304
  ):
305
  """
306
+ Parameters:
307
+ hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset.
308
+ dataset_name: The name of the dataset to save the data to, e.g. "image-classifier-1"
309
+ organization: The organization to save the dataset under. The hf_token must provide write access to this organization. If not provided, saved under the name of the user corresponding to the hf_token.
310
+ private: Whether the dataset should be private (defaults to False).
 
 
 
 
 
 
311
  """
312
+ self.hf_token = hf_token
313
  self.dataset_name = dataset_name
314
  self.organization_name = organization
315
  self.dataset_private = private
 
316
 
317
+ def setup(self, components: List[IOComponent], flagging_dir: str):
318
  """
319
  Params:
320
  flagging_dir (str): local directory where the dataset is cloned,
 
328
  "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
329
  )
330
  path_to_dataset_repo = huggingface_hub.create_repo(
331
+ name=self.dataset_name,
332
+ token=self.hf_token,
 
333
  private=self.dataset_private,
334
  repo_type="dataset",
335
  exist_ok=True,
 
341
  self.repo = huggingface_hub.Repository(
342
  local_dir=self.dataset_dir,
343
  clone_from=path_to_dataset_repo,
344
+ use_auth_token=self.hf_token,
345
  )
346
+ self.repo.git_pull(lfs=True)
347
 
348
  # Should filename be user-specified?
349
  self.log_file = os.path.join(self.dataset_dir, "data.csv")
 
356
  flag_index: Optional[int] = None,
357
  username: Optional[str] = None,
358
  ) -> int:
359
+ self.repo.git_pull(lfs=True)
360
+
361
  is_new = not os.path.exists(self.log_file)
 
362
 
363
+ with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile:
364
  writer = csv.writer(csvfile)
365
 
366
  # File previews for certain input and output types
367
+ infos, file_preview_types, headers = _get_dataset_features_info(
368
+ is_new, self.components
369
+ )
 
 
 
370
 
371
  # Generate the headers and dataset_infos
372
  if is_new:
373
+ writer.writerow(utils.sanitize_list_for_csv(headers))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
  # Generate the row corresponding to the flagged sample
376
  csv_data = []
377
  for component, sample in zip(self.components, flag_data):
378
+ save_dir = os.path.join(
379
+ self.dataset_dir,
380
+ utils.strip_invalid_filename_characters(component.label),
381
  )
382
+ filepath = component.deserialize(sample, save_dir, None)
383
  csv_data.append(filepath)
384
  if isinstance(component, tuple(file_preview_types)):
385
  csv_data.append(
386
  "{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
387
  )
388
  csv_data.append(flag_option if flag_option is not None else "")
389
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
390
 
391
  if is_new:
392
  json.dump(infos, open(self.infos_file, "w"))
393
 
394
+ with open(self.log_file, "r", encoding="utf-8") as csvfile:
395
  line_count = len([None for row in csv.reader(csvfile)]) - 1
396
 
397
  self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
398
 
399
+ return line_count
400
+
401
+
402
+ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
403
+ """
404
+ A FlaggingCallback that saves flagged data to a Hugging Face dataset in JSONL format.
405
+ Each data sample is saved in a different JSONL file,
406
+ allowing multiple users to use flagging simultaneously.
407
+ Saving to a single CSV would cause errors as only one user can edit at the same time.
408
+ """
409
+
410
+ def __init__(
411
+ self,
412
+ hf_foken: str,
413
+ dataset_name: str,
414
+ organization: Optional[str] = None,
415
+ private: bool = False,
416
+ verbose: bool = True,
417
+ ):
418
+ """
419
+ Params:
420
+ hf_token (str): The token to use to access the huggingface API.
421
+ dataset_name (str): The name of the dataset to save the data to, e.g.
422
+ "image-classifier-1"
423
+ organization (str): The name of the organization to which to attach
424
+ the datasets. If None, the dataset attaches to the user only.
425
+ private (bool): If the dataset does not already exist, whether it
426
+ should be created as a private dataset or public. Private datasets
427
+ may require paid huggingface.co accounts
428
+ verbose (bool): Whether to print out the status of the dataset
429
+ creation.
430
+ """
431
+ self.hf_foken = hf_foken
432
+ self.dataset_name = dataset_name
433
+ self.organization_name = organization
434
+ self.dataset_private = private
435
+ self.verbose = verbose
436
+
437
+ def setup(self, components: List[IOComponent], flagging_dir: str):
438
+ """
439
+ Params:
440
+ components List[Component]: list of components for flagging
441
+ flagging_dir (str): local directory where the dataset is cloned,
442
+ updated, and pushed from.
443
+ """
444
+ try:
445
+ import huggingface_hub
446
+ except (ImportError, ModuleNotFoundError):
447
+ raise ImportError(
448
+ "Package `huggingface_hub` not found is needed "
449
+ "for HuggingFaceDatasetJSONSaver. Try 'pip install huggingface_hub'."
450
+ )
451
+ path_to_dataset_repo = huggingface_hub.create_repo(
452
+ name=self.dataset_name,
453
+ token=self.hf_foken,
454
+ private=self.dataset_private,
455
+ repo_type="dataset",
456
+ exist_ok=True,
457
+ )
458
+ self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
459
+ self.components = components
460
+ self.flagging_dir = flagging_dir
461
+ self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
462
+ self.repo = huggingface_hub.Repository(
463
+ local_dir=self.dataset_dir,
464
+ clone_from=path_to_dataset_repo,
465
+ use_auth_token=self.hf_foken,
466
+ )
467
+ self.repo.git_pull(lfs=True)
468
+
469
+ self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
470
+
471
+ def flag(
472
+ self,
473
+ flag_data: List[Any],
474
+ flag_option: Optional[str] = None,
475
+ flag_index: Optional[int] = None,
476
+ username: Optional[str] = None,
477
+ ) -> int:
478
+ self.repo.git_pull(lfs=True)
479
+
480
+ # Generate unique folder for the flagged sample
481
+ unique_name = self.get_unique_name() # unique name for folder
482
+ folder_name = os.path.join(
483
+ self.dataset_dir, unique_name
484
+ ) # unique folder for specific example
485
+ os.makedirs(folder_name)
486
+
487
+ # Now uses the existence of `dataset_infos.json` to determine if new
488
+ is_new = not os.path.exists(self.infos_file)
489
+
490
+ # File previews for certain input and output types
491
+ infos, file_preview_types, _ = _get_dataset_features_info(
492
+ is_new, self.components
493
+ )
494
+
495
+ # Generate the row and header corresponding to the flagged sample
496
+ csv_data = []
497
+ headers = []
498
+
499
+ for component, sample in zip(self.components, flag_data):
500
+ headers.append(component.label)
501
+
502
+ try:
503
+ filepath = component.save_flagged(
504
+ folder_name, component.label, sample, None
505
+ )
506
+ except Exception:
507
+ # Could not parse 'sample' (mostly) because it was None and `component.save_flagged`
508
+ # does not handle None cases.
509
+ # for example: Label (line 3109 of components.py raises an error if data is None)
510
+ filepath = None
511
+
512
+ if isinstance(component, tuple(file_preview_types)):
513
+ headers.append(component.label + " file")
514
+
515
+ csv_data.append(
516
+ "{}/resolve/main/{}/{}".format(
517
+ self.path_to_dataset_repo, unique_name, filepath
518
+ )
519
+ if filepath is not None
520
+ else None
521
+ )
522
+
523
+ csv_data.append(filepath)
524
+ headers.append("flag")
525
+ csv_data.append(flag_option if flag_option is not None else "")
526
+
527
+ # Creates metadata dict from row data and dumps it
528
+ metadata_dict = {
529
+ header: _csv_data for header, _csv_data in zip(headers, csv_data)
530
+ }
531
+ self.dump_json(metadata_dict, os.path.join(folder_name, "metadata.jsonl"))
532
+
533
+ if is_new:
534
+ json.dump(infos, open(self.infos_file, "w"))
535
+
536
+ self.repo.push_to_hub(commit_message="Flagged sample {}".format(unique_name))
537
+ return unique_name
538
+
539
+ def get_unique_name(self):
540
+ id = uuid.uuid4()
541
+ return str(id)
542
+
543
+ def dump_json(self, thing: dict, file_path: str) -> None:
544
+ with open(file_path, "w+", encoding="utf8") as f:
545
+ json.dump(thing, f)