Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import random | |
| import unittest | |
| from typing import Any, Iterable, Iterator, Tuple | |
| from densepose.data import CombinedDataLoader | |
| def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]: | |
| """ | |
| Group elements of an iterable by chunks of size `n`, e.g. | |
| grouper(range(9), 4) -> | |
| (0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None) | |
| """ | |
| it = iter(iterable) | |
| while True: | |
| values = [] | |
| for _ in range(n): | |
| try: | |
| value = next(it) | |
| except StopIteration: | |
| values.extend([fillvalue] * (n - len(values))) | |
| yield tuple(values) | |
| return | |
| values.append(value) | |
| yield tuple(values) | |
| class TestCombinedDataLoader(unittest.TestCase): | |
| def test_combine_loaders_1(self): | |
| loader1 = _grouper([f"1_{i}" for i in range(10)], 2) | |
| loader2 = _grouper([f"2_{i}" for i in range(11)], 3) | |
| batch_size = 4 | |
| ratios = (0.1, 0.9) | |
| random.seed(43) | |
| combined = CombinedDataLoader((loader1, loader2), batch_size, ratios) | |
| BATCHES_GT = [ | |
| ["1_0", "1_1", "2_0", "2_1"], | |
| ["2_2", "2_3", "2_4", "2_5"], | |
| ["1_2", "1_3", "2_6", "2_7"], | |
| ["2_8", "2_9", "2_10", None], | |
| ] | |
| for i, batch in enumerate(combined): | |
| self.assertEqual(len(batch), batch_size) | |
| self.assertEqual(batch, BATCHES_GT[i]) | |