twodgirl commited on
Commit
f9c53a5
·
verified ·
1 Parent(s): 183b8fd

Upload folder using huggingface_hub

Browse files
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.33.1",
4
+ "dynamic_time_shift": true,
5
+ "num_train_timesteps": 1000
6
+ }
scheduler/scheduling_flow_match_euler_discrete.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ """
40
+
41
+ prev_sample: torch.FloatTensor
42
+
43
+
44
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
45
+ """
46
+ Euler scheduler.
47
+
48
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49
+ methods the library implements for all schedulers such as loading and saving.
50
+
51
+ Args:
52
+ num_train_timesteps (`int`, defaults to 1000):
53
+ The number of diffusion steps to train the model.
54
+ timestep_spacing (`str`, defaults to `"linspace"`):
55
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
57
+ shift (`float`, defaults to 1.0):
58
+ The shift value for the timestep schedule.
59
+ """
60
+
61
+ _compatibles = []
62
+ order = 1
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ num_train_timesteps: int = 1000,
68
+ dynamic_time_shift: bool = False
69
+ ):
70
+ timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
71
+
72
+ self.timesteps = timesteps
73
+
74
+ self._step_index = None
75
+ self._begin_index = None
76
+
77
+ @property
78
+ def step_index(self):
79
+ """
80
+ The index counter for current timestep. It will increase 1 after each scheduler step.
81
+ """
82
+ return self._step_index
83
+
84
+ @property
85
+ def begin_index(self):
86
+ """
87
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
88
+ """
89
+ return self._begin_index
90
+
91
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
92
+ def set_begin_index(self, begin_index: int = 0):
93
+ """
94
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
95
+
96
+ Args:
97
+ begin_index (`int`):
98
+ The begin index for the scheduler.
99
+ """
100
+ self._begin_index = begin_index
101
+
102
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
103
+ if schedule_timesteps is None:
104
+ schedule_timesteps = self._timesteps
105
+
106
+ indices = (schedule_timesteps == timestep).nonzero()
107
+
108
+ # The sigma index that is taken for the **very** first `step`
109
+ # is always the second index (or the last index if there is only 1)
110
+ # This way we can ensure we don't accidentally skip a sigma in
111
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
112
+ pos = 1 if len(indices) > 1 else 0
113
+
114
+ return indices[pos].item()
115
+
116
+ # def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
117
+ # return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
118
+
119
+ def set_timesteps(
120
+ self,
121
+ num_inference_steps: int = None,
122
+ device: Union[str, torch.device] = None,
123
+ timesteps: Optional[List[float]] = None,
124
+ num_tokens: Optional[int] = None
125
+ ):
126
+ """
127
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
128
+
129
+ Args:
130
+ num_inference_steps (`int`):
131
+ The number of diffusion steps used when generating samples with a pre-trained model.
132
+ device (`str` or `torch.device`, *optional*):
133
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
134
+ """
135
+
136
+ if timesteps is None:
137
+ self.num_inference_steps = num_inference_steps
138
+ timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
139
+ if self.config.dynamic_time_shift and num_tokens is not None:
140
+ m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
141
+ timesteps = timesteps / (m - m * timesteps + timesteps)
142
+
143
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
144
+ _timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
145
+
146
+ self.timesteps = timesteps
147
+ self._timesteps = _timesteps
148
+ self._step_index = None
149
+ self._begin_index = None
150
+
151
+ def _init_step_index(self, timestep):
152
+ if self.begin_index is None:
153
+ if isinstance(timestep, torch.Tensor):
154
+ timestep = timestep.to(self.timesteps.device)
155
+ self._step_index = self.index_for_timestep(timestep)
156
+ else:
157
+ self._step_index = self._begin_index
158
+
159
+ def step(
160
+ self,
161
+ model_output: torch.FloatTensor,
162
+ timestep: Union[float, torch.FloatTensor],
163
+ sample: torch.FloatTensor,
164
+ generator: Optional[torch.Generator] = None,
165
+ return_dict: bool = True,
166
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
167
+ """
168
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
169
+ process from the learned model outputs (most often the predicted noise).
170
+
171
+ Args:
172
+ model_output (`torch.FloatTensor`):
173
+ The direct output from learned diffusion model.
174
+ timestep (`float`):
175
+ The current discrete timestep in the diffusion chain.
176
+ sample (`torch.FloatTensor`):
177
+ A current instance of a sample created by the diffusion process.
178
+ s_churn (`float`):
179
+ s_tmin (`float`):
180
+ s_tmax (`float`):
181
+ s_noise (`float`, defaults to 1.0):
182
+ Scaling factor for noise added to the sample.
183
+ generator (`torch.Generator`, *optional*):
184
+ A random number generator.
185
+ return_dict (`bool`):
186
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
187
+ tuple.
188
+
189
+ Returns:
190
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
191
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
192
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
193
+ """
194
+
195
+ if (
196
+ isinstance(timestep, int)
197
+ or isinstance(timestep, torch.IntTensor)
198
+ or isinstance(timestep, torch.LongTensor)
199
+ ):
200
+ raise ValueError(
201
+ (
202
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
203
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
204
+ " one of the `scheduler.timesteps` as a timestep."
205
+ ),
206
+ )
207
+
208
+ if self.step_index is None:
209
+ self._init_step_index(timestep)
210
+ # Upcast to avoid precision issues when computing prev_sample
211
+ sample = sample.to(torch.float32)
212
+ t = self._timesteps[self.step_index]
213
+ t_next = self._timesteps[self.step_index + 1]
214
+
215
+ prev_sample = sample + (t_next - t) * model_output
216
+
217
+ # Cast sample back to model compatible dtype
218
+ prev_sample = prev_sample.to(model_output.dtype)
219
+
220
+ # upon completion increase step index by one
221
+ self._step_index += 1
222
+
223
+ if not return_dict:
224
+ return (prev_sample,)
225
+
226
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
227
+
228
+ def __len__(self):
229
+ return self.config.num_train_timesteps