ppbrown commited on
Commit
d9aba4b
·
verified ·
1 Parent(s): 8786d2c

Upload create-t5sdxl-v0.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. create-t5sdxl-v0.py +62 -0
create-t5sdxl-v0.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # This code was used to create t5sdxl-v0-bf16
3
+
4
+ from diffusers import StableDiffusionXLPipeline
5
+ from transformers import T5Tokenizer, T5EncoderModel
6
+ from transformers import (
7
+ CLIPImageProcessor,
8
+ CLIPTextModel,
9
+ CLIPTextModelWithProjection,
10
+ CLIPTokenizer,
11
+ CLIPVisionModelWithProjection,
12
+ )
13
+ from typing import Optional
14
+
15
+ import torch.nn as nn, torch, types
16
+
17
+ T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly"
18
+ SDXL_DIR = "stabilityai/stable-diffusion-xl-base-1.0"
19
+
20
+
21
+
22
+ class T5SDXLPipeline(StableDiffusionXLPipeline):
23
+ def __init__(
24
+ self,
25
+ vae,
26
+ text_encoder,
27
+ text_encoder_2,
28
+ tokenizer,
29
+ tokenizer_2,
30
+ unet,
31
+ scheduler,
32
+ image_encoder: CLIPVisionModelWithProjection = None,
33
+ feature_extractor: CLIPImageProcessor = None,
34
+ force_zeros_for_empty_prompt: bool = True,
35
+ add_watermarker: Optional[bool] = None,
36
+ ):
37
+ super().__init__(
38
+ vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
39
+ unet, scheduler,
40
+ )
41
+ # ----- build T5 + projection -----
42
+ self.tokenizer = T5Tokenizer.from_pretrained(T5_NAME)
43
+ self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME,
44
+ torch_dtype=self.unet.dtype)
45
+ self.t5_projection = nn.Linear(4096, 2048) # trainable
46
+
47
+ # drop CLIP encoders to save VRAM
48
+ self.text_encoder = self.text_encoder_2 = None
49
+ self.tokenizer_2 = None
50
+
51
+
52
+ # --- usage ---
53
+ pipe = T5SDXLPipeline.from_pretrained(
54
+ SDXL_DIR,
55
+ torch_dtype=torch.bfloat16,
56
+ ).to("cuda")
57
+
58
+ pipe.t5_encoder.to(pipe.device, dtype=pipe.unet.dtype)
59
+ pipe.t5_projection.to(pipe.device, dtype=pipe.unet.dtype)
60
+ print("Saving model")
61
+ pipe.save_pretrained("t5-sdxl-model")
62
+