meg-huggingface commited on
Commit
0803ab3
1 Parent(s): a2ae370

Standardizing filenaming a bit.

Browse files
data_measurements/dataset_statistics.py CHANGED
@@ -244,38 +244,61 @@ class DatasetStatisticsCacheClass:
244
  # path to the directory used for caching
245
  if not isinstance(text_field, str):
246
  text_field = "-".join(text_field)
247
- if isinstance(label_field, str):
248
- label_field = label_field
249
- else:
250
- label_field = "-".join(label_field)
251
  self.cache_path = pjoin(
252
  self.cache_dir,
253
- f"{dset_name}_{dset_config}_{split_name}_{text_field}_{label_field}",
254
  )
255
  if not isdir(self.cache_path):
256
  logs.warning("Creating cache directory %s." % self.cache_path)
257
  mkdir(self.cache_path)
 
 
258
  self.dset_fid = pjoin(self.cache_path, "base_dset")
259
- self.dset_peek_fid = pjoin(self.cache_path, "dset_peek.json")
260
- self.text_dset_fid = pjoin(self.cache_path, "text_dset")
261
  self.tokenized_df_fid = pjoin(self.cache_path, "tokenized_df.feather")
262
  self.label_dset_fid = pjoin(self.cache_path, "label_dset")
 
 
 
 
 
 
 
 
 
 
 
 
263
  self.length_df_fid = pjoin(self.cache_path, "length_df.feather")
264
- self.length_stats_fid = pjoin(self.cache_path, "length_stats.json")
 
265
  self.vocab_counts_df_fid = pjoin(self.cache_path, "vocab_counts.feather")
266
- self.general_stats_fid = pjoin(self.cache_path, "general_stats_dict.json")
267
- self.dup_counts_df_fid = pjoin(
268
- self.cache_path, "dup_counts_df.feather"
269
- )
 
 
 
 
 
270
  self.sorted_top_vocab_df_fid = pjoin(self.cache_path,
271
  "sorted_top_vocab.feather")
272
- self.fig_tok_length_fid = pjoin(self.cache_path, "fig_tok_length.json")
273
- self.fig_labels_fid = pjoin(self.cache_path, "fig_labels.json")
274
- self.node_list_fid = pjoin(self.cache_path, "node_list.th")
275
- self.fig_tree_fid = pjoin(self.cache_path, "fig_tree.json")
276
  self.zipf_fid = pjoin(self.cache_path, "zipf_basic_stats.json")
 
277
  self.zipf_fig_fid = pjoin(self.cache_path, "zipf_fig.json")
278
 
 
 
 
 
 
 
279
  def get_base_dataset(self):
280
  """Gets a pointer to the truncated base dataset object."""
281
  if not self.dset:
@@ -301,7 +324,7 @@ class DatasetStatisticsCacheClass:
301
  # General statistics
302
  if (
303
  self.use_cache
304
- and exists(self.general_stats_fid)
305
  and exists(self.dup_counts_df_fid)
306
  and exists(self.sorted_top_vocab_df_fid)
307
  ):
@@ -313,7 +336,7 @@ class DatasetStatisticsCacheClass:
313
  if save:
314
  write_df(self.sorted_top_vocab_df, self.sorted_top_vocab_df_fid)
315
  write_df(self.dup_counts_df, self.dup_counts_df_fid)
316
- write_json(self.general_stats_dict, self.general_stats_fid)
317
 
318
 
319
  def load_or_prepare_text_lengths(self, save=True):
@@ -343,8 +366,8 @@ class DatasetStatisticsCacheClass:
343
  write_df(self.length_df, self.length_df_fid)
344
 
345
  # Text length stats.
346
- if self.use_cache and exists(self.length_stats_fid):
347
- with open(self.length_stats_fid, "r") as f:
348
  self.length_stats_dict = json.load(f)
349
  self.avg_length = self.length_stats_dict["avg length"]
350
  self.std_length = self.length_stats_dict["std length"]
@@ -352,7 +375,7 @@ class DatasetStatisticsCacheClass:
352
  else:
353
  self.prepare_text_length_stats()
354
  if save:
355
- write_json(self.length_stats_dict, self.length_stats_fid)
356
 
357
  def prepare_length_df(self):
358
  if self.tokenized_df is None:
@@ -382,15 +405,15 @@ class DatasetStatisticsCacheClass:
382
  self.fig_tok_length = make_fig_lengths(self.tokenized_df, LENGTH_FIELD)
383
 
384
  def load_or_prepare_embeddings(self, save=True):
385
- if self.use_cache and exists(self.node_list_fid) and exists(self.fig_tree_fid):
386
  self.node_list = torch.load(self.node_list_fid)
387
- self.fig_tree = read_plotly(self.fig_tree_fid)
388
  elif self.use_cache and exists(self.node_list_fid):
389
  self.node_list = torch.load(self.node_list_fid)
390
  self.fig_tree = make_tree_plot(self.node_list,
391
  self.text_dset)
392
  if save:
393
- write_plotly(self.fig_tree, self.fig_tree_fid)
394
  else:
395
  self.embeddings = Embeddings(self, use_cache=self.use_cache)
396
  self.embeddings.make_hierarchical_clustering()
@@ -399,7 +422,7 @@ class DatasetStatisticsCacheClass:
399
  self.text_dset)
400
  if save:
401
  torch.save(self.node_list, self.node_list_fid)
402
- write_plotly(self.fig_tree, self.fig_tree_fid)
403
 
404
  # get vocab with word counts
405
  def load_or_prepare_vocab(self, save=True):
@@ -457,7 +480,7 @@ class DatasetStatisticsCacheClass:
457
  write_df(self.dup_counts_df, self.dup_counts_df_fid)
458
 
459
  def load_general_stats(self):
460
- self.general_stats_dict = json.load(open(self.general_stats_fid, encoding="utf-8"))
461
  with open(self.sorted_top_vocab_df_fid, "rb") as f:
462
  self.sorted_top_vocab_df = feather.read_feather(f)
463
  self.text_nan_count = self.general_stats_dict[TEXT_NAN_CNT]
@@ -520,15 +543,15 @@ class DatasetStatisticsCacheClass:
520
  self.load_or_prepare_dset_peek(save)
521
 
522
  def load_or_prepare_dset_peek(self, save=True):
523
- if self.use_cache and exists(self.dset_peek_fid):
524
- with open(self.dset_peek_fid, "r") as f:
525
  self.dset_peek = json.load(f)["dset peek"]
526
  else:
527
  if self.dset is None:
528
  self.get_base_dataset()
529
  self.dset_peek = self.dset[:100]
530
  if save:
531
- write_json({"dset peek": self.dset_peek}, self.dset_peek_fid)
532
 
533
  def load_or_prepare_tokenized_df(self, save=True):
534
  if (self.use_cache and exists(self.tokenized_df_fid)):
@@ -611,8 +634,8 @@ class DatasetStatisticsCacheClass:
611
  """
612
  # extracted labels
613
  if len(self.label_field) > 0:
614
- if self.use_cache and exists(self.fig_labels_fid):
615
- self.fig_labels = read_plotly(self.fig_labels_fid)
616
  elif self.use_cache and exists(self.label_dset_fid):
617
  # load extracted labels
618
  self.label_dset = load_from_disk(self.label_dset_fid)
@@ -621,13 +644,13 @@ class DatasetStatisticsCacheClass:
621
  self.label_df, self.label_names, OUR_LABEL_FIELD
622
  )
623
  if save:
624
- write_plotly(self.fig_labels, self.fig_labels_fid)
625
  else:
626
  self.prepare_labels()
627
  if save:
628
  # save extracted label instances
629
  self.label_dset.save_to_disk(self.label_dset_fid)
630
- write_plotly(self.fig_labels, self.fig_labels_fid)
631
 
632
  def prepare_labels(self):
633
  self.get_base_dataset()
 
244
  # path to the directory used for caching
245
  if not isinstance(text_field, str):
246
  text_field = "-".join(text_field)
247
+ #if isinstance(label_field, str):
248
+ # label_field = label_field
249
+ #else:
250
+ # label_field = "-".join(label_field)
251
  self.cache_path = pjoin(
252
  self.cache_dir,
253
+ f"{dset_name}_{dset_config}_{split_name}_{text_field}", #{label_field},
254
  )
255
  if not isdir(self.cache_path):
256
  logs.warning("Creating cache directory %s." % self.cache_path)
257
  mkdir(self.cache_path)
258
+
259
+ # Cache files not needed for UI
260
  self.dset_fid = pjoin(self.cache_path, "base_dset")
 
 
261
  self.tokenized_df_fid = pjoin(self.cache_path, "tokenized_df.feather")
262
  self.label_dset_fid = pjoin(self.cache_path, "label_dset")
263
+
264
+ # Needed for UI -- embeddings
265
+ self.text_dset_fid = pjoin(self.cache_path, "text_dset")
266
+ # Needed for UI
267
+ self.dset_peek_json_fid = pjoin(self.cache_path, "dset_peek.json")
268
+
269
+ ## Label cache files.
270
+ # Needed for UI
271
+ self.fig_labels_json_fid = pjoin(self.cache_path, "fig_labels.json")
272
+
273
+ ## Length cache files
274
+ # Needed for UI
275
  self.length_df_fid = pjoin(self.cache_path, "length_df.feather")
276
+ # Needed for UI
277
+ self.length_stats_json_fid = pjoin(self.cache_path, "length_stats.json")
278
  self.vocab_counts_df_fid = pjoin(self.cache_path, "vocab_counts.feather")
279
+ # Needed for UI
280
+ self.dup_counts_df_fid = pjoin(self.cache_path, "dup_counts_df.feather")
281
+ # Needed for UI
282
+ self.fig_tok_length_fid = pjoin(self.cache_path, "fig_tok_length.json")
283
+
284
+ ## General text stats
285
+ # Needed for UI
286
+ self.general_stats_json_fid = pjoin(self.cache_path, "general_stats_dict.json")
287
+ # Needed for UI
288
  self.sorted_top_vocab_df_fid = pjoin(self.cache_path,
289
  "sorted_top_vocab.feather")
290
+ ## Zipf cache files
291
+ # Needed for UI
 
 
292
  self.zipf_fid = pjoin(self.cache_path, "zipf_basic_stats.json")
293
+ # Needed for UI
294
  self.zipf_fig_fid = pjoin(self.cache_path, "zipf_fig.json")
295
 
296
+ ## Embeddings cache files
297
+ # Needed for UI
298
+ self.node_list_fid = pjoin(self.cache_path, "node_list.th")
299
+ # Needed for UI
300
+ self.fig_tree_json_fid = pjoin(self.cache_path, "fig_tree.json")
301
+
302
  def get_base_dataset(self):
303
  """Gets a pointer to the truncated base dataset object."""
304
  if not self.dset:
 
324
  # General statistics
325
  if (
326
  self.use_cache
327
+ and exists(self.general_stats_json_fid)
328
  and exists(self.dup_counts_df_fid)
329
  and exists(self.sorted_top_vocab_df_fid)
330
  ):
 
336
  if save:
337
  write_df(self.sorted_top_vocab_df, self.sorted_top_vocab_df_fid)
338
  write_df(self.dup_counts_df, self.dup_counts_df_fid)
339
+ write_json(self.general_stats_dict, self.general_stats_json_fid)
340
 
341
 
342
  def load_or_prepare_text_lengths(self, save=True):
 
366
  write_df(self.length_df, self.length_df_fid)
367
 
368
  # Text length stats.
369
+ if self.use_cache and exists(self.length_stats_json_fid):
370
+ with open(self.length_stats_json_fid, "r") as f:
371
  self.length_stats_dict = json.load(f)
372
  self.avg_length = self.length_stats_dict["avg length"]
373
  self.std_length = self.length_stats_dict["std length"]
 
375
  else:
376
  self.prepare_text_length_stats()
377
  if save:
378
+ write_json(self.length_stats_dict, self.length_stats_json_fid)
379
 
380
  def prepare_length_df(self):
381
  if self.tokenized_df is None:
 
405
  self.fig_tok_length = make_fig_lengths(self.tokenized_df, LENGTH_FIELD)
406
 
407
  def load_or_prepare_embeddings(self, save=True):
408
+ if self.use_cache and exists(self.node_list_fid) and exists(self.fig_tree_json_fid):
409
  self.node_list = torch.load(self.node_list_fid)
410
+ self.fig_tree = read_plotly(self.fig_tree_json_fid)
411
  elif self.use_cache and exists(self.node_list_fid):
412
  self.node_list = torch.load(self.node_list_fid)
413
  self.fig_tree = make_tree_plot(self.node_list,
414
  self.text_dset)
415
  if save:
416
+ write_plotly(self.fig_tree, self.fig_tree_json_fid)
417
  else:
418
  self.embeddings = Embeddings(self, use_cache=self.use_cache)
419
  self.embeddings.make_hierarchical_clustering()
 
422
  self.text_dset)
423
  if save:
424
  torch.save(self.node_list, self.node_list_fid)
425
+ write_plotly(self.fig_tree, self.fig_tree_json_fid)
426
 
427
  # get vocab with word counts
428
  def load_or_prepare_vocab(self, save=True):
 
480
  write_df(self.dup_counts_df, self.dup_counts_df_fid)
481
 
482
  def load_general_stats(self):
483
+ self.general_stats_dict = json.load(open(self.general_stats_json_fid, encoding="utf-8"))
484
  with open(self.sorted_top_vocab_df_fid, "rb") as f:
485
  self.sorted_top_vocab_df = feather.read_feather(f)
486
  self.text_nan_count = self.general_stats_dict[TEXT_NAN_CNT]
 
543
  self.load_or_prepare_dset_peek(save)
544
 
545
  def load_or_prepare_dset_peek(self, save=True):
546
+ if self.use_cache and exists(self.dset_peek_json_fid):
547
+ with open(self.dset_peek_json_fid, "r") as f:
548
  self.dset_peek = json.load(f)["dset peek"]
549
  else:
550
  if self.dset is None:
551
  self.get_base_dataset()
552
  self.dset_peek = self.dset[:100]
553
  if save:
554
+ write_json({"dset peek": self.dset_peek}, self.dset_peek_json_fid)
555
 
556
  def load_or_prepare_tokenized_df(self, save=True):
557
  if (self.use_cache and exists(self.tokenized_df_fid)):
 
634
  """
635
  # extracted labels
636
  if len(self.label_field) > 0:
637
+ if self.use_cache and exists(self.fig_labels_json_fid):
638
+ self.fig_labels = read_plotly(self.fig_labels_json_fid)
639
  elif self.use_cache and exists(self.label_dset_fid):
640
  # load extracted labels
641
  self.label_dset = load_from_disk(self.label_dset_fid)
 
644
  self.label_df, self.label_names, OUR_LABEL_FIELD
645
  )
646
  if save:
647
+ write_plotly(self.fig_labels, self.fig_labels_json_fid)
648
  else:
649
  self.prepare_labels()
650
  if save:
651
  # save extracted label instances
652
  self.label_dset.save_to_disk(self.label_dset_fid)
653
+ write_plotly(self.fig_labels, self.fig_labels_json_fid)
654
 
655
  def prepare_labels(self):
656
  self.get_base_dataset()
run_data_measurements.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import textwrap
4
+ from os.path import join as pjoin
5
+
6
+ from data_measurements import dataset_statistics
7
+ from data_measurements import dataset_utils
8
+
9
+
10
+ def load_or_prepare_widgets(ds_args, show_embeddings=False, use_cache=False):
11
+ """
12
+ Loader specifically for the widgets used in the app.
13
+ Args:
14
+ ds_args:
15
+ show_embeddings:
16
+ use_cache:
17
+
18
+ Returns:
19
+
20
+ """
21
+ dstats = dataset_statistics.DatasetStatisticsCacheClass(**ds_args,
22
+ use_cache=use_cache)
23
+ # Header widget
24
+ dstats.load_or_prepare_dset_peek()
25
+ # General stats widget
26
+ dstats.load_or_prepare_general_stats()
27
+ # Labels widget
28
+ dstats.load_or_prepare_labels()
29
+ # Text lengths widget
30
+ dstats.load_or_prepare_text_lengths()
31
+ if show_embeddings:
32
+ # Embeddings widget
33
+ dstats.load_or_prepare_embeddings()
34
+ # Text duplicates widget
35
+ dstats.load_or_prepare_text_duplicates()
36
+ # nPMI widget
37
+ dstats.load_or_prepare_npmi()
38
+ npmi_stats = dstats.npmi_stats
39
+ # Handling for all pairs; in the UI, people select.
40
+ do_npmi(npmi_stats)
41
+ # Zipf widget
42
+ dstats.load_or_prepare_zipf()
43
+
44
+
45
+ def load_or_prepare(dataset_args, do_html=False, use_cache=False):
46
+ all = False
47
+ dstats = dataset_statistics.DatasetStatisticsCacheClass(**dataset_args, use_cache=use_cache)
48
+ print("Loading dataset.")
49
+ dstats.load_or_prepare_dataset()
50
+ print("Dataset loaded. Preparing vocab.")
51
+ dstats.load_or_prepare_vocab()
52
+ print("Vocab prepared.")
53
+
54
+ if not dataset_args["calculation"]:
55
+ all = True
56
+
57
+ if all or dataset_args["calculation"] == "general":
58
+ print("\n* Calculating general statistics.")
59
+ dstats.load_or_prepare_general_stats()
60
+ print("Done!")
61
+ print("Basic text statistics now available at %s." % dstats.general_stats_json_fid)
62
+ print(
63
+ "Text duplicates now available at %s." % dstats.dup_counts_df_fid
64
+ )
65
+
66
+ if all or dataset_args["calculation"] == "lengths":
67
+ print("\n* Calculating text lengths.")
68
+ fig_tok_length_fid = pjoin(dstats.cache_path, "lengths_fig.html")
69
+ tok_length_json_fid = pjoin(dstats.cache_path, "lengths.json")
70
+ dstats.load_or_prepare_text_lengths()
71
+ with open(tok_length_json_fid, "w+") as f:
72
+ json.dump(dstats.fig_tok_length.to_json(), f)
73
+ print("Token lengths now available at %s." % tok_length_json_fid)
74
+ if do_html:
75
+ dstats.fig_tok_length.write_html(fig_tok_length_fid)
76
+ print("Figure saved to %s." % fig_tok_length_fid)
77
+ print("Done!")
78
+
79
+ if (all and dstats.label_field) or dataset_args["calculation"] == "labels":
80
+ if not dstats.label_field:
81
+ print("Warning: You asked for label calculation, but didn't provide the labels field name. Assuming it is 'label'...")
82
+ dstats.set_label_field("label")
83
+ print("\n* Calculating label distribution.")
84
+ dstats.load_or_prepare_labels()
85
+ fig_label_html = pjoin(dstats.cache_path, "labels_fig.html")
86
+ fig_label_json = pjoin(dstats.cache_path, "labels.json")
87
+ dstats.fig_labels.write_html(fig_label_html)
88
+ with open(fig_label_json, "w+") as f:
89
+ json.dump(dstats.fig_labels.to_json(), f)
90
+ print("Done!")
91
+ print("Label distribution now available at %s." % dstats.label_dset_fid)
92
+ print("Figure saved to %s." % fig_label_html)
93
+
94
+ if all or dataset_args["calculation"] == "npmi":
95
+ print("\n* Preparing nPMI.")
96
+ npmi_stats = dataset_statistics.nPMIStatisticsCacheClass(
97
+ dstats, use_cache=use_cache
98
+ )
99
+ do_npmi(npmi_stats, use_cache=use_cache)
100
+ print("Done!")
101
+ print(
102
+ "nPMI results now available in %s for all identity terms that "
103
+ "occur more than 10 times and all words that "
104
+ "co-occur with both terms."
105
+ % npmi_stats.pmi_cache_path
106
+ )
107
+
108
+ if all or dataset_args["calculation"] == "zipf":
109
+ print("\n* Preparing Zipf.")
110
+ zipf_fig_fid = pjoin(dstats.cache_path, "zipf_fig.html")
111
+ zipf_json_fid = pjoin(dstats.cache_path, "zipf_fig.json")
112
+ dstats.load_or_prepare_zipf()
113
+ zipf_fig = dstats.zipf_fig
114
+ with open(zipf_json_fid, "w+") as f:
115
+ json.dump(zipf_fig.to_json(), f)
116
+ zipf_fig.write_html(zipf_fig_fid)
117
+ print("Done!")
118
+ print("Zipf results now available at %s." % dstats.zipf_fid)
119
+ print(
120
+ "Figure saved to %s, with corresponding json at %s."
121
+ % (zipf_fig_fid, zipf_json_fid)
122
+ )
123
+
124
+ # Don't do this one until someone specifically asks for it -- takes awhile.
125
+ if dataset_args["calculation"] == "embeddings":
126
+ print("\n* Preparing text embeddings.")
127
+ dstats.load_or_prepare_embeddings()
128
+
129
+
130
+ def do_npmi(npmi_stats, use_cache=True):
131
+ available_terms = npmi_stats.load_or_prepare_npmi_terms()
132
+ completed_pairs = {}
133
+ print("Iterating through terms for joint npmi.")
134
+ for term1 in available_terms:
135
+ for term2 in available_terms:
136
+ if term1 != term2:
137
+ sorted_terms = tuple(sorted([term1, term2]))
138
+ if sorted_terms not in completed_pairs:
139
+ term1, term2 = sorted_terms
140
+ print("Computing nPMI statistics for %s and %s" % (term1, term2))
141
+ _ = npmi_stats.load_or_prepare_joint_npmi(sorted_terms)
142
+ completed_pairs[tuple(sorted_terms)] = {}
143
+
144
+
145
+ def get_text_label_df(
146
+ ds_name,
147
+ config_name,
148
+ split_name,
149
+ text_field,
150
+ label_field,
151
+ calculation,
152
+ out_dir,
153
+ do_html=False,
154
+ use_cache=True,
155
+ ):
156
+ if not use_cache:
157
+ print("Not using any cache; starting afresh")
158
+ ds_name_to_dict = dataset_utils.get_dataset_info_dicts(ds_name)
159
+ if label_field:
160
+ label_field, label_names = (
161
+ ds_name_to_dict[ds_name][config_name]["features"][label_field][0]
162
+ if len(ds_name_to_dict[ds_name][config_name]["features"][label_field]) > 0
163
+ else ((), [])
164
+ )
165
+ else:
166
+ label_field = ()
167
+ label_names = []
168
+ dataset_args = {
169
+ "dset_name": ds_name,
170
+ "dset_config": config_name,
171
+ "split_name": split_name,
172
+ "text_field": text_field,
173
+ "label_field": label_field,
174
+ "label_names": label_names,
175
+ "calculation": calculation,
176
+ "cache_dir": out_dir,
177
+ }
178
+ load_or_prepare_widgets(dataset_args, use_cache=use_cache)
179
+
180
+
181
+ def main():
182
+ # TODO: Make this the Hugging Face arg parser
183
+ parser = argparse.ArgumentParser(
184
+ formatter_class=argparse.RawDescriptionHelpFormatter,
185
+ description=textwrap.dedent(
186
+ """
187
+
188
+ Example for hate speech18 dataset:
189
+ python3 run_data_measurements.py --dataset="hate_speech18" --config="default" --split="train" --feature="text"
190
+
191
+ Example for Glue dataset:
192
+ python3 run_data_measurements.py --dataset="glue" --config="ax" --split="train" --feature="premise"
193
+
194
+ Example for IMDB dataset:
195
+ python3 run_data_measurements.py --dataset="imdb" --config="plain_text" --split="train" --label_field="label" --feature="text"
196
+ """
197
+ ),
198
+ )
199
+
200
+ parser.add_argument(
201
+ "-d", "--dataset", required=True, help="Name of dataset to prepare"
202
+ )
203
+ parser.add_argument(
204
+ "-c", "--config", required=True, help="Dataset configuration to prepare"
205
+ )
206
+ parser.add_argument(
207
+ "-s", "--split", required=True, type=str, help="Dataset split to prepare"
208
+ )
209
+ parser.add_argument(
210
+ "-f",
211
+ "--feature",
212
+ required=True,
213
+ type=str,
214
+ default="text",
215
+ help="Text column to prepare",
216
+ )
217
+ parser.add_argument(
218
+ "-w",
219
+ "--calculation",
220
+ help="""What to calculate (defaults to everything except embeddings).\n
221
+ Options are:\n
222
+
223
+ - `general` (for duplicate counts, missing values, length statistics.)\n
224
+
225
+ - `lengths` for text length distribution\n
226
+
227
+ - `labels` for label distribution\n
228
+
229
+ - `embeddings` (Warning: Slow.)\n
230
+
231
+ - `npmi` for word associations\n
232
+
233
+ - `zipf` for zipfian statistics
234
+ """,
235
+ )
236
+ parser.add_argument(
237
+ "-l",
238
+ "--label_field",
239
+ type=str,
240
+ required=False,
241
+ default="",
242
+ help="Field name for label column in dataset (Required if there is a label field that you want information about)",
243
+ )
244
+ parser.add_argument(
245
+ "--cached",
246
+ default=False,
247
+ required=False,
248
+ action="store_true",
249
+ help="Whether to use cached files (Optional)",
250
+ )
251
+ parser.add_argument(
252
+ "--do_html",
253
+ default=False,
254
+ required=False,
255
+ action="store_true",
256
+ help="Whether to write out corresponding HTML files (Optional)",
257
+ )
258
+ parser.add_argument("--out_dir", default="cache_dir", help="Where to write out to.")
259
+
260
+ args = parser.parse_args()
261
+ print("Proceeding with the following arguments:")
262
+ print(args)
263
+ # run_data_measurements.py -n hate_speech18 -c default -s train -f text -w npmi
264
+ get_text_label_df(
265
+ args.dataset,
266
+ args.config,
267
+ args.split,
268
+ args.feature,
269
+ args.label_field,
270
+ args.calculation,
271
+ args.out_dir,
272
+ do_html=args.do_html,
273
+ use_cache=args.cached,
274
+ )
275
+ print()
276
+
277
+
278
+ if __name__ == "__main__":
279
+ main()