File size: 3,294 Bytes
4c94b0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import webdataset as wds
import soundfile as sf
import io
import os
import random
import copy
from tqdm import tqdm
import shutil
import argparse
import traceback
import logging
import json
from open_clip import tokenize


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--tar-path",
        type=str,
        default=None,
        help="Path to the tars",
    )
    parser.add_argument(
        "--start",
        type=int,
        default=0,
        help="start from tar-path + start",
    )
    parser.add_argument(
        "--end",
        type=int,
        default=99999,
        help="end with tar-path + end",
    )
    parser.add_argument(
        "--exclude",
        nargs='+',
        default=None,
        help="exclude tar-path + exclude",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--order",
        default=False,
        action='store_true',
        help="if keep the search order accendingly",
    )
    args = parser.parse_args()
    return args

def log_and_continue(exn):
    """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
    logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
    return True

def preprocess(
    sample,
):
    """
    Preprocess a single sample for wdsdataloader.
    """
    audio_ext = "flac"
    text_ext = "json"
    audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
    json_dict_raw = json.loads(sample[text_ext].decode("utf-8"))
    sample["waveform"] = audio_data
    texts = json_dict_raw["text"]
    if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
        texts = random.choice(texts)
    sample["raw_text"] = texts
    sample["text"] = tokenize(texts)
    return sample

if __name__ == "__main__":
    args = parse_args()
    tar_path = args.tar_path
    idx_list = list(range(args.start, args.end))
    if args.exclude != None:
        for x in args.exclude:
            idx_list.remove(x)
    if not args.order:
        random.shuffle(idx_list)
    if "aws" in tar_path:
        args.local = False
    if args.local:
        input_shards = [os.path.join(args.tar_path, str(i)+".tar") for i in idx_list]
    else:
        input_shards = [os.path.join(args.tar_path, str(i)+".tar -") for i in idx_list]
    pipeline = [wds.SimpleShardList(input_shards)]
    pipeline.extend(
        [
            wds.split_by_node,
            wds.split_by_worker,
            wds.tarfile_to_samples(handler=log_and_continue),
            wds.map(preprocess),
            wds.to_tuple("__url__", "__key__", "waveform"),
            wds.batched(1),
        ]
    )
    dataset = wds.DataPipeline(*pipeline)
    dataloader = wds.WebLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
    old_k = 0
    old_batch = None
    try:
        for k, batch in tqdm(enumerate(dataloader)):
            print("k:", k)
            print("batch:", batch)
            old_k = k
            old_batch = copy.deepcopy(batch)
    except:
        with open("check_tar_log.txt","a") as file:
            traceback.print_exc(file = file)
        print("old_k:", old_k)
        print("old_batch:", old_batch)
        pass