Burf commited on
Commit
541e9bd
·
1 Parent(s): f8522ce

Init code and weights

Browse files
model.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+ class MultiheadAttention(torch.nn.Module):
6
+ def __init__(self, d_model, n_head, n_token = 77, dropout = 0.1):
7
+ super().__init__()
8
+ self.d_model = d_model
9
+ self.n_head = n_head
10
+ self.d_head = d_model // n_head
11
+ self.n_token = n_token
12
+
13
+ self.query = torch.nn.Linear(d_model, d_model)
14
+ self.key = torch.nn.Linear(d_model, d_model)
15
+ self.value = torch.nn.Linear(d_model, d_model)
16
+ self.proj = torch.nn.Linear(d_model, d_model)
17
+
18
+ self.div = torch.sqrt(torch.tensor(self.d_head, dtype = self.query.weight.dtype))
19
+
20
+ self.softmax = torch.nn.Softmax(dim = -1)
21
+ self.dropout = torch.nn.Dropout(dropout)
22
+
23
+ self._reset_parameters()
24
+
25
+ def _reset_parameters(self):
26
+ torch.nn.init.xavier_uniform_(self.query.weight)
27
+ torch.nn.init.xavier_uniform_(self.key.weight)
28
+ torch.nn.init.xavier_uniform_(self.value.weight)
29
+ torch.nn.init.xavier_uniform_(self.proj.weight)
30
+
31
+ torch.nn.init.constant_(self.query.bias, 0.)
32
+ torch.nn.init.constant_(self.key.bias, 0.)
33
+ torch.nn.init.constant_(self.value.bias, 0.)
34
+ torch.nn.init.constant_(self.proj.bias, 0.)
35
+
36
+ def forward(self, q, k, v, mask = None, weight = None, alpha = None):
37
+ b, s = q.shape[:2]
38
+ b2, s2 = k.shape[:2]
39
+
40
+ q = self.query(q) #b, s, f
41
+ k = self.key(k) #b, s, f
42
+ v = self.value(v) #b, s, f
43
+
44
+ q = q.view(-1, s, self.n_head, self.d_head).transpose(1, 2) #b, h, s, hf
45
+ k = k.view(-1, s2, self.n_head, self.d_head).transpose(1, 2) #b, h, s, hf
46
+ v = v.view(-1, s2, self.n_head, self.d_head).transpose(1, 2) #b, h, s, hf
47
+
48
+ score = torch.matmul(q, k.transpose(-2, -1)) / self.div #b, h, s, s
49
+
50
+ if mask is not None:
51
+ mask = mask.unsqueeze(1) #b, 1, s
52
+ if mask.dim() != score.dim():
53
+ mask = mask.unsqueeze(2) #b, 1, 1, s
54
+ score = score * mask
55
+
56
+ if weight is not None:
57
+ weight = weight.unsqueeze(1) #b, 1, s
58
+ if weight.dim() != score.dim():
59
+ weight = weight.unsqueeze(2) #b, 1, 1, s
60
+ if self.n_token == s2:
61
+ w = self.softmax(score) #b, h, s, s2
62
+ if weight is not None:
63
+ w = w * weight
64
+ w = w / (w.sum(dim = -1, keepdim = True) + 1e-12)
65
+ else:
66
+ target, ref = torch.split(score, [self.n_token, s2 - self.n_token], dim = -1)
67
+ target = self.softmax(target)
68
+ if alpha is None:
69
+ alpha = 0.5
70
+ if weight is not None:
71
+ ws = weight.shape[-1]
72
+ target_weight, ref_weight = torch.split(weight, [self.n_token, ws - self.n_token], dim = -1)
73
+ ref = ref.view(b2, self.n_head, s, ws - self.n_token, self.n_token)
74
+ ref = self.softmax(ref)
75
+ ref = ref * ref_weight.unsqueeze(-1)
76
+ ref = ref.view(b2, self.n_head, s, s2 - self.n_token)
77
+ ref = alpha * (ref / (ref.sum(dim = -1, keepdim = True) + 1e-12))
78
+ target = target * (1 - alpha) * target_weight
79
+ w = torch.cat([target, ref], dim = -1)
80
+ w = w / (w.sum(dim = -1, keepdim = True) + 1e-12)
81
+ w = self.dropout(w)
82
+
83
+ out = torch.matmul(w, v) #b, h, s, hf
84
+ out = out.transpose(1, 2).contiguous().view(b, s, self.d_model) #b, s, d
85
+ out = self.proj(out)
86
+ return out
87
+
88
+ class QuickGELU(torch.nn.Module):
89
+ def forward(self, x):
90
+ return x * torch.sigmoid(1.702 * x)
91
+
92
+ class TransformerBlock(torch.nn.Module):
93
+ def __init__(self, emb_dim, n_head, ff_dim, n_token = 77, activation = "quick_gelu", dropout = 0.1):
94
+ super().__init__()
95
+ self.attn = MultiheadAttention(emb_dim, n_head, n_token = n_token, dropout = dropout)
96
+ if activation.lower() == "gelu" or activation is None:
97
+ self.act = torch.nn.GELU()
98
+ elif activation.lower() == "relu":
99
+ self.act = torch.nn.ReLU()
100
+ elif activation.lower() == "quick_gelu":
101
+ self.act = QuickGELU()
102
+ else:
103
+ self.act = activation
104
+ self.ff = torch.nn.Sequential(
105
+ torch.nn.Linear(emb_dim, ff_dim),
106
+ self.act,
107
+ torch.nn.Linear(ff_dim, emb_dim),
108
+ )
109
+ self.norm1 = torch.nn.LayerNorm(emb_dim)
110
+ self.norm2 = torch.nn.LayerNorm(emb_dim)
111
+ self.dropout1 = torch.nn.Dropout(dropout)
112
+ self.dropout2 = torch.nn.Dropout(dropout)
113
+
114
+ self._reset_parameters()
115
+
116
+ def _reset_parameters(self):
117
+ torch.nn.init.xavier_uniform_(self.ff[0].weight)
118
+ torch.nn.init.xavier_uniform_(self.ff[2].weight)
119
+
120
+ torch.nn.init.constant_(self.ff[0].bias, 0.)
121
+ torch.nn.init.constant_(self.ff[2].bias, 0.)
122
+
123
+ def forward(self, x, context = None, mask = None, weight = None, alpha = None):
124
+ context = context if context is not None else x
125
+ out = self.attn(x, context, context, mask = mask, weight = weight, alpha = alpha)
126
+ out = x + self.dropout1(out)
127
+ out = self.norm1(out)
128
+
129
+ ff_out = self.ff(out)
130
+ out = out + self.dropout2(ff_out)
131
+ out = self.norm2(out)
132
+ return out
133
+
134
+ class PersonalizedAdapter(torch.nn.Module):
135
+ def __init__(self, emb_dim, n_head, ff_dim, n_layer = 4, n_token = 77, proj = False, extra_proj = False, pos = True, cls_pos = False, cls_token = True, encode_ratio = None, activation = "quick_gelu", dropout = 0.1):
136
+ super().__init__()
137
+ self.n_layer = n_layer
138
+ self.n_token = n_token
139
+ self.cls_pos = cls_pos
140
+ self.cls_token = cls_token
141
+ self.encode_ratio = encode_ratio
142
+
143
+ self.pre_proj = self.post_proj = None
144
+ if encode_ratio and encode_ratio != 1:
145
+ self.pre_proj = torch.nn.Linear(emb_dim, int(emb_dim // encode_ratio))
146
+ self.post_proj = torch.nn.Linear(int(emb_dim // encode_ratio), emb_dim)
147
+ emb_dim = int(emb_dim // encode_ratio)
148
+ n_head = int(n_head // encode_ratio)
149
+
150
+ if activation.lower() == "gelu" or activation is None:
151
+ self.act = torch.nn.GELU()
152
+ elif activation.lower() == "relu":
153
+ self.act = torch.nn.ReLU()
154
+ elif activation.lower() == "quick_gelu":
155
+ self.act = QuickGELU()
156
+ else:
157
+ self.act = activation
158
+ self.base_query = torch.nn.Parameter(torch.empty(1, n_token + int(cls_token), emb_dim))
159
+ self.pos = torch.nn.Parameter(torch.empty(1, n_token + int(cls_pos and cls_token), emb_dim)) if pos else None
160
+ self.init_query = None
161
+
162
+ self.proj = None
163
+ if proj:
164
+ self.proj = torch.nn.Sequential(
165
+ torch.nn.Linear(emb_dim, ff_dim),
166
+ self.act,
167
+ torch.nn.Linear(ff_dim, emb_dim),
168
+ )
169
+
170
+ self.extra_proj = None
171
+ self.tf = torch.nn.ModuleList([TransformerBlock(emb_dim, n_head, ff_dim, n_token = n_token, activation = activation, dropout = dropout) for _ in range(n_layer)])
172
+ if extra_proj:
173
+ self.extra_proj = torch.nn.ModuleList([torch.nn.Linear(emb_dim, emb_dim) for _ in range(n_layer)])
174
+
175
+ self._reset_parameters()
176
+
177
+ def _reset_parameters(self):
178
+ torch.nn.init.normal_(self.base_query, std = 0.02)
179
+ if self.pos is not None:
180
+ torch.nn.init.normal_(self.pos, std = 0.01)
181
+
182
+ for proj in [self.pre_proj, self.post_proj]:
183
+ if proj is not None:
184
+ torch.nn.init.xavier_uniform_(proj.weight)
185
+ torch.nn.init.constant_(proj.bias, 0.)
186
+ for proj in [self.proj]:
187
+ if proj is not None:
188
+ torch.nn.init.xavier_uniform_(proj[0].weight)
189
+ torch.nn.init.xavier_uniform_(proj[2].weight)
190
+
191
+ torch.nn.init.constant_(proj[0].bias, 0.)
192
+ torch.nn.init.constant_(proj[2].bias, 0.)
193
+ if self.extra_proj is not None:
194
+ for l in self.extra_proj:
195
+ torch.nn.init.xavier_uniform_(l.weight)
196
+ torch.nn.init.constant_(l.bias, 0.)
197
+
198
+ def set_base_query(self, x):
199
+ if not torch.is_tensor(x):
200
+ x = torch.tensor(x, dtype=self.base_query.dtype).to(self.base_query.device)
201
+ if x.dim() == 2:
202
+ x = x.unsqueeze(0)
203
+ self.init_query = x
204
+
205
+ def normal_forward(self, x, context, mask = None, weight = None, alpha = None):
206
+ out = x
207
+ for i in range(self.n_layer):
208
+ if self.extra_proj is not None:
209
+ _context = self.extra_proj[i](self.act(context))
210
+ else:
211
+ _context = context
212
+ out = self.tf[i](out, _context, mask = mask, weight = weight, alpha = alpha) #n, b, f
213
+ if self.cls_token:
214
+ return out[:, :-1], out[:, -1]
215
+ else:
216
+ return out, None
217
+
218
+ def forward(self, context, mask = None, weight = None, alpha = None, base_query = None):
219
+ dtype = self.base_query.dtype
220
+ if base_query is not None:
221
+ x = base_query
222
+ else:
223
+ x = self.base_query if self.init_query is None else self.init_query
224
+ x = x.type(dtype)
225
+ if context is not None:
226
+ context = context.type(dtype)
227
+ if weight is not None:
228
+ weight = weight.type(dtype)
229
+ if self.encode_ratio is not None and x.shape[-1] != self.base_query.shape[-1]:
230
+ x = self.pre_proj(x)
231
+ if self.n_token < x.shape[1]:
232
+ x, cls = x[:, :self.n_token], x[:, self.n_token:]
233
+ else:
234
+ cls = self.base_query[:, self.n_token:] if self.cls_token else None
235
+ if self.pos is not None:
236
+ if self.cls_pos and self.cls_token:
237
+ x = x + self.pos[:, :self.n_token]
238
+ if cls is not None:
239
+ cls = cls + self.pos[:, self.n_token:]
240
+ else:
241
+ x = x + self.pos
242
+ if self.cls_token:
243
+ x = torch.cat([x, cls], dim = 1)
244
+ x = x.repeat_interleave(context.shape[0], dim = 0)
245
+ if self.encode_ratio is not None:
246
+ if context is not None:
247
+ context = self.pre_proj(context)
248
+ if self.proj is not None:
249
+ context = self.proj(context)
250
+ out = self.normal_forward(x, context, mask = mask, weight = weight, alpha = alpha)
251
+ if self.encode_ratio is not None:
252
+ out = (self.post_proj(out[0]), self.post_proj(out[1]) if out[1] is not None else out[1])
253
+ return out
254
+
255
+ class DrUM:
256
+ def __init__(self, model, processor, n_layer = 8, proj = False, extra_proj = False, mlp_ratio = 4, pos = True, cls_pos = False, cls_token = True, encode_ratio = None, max_token_size = 256, activation = "quick_gelu", dropout = 0.1):
257
+ config = model.config.text_config if hasattr(model.config, "text_config") else model.config
258
+ if hasattr(config, "model_type") and config.model_type == "t5":
259
+ self.d_model = config.d_model
260
+ self.n_head = config.num_heads
261
+ self.n_token = min(processor.model_max_length, max_token_size)
262
+ self.clip = False
263
+ self.cls_token = False
264
+ else:
265
+ self.d_model = config.hidden_size
266
+ self.n_head = config.num_attention_heads
267
+ self.n_token = config.max_position_embeddings
268
+ self.clip = True
269
+ self.cls_token = cls_token
270
+ self.n_layer = n_layer
271
+ self.proj = proj
272
+ self.extra_proj = extra_proj
273
+ self.mlp_ratio = mlp_ratio
274
+ self.pos = pos
275
+ self.cls_pos = cls_pos
276
+ self.encode_ratio = encode_ratio
277
+ self.activation = activation
278
+ self.dropout = dropout
279
+
280
+ self.model = model
281
+ self.processor = processor
282
+ self.adapter = PersonalizedAdapter(self.d_model, self.n_head, self.d_model // mlp_ratio, n_layer, self.n_token, proj = proj, extra_proj = extra_proj, pos = pos, cls_pos = cls_pos, cls_token = self.cls_token, encode_ratio = encode_ratio, activation = activation, dropout = dropout).to(model.device)
283
+
284
+ self.train()
285
+ self.to(model.device)
286
+
287
+ def preprocess(self, text = None, image = None, return_tensors = "pt", padding = "max_length", truncation = True, **kwargs):
288
+ feed = {"text":([text] if np.ndim(text) == 0 else list(text)) if text is not None else None,
289
+ "return_tensors":return_tensors,
290
+ "max_length":self.n_token,
291
+ "padding":padding,
292
+ "truncation":truncation,
293
+ **kwargs}
294
+ if not self.clip:
295
+ feed["add_special_tokens"] = True
296
+ if image is not None:
297
+ feed["images"] = image
298
+ return self.processor(**feed)
299
+
300
+ def pool_text_hidden_state(self, hidden_state, x, padding = "max_length", truncation = True, **kwargs):
301
+ if not self.clip:
302
+ raise TypeError("T5 encoder does not support this function (pool_text_hidden_state).")
303
+ if not hasattr(x, "items"):
304
+ x = self.preprocess(text = x, padding = padding, truncation = truncation, **kwargs)
305
+ if self.model.text_model.eos_token_id == 2:
306
+ out = hidden_state[torch.arange(hidden_state.shape[0], device = hidden_state.device),
307
+ x["input_ids"].to(dtype = torch.int, device = hidden_state.device).argmax(dim = -1),]
308
+ else:
309
+ out = hidden_state[torch.arange(hidden_state.shape[0], device = hidden_state.device),
310
+ (x["input_ids"].to(dtype = torch.int, device = hidden_state.device) == self.model.text_model.eos_token_id).int().argmax(dim = -1),]
311
+ return out
312
+
313
+ def normalize_text_hidden_state(self, hidden_state):
314
+ out = self.model.text_model.final_layer_norm(hidden_state.type(self.model.dtype)) if self.clip and hasattr(self.model.text_model, "final_layer_norm") else hidden_state
315
+ return out
316
+
317
+ def projection_text_hidden_state(self, hidden_state):
318
+ out = self.model.text_projection(hidden_state.type(self.model.dtype)) if self.clip and hasattr(self.model, "text_projection") else hidden_state
319
+ return out
320
+
321
+ def encode_prompt(self, x, pooling = True, skip = -1, skip_pool = None, padding = "max_length", truncation = True, use_attn_mask = False, normalize = True, normalize_pool = True, **kwargs):
322
+ if not hasattr(x, "items"):
323
+ x = self.preprocess(text = x, padding = padding, truncation = truncation, **kwargs)
324
+ input_ids = x["input_ids"].to(self.device)
325
+ attention_mask = x["attention_mask"].to(self.device) if use_attn_mask else None
326
+ with torch.no_grad():
327
+ if self.clip:
328
+ hidden_state = self.model.text_model(output_hidden_states = True, input_ids = input_ids, attention_mask = attention_mask)["hidden_states"]
329
+ pool, hidden_state = hidden_state[skip_pool if skip_pool is not None else skip], hidden_state[skip]
330
+ hidden_state = self.normalize_text_hidden_state(hidden_state) if normalize else hidden_state
331
+ else:
332
+ hidden_state = self.model(input_ids = input_ids, attention_mask = attention_mask)[0]
333
+ pool = None
334
+ if pooling:
335
+ if self.clip:
336
+ with torch.no_grad():
337
+ pool = self.pool_text_hidden_state(self.normalize_text_hidden_state(pool) if normalize_pool else pool, x, **kwargs)
338
+ return (hidden_state, pool)
339
+ return hidden_state
340
+
341
+ def get_text_feature(self, x, ref_x = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, padding = "max_length", truncation = True, use_attn_mask = False, **kwargs):
342
+ if not self.clip:
343
+ raise TypeError("T5 encoder does not support this function (get_text_feature).")
344
+ with torch.no_grad():
345
+ pool_hidden_state = self(x, ref_x, weight = weight, alpha = alpha, pooling = True, skip_pool = skip, batch_size = batch_size, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize_pool = True, **kwargs)[1]
346
+ result = self.projection_text_hidden_state(pool_hidden_state)
347
+ return result
348
+
349
+ def get_image_feature(self, x, return_tensors = "pt", **kwargs):
350
+ if not self.clip:
351
+ raise TypeError("T5 encoder does not support this function (get_image_feature).")
352
+ if hasattr(x, "items"):
353
+ x = x["pixel_values"]
354
+ elif not torch.is_tensor(x):
355
+ x = self.preprocess(image = x, return_tensors = return_tensors, **kwargs)["pixel_values"]
356
+ with torch.no_grad():
357
+ result = self.model.get_image_features(pixel_values = x.to(self.device))
358
+ return result
359
+
360
+ def encode_context(self, ref_x, pooling = False, skip = -1, skip_pool = None, batch_size = 64, padding = "max_length", truncation = True, use_attn_mask = False, normalize = False, normalize_pool = False, **kwargs):
361
+ if not hasattr(ref_x, "items"):
362
+ if np.ndim(ref_x) == 0:
363
+ ref_x = [[ref_x]]
364
+ elif np.ndim(ref_x) == 1:
365
+ ref_x = [ref_x]
366
+ b, ref_size = len(ref_x), len(ref_x[0])
367
+ ref_x = np.reshape(ref_x, [b * ref_size])
368
+ ref_x = self.preprocess(text = list(ref_x), padding = padding, truncation = truncation, **kwargs)
369
+ ref_x = {k:v for k, v in ref_x.items() if k in (["input_ids", "attention_mask"] if use_attn_mask else ["input_ids"])}
370
+ else:
371
+ b, ref_size = ref_x["input_ids"].shape[:2]
372
+ ref_x = {k:v.view(b * ref_size, -1) for k, v in ref_x.items() if k in (["input_ids", "attention_mask"] if use_attn_mask else ["input_ids"])}
373
+ hidden_state, pool_hidden_state = [], []
374
+ batch_indices = [(i * batch_size, min((b * ref_size), (i + 1) * batch_size)) for i in range(int(np.ceil((b * ref_size) / batch_size)))]
375
+ for start, end in batch_indices:
376
+ h, p = self.encode_prompt({k:v[start:end] for k, v in ref_x.items()}, pooling = True, skip = skip, skip_pool = skip_pool, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = normalize, normalize_pool = normalize_pool, **kwargs)
377
+ hidden_state.append(h)
378
+ if p is not None:
379
+ pool_hidden_state.append(p)
380
+ hidden_state = torch.cat(hidden_state, dim = 0) if 1 < len(hidden_state) else hidden_state[0]
381
+ pool_hidden_state = torch.cat(pool_hidden_state, dim = 0) if 1 < len(pool_hidden_state) else (pool_hidden_state[0] if len(pool_hidden_state) == 1 else None)
382
+ with torch.no_grad():
383
+ hidden_state = hidden_state.view(b, ref_size * hidden_state.shape[1], -1)
384
+ if pooling:
385
+ if self.clip:
386
+ pool_hidden_state = pool_hidden_state.view(b, ref_size, -1)
387
+ hidden_state = (hidden_state, pool_hidden_state)
388
+ return hidden_state
389
+
390
+ def __call__(self, x, ref_x = None, weight = None, alpha = 0.3, pooling = True, skip = -1, skip_pool = None, batch_size = 64, padding = "max_length", truncation = True, use_attn_mask = False, normalize = True, normalize_pool = True, training = False, **kwargs):
391
+ if ref_x is not None or training:
392
+ if training:
393
+ context = weight = None
394
+ else:
395
+ _context, _context_pool = self.encode_context(ref_x, pooling = True, skip = skip, skip_pool = None, batch_size = batch_size, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = False, normalize_pool = False, **kwargs)
396
+ if weight is not None:
397
+ if not torch.is_tensor(weight):
398
+ weight = torch.tensor(weight)
399
+ if weight.dim() == 0:
400
+ weight = weight.unsqueeze(0).unsqueeze(0)
401
+ elif weight.dim() == 1:
402
+ weight = weight.unsqueeze(0)
403
+ weight = weight.to(self.device)
404
+ else:
405
+ weight = torch.ones((1, _context.shape[1] // self.n_token), dtype = torch.float32, device = _context.device)
406
+ context = _context
407
+ del _context, _context_pool
408
+ result = self.encode_personalized_prompt(x, context, weight = weight, alpha = alpha, pooling = pooling, skip = skip, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = normalize, normalize_pool = normalize_pool, **kwargs)
409
+ return result
410
+ else:
411
+ return self.encode_prompt(x, pooling = pooling, skip = skip, skip_pool = skip_pool, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = normalize, normalize_pool = normalize_pool, **kwargs)
412
+
413
+ def encode_personalized_prompt(self, x, context = None, weight = None, alpha = 0.3, pooling = True, skip = -1, padding = "max_length", truncation = True, use_attn_mask = False, normalize = True, normalize_pool = True, **kwargs):
414
+ if not torch.is_tensor(x):
415
+ if not hasattr(x, "items"):
416
+ x = self.preprocess(text = x, padding = padding, truncation = truncation, **kwargs)
417
+ x = self.encode_prompt(x, pooling = False, skip = skip, skip_pool = None, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = False, normalize_pool = False, **kwargs)
418
+ if context is None:
419
+ context = x
420
+ else:
421
+ batch_size, n_token = x.shape[:2]
422
+ if context.shape[0] == 1 and batch_size != 1:
423
+ context = context.repeat_interleave(batch_size, dim = 0)
424
+ if weight is not None and weight.shape[0] == 1:
425
+ weight = weight.repeat_interleave(batch_size, dim = 0)
426
+ context_size = context.shape[1]
427
+ context = torch.cat([x, context], dim = 1)
428
+ if weight is not None:
429
+ extra_weight = torch.ones((batch_size, n_token), dtype = torch.float32, device = weight.device)
430
+ weight = torch.cat([extra_weight, weight], dim = 1)
431
+ hidden_state, pool = self.adapter(context, weight = weight, alpha = alpha)
432
+ hidden_state = self.normalize_text_hidden_state(hidden_state) if normalize else hidden_state
433
+ if pooling:
434
+ pool = self.normalize_text_hidden_state(pool) if normalize_pool else pool
435
+ return (hidden_state, pool)
436
+ return hidden_state
437
+
438
+ def to(self, device):
439
+ self.model.to(device)
440
+ self.adapter.to(device)
441
+ self.device = device
442
+ return self
443
+
444
+ def eval(self):
445
+ self.model.eval()
446
+ if self.clip and hasattr(self.model, "text_projection"):
447
+ self.model.text_model.final_layer_norm.requires_grad_(False)
448
+ self.model.text_projection.requires_grad_(False)
449
+ self.adapter.eval()
450
+ return self
451
+
452
+ def train(self):
453
+ self.model.eval()
454
+ if self.clip and hasattr(self.model, "text_projection"):
455
+ self.model.text_model.final_layer_norm.requires_grad_(False)
456
+ self.model.text_projection.requires_grad_(False)
457
+ self.adapter.train()
458
+ return self
459
+
460
+ def parameters(self):
461
+ return list(self.adapter.parameters())
462
+
463
+ def named_parameters(self):
464
+ return list(self.adapter.named_parameters())
pipeline_drum.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .wrapper import DrUM
sampling.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ def clip_score(feature, ref_feature, logit_scale = 100.0, weight = 1, reduce = True):
5
+ ref_feature = np.expand_dims(ref_feature, axis = 0) if np.ndim(ref_feature) == 2 else ref_feature
6
+ batch_size, ref_size = np.shape(ref_feature)[:2]
7
+ feature = feature / np.linalg.norm(feature, axis = -1, keepdims=True)
8
+ ref_feature = ref_feature / np.linalg.norm(ref_feature, axis = -1, keepdims=True)
9
+ sim = logit_scale * np.einsum("bf,btf->bt", feature, ref_feature)
10
+ sim = sim * (np.expand_dims(weight, axis = 0) if np.ndim(weight) == 1 else weight)
11
+ return sim.mean(axis = 1) if reduce else (sim[..., 0] if ref_size == 1 else sim)
12
+
13
+ def coreset_sampling(data, n_sample = 0.1, weight = 1, n_approximate = 10, logit_scale = 100, seed = 42):
14
+ data = np.array(data) if not isinstance(data, np.ndarray) else data
15
+ n_sample = round(len(data) * n_sample) if isinstance(n_sample, float) or (isinstance(n_sample, int) and n_sample < 1) else n_sample
16
+ n_sample = max(min(n_sample, len(data)), 1 if len(data) != 0 else 0)
17
+ weight = 1 if weight is None else weight
18
+ weight = np.transpose(weight) if np.ndim(weight) == 2 else (np.expand_dims(weight, axis = -1) if np.ndim(weight) == 1 else weight)
19
+
20
+ random = ((np.random.RandomState(seed) if isinstance(seed, int) else seed) if seed is not None else np.random)
21
+ if n_sample == len(data):
22
+ indices = np.arange(n_sample)
23
+ else:
24
+ indices = []
25
+ approx_data = data[random.choice(len(data), min(round(len(data) * n_approximate) if isinstance(n_approximate, float) else n_approximate, len(data)), replace = False)]
26
+ dist = clip_score(data, approx_data, weight = weight, logit_scale = logit_scale, reduce = False)
27
+ dist = np.mean(dist, axis = 1, keepdims = True)
28
+ for i in range(n_sample):
29
+ sample_index = np.argmax(dist)
30
+ indices.append(sample_index)
31
+ sample_dist = clip_score(data, data[[sample_index]], weight = weight, logit_scale = logit_scale, reduce = False)
32
+ dist = np.minimum(dist, sample_dist)
33
+ dist[sample_index] = -np.inf
34
+ return indices
weight/H.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3eb561647781fd66c1f84d2c8d5f15f71e5994e8d11c1e38ff7463c6fe84e554
3
+ size 189456448
weight/L.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1e3ffcbe5e21f3e9796c8cdc0450ef90d7a927de242a6e3fe69531071705953
3
+ size 106706168
weight/T5.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d6d8413819cc4a2d388b7a0222eff8c161c7bcf459536db2ad10a403756643b
3
+ size 286706200
weight/bigG.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c58a6c1a558599e090f7fe0eaada1e0684dafd472c0a18fbd4067450d54a606
3
+ size 295799424
wrapper.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+
5
+ from diffusers import DiffusionPipeline, FluxPipeline, StableDiffusion3Pipeline
6
+ from huggingface_hub import hf_hub_download
7
+ from safetensors.torch import load_file
8
+
9
+ from .model import DrUM as backbone
10
+ from .sampling import coreset_sampling
11
+
12
+ def stable_diffusion(large):
13
+ """
14
+ openai/clip-vit-large-patch14, CLIPTextModel, skip -1
15
+ """
16
+ def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, **kwargs):
17
+ return large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, **kwargs), None
18
+ return inference
19
+
20
+ def stable_diffusion_v2(huge):
21
+ """
22
+ openai/clip-vit-huge-patch14, CLIPTextModel, skip -1
23
+ """
24
+ def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, **kwargs):
25
+ return huge(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, **kwargs), None
26
+ return inference
27
+
28
+ def stable_diffusion_xl(large, bigG):
29
+ """
30
+ openai/clip-vit-large-patch14, CLIPTextModel, skip -2, unnorm
31
+ laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj
32
+ """
33
+ def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -2, batch_size = 64, **kwargs):
34
+ hidden_state = large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs)
35
+ if skip == -1:
36
+ hidden_state2, pool_hidden_state = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)
37
+ else:
38
+ hidden_state2 = bigG(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs)
39
+ pool_hidden_state = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1]
40
+ hidden_state = torch.cat([hidden_state, hidden_state2], dim = -1)
41
+ pool_hidden_state = bigG.projection_text_hidden_state(pool_hidden_state)
42
+ return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state
43
+ return inference
44
+
45
+ def stable_diffusion_v3(large, bigG, t5):
46
+ """
47
+ openai/clip-vit-large-patch14, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj
48
+ laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj
49
+ t5-v1_1-xxl, T5EncoderModel
50
+ """
51
+ def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -2, batch_size = 64, **kwargs):
52
+ if skip == -1:
53
+ hidden_state, pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)
54
+ hidden_state2, pool_hidden_state2 = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)
55
+ else:
56
+ hidden_state = large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs)
57
+ hidden_state2 = bigG(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs)
58
+ pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1]
59
+ pool_hidden_state2 = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1]
60
+ hidden_state3 = t5(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, batch_size = batch_size, normalize = False, **kwargs)
61
+ hidden_state = torch.cat([hidden_state, hidden_state2], dim = -1)
62
+ pool_hidden_state = large.projection_text_hidden_state(pool_hidden_state)
63
+ pool_hidden_state2 = bigG.projection_text_hidden_state(pool_hidden_state2)
64
+ hidden_state = torch.nn.functional.pad(hidden_state, (0, hidden_state3.shape[-1] - hidden_state.shape[-1]))
65
+ hidden_state = torch.cat([hidden_state, hidden_state3], dim = -2)
66
+ pool_hidden_state = torch.cat([pool_hidden_state, pool_hidden_state2], dim = -1)
67
+ return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state
68
+ return inference
69
+
70
+ def flux(large, t5):
71
+ """
72
+ openai/clip-vit-large-patch14, CLIPTextModel, pooling
73
+ t5-v1_1-xxl, T5EncoderModel
74
+ """
75
+ def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = None, batch_size = 64, **kwargs):
76
+ hidden_state = t5(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, batch_size = batch_size, normalize = False, **kwargs)
77
+ pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1]
78
+ return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state
79
+ return inference
80
+
81
+ def peca(pipeline, save_path = "./weight", n_layer = 10):
82
+ if os.path.exists(os.path.join(save_path, "L.pth")) or os.path.exists(os.path.join(save_path, "H.pth")):
83
+ load_func = torch.load
84
+ postfix = "pth"
85
+ else:
86
+ from safetensors.torch import load_file as load_func
87
+ postfix = "safetensors"
88
+
89
+ if "flux" in pipeline.config._name_or_path.split("/")[-1].lower():
90
+ model = pipeline.text_encoder
91
+ processor = pipeline.tokenizer
92
+ model2 = pipeline.text_encoder_2
93
+ processor2 = pipeline.tokenizer_2
94
+
95
+ large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
96
+ large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix))))
97
+ t5 = backbone(model2, processor2, n_layer = n_layer, encode_ratio = 4, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
98
+ t5.adapter.load_state_dict(load_func(os.path.join(save_path, "T5.{0}".format(postfix))))
99
+ empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
100
+ large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
101
+ empty, pool = t5.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
102
+ t5.adapter.set_base_query(empty)
103
+
104
+ feature_encoder = large
105
+ encoder = flux(large, t5)
106
+ size = 1024
107
+ num_inference_steps = 28
108
+ skip = -2
109
+ elif "stable-diffusion-3.5" in pipeline.config._name_or_path.split("/")[-1].lower(): #sd v3
110
+ model = pipeline.text_encoder
111
+ processor = pipeline.tokenizer
112
+ model2 = pipeline.text_encoder_2
113
+ processor2 = pipeline.tokenizer_2
114
+ model3 = pipeline.text_encoder_3
115
+ processor3 = pipeline.tokenizer_3
116
+
117
+ large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
118
+ large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix))))
119
+ bigG = backbone(model2, processor2, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
120
+ bigG.adapter.load_state_dict(load_func(os.path.join(save_path, "bigG.{0}".format(postfix))))
121
+ t5 = backbone(model3, processor3, n_layer = n_layer, encode_ratio = 4, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
122
+ t5.adapter.load_state_dict(load_func(os.path.join(save_path, "T5.{0}".format(postfix))))
123
+ empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
124
+ large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
125
+ empty, pool = bigG.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
126
+ bigG.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
127
+ empty, pool = t5.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
128
+ t5.adapter.set_base_query(empty)
129
+
130
+ feature_encoder = large
131
+ encoder = stable_diffusion_v3(large, bigG, t5)
132
+ size = 1024
133
+ num_inference_steps = 28
134
+ skip = -2
135
+ elif "xl-base" in pipeline.config._name_or_path.split("/")[-1].lower(): #sd xl
136
+ model = pipeline.text_encoder
137
+ processor = pipeline.tokenizer
138
+ model2 = pipeline.text_encoder_2
139
+ processor2 = pipeline.tokenizer_2
140
+
141
+ large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
142
+ large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix))))
143
+ bigG = backbone(model2, processor2, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
144
+ bigG.adapter.load_state_dict(load_func(os.path.join(save_path, "bigG.{0}".format(postfix))))
145
+ empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
146
+ large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
147
+ empty, pool = bigG.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
148
+ bigG.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
149
+
150
+ feature_encoder = large
151
+ encoder = stable_diffusion_xl(large, bigG)
152
+ size = 1024
153
+ num_inference_steps = 50
154
+ skip = -2
155
+ elif "stable-diffusion-2" in pipeline.config._name_or_path.split("/")[-1].lower():
156
+ model = pipeline.text_encoder
157
+ processor = pipeline.tokenizer
158
+
159
+ huge = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
160
+ huge.adapter.load_state_dict(load_func(os.path.join(save_path, "H.{0}".format(postfix))))
161
+ empty, pool = huge.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
162
+ huge.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
163
+
164
+ feature_encoder = huge
165
+ encoder = stable_diffusion_v2(huge)
166
+ size = 768
167
+ num_inference_steps = 50
168
+ skip = -1
169
+ else: #sd
170
+ model = pipeline.text_encoder
171
+ processor = pipeline.tokenizer
172
+
173
+ large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
174
+ large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix))))
175
+ empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
176
+ large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
177
+
178
+ feature_encoder = large
179
+ encoder = stable_diffusion(large)
180
+ size = 512
181
+ num_inference_steps = 50
182
+ skip = -1
183
+ return encoder, feature_encoder.get_text_feature, size, num_inference_steps, skip
184
+
185
+ class DrUM(DiffusionPipeline):
186
+ def __init__(self, pipeline, repo_id = "Burf/DrUM"):
187
+ """
188
+ DrUM for various diffusion models
189
+
190
+ Args:
191
+ pipeline: Loaded diffusion pipeline
192
+ repo_id: Hugging Face repository containing adapter weights
193
+ """
194
+ self.pipeline = pipeline
195
+ self.repo_id = repo_id
196
+
197
+ self.adapter, self.feature_encoder, self.size, self.num_inference_steps, self.skip = self.load_peca(pipeline, repo_id)
198
+
199
+ @classmethod
200
+ def from_pretrained(cls, model_id, repo_id = "Burf/DrUM", torch_dtype = torch.bfloat16, device = "cuda"):
201
+ """
202
+ Load DrUM adapter with appropriate pipeline
203
+
204
+ Args:
205
+ model_id: Base diffusion model ID
206
+ repo_id: DrUM adapters repository
207
+ torch_dtype: Model precision
208
+ device: Device
209
+ """
210
+ name = model_id.split("/")[-1].lower()
211
+
212
+ if "flux" in name:
213
+ pipeline = FluxPipeline.from_pretrained(model_id, torch_dtype = torch_dtype)
214
+ elif "stable-diffusion-3.5" in name:
215
+ pipeline = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype = torch_dtype)
216
+ else:
217
+ pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype = torch_dtype)
218
+
219
+ pipeline = pipeline.to(device if torch.cuda.is_available() else "cpu")
220
+ #pipeline.safety_checker = lambda images, clip_input: (images, [False] * len(images))
221
+ return cls(pipeline, repo_id)
222
+
223
+ def load_weight(self, pipeline, repo_id = "Burf/DrUM"):
224
+ name = pipeline.config._name_or_path.split("/")[-1].lower()
225
+
226
+ weights = []
227
+ if "flux" in name:
228
+ weights = ["L.safetensors", "T5.safetensors"]
229
+ elif "stable-diffusion-3.5" in name:
230
+ weights = ["L.safetensors", "bigG.safetensors", "T5.safetensors"]
231
+ elif "xl-base" in name:
232
+ weights = ["L.safetensors", "bigG.safetensors"]
233
+ elif "stable-diffusion-2" in name:
234
+ weights = ["H.safetensors"]
235
+ else: # SD v1.5
236
+ weights = ["L.safetensors"]
237
+
238
+ for weight_file in weights:
239
+ safetensor_path = hf_hub_download(repo_id = repo_id, filename = weight_file)
240
+ weight_path = os.path.dirname(safetensor_path)
241
+ return weight_path
242
+
243
+ def load_peca(self, pipeline, repo_id = "Burf/DrUM"):
244
+ adapter, feature_encoder, size, num_inference_steps, skip = peca(pipeline, save_path = self.load_weight(pipeline, repo_id))
245
+ return adapter, feature_encoder, size, num_inference_steps, skip
246
+
247
+ def __call__(self, prompt, ref = None, weight = None, alpha = 0.3, skip = None, sampling = False, seed = 42,
248
+ size = None, num_inference_steps = None, num_images_per_prompt = 1):
249
+ """
250
+ Generate images using DrUM adapter
251
+
252
+ Args:
253
+ prompt: Text prompt for generation
254
+ ref: Reference prompts (list of strings)
255
+ weight: Weights for reference prompts (list of floats)
256
+ alpha: Personalization strength (0-1)
257
+ skip: Text condition axis
258
+ sampling: Whether to use coreset sampling for reference selection (default: False)
259
+ seed: Random seed
260
+ size: Image size
261
+ num_inference_steps: Inference steps
262
+ num_images_per_prompt: Number of images to generate
263
+
264
+ Returns:
265
+ Personalized images (list of PIL Images)
266
+ """
267
+ size = self.size if size is None else size
268
+ num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
269
+ skip = self.skip if skip is None else skip
270
+
271
+ if sampling and isinstance(ref, (tuple, list)) and 1 < len(ref):
272
+ import numpy as np
273
+
274
+ with torch.no_grad():
275
+ feature = self.feature_encoder(ref).cpu().float().numpy()
276
+
277
+ indices = coreset_sampling(feature, weight = weight, seed = seed)
278
+ ref = np.array(ref)[indices].tolist()
279
+
280
+ if isinstance(weight, (tuple, list)) and len(weight) == len(ref):
281
+ weight = np.array(weight)[indices].tolist()
282
+
283
+ generator = torch.Generator(self.pipeline.device).manual_seed(seed)
284
+ with torch.no_grad():
285
+ cond, pool_cond = self.adapter(prompt, ref, weight = weight, alpha = alpha, skip = skip)
286
+
287
+ pipe_kwargs = {
288
+ "num_images_per_prompt": num_images_per_prompt,
289
+ "num_inference_steps": num_inference_steps,
290
+ "generator": generator,
291
+ "height": size,
292
+ "width": size
293
+ }
294
+
295
+ pipe_kwargs["prompt_embeds"] = cond.type(self.pipeline.dtype)
296
+ if pool_cond is not None:
297
+ pipe_kwargs["pooled_prompt_embeds"] = pool_cond.type(self.pipeline.dtype)
298
+
299
+ name = self.pipeline.config._name_or_path.split("/")[-1].lower()
300
+ if "flux" in name or "stable-diffusion-3" in name:
301
+ pipe_kwargs["max_sequence_length"] = 256
302
+
303
+ images = self.pipeline(**pipe_kwargs).images
304
+ return images