hlky HF staff commited on
Commit
9d666aa
·
verified ·
1 Parent(s): 39c3dbe

new handler

Browse files
Files changed (1) hide show
  1. handler.py +46 -59
handler.py CHANGED
@@ -1,85 +1,72 @@
1
- from typing import Dict, List, Any
 
 
2
  import torch
3
- from base64 import b64decode
4
  from diffusers import AutoencoderKL
5
  from diffusers.image_processor import VaeImageProcessor
6
 
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
  self.device = "cuda"
10
  self.dtype = torch.float16
11
- self.vae = AutoencoderKL.from_pretrained(path, torch_dtype=self.dtype).to(self.device, self.dtype).eval()
12
 
13
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
14
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
15
 
16
  @torch.no_grad()
17
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
18
  """
19
  Args:
20
  data (:obj:):
21
  includes the input data and the parameters for the inference.
22
  """
23
- tensor = data["inputs"]
24
- tensor = b64decode(tensor.encode("utf-8"))
25
- parameters = data.get("parameters", {})
26
- if "shape" not in parameters:
27
- raise ValueError("Expected `shape` in parameters.")
28
- if "dtype" not in parameters:
29
- raise ValueError("Expected `dtype` in parameters.")
30
-
31
- DTYPE_MAP = {
32
- "float16": torch.float16,
33
- "float32": torch.float32,
34
- "bfloat16": torch.bfloat16,
35
- }
36
 
37
- shape = parameters.get("shape")
38
- dtype = DTYPE_MAP.get(parameters.get("dtype"))
39
- tensor = torch.frombuffer(bytearray(tensor), dtype=dtype).reshape(shape)
40
 
41
- needs_upcasting = (
42
- self.vae.dtype == torch.float16 and self.vae.config.force_upcast
43
- )
44
- if needs_upcasting:
45
- self.vae = self.vae.to(torch.float32)
46
- tensor = tensor.to(self.device, torch.float32)
47
- else:
48
- tensor = tensor.to(self.device, self.dtype)
49
-
50
- # unscale/denormalize the latents
51
- # denormalize with the mean and std if available and not None
52
- has_latents_mean = (
53
- hasattr(self.vae.config, "latents_mean")
54
- and self.vae.config.latents_mean is not None
55
- )
56
- has_latents_std = (
57
- hasattr(self.vae.config, "latents_std")
58
- and self.vae.config.latents_std is not None
59
- )
60
- if has_latents_mean and has_latents_std:
61
- latents_mean = (
62
- torch.tensor(self.vae.config.latents_mean)
63
- .view(1, 4, 1, 1)
64
- .to(tensor.device, tensor.dtype)
65
- )
66
- latents_std = (
67
- torch.tensor(self.vae.config.latents_std)
68
- .view(1, 4, 1, 1)
69
- .to(tensor.device, tensor.dtype)
70
  )
71
- tensor = (
72
- tensor * latents_std / self.vae.config.scaling_factor + latents_mean
 
73
  )
74
- else:
75
- tensor = tensor / self.vae.config.scaling_factor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  with torch.no_grad():
78
- image = self.vae.decode(tensor, return_dict=False)[0]
79
-
80
- if needs_upcasting:
81
- self.vae.to(dtype=torch.float16)
82
 
83
- image = self.image_processor.postprocess(image, output_type="pil")
 
 
 
 
 
84
 
85
- return image[0]
 
1
+ from typing import cast, Union
2
+
3
+ import PIL.Image
4
  import torch
5
+
6
  from diffusers import AutoencoderKL
7
  from diffusers.image_processor import VaeImageProcessor
8
 
9
+
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
  self.device = "cuda"
13
  self.dtype = torch.float16
14
+ self.vae = cast(AutoencoderKL, AutoencoderKL.from_pretrained(path, torch_dtype=self.dtype).to(self.device, self.dtype).eval())
15
 
16
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
17
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
18
 
19
  @torch.no_grad()
20
+ def __call__(self, data) -> Union[torch.Tensor, PIL.Image.Image]:
21
  """
22
  Args:
23
  data (:obj:):
24
  includes the input data and the parameters for the inference.
25
  """
26
+ tensor = cast(torch.Tensor, data["inputs"])
27
+ parameters = cast(dict, data.get("parameters", {}))
28
+ do_scaling = cast(bool, parameters.get("do_scaling", True))
29
+ output_type = cast(str, parameters.get("output_type", "pil"))
30
+ partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
31
+ if partial_postprocess and output_type != "pt":
32
+ output_type = "pt"
 
 
 
 
 
 
33
 
34
+ tensor = tensor.to(self.device, self.dtype)
 
 
35
 
36
+ if do_scaling:
37
+ has_latents_mean = (
38
+ hasattr(self.vae.config, "latents_mean")
39
+ and self.vae.config.latents_mean is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
41
+ has_latents_std = (
42
+ hasattr(self.vae.config, "latents_std")
43
+ and self.vae.config.latents_std is not None
44
  )
45
+ if has_latents_mean and has_latents_std:
46
+ latents_mean = (
47
+ torch.tensor(self.vae.config.latents_mean)
48
+ .view(1, 4, 1, 1)
49
+ .to(tensor.device, tensor.dtype)
50
+ )
51
+ latents_std = (
52
+ torch.tensor(self.vae.config.latents_std)
53
+ .view(1, 4, 1, 1)
54
+ .to(tensor.device, tensor.dtype)
55
+ )
56
+ tensor = (
57
+ tensor * latents_std / self.vae.config.scaling_factor + latents_mean
58
+ )
59
+ else:
60
+ tensor = tensor / self.vae.config.scaling_factor
61
 
62
  with torch.no_grad():
63
+ image = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])
 
 
 
64
 
65
+ if partial_postprocess:
66
+ image = (image * 0.5 + 0.5).clamp(0, 1)
67
+ image = image.permute(0, 2, 3, 1).contiguous().float()
68
+ image = (image * 255).round().to(torch.uint8)
69
+ elif output_type == "pil":
70
+ image = cast(PIL.Image.Image, self.image_processor.postprocess(image, output_type="pil")[0])
71
 
72
+ return image