jadechoghari commited on
Commit
7b5beb5
1 Parent(s): 7c8fd9c

Create controlnet_utils.py

Browse files
Files changed (1) hide show
  1. controlnet_utils.py +99 -0
controlnet_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ # from PyQt5.QtCore import QLibraryInfo
3
+ import cv2
4
+ import os
5
+ import torch
6
+ import torchvision.transforms as T
7
+ # os.environ["QT_QPA_PLATFORM_PLUGIN_PATH"] = QLibraryInfo.location(
8
+ # QLibraryInfo.PluginsPath
9
+ # )
10
+ # os.environ["QT_QPA_PLATFORM_PLUGIN_PATH"] = "/home/lixirui/anaconda3/envs/dfwebui/lib/python3.9/site-packages/PyQt5/Qt5/plugins"
11
+
12
+ from controlnet_aux.processor import Processor
13
+ import transformers
14
+ import numpy as np
15
+ from diffusers.utils import load_image
16
+
17
+ CONTROLNET_DICT = {
18
+ "tile": "lllyasviel/control_v11f1e_sd15_tile",
19
+ "ip2p": "lllyasviel/control_v11e_sd15_ip2p",
20
+ "openpose": "lllyasviel/control_v11p_sd15_openpose",
21
+ "softedge": "lllyasviel/control_v11p_sd15_softedge",
22
+ "depth": "lllyasviel/control_v11f1p_sd15_depth",
23
+ "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime",
24
+ "canny": "lllyasviel/control_v11p_sd15_canny"
25
+ }
26
+
27
+ processor_cache = dict()
28
+
29
+ def process(image, processor_id):
30
+ process_ls = []
31
+ H, W = image.shape[2:]
32
+ if processor_id in processor_cache:
33
+ processor = processor_cache[processor_id]
34
+ else:
35
+ processor = Processor(processor_id, {"output_type": "numpy"})
36
+ processor_cache[processor_id] = processor
37
+ for img in image:
38
+ img = img.clone().cpu().permute(1,2,0) * 255
39
+ processed_image = processor(img)
40
+ processed_image = cv2.resize(processed_image, (W, H), interpolation=cv2.INTER_LINEAR)
41
+ processed_image = torch.tensor(processed_image).to(image).permute(2,0,1) / 255
42
+ process_ls.append(processed_image)
43
+ processed_image = torch.stack(process_ls)
44
+ return processed_image
45
+
46
+ def tile_preprocess(image, resample_rate = 1.0, **kwargs):
47
+ cond_image = F.interpolate(image, scale_factor=resample_rate, mode="bilinear")
48
+ cond_image = F.interpolate(cond_image, scale_factor=1 / resample_rate)
49
+ return cond_image
50
+
51
+ def ip2p_prepreocess(image, **kwargs):
52
+ return image
53
+
54
+ def openpose_prepreocess(image, **kwargs):
55
+ processor_id = 'openpose'
56
+ return process(image, processor_id)
57
+
58
+ def softedge_prepreocess(image, proc = "pidsafe", **kwargs):
59
+ processor_id = f'softedge_{proc}'
60
+ return process(image, processor_id)
61
+
62
+ def depth_prepreocess(image, **kwargs):
63
+ image_ls = []
64
+ for img in image:
65
+ image_ls.append(T.ToPILImage()(img))
66
+ depth_estimator = transformers.pipeline('depth-estimation')
67
+ ret = depth_estimator(image_ls)
68
+ depth_ls = []
69
+ for r in ret:
70
+ depth_ls.append(T.ToTensor()(r['depth']))
71
+ depth = torch.cat(depth_ls)
72
+ depth = torch.stack([depth, depth, depth], axis=1)
73
+ return depth
74
+
75
+ def lineart_anime_prepreocess(image, proc = "anime",**kwargs):
76
+ processor_id = f'lineart_{proc}'
77
+ return process(image, processor_id)
78
+
79
+ def canny_preprocess(image, **kwargs):
80
+ processor_id = f'canny'
81
+ return process(image, processor_id)
82
+
83
+ PREPROCESS_DICT = {
84
+ "tile": tile_preprocess,
85
+ "ip2p": ip2p_prepreocess,
86
+ "openpose": openpose_prepreocess,
87
+ "softedge": softedge_prepreocess,
88
+ "depth": depth_prepreocess,
89
+ "lineart_anime": lineart_anime_prepreocess,
90
+ "canny": canny_preprocess
91
+ }
92
+
93
+ def control_preprocess(images, control_type, **kwargs):
94
+ return PREPROCESS_DICT[control_type](images, **kwargs)
95
+
96
+ def empty_cache():
97
+ global processor_cache
98
+ processor_cache = dict()
99
+ torch.cuda.empty_cache()