|
import streamlit as st |
|
import os |
|
from PIL import Image |
|
from torchvision.models import vgg19 |
|
from model import StyleTransferModel |
|
from trainer import trainer_fn |
|
from utils import process_image, tensor_to_image |
|
|
|
|
|
base_model = vgg19(pretrained=True).features |
|
final_model = StyleTransferModel(base_model) |
|
|
|
|
|
st.title('Style Transfer App') |
|
|
|
|
|
st.write('This app applies the style of one image to another image. This can be used to create artistic images.') |
|
|
|
|
|
image_files = [f for f in os.listdir('styles') if f.lower().endswith(('png', 'jpg', 'jpeg', 'gif', 'bmp'))] |
|
|
|
|
|
st.write('Select style art to apply into your image:') |
|
|
|
|
|
num_images = len(image_files) |
|
cols = st.columns(num_images) |
|
|
|
|
|
resize_width = 300 |
|
resize_height = 300 |
|
|
|
|
|
for idx, img_file in enumerate(image_files): |
|
with cols[idx]: |
|
st.write(f"Style {idx + 1}") |
|
img_path = f'styles/{img_file}' |
|
img = Image.open(img_path) |
|
|
|
|
|
img_resized = img.resize((resize_width, resize_height)) |
|
|
|
st.image(img_resized, use_container_width=True) |
|
|
|
|
|
st.write('Upload the content image:') |
|
content_image = st.file_uploader('Content Image', type=['jpg', 'jpeg']) |
|
|
|
|
|
choice = st.selectbox('Select the style art:', [f'Style {i + 1}' for i in range(num_images)]) |
|
|
|
|
|
if st.button('Apply Style Transfer'): |
|
if content_image is not None: |
|
|
|
content_img = Image.open(content_image) |
|
|
|
|
|
style_choice = choice.split()[-1] |
|
style_img = Image.open(os.path.join('styles', image_files[int(style_choice) - 1])) |
|
|
|
|
|
content_img = process_image(content_img) |
|
style_img = process_image(style_img) |
|
|
|
|
|
st.write('Applying Style Transfer...') |
|
target_image = trainer_fn( |
|
content_img, style_img, content_img.clone().requires_grad_(True), final_model |
|
) |
|
|
|
|
|
target_image = tensor_to_image(target_image.squeeze(0)) |
|
|
|
|
|
st.write('Result:') |
|
st.image(target_image, use_container_width=True) |
|
else: |
|
st.write('Please upload a content image') |
|
|
|
|