File size: 1,785 Bytes
d9aba4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

# This code was used to create t5sdxl-v0-bf16

from diffusers import StableDiffusionXLPipeline
from transformers import T5Tokenizer, T5EncoderModel
from transformers import (
    CLIPImageProcessor,
    CLIPTextModel,
    CLIPTextModelWithProjection,
    CLIPTokenizer,
    CLIPVisionModelWithProjection,
)
from typing import Optional

import torch.nn as nn, torch, types

T5_NAME  = "mcmonkey/google_t5-v1_1-xxl_encoderonly"
SDXL_DIR = "stabilityai/stable-diffusion-xl-base-1.0"



class T5SDXLPipeline(StableDiffusionXLPipeline):
    def __init__(
        self,
        vae,
        text_encoder,
        text_encoder_2,
        tokenizer,
        tokenizer_2,
        unet,
        scheduler,
        image_encoder: CLIPVisionModelWithProjection = None,
        feature_extractor: CLIPImageProcessor = None,
        force_zeros_for_empty_prompt: bool = True,
        add_watermarker: Optional[bool] = None,
    ):
        super().__init__(
            vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
            unet, scheduler, 
        )
        # ----- build T5 + projection -----
        self.tokenizer      = T5Tokenizer.from_pretrained(T5_NAME)
        self.t5_encoder     = T5EncoderModel.from_pretrained(T5_NAME,
                                torch_dtype=self.unet.dtype)
        self.t5_projection  = nn.Linear(4096, 2048)   # trainable

        # drop CLIP encoders to save VRAM
        self.text_encoder = self.text_encoder_2 = None
        self.tokenizer_2  = None


# --- usage ---
pipe = T5SDXLPipeline.from_pretrained(
    SDXL_DIR,
    torch_dtype=torch.bfloat16,
).to("cuda")

pipe.t5_encoder.to(pipe.device, dtype=pipe.unet.dtype)
pipe.t5_projection.to(pipe.device, dtype=pipe.unet.dtype)
print("Saving model")
pipe.save_pretrained("t5-sdxl-model")