Init code and weights
Browse files- model.py +464 -0
- pipeline_drum.py +1 -0
- sampling.py +34 -0
- weight/H.safetensors +3 -0
- weight/L.safetensors +3 -0
- weight/T5.safetensors +3 -0
- weight/bigG.safetensors +3 -0
- wrapper.py +304 -0
model.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class MultiheadAttention(torch.nn.Module):
|
| 6 |
+
def __init__(self, d_model, n_head, n_token = 77, dropout = 0.1):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.d_model = d_model
|
| 9 |
+
self.n_head = n_head
|
| 10 |
+
self.d_head = d_model // n_head
|
| 11 |
+
self.n_token = n_token
|
| 12 |
+
|
| 13 |
+
self.query = torch.nn.Linear(d_model, d_model)
|
| 14 |
+
self.key = torch.nn.Linear(d_model, d_model)
|
| 15 |
+
self.value = torch.nn.Linear(d_model, d_model)
|
| 16 |
+
self.proj = torch.nn.Linear(d_model, d_model)
|
| 17 |
+
|
| 18 |
+
self.div = torch.sqrt(torch.tensor(self.d_head, dtype = self.query.weight.dtype))
|
| 19 |
+
|
| 20 |
+
self.softmax = torch.nn.Softmax(dim = -1)
|
| 21 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 22 |
+
|
| 23 |
+
self._reset_parameters()
|
| 24 |
+
|
| 25 |
+
def _reset_parameters(self):
|
| 26 |
+
torch.nn.init.xavier_uniform_(self.query.weight)
|
| 27 |
+
torch.nn.init.xavier_uniform_(self.key.weight)
|
| 28 |
+
torch.nn.init.xavier_uniform_(self.value.weight)
|
| 29 |
+
torch.nn.init.xavier_uniform_(self.proj.weight)
|
| 30 |
+
|
| 31 |
+
torch.nn.init.constant_(self.query.bias, 0.)
|
| 32 |
+
torch.nn.init.constant_(self.key.bias, 0.)
|
| 33 |
+
torch.nn.init.constant_(self.value.bias, 0.)
|
| 34 |
+
torch.nn.init.constant_(self.proj.bias, 0.)
|
| 35 |
+
|
| 36 |
+
def forward(self, q, k, v, mask = None, weight = None, alpha = None):
|
| 37 |
+
b, s = q.shape[:2]
|
| 38 |
+
b2, s2 = k.shape[:2]
|
| 39 |
+
|
| 40 |
+
q = self.query(q) #b, s, f
|
| 41 |
+
k = self.key(k) #b, s, f
|
| 42 |
+
v = self.value(v) #b, s, f
|
| 43 |
+
|
| 44 |
+
q = q.view(-1, s, self.n_head, self.d_head).transpose(1, 2) #b, h, s, hf
|
| 45 |
+
k = k.view(-1, s2, self.n_head, self.d_head).transpose(1, 2) #b, h, s, hf
|
| 46 |
+
v = v.view(-1, s2, self.n_head, self.d_head).transpose(1, 2) #b, h, s, hf
|
| 47 |
+
|
| 48 |
+
score = torch.matmul(q, k.transpose(-2, -1)) / self.div #b, h, s, s
|
| 49 |
+
|
| 50 |
+
if mask is not None:
|
| 51 |
+
mask = mask.unsqueeze(1) #b, 1, s
|
| 52 |
+
if mask.dim() != score.dim():
|
| 53 |
+
mask = mask.unsqueeze(2) #b, 1, 1, s
|
| 54 |
+
score = score * mask
|
| 55 |
+
|
| 56 |
+
if weight is not None:
|
| 57 |
+
weight = weight.unsqueeze(1) #b, 1, s
|
| 58 |
+
if weight.dim() != score.dim():
|
| 59 |
+
weight = weight.unsqueeze(2) #b, 1, 1, s
|
| 60 |
+
if self.n_token == s2:
|
| 61 |
+
w = self.softmax(score) #b, h, s, s2
|
| 62 |
+
if weight is not None:
|
| 63 |
+
w = w * weight
|
| 64 |
+
w = w / (w.sum(dim = -1, keepdim = True) + 1e-12)
|
| 65 |
+
else:
|
| 66 |
+
target, ref = torch.split(score, [self.n_token, s2 - self.n_token], dim = -1)
|
| 67 |
+
target = self.softmax(target)
|
| 68 |
+
if alpha is None:
|
| 69 |
+
alpha = 0.5
|
| 70 |
+
if weight is not None:
|
| 71 |
+
ws = weight.shape[-1]
|
| 72 |
+
target_weight, ref_weight = torch.split(weight, [self.n_token, ws - self.n_token], dim = -1)
|
| 73 |
+
ref = ref.view(b2, self.n_head, s, ws - self.n_token, self.n_token)
|
| 74 |
+
ref = self.softmax(ref)
|
| 75 |
+
ref = ref * ref_weight.unsqueeze(-1)
|
| 76 |
+
ref = ref.view(b2, self.n_head, s, s2 - self.n_token)
|
| 77 |
+
ref = alpha * (ref / (ref.sum(dim = -1, keepdim = True) + 1e-12))
|
| 78 |
+
target = target * (1 - alpha) * target_weight
|
| 79 |
+
w = torch.cat([target, ref], dim = -1)
|
| 80 |
+
w = w / (w.sum(dim = -1, keepdim = True) + 1e-12)
|
| 81 |
+
w = self.dropout(w)
|
| 82 |
+
|
| 83 |
+
out = torch.matmul(w, v) #b, h, s, hf
|
| 84 |
+
out = out.transpose(1, 2).contiguous().view(b, s, self.d_model) #b, s, d
|
| 85 |
+
out = self.proj(out)
|
| 86 |
+
return out
|
| 87 |
+
|
| 88 |
+
class QuickGELU(torch.nn.Module):
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
return x * torch.sigmoid(1.702 * x)
|
| 91 |
+
|
| 92 |
+
class TransformerBlock(torch.nn.Module):
|
| 93 |
+
def __init__(self, emb_dim, n_head, ff_dim, n_token = 77, activation = "quick_gelu", dropout = 0.1):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.attn = MultiheadAttention(emb_dim, n_head, n_token = n_token, dropout = dropout)
|
| 96 |
+
if activation.lower() == "gelu" or activation is None:
|
| 97 |
+
self.act = torch.nn.GELU()
|
| 98 |
+
elif activation.lower() == "relu":
|
| 99 |
+
self.act = torch.nn.ReLU()
|
| 100 |
+
elif activation.lower() == "quick_gelu":
|
| 101 |
+
self.act = QuickGELU()
|
| 102 |
+
else:
|
| 103 |
+
self.act = activation
|
| 104 |
+
self.ff = torch.nn.Sequential(
|
| 105 |
+
torch.nn.Linear(emb_dim, ff_dim),
|
| 106 |
+
self.act,
|
| 107 |
+
torch.nn.Linear(ff_dim, emb_dim),
|
| 108 |
+
)
|
| 109 |
+
self.norm1 = torch.nn.LayerNorm(emb_dim)
|
| 110 |
+
self.norm2 = torch.nn.LayerNorm(emb_dim)
|
| 111 |
+
self.dropout1 = torch.nn.Dropout(dropout)
|
| 112 |
+
self.dropout2 = torch.nn.Dropout(dropout)
|
| 113 |
+
|
| 114 |
+
self._reset_parameters()
|
| 115 |
+
|
| 116 |
+
def _reset_parameters(self):
|
| 117 |
+
torch.nn.init.xavier_uniform_(self.ff[0].weight)
|
| 118 |
+
torch.nn.init.xavier_uniform_(self.ff[2].weight)
|
| 119 |
+
|
| 120 |
+
torch.nn.init.constant_(self.ff[0].bias, 0.)
|
| 121 |
+
torch.nn.init.constant_(self.ff[2].bias, 0.)
|
| 122 |
+
|
| 123 |
+
def forward(self, x, context = None, mask = None, weight = None, alpha = None):
|
| 124 |
+
context = context if context is not None else x
|
| 125 |
+
out = self.attn(x, context, context, mask = mask, weight = weight, alpha = alpha)
|
| 126 |
+
out = x + self.dropout1(out)
|
| 127 |
+
out = self.norm1(out)
|
| 128 |
+
|
| 129 |
+
ff_out = self.ff(out)
|
| 130 |
+
out = out + self.dropout2(ff_out)
|
| 131 |
+
out = self.norm2(out)
|
| 132 |
+
return out
|
| 133 |
+
|
| 134 |
+
class PersonalizedAdapter(torch.nn.Module):
|
| 135 |
+
def __init__(self, emb_dim, n_head, ff_dim, n_layer = 4, n_token = 77, proj = False, extra_proj = False, pos = True, cls_pos = False, cls_token = True, encode_ratio = None, activation = "quick_gelu", dropout = 0.1):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.n_layer = n_layer
|
| 138 |
+
self.n_token = n_token
|
| 139 |
+
self.cls_pos = cls_pos
|
| 140 |
+
self.cls_token = cls_token
|
| 141 |
+
self.encode_ratio = encode_ratio
|
| 142 |
+
|
| 143 |
+
self.pre_proj = self.post_proj = None
|
| 144 |
+
if encode_ratio and encode_ratio != 1:
|
| 145 |
+
self.pre_proj = torch.nn.Linear(emb_dim, int(emb_dim // encode_ratio))
|
| 146 |
+
self.post_proj = torch.nn.Linear(int(emb_dim // encode_ratio), emb_dim)
|
| 147 |
+
emb_dim = int(emb_dim // encode_ratio)
|
| 148 |
+
n_head = int(n_head // encode_ratio)
|
| 149 |
+
|
| 150 |
+
if activation.lower() == "gelu" or activation is None:
|
| 151 |
+
self.act = torch.nn.GELU()
|
| 152 |
+
elif activation.lower() == "relu":
|
| 153 |
+
self.act = torch.nn.ReLU()
|
| 154 |
+
elif activation.lower() == "quick_gelu":
|
| 155 |
+
self.act = QuickGELU()
|
| 156 |
+
else:
|
| 157 |
+
self.act = activation
|
| 158 |
+
self.base_query = torch.nn.Parameter(torch.empty(1, n_token + int(cls_token), emb_dim))
|
| 159 |
+
self.pos = torch.nn.Parameter(torch.empty(1, n_token + int(cls_pos and cls_token), emb_dim)) if pos else None
|
| 160 |
+
self.init_query = None
|
| 161 |
+
|
| 162 |
+
self.proj = None
|
| 163 |
+
if proj:
|
| 164 |
+
self.proj = torch.nn.Sequential(
|
| 165 |
+
torch.nn.Linear(emb_dim, ff_dim),
|
| 166 |
+
self.act,
|
| 167 |
+
torch.nn.Linear(ff_dim, emb_dim),
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.extra_proj = None
|
| 171 |
+
self.tf = torch.nn.ModuleList([TransformerBlock(emb_dim, n_head, ff_dim, n_token = n_token, activation = activation, dropout = dropout) for _ in range(n_layer)])
|
| 172 |
+
if extra_proj:
|
| 173 |
+
self.extra_proj = torch.nn.ModuleList([torch.nn.Linear(emb_dim, emb_dim) for _ in range(n_layer)])
|
| 174 |
+
|
| 175 |
+
self._reset_parameters()
|
| 176 |
+
|
| 177 |
+
def _reset_parameters(self):
|
| 178 |
+
torch.nn.init.normal_(self.base_query, std = 0.02)
|
| 179 |
+
if self.pos is not None:
|
| 180 |
+
torch.nn.init.normal_(self.pos, std = 0.01)
|
| 181 |
+
|
| 182 |
+
for proj in [self.pre_proj, self.post_proj]:
|
| 183 |
+
if proj is not None:
|
| 184 |
+
torch.nn.init.xavier_uniform_(proj.weight)
|
| 185 |
+
torch.nn.init.constant_(proj.bias, 0.)
|
| 186 |
+
for proj in [self.proj]:
|
| 187 |
+
if proj is not None:
|
| 188 |
+
torch.nn.init.xavier_uniform_(proj[0].weight)
|
| 189 |
+
torch.nn.init.xavier_uniform_(proj[2].weight)
|
| 190 |
+
|
| 191 |
+
torch.nn.init.constant_(proj[0].bias, 0.)
|
| 192 |
+
torch.nn.init.constant_(proj[2].bias, 0.)
|
| 193 |
+
if self.extra_proj is not None:
|
| 194 |
+
for l in self.extra_proj:
|
| 195 |
+
torch.nn.init.xavier_uniform_(l.weight)
|
| 196 |
+
torch.nn.init.constant_(l.bias, 0.)
|
| 197 |
+
|
| 198 |
+
def set_base_query(self, x):
|
| 199 |
+
if not torch.is_tensor(x):
|
| 200 |
+
x = torch.tensor(x, dtype=self.base_query.dtype).to(self.base_query.device)
|
| 201 |
+
if x.dim() == 2:
|
| 202 |
+
x = x.unsqueeze(0)
|
| 203 |
+
self.init_query = x
|
| 204 |
+
|
| 205 |
+
def normal_forward(self, x, context, mask = None, weight = None, alpha = None):
|
| 206 |
+
out = x
|
| 207 |
+
for i in range(self.n_layer):
|
| 208 |
+
if self.extra_proj is not None:
|
| 209 |
+
_context = self.extra_proj[i](self.act(context))
|
| 210 |
+
else:
|
| 211 |
+
_context = context
|
| 212 |
+
out = self.tf[i](out, _context, mask = mask, weight = weight, alpha = alpha) #n, b, f
|
| 213 |
+
if self.cls_token:
|
| 214 |
+
return out[:, :-1], out[:, -1]
|
| 215 |
+
else:
|
| 216 |
+
return out, None
|
| 217 |
+
|
| 218 |
+
def forward(self, context, mask = None, weight = None, alpha = None, base_query = None):
|
| 219 |
+
dtype = self.base_query.dtype
|
| 220 |
+
if base_query is not None:
|
| 221 |
+
x = base_query
|
| 222 |
+
else:
|
| 223 |
+
x = self.base_query if self.init_query is None else self.init_query
|
| 224 |
+
x = x.type(dtype)
|
| 225 |
+
if context is not None:
|
| 226 |
+
context = context.type(dtype)
|
| 227 |
+
if weight is not None:
|
| 228 |
+
weight = weight.type(dtype)
|
| 229 |
+
if self.encode_ratio is not None and x.shape[-1] != self.base_query.shape[-1]:
|
| 230 |
+
x = self.pre_proj(x)
|
| 231 |
+
if self.n_token < x.shape[1]:
|
| 232 |
+
x, cls = x[:, :self.n_token], x[:, self.n_token:]
|
| 233 |
+
else:
|
| 234 |
+
cls = self.base_query[:, self.n_token:] if self.cls_token else None
|
| 235 |
+
if self.pos is not None:
|
| 236 |
+
if self.cls_pos and self.cls_token:
|
| 237 |
+
x = x + self.pos[:, :self.n_token]
|
| 238 |
+
if cls is not None:
|
| 239 |
+
cls = cls + self.pos[:, self.n_token:]
|
| 240 |
+
else:
|
| 241 |
+
x = x + self.pos
|
| 242 |
+
if self.cls_token:
|
| 243 |
+
x = torch.cat([x, cls], dim = 1)
|
| 244 |
+
x = x.repeat_interleave(context.shape[0], dim = 0)
|
| 245 |
+
if self.encode_ratio is not None:
|
| 246 |
+
if context is not None:
|
| 247 |
+
context = self.pre_proj(context)
|
| 248 |
+
if self.proj is not None:
|
| 249 |
+
context = self.proj(context)
|
| 250 |
+
out = self.normal_forward(x, context, mask = mask, weight = weight, alpha = alpha)
|
| 251 |
+
if self.encode_ratio is not None:
|
| 252 |
+
out = (self.post_proj(out[0]), self.post_proj(out[1]) if out[1] is not None else out[1])
|
| 253 |
+
return out
|
| 254 |
+
|
| 255 |
+
class DrUM:
|
| 256 |
+
def __init__(self, model, processor, n_layer = 8, proj = False, extra_proj = False, mlp_ratio = 4, pos = True, cls_pos = False, cls_token = True, encode_ratio = None, max_token_size = 256, activation = "quick_gelu", dropout = 0.1):
|
| 257 |
+
config = model.config.text_config if hasattr(model.config, "text_config") else model.config
|
| 258 |
+
if hasattr(config, "model_type") and config.model_type == "t5":
|
| 259 |
+
self.d_model = config.d_model
|
| 260 |
+
self.n_head = config.num_heads
|
| 261 |
+
self.n_token = min(processor.model_max_length, max_token_size)
|
| 262 |
+
self.clip = False
|
| 263 |
+
self.cls_token = False
|
| 264 |
+
else:
|
| 265 |
+
self.d_model = config.hidden_size
|
| 266 |
+
self.n_head = config.num_attention_heads
|
| 267 |
+
self.n_token = config.max_position_embeddings
|
| 268 |
+
self.clip = True
|
| 269 |
+
self.cls_token = cls_token
|
| 270 |
+
self.n_layer = n_layer
|
| 271 |
+
self.proj = proj
|
| 272 |
+
self.extra_proj = extra_proj
|
| 273 |
+
self.mlp_ratio = mlp_ratio
|
| 274 |
+
self.pos = pos
|
| 275 |
+
self.cls_pos = cls_pos
|
| 276 |
+
self.encode_ratio = encode_ratio
|
| 277 |
+
self.activation = activation
|
| 278 |
+
self.dropout = dropout
|
| 279 |
+
|
| 280 |
+
self.model = model
|
| 281 |
+
self.processor = processor
|
| 282 |
+
self.adapter = PersonalizedAdapter(self.d_model, self.n_head, self.d_model // mlp_ratio, n_layer, self.n_token, proj = proj, extra_proj = extra_proj, pos = pos, cls_pos = cls_pos, cls_token = self.cls_token, encode_ratio = encode_ratio, activation = activation, dropout = dropout).to(model.device)
|
| 283 |
+
|
| 284 |
+
self.train()
|
| 285 |
+
self.to(model.device)
|
| 286 |
+
|
| 287 |
+
def preprocess(self, text = None, image = None, return_tensors = "pt", padding = "max_length", truncation = True, **kwargs):
|
| 288 |
+
feed = {"text":([text] if np.ndim(text) == 0 else list(text)) if text is not None else None,
|
| 289 |
+
"return_tensors":return_tensors,
|
| 290 |
+
"max_length":self.n_token,
|
| 291 |
+
"padding":padding,
|
| 292 |
+
"truncation":truncation,
|
| 293 |
+
**kwargs}
|
| 294 |
+
if not self.clip:
|
| 295 |
+
feed["add_special_tokens"] = True
|
| 296 |
+
if image is not None:
|
| 297 |
+
feed["images"] = image
|
| 298 |
+
return self.processor(**feed)
|
| 299 |
+
|
| 300 |
+
def pool_text_hidden_state(self, hidden_state, x, padding = "max_length", truncation = True, **kwargs):
|
| 301 |
+
if not self.clip:
|
| 302 |
+
raise TypeError("T5 encoder does not support this function (pool_text_hidden_state).")
|
| 303 |
+
if not hasattr(x, "items"):
|
| 304 |
+
x = self.preprocess(text = x, padding = padding, truncation = truncation, **kwargs)
|
| 305 |
+
if self.model.text_model.eos_token_id == 2:
|
| 306 |
+
out = hidden_state[torch.arange(hidden_state.shape[0], device = hidden_state.device),
|
| 307 |
+
x["input_ids"].to(dtype = torch.int, device = hidden_state.device).argmax(dim = -1),]
|
| 308 |
+
else:
|
| 309 |
+
out = hidden_state[torch.arange(hidden_state.shape[0], device = hidden_state.device),
|
| 310 |
+
(x["input_ids"].to(dtype = torch.int, device = hidden_state.device) == self.model.text_model.eos_token_id).int().argmax(dim = -1),]
|
| 311 |
+
return out
|
| 312 |
+
|
| 313 |
+
def normalize_text_hidden_state(self, hidden_state):
|
| 314 |
+
out = self.model.text_model.final_layer_norm(hidden_state.type(self.model.dtype)) if self.clip and hasattr(self.model.text_model, "final_layer_norm") else hidden_state
|
| 315 |
+
return out
|
| 316 |
+
|
| 317 |
+
def projection_text_hidden_state(self, hidden_state):
|
| 318 |
+
out = self.model.text_projection(hidden_state.type(self.model.dtype)) if self.clip and hasattr(self.model, "text_projection") else hidden_state
|
| 319 |
+
return out
|
| 320 |
+
|
| 321 |
+
def encode_prompt(self, x, pooling = True, skip = -1, skip_pool = None, padding = "max_length", truncation = True, use_attn_mask = False, normalize = True, normalize_pool = True, **kwargs):
|
| 322 |
+
if not hasattr(x, "items"):
|
| 323 |
+
x = self.preprocess(text = x, padding = padding, truncation = truncation, **kwargs)
|
| 324 |
+
input_ids = x["input_ids"].to(self.device)
|
| 325 |
+
attention_mask = x["attention_mask"].to(self.device) if use_attn_mask else None
|
| 326 |
+
with torch.no_grad():
|
| 327 |
+
if self.clip:
|
| 328 |
+
hidden_state = self.model.text_model(output_hidden_states = True, input_ids = input_ids, attention_mask = attention_mask)["hidden_states"]
|
| 329 |
+
pool, hidden_state = hidden_state[skip_pool if skip_pool is not None else skip], hidden_state[skip]
|
| 330 |
+
hidden_state = self.normalize_text_hidden_state(hidden_state) if normalize else hidden_state
|
| 331 |
+
else:
|
| 332 |
+
hidden_state = self.model(input_ids = input_ids, attention_mask = attention_mask)[0]
|
| 333 |
+
pool = None
|
| 334 |
+
if pooling:
|
| 335 |
+
if self.clip:
|
| 336 |
+
with torch.no_grad():
|
| 337 |
+
pool = self.pool_text_hidden_state(self.normalize_text_hidden_state(pool) if normalize_pool else pool, x, **kwargs)
|
| 338 |
+
return (hidden_state, pool)
|
| 339 |
+
return hidden_state
|
| 340 |
+
|
| 341 |
+
def get_text_feature(self, x, ref_x = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, padding = "max_length", truncation = True, use_attn_mask = False, **kwargs):
|
| 342 |
+
if not self.clip:
|
| 343 |
+
raise TypeError("T5 encoder does not support this function (get_text_feature).")
|
| 344 |
+
with torch.no_grad():
|
| 345 |
+
pool_hidden_state = self(x, ref_x, weight = weight, alpha = alpha, pooling = True, skip_pool = skip, batch_size = batch_size, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize_pool = True, **kwargs)[1]
|
| 346 |
+
result = self.projection_text_hidden_state(pool_hidden_state)
|
| 347 |
+
return result
|
| 348 |
+
|
| 349 |
+
def get_image_feature(self, x, return_tensors = "pt", **kwargs):
|
| 350 |
+
if not self.clip:
|
| 351 |
+
raise TypeError("T5 encoder does not support this function (get_image_feature).")
|
| 352 |
+
if hasattr(x, "items"):
|
| 353 |
+
x = x["pixel_values"]
|
| 354 |
+
elif not torch.is_tensor(x):
|
| 355 |
+
x = self.preprocess(image = x, return_tensors = return_tensors, **kwargs)["pixel_values"]
|
| 356 |
+
with torch.no_grad():
|
| 357 |
+
result = self.model.get_image_features(pixel_values = x.to(self.device))
|
| 358 |
+
return result
|
| 359 |
+
|
| 360 |
+
def encode_context(self, ref_x, pooling = False, skip = -1, skip_pool = None, batch_size = 64, padding = "max_length", truncation = True, use_attn_mask = False, normalize = False, normalize_pool = False, **kwargs):
|
| 361 |
+
if not hasattr(ref_x, "items"):
|
| 362 |
+
if np.ndim(ref_x) == 0:
|
| 363 |
+
ref_x = [[ref_x]]
|
| 364 |
+
elif np.ndim(ref_x) == 1:
|
| 365 |
+
ref_x = [ref_x]
|
| 366 |
+
b, ref_size = len(ref_x), len(ref_x[0])
|
| 367 |
+
ref_x = np.reshape(ref_x, [b * ref_size])
|
| 368 |
+
ref_x = self.preprocess(text = list(ref_x), padding = padding, truncation = truncation, **kwargs)
|
| 369 |
+
ref_x = {k:v for k, v in ref_x.items() if k in (["input_ids", "attention_mask"] if use_attn_mask else ["input_ids"])}
|
| 370 |
+
else:
|
| 371 |
+
b, ref_size = ref_x["input_ids"].shape[:2]
|
| 372 |
+
ref_x = {k:v.view(b * ref_size, -1) for k, v in ref_x.items() if k in (["input_ids", "attention_mask"] if use_attn_mask else ["input_ids"])}
|
| 373 |
+
hidden_state, pool_hidden_state = [], []
|
| 374 |
+
batch_indices = [(i * batch_size, min((b * ref_size), (i + 1) * batch_size)) for i in range(int(np.ceil((b * ref_size) / batch_size)))]
|
| 375 |
+
for start, end in batch_indices:
|
| 376 |
+
h, p = self.encode_prompt({k:v[start:end] for k, v in ref_x.items()}, pooling = True, skip = skip, skip_pool = skip_pool, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = normalize, normalize_pool = normalize_pool, **kwargs)
|
| 377 |
+
hidden_state.append(h)
|
| 378 |
+
if p is not None:
|
| 379 |
+
pool_hidden_state.append(p)
|
| 380 |
+
hidden_state = torch.cat(hidden_state, dim = 0) if 1 < len(hidden_state) else hidden_state[0]
|
| 381 |
+
pool_hidden_state = torch.cat(pool_hidden_state, dim = 0) if 1 < len(pool_hidden_state) else (pool_hidden_state[0] if len(pool_hidden_state) == 1 else None)
|
| 382 |
+
with torch.no_grad():
|
| 383 |
+
hidden_state = hidden_state.view(b, ref_size * hidden_state.shape[1], -1)
|
| 384 |
+
if pooling:
|
| 385 |
+
if self.clip:
|
| 386 |
+
pool_hidden_state = pool_hidden_state.view(b, ref_size, -1)
|
| 387 |
+
hidden_state = (hidden_state, pool_hidden_state)
|
| 388 |
+
return hidden_state
|
| 389 |
+
|
| 390 |
+
def __call__(self, x, ref_x = None, weight = None, alpha = 0.3, pooling = True, skip = -1, skip_pool = None, batch_size = 64, padding = "max_length", truncation = True, use_attn_mask = False, normalize = True, normalize_pool = True, training = False, **kwargs):
|
| 391 |
+
if ref_x is not None or training:
|
| 392 |
+
if training:
|
| 393 |
+
context = weight = None
|
| 394 |
+
else:
|
| 395 |
+
_context, _context_pool = self.encode_context(ref_x, pooling = True, skip = skip, skip_pool = None, batch_size = batch_size, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = False, normalize_pool = False, **kwargs)
|
| 396 |
+
if weight is not None:
|
| 397 |
+
if not torch.is_tensor(weight):
|
| 398 |
+
weight = torch.tensor(weight)
|
| 399 |
+
if weight.dim() == 0:
|
| 400 |
+
weight = weight.unsqueeze(0).unsqueeze(0)
|
| 401 |
+
elif weight.dim() == 1:
|
| 402 |
+
weight = weight.unsqueeze(0)
|
| 403 |
+
weight = weight.to(self.device)
|
| 404 |
+
else:
|
| 405 |
+
weight = torch.ones((1, _context.shape[1] // self.n_token), dtype = torch.float32, device = _context.device)
|
| 406 |
+
context = _context
|
| 407 |
+
del _context, _context_pool
|
| 408 |
+
result = self.encode_personalized_prompt(x, context, weight = weight, alpha = alpha, pooling = pooling, skip = skip, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = normalize, normalize_pool = normalize_pool, **kwargs)
|
| 409 |
+
return result
|
| 410 |
+
else:
|
| 411 |
+
return self.encode_prompt(x, pooling = pooling, skip = skip, skip_pool = skip_pool, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = normalize, normalize_pool = normalize_pool, **kwargs)
|
| 412 |
+
|
| 413 |
+
def encode_personalized_prompt(self, x, context = None, weight = None, alpha = 0.3, pooling = True, skip = -1, padding = "max_length", truncation = True, use_attn_mask = False, normalize = True, normalize_pool = True, **kwargs):
|
| 414 |
+
if not torch.is_tensor(x):
|
| 415 |
+
if not hasattr(x, "items"):
|
| 416 |
+
x = self.preprocess(text = x, padding = padding, truncation = truncation, **kwargs)
|
| 417 |
+
x = self.encode_prompt(x, pooling = False, skip = skip, skip_pool = None, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = False, normalize_pool = False, **kwargs)
|
| 418 |
+
if context is None:
|
| 419 |
+
context = x
|
| 420 |
+
else:
|
| 421 |
+
batch_size, n_token = x.shape[:2]
|
| 422 |
+
if context.shape[0] == 1 and batch_size != 1:
|
| 423 |
+
context = context.repeat_interleave(batch_size, dim = 0)
|
| 424 |
+
if weight is not None and weight.shape[0] == 1:
|
| 425 |
+
weight = weight.repeat_interleave(batch_size, dim = 0)
|
| 426 |
+
context_size = context.shape[1]
|
| 427 |
+
context = torch.cat([x, context], dim = 1)
|
| 428 |
+
if weight is not None:
|
| 429 |
+
extra_weight = torch.ones((batch_size, n_token), dtype = torch.float32, device = weight.device)
|
| 430 |
+
weight = torch.cat([extra_weight, weight], dim = 1)
|
| 431 |
+
hidden_state, pool = self.adapter(context, weight = weight, alpha = alpha)
|
| 432 |
+
hidden_state = self.normalize_text_hidden_state(hidden_state) if normalize else hidden_state
|
| 433 |
+
if pooling:
|
| 434 |
+
pool = self.normalize_text_hidden_state(pool) if normalize_pool else pool
|
| 435 |
+
return (hidden_state, pool)
|
| 436 |
+
return hidden_state
|
| 437 |
+
|
| 438 |
+
def to(self, device):
|
| 439 |
+
self.model.to(device)
|
| 440 |
+
self.adapter.to(device)
|
| 441 |
+
self.device = device
|
| 442 |
+
return self
|
| 443 |
+
|
| 444 |
+
def eval(self):
|
| 445 |
+
self.model.eval()
|
| 446 |
+
if self.clip and hasattr(self.model, "text_projection"):
|
| 447 |
+
self.model.text_model.final_layer_norm.requires_grad_(False)
|
| 448 |
+
self.model.text_projection.requires_grad_(False)
|
| 449 |
+
self.adapter.eval()
|
| 450 |
+
return self
|
| 451 |
+
|
| 452 |
+
def train(self):
|
| 453 |
+
self.model.eval()
|
| 454 |
+
if self.clip and hasattr(self.model, "text_projection"):
|
| 455 |
+
self.model.text_model.final_layer_norm.requires_grad_(False)
|
| 456 |
+
self.model.text_projection.requires_grad_(False)
|
| 457 |
+
self.adapter.train()
|
| 458 |
+
return self
|
| 459 |
+
|
| 460 |
+
def parameters(self):
|
| 461 |
+
return list(self.adapter.parameters())
|
| 462 |
+
|
| 463 |
+
def named_parameters(self):
|
| 464 |
+
return list(self.adapter.named_parameters())
|
pipeline_drum.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .wrapper import DrUM
|
sampling.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
def clip_score(feature, ref_feature, logit_scale = 100.0, weight = 1, reduce = True):
|
| 5 |
+
ref_feature = np.expand_dims(ref_feature, axis = 0) if np.ndim(ref_feature) == 2 else ref_feature
|
| 6 |
+
batch_size, ref_size = np.shape(ref_feature)[:2]
|
| 7 |
+
feature = feature / np.linalg.norm(feature, axis = -1, keepdims=True)
|
| 8 |
+
ref_feature = ref_feature / np.linalg.norm(ref_feature, axis = -1, keepdims=True)
|
| 9 |
+
sim = logit_scale * np.einsum("bf,btf->bt", feature, ref_feature)
|
| 10 |
+
sim = sim * (np.expand_dims(weight, axis = 0) if np.ndim(weight) == 1 else weight)
|
| 11 |
+
return sim.mean(axis = 1) if reduce else (sim[..., 0] if ref_size == 1 else sim)
|
| 12 |
+
|
| 13 |
+
def coreset_sampling(data, n_sample = 0.1, weight = 1, n_approximate = 10, logit_scale = 100, seed = 42):
|
| 14 |
+
data = np.array(data) if not isinstance(data, np.ndarray) else data
|
| 15 |
+
n_sample = round(len(data) * n_sample) if isinstance(n_sample, float) or (isinstance(n_sample, int) and n_sample < 1) else n_sample
|
| 16 |
+
n_sample = max(min(n_sample, len(data)), 1 if len(data) != 0 else 0)
|
| 17 |
+
weight = 1 if weight is None else weight
|
| 18 |
+
weight = np.transpose(weight) if np.ndim(weight) == 2 else (np.expand_dims(weight, axis = -1) if np.ndim(weight) == 1 else weight)
|
| 19 |
+
|
| 20 |
+
random = ((np.random.RandomState(seed) if isinstance(seed, int) else seed) if seed is not None else np.random)
|
| 21 |
+
if n_sample == len(data):
|
| 22 |
+
indices = np.arange(n_sample)
|
| 23 |
+
else:
|
| 24 |
+
indices = []
|
| 25 |
+
approx_data = data[random.choice(len(data), min(round(len(data) * n_approximate) if isinstance(n_approximate, float) else n_approximate, len(data)), replace = False)]
|
| 26 |
+
dist = clip_score(data, approx_data, weight = weight, logit_scale = logit_scale, reduce = False)
|
| 27 |
+
dist = np.mean(dist, axis = 1, keepdims = True)
|
| 28 |
+
for i in range(n_sample):
|
| 29 |
+
sample_index = np.argmax(dist)
|
| 30 |
+
indices.append(sample_index)
|
| 31 |
+
sample_dist = clip_score(data, data[[sample_index]], weight = weight, logit_scale = logit_scale, reduce = False)
|
| 32 |
+
dist = np.minimum(dist, sample_dist)
|
| 33 |
+
dist[sample_index] = -np.inf
|
| 34 |
+
return indices
|
weight/H.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3eb561647781fd66c1f84d2c8d5f15f71e5994e8d11c1e38ff7463c6fe84e554
|
| 3 |
+
size 189456448
|
weight/L.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d1e3ffcbe5e21f3e9796c8cdc0450ef90d7a927de242a6e3fe69531071705953
|
| 3 |
+
size 106706168
|
weight/T5.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7d6d8413819cc4a2d388b7a0222eff8c161c7bcf459536db2ad10a403756643b
|
| 3 |
+
size 286706200
|
weight/bigG.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c58a6c1a558599e090f7fe0eaada1e0684dafd472c0a18fbd4067450d54a606
|
| 3 |
+
size 295799424
|
wrapper.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from diffusers import DiffusionPipeline, FluxPipeline, StableDiffusion3Pipeline
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
from safetensors.torch import load_file
|
| 8 |
+
|
| 9 |
+
from .model import DrUM as backbone
|
| 10 |
+
from .sampling import coreset_sampling
|
| 11 |
+
|
| 12 |
+
def stable_diffusion(large):
|
| 13 |
+
"""
|
| 14 |
+
openai/clip-vit-large-patch14, CLIPTextModel, skip -1
|
| 15 |
+
"""
|
| 16 |
+
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, **kwargs):
|
| 17 |
+
return large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, **kwargs), None
|
| 18 |
+
return inference
|
| 19 |
+
|
| 20 |
+
def stable_diffusion_v2(huge):
|
| 21 |
+
"""
|
| 22 |
+
openai/clip-vit-huge-patch14, CLIPTextModel, skip -1
|
| 23 |
+
"""
|
| 24 |
+
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, **kwargs):
|
| 25 |
+
return huge(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, **kwargs), None
|
| 26 |
+
return inference
|
| 27 |
+
|
| 28 |
+
def stable_diffusion_xl(large, bigG):
|
| 29 |
+
"""
|
| 30 |
+
openai/clip-vit-large-patch14, CLIPTextModel, skip -2, unnorm
|
| 31 |
+
laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj
|
| 32 |
+
"""
|
| 33 |
+
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -2, batch_size = 64, **kwargs):
|
| 34 |
+
hidden_state = large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs)
|
| 35 |
+
if skip == -1:
|
| 36 |
+
hidden_state2, pool_hidden_state = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)
|
| 37 |
+
else:
|
| 38 |
+
hidden_state2 = bigG(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs)
|
| 39 |
+
pool_hidden_state = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1]
|
| 40 |
+
hidden_state = torch.cat([hidden_state, hidden_state2], dim = -1)
|
| 41 |
+
pool_hidden_state = bigG.projection_text_hidden_state(pool_hidden_state)
|
| 42 |
+
return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state
|
| 43 |
+
return inference
|
| 44 |
+
|
| 45 |
+
def stable_diffusion_v3(large, bigG, t5):
|
| 46 |
+
"""
|
| 47 |
+
openai/clip-vit-large-patch14, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj
|
| 48 |
+
laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj
|
| 49 |
+
t5-v1_1-xxl, T5EncoderModel
|
| 50 |
+
"""
|
| 51 |
+
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -2, batch_size = 64, **kwargs):
|
| 52 |
+
if skip == -1:
|
| 53 |
+
hidden_state, pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)
|
| 54 |
+
hidden_state2, pool_hidden_state2 = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)
|
| 55 |
+
else:
|
| 56 |
+
hidden_state = large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs)
|
| 57 |
+
hidden_state2 = bigG(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs)
|
| 58 |
+
pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1]
|
| 59 |
+
pool_hidden_state2 = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1]
|
| 60 |
+
hidden_state3 = t5(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, batch_size = batch_size, normalize = False, **kwargs)
|
| 61 |
+
hidden_state = torch.cat([hidden_state, hidden_state2], dim = -1)
|
| 62 |
+
pool_hidden_state = large.projection_text_hidden_state(pool_hidden_state)
|
| 63 |
+
pool_hidden_state2 = bigG.projection_text_hidden_state(pool_hidden_state2)
|
| 64 |
+
hidden_state = torch.nn.functional.pad(hidden_state, (0, hidden_state3.shape[-1] - hidden_state.shape[-1]))
|
| 65 |
+
hidden_state = torch.cat([hidden_state, hidden_state3], dim = -2)
|
| 66 |
+
pool_hidden_state = torch.cat([pool_hidden_state, pool_hidden_state2], dim = -1)
|
| 67 |
+
return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state
|
| 68 |
+
return inference
|
| 69 |
+
|
| 70 |
+
def flux(large, t5):
|
| 71 |
+
"""
|
| 72 |
+
openai/clip-vit-large-patch14, CLIPTextModel, pooling
|
| 73 |
+
t5-v1_1-xxl, T5EncoderModel
|
| 74 |
+
"""
|
| 75 |
+
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = None, batch_size = 64, **kwargs):
|
| 76 |
+
hidden_state = t5(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, batch_size = batch_size, normalize = False, **kwargs)
|
| 77 |
+
pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1]
|
| 78 |
+
return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state
|
| 79 |
+
return inference
|
| 80 |
+
|
| 81 |
+
def peca(pipeline, save_path = "./weight", n_layer = 10):
|
| 82 |
+
if os.path.exists(os.path.join(save_path, "L.pth")) or os.path.exists(os.path.join(save_path, "H.pth")):
|
| 83 |
+
load_func = torch.load
|
| 84 |
+
postfix = "pth"
|
| 85 |
+
else:
|
| 86 |
+
from safetensors.torch import load_file as load_func
|
| 87 |
+
postfix = "safetensors"
|
| 88 |
+
|
| 89 |
+
if "flux" in pipeline.config._name_or_path.split("/")[-1].lower():
|
| 90 |
+
model = pipeline.text_encoder
|
| 91 |
+
processor = pipeline.tokenizer
|
| 92 |
+
model2 = pipeline.text_encoder_2
|
| 93 |
+
processor2 = pipeline.tokenizer_2
|
| 94 |
+
|
| 95 |
+
large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
|
| 96 |
+
large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix))))
|
| 97 |
+
t5 = backbone(model2, processor2, n_layer = n_layer, encode_ratio = 4, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
|
| 98 |
+
t5.adapter.load_state_dict(load_func(os.path.join(save_path, "T5.{0}".format(postfix))))
|
| 99 |
+
empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
|
| 100 |
+
large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
|
| 101 |
+
empty, pool = t5.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
|
| 102 |
+
t5.adapter.set_base_query(empty)
|
| 103 |
+
|
| 104 |
+
feature_encoder = large
|
| 105 |
+
encoder = flux(large, t5)
|
| 106 |
+
size = 1024
|
| 107 |
+
num_inference_steps = 28
|
| 108 |
+
skip = -2
|
| 109 |
+
elif "stable-diffusion-3.5" in pipeline.config._name_or_path.split("/")[-1].lower(): #sd v3
|
| 110 |
+
model = pipeline.text_encoder
|
| 111 |
+
processor = pipeline.tokenizer
|
| 112 |
+
model2 = pipeline.text_encoder_2
|
| 113 |
+
processor2 = pipeline.tokenizer_2
|
| 114 |
+
model3 = pipeline.text_encoder_3
|
| 115 |
+
processor3 = pipeline.tokenizer_3
|
| 116 |
+
|
| 117 |
+
large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
|
| 118 |
+
large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix))))
|
| 119 |
+
bigG = backbone(model2, processor2, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
|
| 120 |
+
bigG.adapter.load_state_dict(load_func(os.path.join(save_path, "bigG.{0}".format(postfix))))
|
| 121 |
+
t5 = backbone(model3, processor3, n_layer = n_layer, encode_ratio = 4, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
|
| 122 |
+
t5.adapter.load_state_dict(load_func(os.path.join(save_path, "T5.{0}".format(postfix))))
|
| 123 |
+
empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
|
| 124 |
+
large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
|
| 125 |
+
empty, pool = bigG.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
|
| 126 |
+
bigG.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
|
| 127 |
+
empty, pool = t5.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
|
| 128 |
+
t5.adapter.set_base_query(empty)
|
| 129 |
+
|
| 130 |
+
feature_encoder = large
|
| 131 |
+
encoder = stable_diffusion_v3(large, bigG, t5)
|
| 132 |
+
size = 1024
|
| 133 |
+
num_inference_steps = 28
|
| 134 |
+
skip = -2
|
| 135 |
+
elif "xl-base" in pipeline.config._name_or_path.split("/")[-1].lower(): #sd xl
|
| 136 |
+
model = pipeline.text_encoder
|
| 137 |
+
processor = pipeline.tokenizer
|
| 138 |
+
model2 = pipeline.text_encoder_2
|
| 139 |
+
processor2 = pipeline.tokenizer_2
|
| 140 |
+
|
| 141 |
+
large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
|
| 142 |
+
large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix))))
|
| 143 |
+
bigG = backbone(model2, processor2, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
|
| 144 |
+
bigG.adapter.load_state_dict(load_func(os.path.join(save_path, "bigG.{0}".format(postfix))))
|
| 145 |
+
empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
|
| 146 |
+
large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
|
| 147 |
+
empty, pool = bigG.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
|
| 148 |
+
bigG.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
|
| 149 |
+
|
| 150 |
+
feature_encoder = large
|
| 151 |
+
encoder = stable_diffusion_xl(large, bigG)
|
| 152 |
+
size = 1024
|
| 153 |
+
num_inference_steps = 50
|
| 154 |
+
skip = -2
|
| 155 |
+
elif "stable-diffusion-2" in pipeline.config._name_or_path.split("/")[-1].lower():
|
| 156 |
+
model = pipeline.text_encoder
|
| 157 |
+
processor = pipeline.tokenizer
|
| 158 |
+
|
| 159 |
+
huge = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
|
| 160 |
+
huge.adapter.load_state_dict(load_func(os.path.join(save_path, "H.{0}".format(postfix))))
|
| 161 |
+
empty, pool = huge.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
|
| 162 |
+
huge.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
|
| 163 |
+
|
| 164 |
+
feature_encoder = huge
|
| 165 |
+
encoder = stable_diffusion_v2(huge)
|
| 166 |
+
size = 768
|
| 167 |
+
num_inference_steps = 50
|
| 168 |
+
skip = -1
|
| 169 |
+
else: #sd
|
| 170 |
+
model = pipeline.text_encoder
|
| 171 |
+
processor = pipeline.tokenizer
|
| 172 |
+
|
| 173 |
+
large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
|
| 174 |
+
large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix))))
|
| 175 |
+
empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
|
| 176 |
+
large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
|
| 177 |
+
|
| 178 |
+
feature_encoder = large
|
| 179 |
+
encoder = stable_diffusion(large)
|
| 180 |
+
size = 512
|
| 181 |
+
num_inference_steps = 50
|
| 182 |
+
skip = -1
|
| 183 |
+
return encoder, feature_encoder.get_text_feature, size, num_inference_steps, skip
|
| 184 |
+
|
| 185 |
+
class DrUM(DiffusionPipeline):
|
| 186 |
+
def __init__(self, pipeline, repo_id = "Burf/DrUM"):
|
| 187 |
+
"""
|
| 188 |
+
DrUM for various diffusion models
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
pipeline: Loaded diffusion pipeline
|
| 192 |
+
repo_id: Hugging Face repository containing adapter weights
|
| 193 |
+
"""
|
| 194 |
+
self.pipeline = pipeline
|
| 195 |
+
self.repo_id = repo_id
|
| 196 |
+
|
| 197 |
+
self.adapter, self.feature_encoder, self.size, self.num_inference_steps, self.skip = self.load_peca(pipeline, repo_id)
|
| 198 |
+
|
| 199 |
+
@classmethod
|
| 200 |
+
def from_pretrained(cls, model_id, repo_id = "Burf/DrUM", torch_dtype = torch.bfloat16, device = "cuda"):
|
| 201 |
+
"""
|
| 202 |
+
Load DrUM adapter with appropriate pipeline
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
model_id: Base diffusion model ID
|
| 206 |
+
repo_id: DrUM adapters repository
|
| 207 |
+
torch_dtype: Model precision
|
| 208 |
+
device: Device
|
| 209 |
+
"""
|
| 210 |
+
name = model_id.split("/")[-1].lower()
|
| 211 |
+
|
| 212 |
+
if "flux" in name:
|
| 213 |
+
pipeline = FluxPipeline.from_pretrained(model_id, torch_dtype = torch_dtype)
|
| 214 |
+
elif "stable-diffusion-3.5" in name:
|
| 215 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype = torch_dtype)
|
| 216 |
+
else:
|
| 217 |
+
pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype = torch_dtype)
|
| 218 |
+
|
| 219 |
+
pipeline = pipeline.to(device if torch.cuda.is_available() else "cpu")
|
| 220 |
+
#pipeline.safety_checker = lambda images, clip_input: (images, [False] * len(images))
|
| 221 |
+
return cls(pipeline, repo_id)
|
| 222 |
+
|
| 223 |
+
def load_weight(self, pipeline, repo_id = "Burf/DrUM"):
|
| 224 |
+
name = pipeline.config._name_or_path.split("/")[-1].lower()
|
| 225 |
+
|
| 226 |
+
weights = []
|
| 227 |
+
if "flux" in name:
|
| 228 |
+
weights = ["L.safetensors", "T5.safetensors"]
|
| 229 |
+
elif "stable-diffusion-3.5" in name:
|
| 230 |
+
weights = ["L.safetensors", "bigG.safetensors", "T5.safetensors"]
|
| 231 |
+
elif "xl-base" in name:
|
| 232 |
+
weights = ["L.safetensors", "bigG.safetensors"]
|
| 233 |
+
elif "stable-diffusion-2" in name:
|
| 234 |
+
weights = ["H.safetensors"]
|
| 235 |
+
else: # SD v1.5
|
| 236 |
+
weights = ["L.safetensors"]
|
| 237 |
+
|
| 238 |
+
for weight_file in weights:
|
| 239 |
+
safetensor_path = hf_hub_download(repo_id = repo_id, filename = weight_file)
|
| 240 |
+
weight_path = os.path.dirname(safetensor_path)
|
| 241 |
+
return weight_path
|
| 242 |
+
|
| 243 |
+
def load_peca(self, pipeline, repo_id = "Burf/DrUM"):
|
| 244 |
+
adapter, feature_encoder, size, num_inference_steps, skip = peca(pipeline, save_path = self.load_weight(pipeline, repo_id))
|
| 245 |
+
return adapter, feature_encoder, size, num_inference_steps, skip
|
| 246 |
+
|
| 247 |
+
def __call__(self, prompt, ref = None, weight = None, alpha = 0.3, skip = None, sampling = False, seed = 42,
|
| 248 |
+
size = None, num_inference_steps = None, num_images_per_prompt = 1):
|
| 249 |
+
"""
|
| 250 |
+
Generate images using DrUM adapter
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
prompt: Text prompt for generation
|
| 254 |
+
ref: Reference prompts (list of strings)
|
| 255 |
+
weight: Weights for reference prompts (list of floats)
|
| 256 |
+
alpha: Personalization strength (0-1)
|
| 257 |
+
skip: Text condition axis
|
| 258 |
+
sampling: Whether to use coreset sampling for reference selection (default: False)
|
| 259 |
+
seed: Random seed
|
| 260 |
+
size: Image size
|
| 261 |
+
num_inference_steps: Inference steps
|
| 262 |
+
num_images_per_prompt: Number of images to generate
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
Personalized images (list of PIL Images)
|
| 266 |
+
"""
|
| 267 |
+
size = self.size if size is None else size
|
| 268 |
+
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
|
| 269 |
+
skip = self.skip if skip is None else skip
|
| 270 |
+
|
| 271 |
+
if sampling and isinstance(ref, (tuple, list)) and 1 < len(ref):
|
| 272 |
+
import numpy as np
|
| 273 |
+
|
| 274 |
+
with torch.no_grad():
|
| 275 |
+
feature = self.feature_encoder(ref).cpu().float().numpy()
|
| 276 |
+
|
| 277 |
+
indices = coreset_sampling(feature, weight = weight, seed = seed)
|
| 278 |
+
ref = np.array(ref)[indices].tolist()
|
| 279 |
+
|
| 280 |
+
if isinstance(weight, (tuple, list)) and len(weight) == len(ref):
|
| 281 |
+
weight = np.array(weight)[indices].tolist()
|
| 282 |
+
|
| 283 |
+
generator = torch.Generator(self.pipeline.device).manual_seed(seed)
|
| 284 |
+
with torch.no_grad():
|
| 285 |
+
cond, pool_cond = self.adapter(prompt, ref, weight = weight, alpha = alpha, skip = skip)
|
| 286 |
+
|
| 287 |
+
pipe_kwargs = {
|
| 288 |
+
"num_images_per_prompt": num_images_per_prompt,
|
| 289 |
+
"num_inference_steps": num_inference_steps,
|
| 290 |
+
"generator": generator,
|
| 291 |
+
"height": size,
|
| 292 |
+
"width": size
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
pipe_kwargs["prompt_embeds"] = cond.type(self.pipeline.dtype)
|
| 296 |
+
if pool_cond is not None:
|
| 297 |
+
pipe_kwargs["pooled_prompt_embeds"] = pool_cond.type(self.pipeline.dtype)
|
| 298 |
+
|
| 299 |
+
name = self.pipeline.config._name_or_path.split("/")[-1].lower()
|
| 300 |
+
if "flux" in name or "stable-diffusion-3" in name:
|
| 301 |
+
pipe_kwargs["max_sequence_length"] = 256
|
| 302 |
+
|
| 303 |
+
images = self.pipeline(**pipe_kwargs).images
|
| 304 |
+
return images
|