Pierre Chapuis commited on
Commit
badbac0
1 Parent(s): 08db796

simplify enhancer code

Browse files
Files changed (1) hide show
  1. src/enhancer.py +6 -69
src/enhancer.py CHANGED
@@ -4,7 +4,6 @@ from typing import Any
4
 
5
  import torch
6
  from PIL import Image
7
- from refiners.foundationals.clip.concepts import ConceptExtender
8
  from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
9
  MultiUpscaler,
10
  UpscalerCheckpoints,
@@ -15,7 +14,7 @@ from esrgan_model import UpscalerESRGAN
15
 
16
  @dataclass(kw_only=True)
17
  class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
18
- esrgan: Path | None = None
19
 
20
 
21
  class ESRGANUpscaler(MultiUpscaler):
@@ -26,7 +25,8 @@ class ESRGANUpscaler(MultiUpscaler):
26
  dtype: torch.dtype,
27
  ) -> None:
28
  super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
29
- self.esrgan = self.load_esrgan(checkpoints.esrgan)
 
30
 
31
  def to(self, device: torch.device, dtype: torch.dtype):
32
  self.esrgan.to(device=device, dtype=dtype)
@@ -34,69 +34,6 @@ class ESRGANUpscaler(MultiUpscaler):
34
  self.device = device
35
  self.dtype = dtype
36
 
37
- def load_esrgan(self, path: Path | None) -> UpscalerESRGAN | None:
38
- if path is None:
39
- return None
40
- return UpscalerESRGAN(path, device=self.device, dtype=self.dtype)
41
-
42
- def load_negative_embedding(self, path: Path | None, key: str | None) -> str:
43
- if path is None:
44
- return ""
45
-
46
- embeddings: torch.Tensor | dict[str, Any] = torch.load( # type: ignore
47
- path, weights_only=True, map_location=self.device
48
- )
49
-
50
- if isinstance(embeddings, dict):
51
- assert (
52
- key is not None
53
- ), "Key must be provided to access the negative embedding."
54
- key_sequence = key.split(".")
55
- for key in key_sequence:
56
- assert (
57
- key in embeddings
58
- ), f"Key {key} not found in the negative embedding dictionary. Available keys: {list(embeddings.keys())}"
59
- embeddings = embeddings[key]
60
-
61
- assert isinstance(
62
- embeddings, torch.Tensor
63
- ), f"The negative embedding must be a tensor, found {type(embeddings)}."
64
- assert (
65
- embeddings.ndim == 2
66
- ), f"The negative embedding must be a 2D tensor, found {embeddings.ndim}D tensor."
67
-
68
- extender = ConceptExtender(self.sd.clip_text_encoder)
69
- negative_embedding_token = ", "
70
- for i, embedding in enumerate(embeddings):
71
- embedding = embedding.to(device=self.device, dtype=self.dtype)
72
- extender.add_concept(token=f"<{i}>", embedding=embedding)
73
- negative_embedding_token += f"<{i}> "
74
- extender.inject()
75
-
76
- return negative_embedding_token
77
-
78
- def pre_upscale(
79
- self,
80
- image: Image.Image,
81
- upscale_factor: float,
82
- use_esrgan: bool = True,
83
- use_esrgan_tiling: bool = True,
84
- **_: Any,
85
- ) -> Image.Image:
86
- if self.esrgan is None or not use_esrgan:
87
- return super().pre_upscale(image=image, upscale_factor=upscale_factor)
88
-
89
- width, height = image.size
90
-
91
- if use_esrgan_tiling:
92
- image = self.esrgan.upscale_with_tiling(image)
93
- else:
94
- image = self.esrgan.upscale_without_tiling(image)
95
-
96
- return image.resize(
97
- size=(
98
- int(width * upscale_factor),
99
- int(height * upscale_factor),
100
- ),
101
- resample=Image.LANCZOS,
102
- )
 
4
 
5
  import torch
6
  from PIL import Image
 
7
  from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
8
  MultiUpscaler,
9
  UpscalerCheckpoints,
 
14
 
15
  @dataclass(kw_only=True)
16
  class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
17
+ esrgan: Path
18
 
19
 
20
  class ESRGANUpscaler(MultiUpscaler):
 
25
  dtype: torch.dtype,
26
  ) -> None:
27
  super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
28
+ self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype)
29
+ self.esrgan.to(device=device, dtype=dtype)
30
 
31
  def to(self, device: torch.device, dtype: torch.dtype):
32
  self.esrgan.to(device=device, dtype=dtype)
 
34
  self.device = device
35
  self.dtype = dtype
36
 
37
+ def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image:
38
+ image = self.esrgan.upscale_with_tiling(image)
39
+ return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4)