from torch.optim import Adam from loss import StyleTransferLoss import torch def trainer_fn(content, style, target_image, model): optimizer = Adam([target_image], lr=0.1) loss_fn = StyleTransferLoss( model=model, content_img=content, style_img=style, device="cpu" ) with torch.no_grad(): content_features = loss_fn.get_features(content.to("cpu")) style_features = loss_fn.get_features(style.to("cpu")) EPOCHS = 100 for epoch in range(EPOCHS): # set the gradients to zero optimizer.zero_grad() # get the features of the target image target_features = loss_fn.get_features(target_image) # calculate the total loss loss = loss_fn.total_loss(target_features, content_features, style_features) # backpropagate loss.backward() # update the weights optimizer.step() return target_image