Spaces:
Build error
Build error
| import argparse | |
| import json | |
| import textwrap | |
| from os import mkdir | |
| from os.path import join as pjoin, isdir | |
| from data_measurements import dataset_statistics | |
| from data_measurements import dataset_utils | |
| def load_or_prepare_widgets(ds_args, show_embeddings=False, use_cache=False): | |
| """ | |
| Loader specifically for the widgets used in the app -- does not compute | |
| intermediate files, unless they are not there and are needed for a file | |
| used in the UI. | |
| Does not take specifications from user; does all widgets. | |
| Args: | |
| ds_args: Dataset configuration settings (config name, split, etc) | |
| show_embeddings: Whether to compute embeddings (slow) | |
| use_cache: Whether to grab files that have already been computed | |
| Returns: | |
| Saves files to disk in cache_dir, if user has not specified another dir. | |
| """ | |
| if not isdir(ds_args["cache_dir"]): | |
| print("Creating cache") | |
| # We need to preprocess everything. | |
| # This should eventually all go into a prepare_dataset CLI | |
| mkdir(ds_args["cache_dir"]) | |
| dstats = dataset_statistics.DatasetStatisticsCacheClass(**ds_args, | |
| use_cache=use_cache) | |
| # Embeddings widget | |
| dstats.load_or_prepare_dataset() | |
| # Header widget | |
| dstats.load_or_prepare_dset_peek() | |
| # General stats widget | |
| dstats.load_or_prepare_general_stats() | |
| # Labels widget | |
| try: | |
| dstats.set_label_field(ds_args['label_field']) | |
| dstats.load_or_prepare_labels() | |
| except: | |
| pass | |
| # Text lengths widget | |
| dstats.load_or_prepare_text_lengths() | |
| if show_embeddings: | |
| # Embeddings widget | |
| dstats.load_or_prepare_embeddings() | |
| # Text duplicates widget | |
| dstats.load_or_prepare_text_duplicates() | |
| # nPMI widget | |
| dstats.load_or_prepare_npmi() | |
| npmi_stats = dstats.npmi_stats | |
| # Handling for all pairs; in the UI, people select. | |
| do_npmi(npmi_stats) | |
| # Zipf widget | |
| dstats.load_or_prepare_zipf() | |
| def load_or_prepare(dataset_args, use_cache=False): | |
| """ | |
| Users can specify which aspects of the dataset they would like to compute. | |
| This additionally computes intermediate files not used in the UI. | |
| If the calculation flag is not specified by the user (-w), calculates all | |
| except for embeddings, as those are quite time consuming so should be | |
| specified separately. | |
| Args: | |
| dataset_args: Dataset configuration settings (config name, split, etc) | |
| use_cache: Whether to grab files that have already been computed | |
| Returns: | |
| Saves files to disk in cache_dir, if user has not specified another dir. | |
| """ | |
| all = False | |
| dstats = dataset_statistics.DatasetStatisticsCacheClass(**dataset_args, | |
| use_cache=use_cache) | |
| print("Loading dataset.") | |
| dstats.load_or_prepare_dataset() | |
| print("Dataset loaded. Preparing vocab.") | |
| dstats.load_or_prepare_vocab() | |
| print("Vocab prepared.") | |
| if not dataset_args["calculation"]: | |
| all = True | |
| if all or dataset_args["calculation"] == "general": | |
| print("\n* Calculating general statistics.") | |
| dstats.load_or_prepare_general_stats() | |
| print("Done!") | |
| print("Basic text statistics now available at %s." % | |
| dstats.general_stats_json_fid) | |
| print( | |
| "Text duplicates now available at %s." % dstats.dup_counts_df_fid | |
| ) | |
| if all or dataset_args["calculation"] == "lengths": | |
| print("\n* Calculating text lengths.") | |
| dstats.load_or_prepare_text_lengths() | |
| print("Done!") | |
| if all or dataset_args["calculation"] == "labels": | |
| if not dstats.label_field: | |
| print("Warning: You asked for label calculation, but didn't " | |
| "provide the labels field name. Assuming it is 'label'...") | |
| dstats.set_label_field("label") | |
| else: | |
| print("\n* Calculating label distribution.") | |
| dstats.load_or_prepare_labels() | |
| fig_label_html = pjoin(dstats.cache_path, "labels_fig.html") | |
| fig_label_json = pjoin(dstats.cache_path, "labels.json") | |
| dstats.fig_labels.write_html(fig_label_html) | |
| with open(fig_label_json, "w+") as f: | |
| json.dump(dstats.fig_labels.to_json(), f) | |
| print("Done!") | |
| print("Label distribution now available at %s." % | |
| dstats.label_dset_fid) | |
| print("Figure saved to %s." % fig_label_html) | |
| if all or dataset_args["calculation"] == "npmi": | |
| print("\n* Preparing nPMI.") | |
| npmi_stats = dataset_statistics.nPMIStatisticsCacheClass( | |
| dstats, use_cache=use_cache | |
| ) | |
| do_npmi(npmi_stats) | |
| print("Done!") | |
| print( | |
| "nPMI results now available in %s for all identity terms that " | |
| "occur more than 10 times and all words that " | |
| "co-occur with both terms." | |
| % npmi_stats.pmi_cache_path | |
| ) | |
| if all or dataset_args["calculation"] == "zipf": | |
| print("\n* Preparing Zipf.") | |
| zipf_fig_fid = pjoin(dstats.cache_path, "zipf_fig.html") | |
| zipf_json_fid = pjoin(dstats.cache_path, "zipf_fig.json") | |
| dstats.load_or_prepare_zipf() | |
| zipf_fig = dstats.zipf_fig | |
| with open(zipf_json_fid, "w+") as f: | |
| json.dump(zipf_fig.to_json(), f) | |
| zipf_fig.write_html(zipf_fig_fid) | |
| print("Done!") | |
| print("Zipf results now available at %s." % dstats.zipf_fid) | |
| print( | |
| "Figure saved to %s, with corresponding json at %s." | |
| % (zipf_fig_fid, zipf_json_fid) | |
| ) | |
| # Don't do this one until someone specifically asks for it -- takes awhile. | |
| if dataset_args["calculation"] == "embeddings": | |
| print("\n* Preparing text embeddings.") | |
| dstats.load_or_prepare_embeddings() | |
| def do_npmi(npmi_stats): | |
| available_terms = npmi_stats.load_or_prepare_npmi_terms() | |
| completed_pairs = {} | |
| print("Iterating through terms for joint npmi.") | |
| for term1 in available_terms: | |
| for term2 in available_terms: | |
| if term1 != term2: | |
| sorted_terms = tuple(sorted([term1, term2])) | |
| if sorted_terms not in completed_pairs: | |
| term1, term2 = sorted_terms | |
| print("Computing nPMI statistics for %s and %s" % (term1, term2)) | |
| _ = npmi_stats.load_or_prepare_joint_npmi(sorted_terms) | |
| completed_pairs[tuple(sorted_terms)] = {} | |
| def get_text_label_df( | |
| ds_name, | |
| config_name, | |
| split_name, | |
| text_field, | |
| label_field, | |
| calculation, | |
| out_dir, | |
| use_cache=True, | |
| ): | |
| if not use_cache: | |
| print("Not using any cache; starting afresh") | |
| ds_name_to_dict = dataset_utils.get_dataset_info_dicts(ds_name) | |
| if label_field: | |
| label_field, label_names = ( | |
| ds_name_to_dict[ds_name][config_name]["features"][label_field][0] | |
| if len(ds_name_to_dict[ds_name][config_name]["features"][label_field]) > 0 | |
| else ((), []) | |
| ) | |
| else: | |
| label_field = () | |
| label_names = [] | |
| dataset_args = { | |
| "dset_name": ds_name, | |
| "dset_config": config_name, | |
| "split_name": split_name, | |
| "text_field": text_field, | |
| "label_field": label_field, | |
| "label_names": label_names, | |
| "calculation": calculation, | |
| "cache_dir": out_dir, | |
| } | |
| load_or_prepare(dataset_args, use_cache=use_cache) | |
| def main(): | |
| # TODO: Make this the Hugging Face arg parser | |
| parser = argparse.ArgumentParser( | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| description=textwrap.dedent( | |
| """ | |
| Example for hate speech18 dataset: | |
| python3 run_data_measurements.py --dataset="hate_speech18" --config="default" --split="train" --feature="text" | |
| Example for IMDB dataset: | |
| python3 run_data_measurements.py --dataset="imdb" --config="plain_text" --split="train" --label_field="label" --feature="text" | |
| """ | |
| ), | |
| ) | |
| parser.add_argument( | |
| "-d", "--dataset", required=True, help="Name of dataset to prepare" | |
| ) | |
| parser.add_argument( | |
| "-c", "--config", required=True, help="Dataset configuration to prepare" | |
| ) | |
| parser.add_argument( | |
| "-s", "--split", required=True, type=str, help="Dataset split to prepare" | |
| ) | |
| parser.add_argument( | |
| "-f", | |
| "--feature", | |
| required=True, | |
| type=str, | |
| default="text", | |
| help="Text column to prepare", | |
| ) | |
| parser.add_argument( | |
| "-w", | |
| "--calculation", | |
| help="""What to calculate (defaults to everything except embeddings).\n | |
| Options are:\n | |
| - `general` (for duplicate counts, missing values, length statistics.)\n | |
| - `lengths` for text length distribution\n | |
| - `labels` for label distribution\n | |
| - `embeddings` (Warning: Slow.)\n | |
| - `npmi` for word associations\n | |
| - `zipf` for zipfian statistics | |
| """, | |
| ) | |
| parser.add_argument( | |
| "-l", | |
| "--label_field", | |
| type=str, | |
| required=False, | |
| default="", | |
| help="Field name for label column in dataset (Required if there is a label field that you want information about)", | |
| ) | |
| parser.add_argument( | |
| "--cached", | |
| default=False, | |
| required=False, | |
| action="store_true", | |
| help="Whether to use cached files (Optional)", | |
| ) | |
| parser.add_argument( | |
| "--do_html", | |
| default=False, | |
| required=False, | |
| action="store_true", | |
| help="Whether to write out corresponding HTML files (Optional)", | |
| ) | |
| parser.add_argument("--out_dir", default="cache_dir", help="Where to write out to.") | |
| args = parser.parse_args() | |
| print("Proceeding with the following arguments:") | |
| print(args) | |
| # run_data_measurements.py -d hate_speech18 -c default -s train -f text -w npmi | |
| get_text_label_df(args.dataset, args.config, args.split, args.feature, | |
| args.label_field, args.calculation, args.out_dir, | |
| use_cache=args.cached) | |
| print() | |
| if __name__ == "__main__": | |
| main() | |