Spaces:
Paused
Paused
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| """ | |
| Utility functions for creating schedules and samplers from config. | |
| """ | |
| import torch | |
| from omegaconf import DictConfig | |
| from .samplers.base import Sampler | |
| from .samplers.euler import EulerSampler | |
| from .schedules.base import Schedule | |
| from .schedules.lerp import LinearInterpolationSchedule | |
| from .timesteps.base import SamplingTimesteps | |
| from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps | |
| def create_schedule_from_config( | |
| config: DictConfig, | |
| device: torch.device, | |
| dtype: torch.dtype = torch.float32, | |
| ) -> Schedule: | |
| """ | |
| Create a schedule from configuration. | |
| """ | |
| if config.type == "lerp": | |
| return LinearInterpolationSchedule(T=config.get("T", 1.0)) | |
| raise NotImplementedError | |
| def create_sampler_from_config( | |
| config: DictConfig, | |
| schedule: Schedule, | |
| timesteps: SamplingTimesteps, | |
| ) -> Sampler: | |
| """ | |
| Create a sampler from configuration. | |
| """ | |
| if config.type == "euler": | |
| return EulerSampler( | |
| schedule=schedule, | |
| timesteps=timesteps, | |
| prediction_type=config.prediction_type, | |
| ) | |
| raise NotImplementedError | |
| def create_sampling_timesteps_from_config( | |
| config: DictConfig, | |
| schedule: Schedule, | |
| device: torch.device, | |
| dtype: torch.dtype = torch.float32, | |
| ) -> SamplingTimesteps: | |
| if config.type == "uniform_trailing": | |
| return UniformTrailingSamplingTimesteps( | |
| T=schedule.T, | |
| steps=config.steps, | |
| shift=config.get("shift", 1.0), | |
| device=device, | |
| ) | |
| raise NotImplementedError |