Spaces:
Runtime error
Runtime error
| from tqdm import tqdm | |
| import numpy as np | |
| import argparse | |
| import torch | |
| import lmdb | |
| import glob | |
| import os | |
| from utils.lmdb import store_arrays_to_lmdb, process_data_dict | |
| def main(): | |
| """ | |
| Aggregate all ode pairs inside a folder into a lmdb dataset. | |
| Each pt file should contain a (key, value) pair representing a | |
| video's ODE trajectories. | |
| """ | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--data_path", type=str, | |
| required=True, help="path to ode pairs") | |
| parser.add_argument("--lmdb_path", type=str, | |
| required=True, help="path to lmdb") | |
| args = parser.parse_args() | |
| all_files = sorted(glob.glob(os.path.join(args.data_path, "*.pt"))) | |
| # figure out the maximum map size needed | |
| total_array_size = 5000000000000 # adapt to your need, set to 5TB by default | |
| env = lmdb.open(args.lmdb_path, map_size=total_array_size * 2) | |
| counter = 0 | |
| seen_prompts = set() # for deduplication | |
| for index, file in tqdm(enumerate(all_files)): | |
| # read from disk | |
| data_dict = torch.load(file) | |
| data_dict = process_data_dict(data_dict, seen_prompts) | |
| # write to lmdb file | |
| store_arrays_to_lmdb(env, data_dict, start_index=counter) | |
| counter += len(data_dict['prompts']) | |
| # save each entry's shape to lmdb | |
| with env.begin(write=True) as txn: | |
| for key, val in data_dict.items(): | |
| print(key, val) | |
| array_shape = np.array(val.shape) | |
| array_shape[0] = counter | |
| shape_key = f"{key}_shape".encode() | |
| shape_str = " ".join(map(str, array_shape)) | |
| txn.put(shape_key, shape_str.encode()) | |
| if __name__ == "__main__": | |
| main() | |