jamino30 commited on
Commit
e2f3074
·
verified ·
1 Parent(s): 672203e

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +4 -8
app.py CHANGED
@@ -86,18 +86,14 @@ def inference(content_image, style_name, style_strength, output_quality, progres
86
  content_features = model(content_img)
87
  style_features = cached_style_features[style_name][0 if img_size == 512 else 1]
88
 
89
- scaler = torch.amp.GradScaler('cuda')
90
-
91
  for _ in tqdm(range(iters), desc='The magic is happening ✨'):
92
  optimizer.zero_grad()
93
 
94
- with torch.amp.autocast('cuda'):
95
- generated_features = model(generated_img)
96
- total_loss = compute_loss(generated_features, content_features, style_features, alpha, beta)
97
 
98
- scaler.scale(total_loss).backward()
99
- scaler.step(optimizer)
100
- scaler.update()
101
 
102
  et = time.time()
103
  print('TIME TAKEN:', et-st)
 
86
  content_features = model(content_img)
87
  style_features = cached_style_features[style_name][0 if img_size == 512 else 1]
88
 
 
 
89
  for _ in tqdm(range(iters), desc='The magic is happening ✨'):
90
  optimizer.zero_grad()
91
 
92
+ generated_features = model(generated_img)
93
+ total_loss = compute_loss(generated_features, content_features, style_features, alpha, beta)
 
94
 
95
+ total_loss.backward()
96
+ optimizer.step()
 
97
 
98
  et = time.time()
99
  print('TIME TAKEN:', et-st)