|
from torch.optim import Adam |
|
from loss import StyleTransferLoss |
|
import torch |
|
|
|
def trainer_fn(content, style, target_image, model): |
|
optimizer = Adam([target_image], lr=0.1) |
|
loss_fn = StyleTransferLoss( |
|
model=model, content_img=content, style_img=style, device="cpu" |
|
) |
|
|
|
with torch.no_grad(): |
|
content_features = loss_fn.get_features(content.to("cpu")) |
|
style_features = loss_fn.get_features(style.to("cpu")) |
|
|
|
EPOCHS = 100 |
|
for epoch in range(EPOCHS): |
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
target_features = loss_fn.get_features(target_image) |
|
|
|
|
|
loss = loss_fn.total_loss(target_features, content_features, style_features) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
optimizer.step() |
|
|
|
return target_image |