NueralStyleTransfer / trainer.py
sebastiansarasti's picture
Upload 5 files
8b06175 verified
raw
history blame contribute delete
922 Bytes
from torch.optim import Adam
from loss import StyleTransferLoss
import torch
def trainer_fn(content, style, target_image, model):
optimizer = Adam([target_image], lr=0.1)
loss_fn = StyleTransferLoss(
model=model, content_img=content, style_img=style, device="cpu"
)
with torch.no_grad():
content_features = loss_fn.get_features(content.to("cpu"))
style_features = loss_fn.get_features(style.to("cpu"))
EPOCHS = 100
for epoch in range(EPOCHS):
# set the gradients to zero
optimizer.zero_grad()
# get the features of the target image
target_features = loss_fn.get_features(target_image)
# calculate the total loss
loss = loss_fn.total_loss(target_features, content_features, style_features)
# backpropagate
loss.backward()
# update the weights
optimizer.step()
return target_image