TBurdairon commited on
Commit
8d81eef
·
verified ·
1 Parent(s): a4196d6

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. config.json +0 -1
  2. enhancer.py +55 -1
  3. esrgan_model.py +305 -1
  4. inference.py +23 -19
  5. requirements.txt +0 -1
config.json CHANGED
@@ -1,4 +1,3 @@
1
-
2
  {
3
  "pipeline_tag": "image-to-image",
4
  "model_type": "esrgan",
 
 
1
  {
2
  "pipeline_tag": "image-to-image",
3
  "model_type": "esrgan",
enhancer.py CHANGED
@@ -1 +1,55 @@
1
- <copied from uploaded enhancer.py>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
8
+ MultiUpscaler,
9
+ UpscalerCheckpoints,
10
+ )
11
+
12
+ from esrgan_model import UpscalerESRGAN
13
+
14
+
15
+ @dataclass(kw_only=True)
16
+ class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
17
+ """Extends the SD-1.5 MultiUpscaler checkpoints to hold an extra ESRGAN file."""
18
+ esrgan: Path
19
+
20
+
21
+ class ESRGANUpscaler(MultiUpscaler):
22
+ """
23
+ Multi-stage image enhancer that:
24
+ 1. Runs ESRGAN 4× super-resolution first (tiling to avoid VRAM overflow),
25
+ 2. Passes the up-scaled image to Stable-Diffusion 1.5 MultiUpscaler for refinement.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ checkpoints: ESRGANUpscalerCheckpoints,
31
+ device: torch.device,
32
+ dtype: torch.dtype,
33
+ ) -> None:
34
+ super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
35
+ self.esrgan = UpscalerESRGAN(
36
+ checkpoints.esrgan, device=self.device, dtype=self.dtype
37
+ )
38
+
39
+ # ---- automatically called by HF when the model is moved to another device ----
40
+ def to(self, device: torch.device, dtype: torch.dtype):
41
+ self.esrgan.to(device=device, dtype=dtype)
42
+ self.sd = self.sd.to(device=device, dtype=dtype)
43
+ self.device = device
44
+ self.dtype = dtype
45
+
46
+ # ---- hook that runs *before* SD-1.5 up-scaling ----
47
+ def pre_upscale(
48
+ self,
49
+ image: Image.Image,
50
+ upscale_factor: float,
51
+ **_: Any,
52
+ ) -> Image.Image:
53
+ # 4× ESRGAN first, then the SD-1.5 stage handles the residual upscale
54
+ image = self.esrgan.upscale_with_tiling(image)
55
+ return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4)
esrgan_model.py CHANGED
@@ -1 +1,305 @@
1
- <copied from uploaded esrgan_model.py>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified from https://github.com/philz1337x/clarity-upscaler
3
+ which is a copy of https://github.com/AUTOMATIC1111/stable-diffusion-webui
4
+ which is a copy of https://github.com/victorca25/iNNfer
5
+ which is a copy of https://github.com/xinntao/ESRGAN
6
+ """
7
+
8
+ import math
9
+ from pathlib import Path
10
+ from typing import NamedTuple
11
+
12
+ import numpy as np
13
+ import numpy.typing as npt
14
+ import torch
15
+ import torch.nn as nn
16
+ from PIL import Image
17
+
18
+
19
+ def conv_block(in_nc: int, out_nc: int) -> nn.Sequential:
20
+ return nn.Sequential(
21
+ nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1),
22
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
23
+ )
24
+
25
+
26
+ class ResidualDenseBlock_5C(nn.Module):
27
+ """
28
+ Residual Dense Block
29
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
30
+ Modified options that can be used:
31
+ - "Partial Convolution based Padding" arXiv:1811.11718
32
+ - "Spectral normalization" arXiv:1802.05957
33
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
34
+ {Rakotonirina} and A. {Rasoanaivo}
35
+ """
36
+
37
+ def __init__(self, nf: int = 64, gc: int = 32) -> None:
38
+ super().__init__() # type: ignore[reportUnknownMemberType]
39
+
40
+ self.conv1 = conv_block(nf, gc)
41
+ self.conv2 = conv_block(nf + gc, gc)
42
+ self.conv3 = conv_block(nf + 2 * gc, gc)
43
+ self.conv4 = conv_block(nf + 3 * gc, gc)
44
+ # Wrapped in Sequential because of key in state dict.
45
+ self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ x1 = self.conv1(x)
49
+ x2 = self.conv2(torch.cat((x, x1), 1))
50
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
51
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
52
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
53
+ return x5 * 0.2 + x
54
+
55
+
56
+ class RRDB(nn.Module):
57
+ """
58
+ Residual in Residual Dense Block
59
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
60
+ """
61
+
62
+ def __init__(self, nf: int) -> None:
63
+ super().__init__() # type: ignore[reportUnknownMemberType]
64
+ self.RDB1 = ResidualDenseBlock_5C(nf)
65
+ self.RDB2 = ResidualDenseBlock_5C(nf)
66
+ self.RDB3 = ResidualDenseBlock_5C(nf)
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ out = self.RDB1(x)
70
+ out = self.RDB2(out)
71
+ out = self.RDB3(out)
72
+ return out * 0.2 + x
73
+
74
+
75
+ class Upsample2x(nn.Module):
76
+ """Upsample 2x."""
77
+
78
+ def __init__(self) -> None:
79
+ super().__init__() # type: ignore[reportUnknownMemberType]
80
+
81
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
82
+ return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore
83
+
84
+
85
+ class ShortcutBlock(nn.Module):
86
+ """Elementwise sum the output of a submodule to its input"""
87
+
88
+ def __init__(self, submodule: nn.Module) -> None:
89
+ super().__init__() # type: ignore[reportUnknownMemberType]
90
+ self.sub = submodule
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ return x + self.sub(x)
94
+
95
+
96
+ class RRDBNet(nn.Module):
97
+ def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None:
98
+ super().__init__() # type: ignore[reportUnknownMemberType]
99
+ assert in_nc % 4 != 0 # in_nc is 3
100
+
101
+ self.model = nn.Sequential(
102
+ nn.Conv2d(in_nc, nf, kernel_size=3, padding=1),
103
+ ShortcutBlock(
104
+ nn.Sequential(
105
+ *(RRDB(nf) for _ in range(nb)),
106
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
107
+ )
108
+ ),
109
+ Upsample2x(),
110
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
111
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
112
+ Upsample2x(),
113
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
114
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
115
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
116
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
117
+ nn.Conv2d(nf, out_nc, kernel_size=3, padding=1),
118
+ )
119
+
120
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
121
+ return self.model(x)
122
+
123
+
124
+ def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]:
125
+ # this code is adapted from https://github.com/victorca25/iNNfer
126
+ scale2x = 0
127
+ scalemin = 6
128
+ n_uplayer = 0
129
+ out_nc = 0
130
+ nb = 0
131
+
132
+ for block in list(state_dict):
133
+ parts = block.split(".")
134
+ n_parts = len(parts)
135
+ if n_parts == 5 and parts[2] == "sub":
136
+ nb = int(parts[3])
137
+ elif n_parts == 3:
138
+ part_num = int(parts[1])
139
+ if part_num > scalemin and parts[0] == "model" and parts[2] == "weight":
140
+ scale2x += 1
141
+ if part_num > n_uplayer:
142
+ n_uplayer = part_num
143
+ out_nc = state_dict[block].shape[0]
144
+ assert "conv1x1" not in block # no ESRGANPlus
145
+
146
+ nf = state_dict["model.0.weight"].shape[0]
147
+ in_nc = state_dict["model.0.weight"].shape[1]
148
+ scale = 2**scale2x
149
+
150
+ assert out_nc > 0
151
+ assert nb > 0
152
+
153
+ return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4
154
+
155
+
156
+ Tile = tuple[int, int, Image.Image]
157
+ Tiles = list[tuple[int, int, list[Tile]]]
158
+
159
+
160
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64
161
+ class Grid(NamedTuple):
162
+ tiles: Tiles
163
+ tile_w: int
164
+ tile_h: int
165
+ image_w: int
166
+ image_h: int
167
+ overlap: int
168
+
169
+
170
+ # adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67
171
+ def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
172
+ w = image.width
173
+ h = image.height
174
+
175
+ non_overlap_width = tile_w - overlap
176
+ non_overlap_height = tile_h - overlap
177
+
178
+ cols = max(1, math.ceil((w - overlap) / non_overlap_width))
179
+ rows = max(1, math.ceil((h - overlap) / non_overlap_height))
180
+
181
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
182
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
183
+
184
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
185
+ for row in range(rows):
186
+ row_images: list[Tile] = []
187
+ y1 = max(min(int(row * dy), h - tile_h), 0)
188
+ y2 = min(y1 + tile_h, h)
189
+ for col in range(cols):
190
+ x1 = max(min(int(col * dx), w - tile_w), 0)
191
+ x2 = min(x1 + tile_w, w)
192
+ tile = image.crop((x1, y1, x2, y2))
193
+ row_images.append((x1, tile_w, tile))
194
+ grid.tiles.append((y1, tile_h, row_images))
195
+
196
+ return grid
197
+
198
+
199
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104
200
+ def combine_grid(grid: Grid):
201
+ def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image:
202
+ r = r * 255 / grid.overlap
203
+ return Image.fromarray(r.astype(np.uint8), "L")
204
+
205
+ mask_w = make_mask_image(
206
+ np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)
207
+ )
208
+ mask_h = make_mask_image(
209
+ np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)
210
+ )
211
+
212
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
213
+ for y, h, row in grid.tiles:
214
+ combined_row = Image.new("RGB", (grid.image_w, h))
215
+ for x, w, tile in row:
216
+ if x == 0:
217
+ combined_row.paste(tile, (0, 0))
218
+ continue
219
+
220
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
221
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
222
+
223
+ if y == 0:
224
+ combined_image.paste(combined_row, (0, 0))
225
+ continue
226
+
227
+ combined_image.paste(
228
+ combined_row.crop((0, 0, combined_row.width, grid.overlap)),
229
+ (0, y),
230
+ mask=mask_h,
231
+ )
232
+ combined_image.paste(
233
+ combined_row.crop((0, grid.overlap, combined_row.width, h)),
234
+ (0, y + grid.overlap),
235
+ )
236
+
237
+ return combined_image
238
+
239
+
240
+ class UpscalerESRGAN:
241
+ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
242
+ self.model_path = model_path
243
+ self.device = device
244
+ self.model = self.load_model(model_path)
245
+ self.to(device, dtype)
246
+
247
+ def __call__(self, img: Image.Image) -> Image.Image:
248
+ return self.upscale_without_tiling(img)
249
+
250
+ def to(self, device: torch.device, dtype: torch.dtype):
251
+ self.device = device
252
+ self.dtype = dtype
253
+ self.model.to(device=device, dtype=dtype)
254
+
255
+ def load_model(self, path: Path) -> RRDBNet:
256
+ filename = path
257
+ state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore
258
+ in_nc, out_nc, nf, nb, upscale = infer_params(state_dict)
259
+ assert upscale == 4, "Only 4x upscaling is supported"
260
+ model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb)
261
+ model.load_state_dict(state_dict)
262
+ model.eval()
263
+
264
+ return model
265
+
266
+ def upscale_without_tiling(self, img: Image.Image) -> Image.Image:
267
+ img_np = np.array(img)
268
+ img_np = img_np[:, :, ::-1]
269
+ img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255
270
+ img_t = torch.from_numpy(img_np).float() # type: ignore
271
+ img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype)
272
+ with torch.no_grad():
273
+ output = self.model(img_t)
274
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
275
+ output = 255.0 * np.moveaxis(output, 0, 2)
276
+ output = output.astype(np.uint8)
277
+ output = output[:, :, ::-1]
278
+ return Image.fromarray(output, "RGB")
279
+
280
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208
281
+ def upscale_with_tiling(self, img: Image.Image) -> Image.Image:
282
+ img = img.convert("RGB")
283
+ grid = split_grid(img)
284
+ newtiles: Tiles = []
285
+ scale_factor: int = 1
286
+
287
+ for y, h, row in grid.tiles:
288
+ newrow: list[Tile] = []
289
+ for tiledata in row:
290
+ x, w, tile = tiledata
291
+ output = self.upscale_without_tiling(tile)
292
+ scale_factor = output.width // tile.width
293
+ newrow.append((x * scale_factor, w * scale_factor, output))
294
+ newtiles.append((y * scale_factor, h * scale_factor, newrow))
295
+
296
+ newgrid = Grid(
297
+ newtiles,
298
+ grid.tile_w * scale_factor,
299
+ grid.tile_h * scale_factor,
300
+ grid.image_w * scale_factor,
301
+ grid.image_h * scale_factor,
302
+ grid.overlap * scale_factor,
303
+ )
304
+ output = combine_grid(newgrid)
305
+ return output
inference.py CHANGED
@@ -1,11 +1,11 @@
1
-
2
  from pathlib import Path
3
  import torch
4
  from PIL import Image
5
- import base64
6
- import io
7
  from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints
8
 
 
9
  checkpoints = ESRGANUpscalerCheckpoints(
10
  esrgan=Path("checkpoints/4x-UltraSharp.pth")
11
  )
@@ -13,30 +13,34 @@ checkpoints = ESRGANUpscalerCheckpoints(
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
15
 
16
- enhancer = ESRGANUpscaler(
17
- checkpoints=checkpoints,
18
- device=device,
19
- dtype=dtype
20
- )
21
 
 
22
  def inference(inputs: dict) -> dict:
 
 
 
 
 
 
23
  if "image" not in inputs:
24
  return {"error": "No image provided"}
25
 
26
- image_data = inputs["image"]
27
- if image_data.startswith("data:image"):
28
- image_data = image_data.split(",")[1]
29
- image_bytes = base64.b64decode(image_data)
30
- input_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
31
 
32
- enhanced_image = enhancer.upscale(input_image)
 
33
 
34
  buf = io.BytesIO()
35
- enhanced_image.save(buf, format="PNG")
36
- b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
37
 
38
  return {
39
- "enhanced_image": b64,
40
- "original_size": input_image.size,
41
- "enhanced_size": enhanced_image.size
42
  }
 
 
1
  from pathlib import Path
2
  import torch
3
  from PIL import Image
4
+ import base64, io
5
+
6
  from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints
7
 
8
+ # -------- initialise model once at cold-start --------
9
  checkpoints = ESRGANUpscalerCheckpoints(
10
  esrgan=Path("checkpoints/4x-UltraSharp.pth")
11
  )
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
15
 
16
+ enhancer = ESRGANUpscaler(checkpoints=checkpoints, device=device, dtype=dtype)
 
 
 
 
17
 
18
+ # -------- entry-point for Hugging Face Hosted Inference API --------
19
  def inference(inputs: dict) -> dict:
20
+ """
21
+ Expected payload:
22
+ {"image": "<BASE64-STRING>"}
23
+ Returns:
24
+ { "enhanced_image": "<BASE64-PNG>", "original_size": [w,h], "enhanced_size": [w,h] }
25
+ """
26
  if "image" not in inputs:
27
  return {"error": "No image provided"}
28
 
29
+ # decode base64
30
+ data = inputs["image"]
31
+ if data.startswith("data:image"):
32
+ data = data.split(",")[1]
33
+ img = Image.open(io.BytesIO(base64.b64decode(data))).convert("RGB")
34
 
35
+ # run ESRGAN ➜ SD-1.5 upscale
36
+ result = enhancer.upscale(img)
37
 
38
  buf = io.BytesIO()
39
+ result.save(buf, format="PNG")
40
+ result_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
41
 
42
  return {
43
+ "enhanced_image": result_b64,
44
+ "original_size": img.size,
45
+ "enhanced_size": result.size,
46
  }
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
-
2
  git+https://github.com/finegrain-ai/refiners@cfe8b66ba4f8a906583850ac25e9e89cb83a44b9
3
  numpy<2.0.0
4
  pillow>=10.4.0
 
 
1
  git+https://github.com/finegrain-ai/refiners@cfe8b66ba4f8a906583850ac25e9e89cb83a44b9
2
  numpy<2.0.0
3
  pillow>=10.4.0