Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 1,734 Bytes
			
			| 0fd2f06 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | from tqdm import tqdm
import numpy as np
import argparse
import torch
import lmdb
import glob
import os
from utils.lmdb import store_arrays_to_lmdb, process_data_dict
def main():
    """
    Aggregate all ode pairs inside a folder into a lmdb dataset.
    Each pt file should contain a (key, value) pair representing a
    video's ODE trajectories.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str,
                        required=True, help="path to ode pairs")
    parser.add_argument("--lmdb_path", type=str,
                        required=True, help="path to lmdb")
    args = parser.parse_args()
    all_files = sorted(glob.glob(os.path.join(args.data_path, "*.pt")))
    # figure out the maximum map size needed
    total_array_size = 5000000000000  # adapt to your need, set to 5TB by default
    env = lmdb.open(args.lmdb_path, map_size=total_array_size * 2)
    counter = 0
    seen_prompts = set()  # for deduplication
    for index, file in tqdm(enumerate(all_files)):
        # read from disk
        data_dict = torch.load(file)
        data_dict = process_data_dict(data_dict, seen_prompts)
        # write to lmdb file
        store_arrays_to_lmdb(env, data_dict, start_index=counter)
        counter += len(data_dict['prompts'])
    # save each entry's shape to lmdb
    with env.begin(write=True) as txn:
        for key, val in data_dict.items():
            print(key, val)
            array_shape = np.array(val.shape)
            array_shape[0] = counter
            shape_key = f"{key}_shape".encode()
            shape_str = " ".join(map(str, array_shape))
            txn.put(shape_key, shape_str.encode())
if __name__ == "__main__":
    main()
 |