Spaces:
Build error
Build error
| import logging | |
| import random | |
| import webdataset as wds | |
| from webdataset.tariterators import group_by_keys, tar_file_expander, url_opener | |
| from m4.training.types import DatasetTypes | |
| meta_prefix = "__" | |
| meta_suffix = "__" | |
| logger = logging.getLogger(__name__) | |
| trace = False | |
| def webdoc_valid_sample(sample): | |
| """Check whether a sample is valid. | |
| :param sample: sample to be checked | |
| """ | |
| return ( | |
| sample is not None | |
| and isinstance(sample, dict) | |
| and len(list(sample.keys())) > 0 | |
| and not sample.get("__bad__", False) | |
| and sample_has_all_files(sample) | |
| ) | |
| def sample_has_all_files(current_sample): | |
| meta = current_sample.get("metadata.value", None) | |
| if meta is None: | |
| return False | |
| meta = meta.decode("utf-8") | |
| if len(meta) == 0: | |
| return False | |
| target_file_list = meta.split("\n") | |
| fname_keys = [key for key in current_sample.keys() if key.endswith(".fname")] | |
| fnames = [current_sample[key] for key in fname_keys] | |
| check = all([fname in fnames for fname in target_file_list]) | |
| if not check: | |
| return False | |
| return True | |
| class ImageDecoder: | |
| def __call__(self, bytes_): | |
| import io | |
| import PIL.Image | |
| img = PIL.Image.open(io.BytesIO(bytes_)) | |
| img.load() | |
| return img | |
| # Taken from https://github.com/mlfoundations/open_clip/blob/c48111dacac55db24878af229d8a5662c03e6f1c/src/training/data.py#L180-L183 | |
| def log_and_continue(exn): | |
| """Call in an exception handler to ignore any exception, issue a warning, and continue.""" | |
| logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") | |
| return True | |
| # Adapt group_by_keys to our webdocument format in which each samples contains several text and image files | |
| # https://github.com/webdataset/webdataset/blob/039d74319ae55e5696dcef89829be9671802cf70/webdataset/tariterators.py#L195-L250 | |
| def group_by_keys_interleaved(data, handler=log_and_continue): | |
| """Return function over iterator that groups key, value pairs into samples.""" | |
| current_sample = None | |
| for filesample in data: | |
| try: | |
| assert isinstance(filesample, dict) | |
| fname, value = filesample["fname"], filesample["data"] | |
| fname = fname.strip("./") | |
| if fname.endswith(".metadata.txt"): | |
| prefix, data_type, extension = fname.split(".") | |
| suffix = data_type | |
| else: | |
| prefix, idx, data_type, extension = fname.split(".") | |
| if data_type not in ["text", "image"]: | |
| raise ValueError(f"{fname}: unknown data type {data_type}") | |
| suffix = idx | |
| if trace: | |
| print( | |
| f"prefix: {prefix}, idx: {idx}, data_type: {data_type}, extension: {extension}, keys:" | |
| f" {current_sample.keys() if isinstance(current_sample, dict) else None}" | |
| ) | |
| if prefix is None: | |
| continue | |
| if current_sample is None or prefix != current_sample["__key__"]: | |
| valid = webdoc_valid_sample(current_sample) | |
| if valid: | |
| yield current_sample | |
| elif current_sample is not None: | |
| logging.warning(f"{fname}: invalid sample {current_sample} ignored") | |
| current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) | |
| if suffix in current_sample: | |
| raise ValueError(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}") | |
| current_sample[f"{suffix}.value"] = value | |
| current_sample[f"{suffix}.type"] = data_type | |
| current_sample[f"{suffix}.fname"] = fname | |
| except Exception as exn: | |
| exn.args = exn.args + (filesample.get("stream"), filesample.get("url")) | |
| if handler(exn): | |
| continue | |
| else: | |
| break | |
| if webdoc_valid_sample(current_sample): | |
| yield current_sample | |
| def _tarfile_to_webdocument_samples(src, handler=log_and_continue): | |
| streams = url_opener(src, handler=handler) | |
| files = tar_file_expander(streams, handler=handler) | |
| samples = group_by_keys_interleaved(files, handler=handler) | |
| return samples | |
| tarfile_to_webdocument_samples = wds.filters.pipelinefilter(_tarfile_to_webdocument_samples) | |
| def _collate_texts_and_images_webdocument(data, handler=log_and_continue): | |
| for sample in data: | |
| try: | |
| max_example_indices = max( | |
| [int(key.split(".")[0]) for key in sample.keys() if key.endswith(".value") and key != "metadata.value"] | |
| ) | |
| texts = [None for _ in range(max_example_indices + 1)] | |
| images = [None for _ in range(max_example_indices + 1)] | |
| for idx in range(max_example_indices + 1): | |
| if f"{idx}.value" not in sample: | |
| continue | |
| if "text" in sample[f"{idx}.type"]: | |
| texts[idx] = sample[f"{idx}.value"] | |
| elif "image" in sample[f"{idx}.type"]: | |
| images[idx] = sample[f"{idx}.value"] | |
| else: | |
| raise ValueError(f"Unknown data type: {sample[f'{idx}.type']}") | |
| example = {"__key__": sample["__key__"], "__url__": sample["__url__"], "texts": texts, "images": images} | |
| yield example | |
| except Exception as exn: | |
| exn.args = exn.args + (sample.get("stream"), sample.get("url")) | |
| if handler(exn): | |
| continue | |
| else: | |
| break | |
| collate_texts_and_images_webdocument = wds.filters.pipelinefilter(_collate_texts_and_images_webdocument) | |
| def _decode_image_and_text_webdocument(data, handler=log_and_continue): | |
| image_decoder = ImageDecoder() | |
| for sample in data: | |
| try: | |
| sample["images"] = [image_decoder(image) if image is not None else None for image in sample["images"]] | |
| sample["texts"] = [text.decode("utf-8") if text is not None else None for text in sample["texts"]] | |
| yield sample | |
| except Exception as exn: | |
| exn.args = exn.args + (sample.get("stream"), sample.get("url")) | |
| if handler(exn): | |
| continue | |
| else: | |
| break | |
| decode_image_and_text_webdocument = wds.filters.pipelinefilter(_decode_image_and_text_webdocument) | |
| def collate_dicts(samples): | |
| keys = samples[0].keys() | |
| batched_samples = {key: [sample[key] for sample in samples] for key in keys} | |
| return batched_samples | |
| def get_webdocuments_webdataset( | |
| urls, | |
| batch_size, | |
| shuffle_initial_urls_list=False, | |
| shuffle_before_split_by_node_buffer_size=100, | |
| shuffle_before_split_by_worker_buffer_size=100, | |
| shuffle_after_tarfile_to_samples_buffer_size=100, | |
| shuffle_after_batching_buffer_size=1000, | |
| ): | |
| if shuffle_initial_urls_list: | |
| random.shuffle(urls) | |
| pipeline_list = [wds.SimpleShardList(urls)] | |
| if shuffle_before_split_by_node_buffer_size is not None: | |
| pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size)) | |
| pipeline_list.append(wds.split_by_node) | |
| if shuffle_before_split_by_worker_buffer_size is not None: | |
| pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size)) | |
| pipeline_list.extend( | |
| [ | |
| wds.split_by_worker, | |
| tarfile_to_webdocument_samples(), | |
| ] | |
| ) | |
| if shuffle_after_tarfile_to_samples_buffer_size is not None: | |
| pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size)) | |
| pipeline_list.extend( | |
| [ | |
| collate_texts_and_images_webdocument(), | |
| decode_image_and_text_webdocument(), | |
| wds.batched(batch_size, collation_fn=collate_dicts, partial=True), | |
| ] | |
| ) | |
| if shuffle_after_batching_buffer_size is not None: | |
| pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size)) | |
| dataset = wds.DataPipeline(pipeline_list) | |
| return dataset | |
| def split_keep_2(x): | |
| x = x.strip("./") | |
| x_splitter = x.split(".") | |
| return x_splitter[0], x_splitter[1] | |
| def _tarfile_to_pair_samples(src, handler=log_and_continue): | |
| streams = url_opener(src, handler=handler) | |
| files = tar_file_expander(streams, handler=handler) | |
| samples = group_by_keys(files, keys=split_keep_2, handler=handler) | |
| return samples | |
| tarfile_to_pair_samples = wds.filters.pipelinefilter(_tarfile_to_pair_samples) | |
| def _decode_image_and_text_pairs(data, handler=log_and_continue): | |
| image_decoder = ImageDecoder() | |
| for sample in data: | |
| try: | |
| sample["image"] = image_decoder(sample["image"]) | |
| sample["text"] = sample["text"].decode("utf-8") | |
| yield sample | |
| except Exception as exn: | |
| exn.args = exn.args + (sample.get("stream"), sample.get("url")) | |
| if handler(exn): | |
| continue | |
| else: | |
| break | |
| decode_image_and_text_pairs = wds.filters.pipelinefilter(_decode_image_and_text_pairs) | |
| def get_image_caption_pairs_webdataset( | |
| urls, | |
| batch_size, | |
| shuffle_initial_urls_list=False, | |
| shuffle_before_split_by_node_buffer_size=100, | |
| shuffle_before_split_by_worker_buffer_size=100, | |
| shuffle_after_tarfile_to_samples_buffer_size=100, | |
| shuffle_after_batching_buffer_size=1000, | |
| ): | |
| if shuffle_initial_urls_list: | |
| random.shuffle(urls) | |
| pipeline_list = [wds.SimpleShardList(urls)] | |
| if shuffle_before_split_by_node_buffer_size is not None: | |
| pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size)) | |
| pipeline_list.append(wds.split_by_node) | |
| if shuffle_before_split_by_worker_buffer_size is not None: | |
| pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size)) | |
| pipeline_list.extend( | |
| [ | |
| wds.split_by_worker, | |
| tarfile_to_pair_samples(handler=log_and_continue), | |
| ] | |
| ) | |
| if shuffle_after_tarfile_to_samples_buffer_size is not None: | |
| pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size)) | |
| pipeline_list.extend( | |
| [ | |
| decode_image_and_text_pairs(), | |
| wds.batched(batch_size, collation_fn=collate_dicts, partial=True), # todo: check if partial is needed | |
| ] | |
| ) | |
| if shuffle_after_batching_buffer_size is not None: | |
| pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size)) | |
| dataset = wds.DataPipeline(pipeline_list) | |
| return dataset | |
| def get_webdataset( | |
| urls, | |
| ds_type: DatasetTypes, | |
| batch_size: int, | |
| shuffle_initial_urls_list, | |
| shuffle_before_split_by_node_buffer_size, | |
| shuffle_before_split_by_worker_buffer_size, | |
| shuffle_after_tarfile_to_samples_buffer_size, | |
| shuffle_after_batching_buffer_size, | |
| ): | |
| if ds_type == DatasetTypes.WEB_DOCUMENTS: | |
| return get_webdocuments_webdataset( | |
| urls, | |
| batch_size, | |
| shuffle_initial_urls_list, | |
| shuffle_before_split_by_node_buffer_size, | |
| shuffle_before_split_by_worker_buffer_size, | |
| shuffle_after_tarfile_to_samples_buffer_size, | |
| shuffle_after_batching_buffer_size, | |
| ) | |
| elif ds_type == DatasetTypes.IMAGE_CAPTION_PAIRS: | |
| return get_image_caption_pairs_webdataset( | |
| urls, | |
| batch_size, | |
| shuffle_initial_urls_list, | |
| shuffle_before_split_by_node_buffer_size, | |
| shuffle_before_split_by_worker_buffer_size, | |
| shuffle_after_tarfile_to_samples_buffer_size, | |
| shuffle_after_batching_buffer_size, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown dataset type: {ds_type}") | |
| def check_webdataset_command(command): | |
| if "s3:/" not in command: | |
| return True | |
| command = command.strip() | |
| if not command.startswith("pipe:bash"): | |
| return False | |
| if not command.endswith(".tar"): | |
| return False | |
| if "get_file.sh" not in command: | |
| return False | |
| return True | |