Spaces:
Sleeping
Sleeping
File size: 2,018 Bytes
eee507e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import streamlit as st
import torch
from model import Generator
import torchvision.utils as vutils
import os
from math import log2
# Function to generate images
def generate_images():
Z_DIM = 256
IN_CHANNELS = 256
# Load pretrained generator weights
checkpoint = torch.load("generator.pth", map_location=torch.device('cpu'))
# Filter out optimizer-related keys
state_dict = checkpoint['state_dict']
# Load the filtered state dictionary into the model
generator = Generator(Z_DIM, IN_CHANNELS, img_channels=3)
generator.load_state_dict(state_dict)
generator.eval()
# Set output directory
output_dir = "generated_images"
os.makedirs(output_dir, exist_ok=True)
# Generate images
img_sizes = [256]
images = []
for img_size in img_sizes:
num_steps = int(log2(img_size / 4))
x = torch.randn((6, Z_DIM, 1, 1)) # Generate a batch of 6 images
with torch.no_grad():
z = generator(x, alpha=0.5, steps=num_steps)
# Normalize the generated images to the range [-1, 1]
z = (z + 1) / 2
assert z.shape == (6, 3, img_size, img_size)
# Append generated images to the list
for i in range(6):
images.append(z[i].detach())
return images
# Main function to create Streamlit web app
def main():
st.title('Image Generation with pro-gan 🤖')
st.write("Click the buttons below to generate images.")
st.write("Trained on CelebHQ dataset.")
# Prompt message about image size
st.write("Note: Due to limited resources, the model has been trained to generate 256x256 size images. They are still awesome!")
# Generate images on button click
if st.button('Generate Images'):
images = generate_images()
# Display the generated images
for i, image in enumerate(images):
st.image(image.permute(1, 2, 0).cpu().numpy(), caption=f'Generated Image {i+1}', use_column_width=True)
if __name__ == '__main__':
main()
|