progressive-GAN / app.py
mlgawd's picture
Create app.py
eee507e verified
raw
history blame contribute delete
2.02 kB
import streamlit as st
import torch
from model import Generator
import torchvision.utils as vutils
import os
from math import log2
# Function to generate images
def generate_images():
Z_DIM = 256
IN_CHANNELS = 256
# Load pretrained generator weights
checkpoint = torch.load("generator.pth", map_location=torch.device('cpu'))
# Filter out optimizer-related keys
state_dict = checkpoint['state_dict']
# Load the filtered state dictionary into the model
generator = Generator(Z_DIM, IN_CHANNELS, img_channels=3)
generator.load_state_dict(state_dict)
generator.eval()
# Set output directory
output_dir = "generated_images"
os.makedirs(output_dir, exist_ok=True)
# Generate images
img_sizes = [256]
images = []
for img_size in img_sizes:
num_steps = int(log2(img_size / 4))
x = torch.randn((6, Z_DIM, 1, 1)) # Generate a batch of 6 images
with torch.no_grad():
z = generator(x, alpha=0.5, steps=num_steps)
# Normalize the generated images to the range [-1, 1]
z = (z + 1) / 2
assert z.shape == (6, 3, img_size, img_size)
# Append generated images to the list
for i in range(6):
images.append(z[i].detach())
return images
# Main function to create Streamlit web app
def main():
st.title('Image Generation with pro-gan 🤖')
st.write("Click the buttons below to generate images.")
st.write("Trained on CelebHQ dataset.")
# Prompt message about image size
st.write("Note: Due to limited resources, the model has been trained to generate 256x256 size images. They are still awesome!")
# Generate images on button click
if st.button('Generate Images'):
images = generate_images()
# Display the generated images
for i, image in enumerate(images):
st.image(image.permute(1, 2, 0).cpu().numpy(), caption=f'Generated Image {i+1}', use_column_width=True)
if __name__ == '__main__':
main()