|
---
|
|
license: apache-2.0
|
|
pipeline_tag: text-to-image
|
|
---
|
|
|
|
# Simple Diffusion XS
|
|
|
|
*XS Size, Excess Quality*
|
|
|
|
At AiArtLab, we strive to create a free, compact (1.7b) and fast (3 sec/image) model that can be trained on consumer graphics cards.
|
|
|
|
- We use U-Net for its high efficiency.
|
|
- We have chosen the multilingual/multimodal encoder Mexma-SigLIP, which supports 80 languages.
|
|
- We use the AuraDiffusion 16ch-VAE architecture, which preserves details and anatomy.
|
|
- The model was trained (~1 month on 4xA5000) on approximately 1 million images with various resolutions and styles, including anime and realistic photos.
|
|
|
|
### Model Limitations:
|
|
- Limited concept coverage due to the small dataset.
|
|
- The Image2Image functionality requires further training.
|
|
|
|
## Acknowledgments
|
|
- **[Stan](https://t.me/Stangle)** — Key investor. Thank you for believing in us when others called it madness.
|
|
- **Captainsaturnus**
|
|
- **Love. Death. Transformers.**
|
|
|
|
## Datasets
|
|
- **[CaptionEmporium](https://huggingface.co/CaptionEmporium)**
|
|
|
|
## Training budget
|
|
|
|
Around ~$1k for now, but research budget ~$10k
|
|
|
|
## Donations
|
|
|
|
Please contact with us if you may provide some GPU's or money on training
|
|
|
|
DOGE: DEw2DR8C7BnF8GgcrfTzUjSnGkuMeJhg83
|
|
|
|
BTC: 3JHv9Hb8kEW8zMAccdgCdZGfrHeMhH1rpN
|
|
|
|
## Contacts
|
|
|
|
[recoilme](https://t.me/recoilme)
|
|
|
|
Train status, in progress: [wandb](https://wandb.ai/recoilme/micro)
|
|
|
|

|
|
|
|
## Example
|
|
|
|
```python
|
|
import torch
|
|
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
|
from transformers import AutoModel, AutoTokenizer
|
|
from PIL import Image
|
|
from tqdm.auto import tqdm
|
|
import os
|
|
|
|
def encode_prompt(prompt, negative_prompt, device, dtype):
|
|
if negative_prompt is None:
|
|
negative_prompt = ""
|
|
|
|
with torch.no_grad():
|
|
positive_inputs = tokenizer(
|
|
prompt,
|
|
return_tensors="pt",
|
|
padding="max_length",
|
|
max_length=512,
|
|
truncation=True,
|
|
).to(device)
|
|
positive_embeddings = text_model.encode_texts(
|
|
positive_inputs.input_ids, positive_inputs.attention_mask
|
|
)
|
|
if positive_embeddings.ndim == 2:
|
|
positive_embeddings = positive_embeddings.unsqueeze(1)
|
|
positive_embeddings = positive_embeddings.to(device, dtype=dtype)
|
|
|
|
negative_inputs = tokenizer(
|
|
negative_prompt,
|
|
return_tensors="pt",
|
|
padding="max_length",
|
|
max_length=512,
|
|
truncation=True,
|
|
).to(device)
|
|
negative_embeddings = text_model.encode_texts(negative_inputs.input_ids, negative_inputs.attention_mask)
|
|
if negative_embeddings.ndim == 2:
|
|
negative_embeddings = negative_embeddings.unsqueeze(1)
|
|
negative_embeddings = negative_embeddings.to(device, dtype=dtype)
|
|
return torch.cat([negative_embeddings, positive_embeddings], dim=0)
|
|
|
|
def generate_latents(embeddings, height=576, width=576, num_inference_steps=50, guidance_scale=5.5):
|
|
with torch.no_grad():
|
|
device, dtype = embeddings.device, embeddings.dtype
|
|
half = embeddings.shape[0] // 2
|
|
latent_shape = (half, 16, height // 8, width // 8)
|
|
latents = torch.randn(latent_shape, device=device, dtype=dtype)
|
|
embeddings = embeddings.repeat_interleave(half, dim=0)
|
|
|
|
scheduler.set_timesteps(num_inference_steps)
|
|
|
|
for t in tqdm(scheduler.timesteps, desc="Генерация"):
|
|
latent_model_input = torch.cat([latents] * 2)
|
|
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
|
noise_pred = unet(latent_model_input, t, embeddings).sample
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
|
return latents
|
|
|
|
|
|
def decode_latents(latents, vae, output_type="pil"):
|
|
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
|
|
with torch.no_grad():
|
|
images = vae.decode(latents).sample
|
|
images = (images / 2 + 0.5).clamp(0, 1)
|
|
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
if output_type == "pil":
|
|
images = (images * 255).round().astype("uint8")
|
|
images = [Image.fromarray(image) for image in images]
|
|
return images
|
|
|
|
# Example usage:
|
|
if __name__ == "__main__":
|
|
device = "cuda"
|
|
dtype = torch.float16
|
|
|
|
prompt = "girl"
|
|
negative_prompt = "bad quality"
|
|
tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip")
|
|
text_model = AutoModel.from_pretrained(
|
|
"visheratin/mexma-siglip", torch_dtype=dtype, trust_remote_code=True
|
|
).to(device, dtype=dtype).eval()
|
|
|
|
embeddings = encode_prompt(prompt, negative_prompt, device, dtype)
|
|
|
|
pipeid = "AiArtLab/sdxs"
|
|
variant = "fp16"
|
|
|
|
unet = UNet2DConditionModel.from_pretrained(pipeid, subfolder="unet", variant=variant).to(device, dtype=dtype).eval()
|
|
vae = AutoencoderKL.from_pretrained(pipeid, subfolder="vae", variant=variant).to(device, dtype=dtype).eval()
|
|
scheduler = DDPMScheduler.from_pretrained(pipeid, subfolder="scheduler")
|
|
|
|
|
|
height, width = 576, 384
|
|
num_inference_steps = 40
|
|
output_folder, project_name = "samples", "sdxs"
|
|
latents = generate_latents(
|
|
embeddings=embeddings,
|
|
height=height,
|
|
width=width,
|
|
num_inference_steps = num_inference_steps
|
|
)
|
|
|
|
images = decode_latents(latents, vae)
|
|
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
for idx, image in enumerate(images):
|
|
image.save(f"{output_folder}/{project_name}_{idx}.jpg")
|
|
|
|
print("Images generated and saved to:", output_folder)
|
|
``` |