t5sdxl-v0-bf16 / create-t5sdxl-v0.py
ppbrown's picture
Upload create-t5sdxl-v0.py with huggingface_hub
d9aba4b verified
# This code was used to create t5sdxl-v0-bf16
from diffusers import StableDiffusionXLPipeline
from transformers import T5Tokenizer, T5EncoderModel
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from typing import Optional
import torch.nn as nn, torch, types
T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly"
SDXL_DIR = "stabilityai/stable-diffusion-xl-base-1.0"
class T5SDXLPipeline(StableDiffusionXLPipeline):
def __init__(
self,
vae,
text_encoder,
text_encoder_2,
tokenizer,
tokenizer_2,
unet,
scheduler,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
super().__init__(
vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
unet, scheduler,
)
# ----- build T5 + projection -----
self.tokenizer = T5Tokenizer.from_pretrained(T5_NAME)
self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME,
torch_dtype=self.unet.dtype)
self.t5_projection = nn.Linear(4096, 2048) # trainable
# drop CLIP encoders to save VRAM
self.text_encoder = self.text_encoder_2 = None
self.tokenizer_2 = None
# --- usage ---
pipe = T5SDXLPipeline.from_pretrained(
SDXL_DIR,
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.t5_encoder.to(pipe.device, dtype=pipe.unet.dtype)
pipe.t5_projection.to(pipe.device, dtype=pipe.unet.dtype)
print("Saving model")
pipe.save_pretrained("t5-sdxl-model")