|
|
|
|
|
|
|
from diffusers import StableDiffusionXLPipeline |
|
from transformers import T5Tokenizer, T5EncoderModel |
|
from transformers import ( |
|
CLIPImageProcessor, |
|
CLIPTextModel, |
|
CLIPTextModelWithProjection, |
|
CLIPTokenizer, |
|
CLIPVisionModelWithProjection, |
|
) |
|
from typing import Optional |
|
|
|
import torch.nn as nn, torch, types |
|
|
|
T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly" |
|
SDXL_DIR = "stabilityai/stable-diffusion-xl-base-1.0" |
|
|
|
|
|
|
|
class T5SDXLPipeline(StableDiffusionXLPipeline): |
|
def __init__( |
|
self, |
|
vae, |
|
text_encoder, |
|
text_encoder_2, |
|
tokenizer, |
|
tokenizer_2, |
|
unet, |
|
scheduler, |
|
image_encoder: CLIPVisionModelWithProjection = None, |
|
feature_extractor: CLIPImageProcessor = None, |
|
force_zeros_for_empty_prompt: bool = True, |
|
add_watermarker: Optional[bool] = None, |
|
): |
|
super().__init__( |
|
vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, |
|
unet, scheduler, |
|
) |
|
|
|
self.tokenizer = T5Tokenizer.from_pretrained(T5_NAME) |
|
self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME, |
|
torch_dtype=self.unet.dtype) |
|
self.t5_projection = nn.Linear(4096, 2048) |
|
|
|
|
|
self.text_encoder = self.text_encoder_2 = None |
|
self.tokenizer_2 = None |
|
|
|
|
|
|
|
pipe = T5SDXLPipeline.from_pretrained( |
|
SDXL_DIR, |
|
torch_dtype=torch.bfloat16, |
|
).to("cuda") |
|
|
|
pipe.t5_encoder.to(pipe.device, dtype=pipe.unet.dtype) |
|
pipe.t5_projection.to(pipe.device, dtype=pipe.unet.dtype) |
|
print("Saving model") |
|
pipe.save_pretrained("t5-sdxl-model") |
|
|
|
|