Spaces:
Runtime error
Runtime error
| from utils.lmdb import get_array_shape_from_lmdb, retrieve_row_from_lmdb | |
| from torch.utils.data import Dataset | |
| import numpy as np | |
| import torch | |
| import lmdb | |
| import json | |
| from pathlib import Path | |
| from PIL import Image | |
| import os | |
| class TextDataset(Dataset): | |
| def __init__(self, prompt_path, extended_prompt_path=None): | |
| with open(prompt_path, encoding="utf-8") as f: | |
| self.prompt_list = [line.rstrip() for line in f] | |
| if extended_prompt_path is not None: | |
| with open(extended_prompt_path, encoding="utf-8") as f: | |
| self.extended_prompt_list = [line.rstrip() for line in f] | |
| assert len(self.extended_prompt_list) == len(self.prompt_list) | |
| else: | |
| self.extended_prompt_list = None | |
| def __len__(self): | |
| return len(self.prompt_list) | |
| def __getitem__(self, idx): | |
| batch = { | |
| "prompts": self.prompt_list[idx], | |
| "idx": idx, | |
| } | |
| if self.extended_prompt_list is not None: | |
| batch["extended_prompts"] = self.extended_prompt_list[idx] | |
| return batch | |
| class ODERegressionLMDBDataset(Dataset): | |
| def __init__(self, data_path: str, max_pair: int = int(1e8)): | |
| self.env = lmdb.open(data_path, readonly=True, | |
| lock=False, readahead=False, meminit=False) | |
| self.latents_shape = get_array_shape_from_lmdb(self.env, 'latents') | |
| self.max_pair = max_pair | |
| def __len__(self): | |
| return min(self.latents_shape[0], self.max_pair) | |
| def __getitem__(self, idx): | |
| """ | |
| Outputs: | |
| - prompts: List of Strings | |
| - latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image. | |
| """ | |
| latents = retrieve_row_from_lmdb( | |
| self.env, | |
| "latents", np.float16, idx, shape=self.latents_shape[1:] | |
| ) | |
| if len(latents.shape) == 4: | |
| latents = latents[None, ...] | |
| prompts = retrieve_row_from_lmdb( | |
| self.env, | |
| "prompts", str, idx | |
| ) | |
| return { | |
| "prompts": prompts, | |
| "ode_latent": torch.tensor(latents, dtype=torch.float32) | |
| } | |
| class ShardingLMDBDataset(Dataset): | |
| def __init__(self, data_path: str, max_pair: int = int(1e8)): | |
| self.envs = [] | |
| self.index = [] | |
| for fname in sorted(os.listdir(data_path)): | |
| path = os.path.join(data_path, fname) | |
| env = lmdb.open(path, | |
| readonly=True, | |
| lock=False, | |
| readahead=False, | |
| meminit=False) | |
| self.envs.append(env) | |
| self.latents_shape = [None] * len(self.envs) | |
| for shard_id, env in enumerate(self.envs): | |
| self.latents_shape[shard_id] = get_array_shape_from_lmdb(env, 'latents') | |
| for local_i in range(self.latents_shape[shard_id][0]): | |
| self.index.append((shard_id, local_i)) | |
| # print("shard_id ", shard_id, " local_i ", local_i) | |
| self.max_pair = max_pair | |
| def __len__(self): | |
| return len(self.index) | |
| def __getitem__(self, idx): | |
| """ | |
| Outputs: | |
| - prompts: List of Strings | |
| - latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image. | |
| """ | |
| shard_id, local_idx = self.index[idx] | |
| latents = retrieve_row_from_lmdb( | |
| self.envs[shard_id], | |
| "latents", np.float16, local_idx, | |
| shape=self.latents_shape[shard_id][1:] | |
| ) | |
| if len(latents.shape) == 4: | |
| latents = latents[None, ...] | |
| prompts = retrieve_row_from_lmdb( | |
| self.envs[shard_id], | |
| "prompts", str, local_idx | |
| ) | |
| return { | |
| "prompts": prompts, | |
| "ode_latent": torch.tensor(latents, dtype=torch.float32) | |
| } | |
| class TextImagePairDataset(Dataset): | |
| def __init__( | |
| self, | |
| data_dir, | |
| transform=None, | |
| eval_first_n=-1, | |
| pad_to_multiple_of=None | |
| ): | |
| """ | |
| Args: | |
| data_dir (str): Path to the directory containing: | |
| - target_crop_info_*.json (metadata file) | |
| - */ (subdirectory containing images with matching aspect ratio) | |
| transform (callable, optional): Optional transform to be applied on the image | |
| """ | |
| self.transform = transform | |
| data_dir = Path(data_dir) | |
| # Find the metadata JSON file | |
| metadata_files = list(data_dir.glob('target_crop_info_*.json')) | |
| if not metadata_files: | |
| raise FileNotFoundError(f"No metadata file found in {data_dir}") | |
| if len(metadata_files) > 1: | |
| raise ValueError(f"Multiple metadata files found in {data_dir}") | |
| metadata_path = metadata_files[0] | |
| # Extract aspect ratio from metadata filename (e.g. target_crop_info_26-15.json -> 26-15) | |
| aspect_ratio = metadata_path.stem.split('_')[-1] | |
| # Use aspect ratio subfolder for images | |
| self.image_dir = data_dir / aspect_ratio | |
| if not self.image_dir.exists(): | |
| raise FileNotFoundError(f"Image directory not found: {self.image_dir}") | |
| # Load metadata | |
| with open(metadata_path, 'r') as f: | |
| self.metadata = json.load(f) | |
| eval_first_n = eval_first_n if eval_first_n != -1 else len(self.metadata) | |
| self.metadata = self.metadata[:eval_first_n] | |
| # Verify all images exist | |
| for item in self.metadata: | |
| image_path = self.image_dir / item['file_name'] | |
| if not image_path.exists(): | |
| raise FileNotFoundError(f"Image not found: {image_path}") | |
| self.dummy_prompt = "DUMMY PROMPT" | |
| self.pre_pad_len = len(self.metadata) | |
| if pad_to_multiple_of is not None and len(self.metadata) % pad_to_multiple_of != 0: | |
| # Duplicate the last entry | |
| self.metadata += [self.metadata[-1]] * ( | |
| pad_to_multiple_of - len(self.metadata) % pad_to_multiple_of | |
| ) | |
| def __len__(self): | |
| return len(self.metadata) | |
| def __getitem__(self, idx): | |
| """ | |
| Returns: | |
| dict: A dictionary containing: | |
| - image: PIL Image | |
| - caption: str | |
| - target_bbox: list of int [x1, y1, x2, y2] | |
| - target_ratio: str | |
| - type: str | |
| - origin_size: tuple of int (width, height) | |
| """ | |
| item = self.metadata[idx] | |
| # Load image | |
| image_path = self.image_dir / item['file_name'] | |
| image = Image.open(image_path).convert('RGB') | |
| # Apply transform if specified | |
| if self.transform: | |
| image = self.transform(image) | |
| return { | |
| 'image': image, | |
| 'prompts': item['caption'], | |
| 'target_bbox': item['target_crop']['target_bbox'], | |
| 'target_ratio': item['target_crop']['target_ratio'], | |
| 'type': item['type'], | |
| 'origin_size': (item['origin_width'], item['origin_height']), | |
| 'idx': idx | |
| } | |
| def cycle(dl): | |
| while True: | |
| for data in dl: | |
| yield data | |