sebastiansarasti's picture
Update app.py
9bf033b verified
raw
history blame
2.74 kB
import streamlit as st
import os
from PIL import Image
from torchvision.models import vgg19
from model import StyleTransferModel
from trainer import trainer_fn
from utils import process_image, tensor_to_image
base_model = vgg19(pretrained=True).features
final_model = StyleTransferModel(base_model)
# define the title of the app
st.title('Style Transfer App')
# define the description of the app
st.write('This app applies the style of one image to another image. This can be used to create artistic images.')
# get all image files in the 'styles' folder
image_files = [f for f in os.listdir('styles') if f.lower().endswith(('png', 'jpg', 'jpeg', 'gif', 'bmp'))]
# display the images
st.write('Select style art to apply into your image:')
# Check how many images are available and set columns accordingly
num_images = len(image_files)
cols = st.columns(num_images)
# Define the size to which the images will be resized (width, height)
resize_width = 300
resize_height = 300
# show each image in a corresponding column
for idx, img_file in enumerate(image_files):
with cols[idx]:
st.write(f"Style {idx + 1}")
img_path = f'styles/{img_file}'
img = Image.open(img_path)
# Redimensionar la imagen
img_resized = img.resize((resize_width, resize_height))
st.image(img_resized, use_container_width=True)
# create a file uploader for the content image
st.write('Upload the content image:')
content_image = st.file_uploader('Content Image', type=['jpg', 'jpeg'])
# create the botton to select the style image between 1, 2, and 3
choice = st.selectbox('Select the style art:', [f'Style {i + 1}' for i in range(num_images)])
# create a button to run the model
if st.button('Apply Style Transfer'):
if content_image is not None:
# get the content image
content_img = Image.open(content_image)
# get the style image
style_choice = choice.split()[-1] # Extract style number from "Style 1", "Style 2", etc.
style_img = Image.open(os.path.join('styles', image_files[int(style_choice) - 1])) # Get full path
# preprocess the images
content_img = process_image(content_img)
style_img = process_image(style_img)
# run the model
st.write('Applying Style Transfer...')
target_image = trainer_fn(
content_img, style_img, content_img.clone().requires_grad_(True), final_model
)
# convert the tensor to image
target_image = tensor_to_image(target_image.squeeze(0))
# display the result
st.write('Result:')
st.image(target_image, use_container_width=True)
else:
st.write('Please upload a content image')