File size: 922 Bytes
8b06175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torch.optim import Adam
from loss import StyleTransferLoss
import torch

def trainer_fn(content, style, target_image, model):
    optimizer = Adam([target_image], lr=0.1)
    loss_fn = StyleTransferLoss(
        model=model, content_img=content, style_img=style, device="cpu"
    )

    with torch.no_grad():
        content_features = loss_fn.get_features(content.to("cpu"))
        style_features = loss_fn.get_features(style.to("cpu"))
    
    EPOCHS = 100
    for epoch in range(EPOCHS):
        # set the gradients to zero
        optimizer.zero_grad()

        # get the features of the target image
        target_features = loss_fn.get_features(target_image)

        # calculate the total loss
        loss = loss_fn.total_loss(target_features, content_features, style_features)

        # backpropagate
        loss.backward()

        # update the weights
        optimizer.step()

    return target_image