|
import torch, multiprocessing, itertools, os, shutil, PIL, argparse, numpy |
|
from collections import OrderedDict |
|
from numbers import Number |
|
from torch.nn.functional import mse_loss, l1_loss |
|
from seeing import pbar |
|
from seeing import zdataset, seededsampler |
|
from seeing import proggan, customnet, parallelfolder |
|
from seeing import encoder_net, encoder_loss, setting |
|
from torchvision import transforms, models |
|
from torchvision.models.vgg import model_urls |
|
from seeing.pidfile import exit_if_job_done, mark_job_done |
|
from seeing import nethook, LBFGS |
|
from seeing.encoder_loss import cor_square_error |
|
from seeing.nethook import InstrumentedModel |
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--image_number', type=int, help='Image number', |
|
default=95) |
|
parser.add_argument('--image_source', |
|
default='test') |
|
parser.add_argument('--redo', type=int, help='Nonzero to delete done.txt', |
|
default=0) |
|
parser.add_argument('--model', type=str, help='Dataset being modeled', |
|
default='church') |
|
parser.add_argument('--halfsize', type=int, |
|
help='Set to 1 for half size enoder', |
|
default=0) |
|
parser.add_argument('--lambda_f', type=float, help='Feature regularizer', |
|
default=0.25) |
|
parser.add_argument('--num_steps', type=int, |
|
help='run for n steps', |
|
default=3000) |
|
parser.add_argument('--snapshot_every', type=int, |
|
help='only generate snapshots every n iterations', |
|
default=1000) |
|
args = parser.parse_args() |
|
|
|
|
|
num_steps = args.num_steps |
|
global_seed = 1 |
|
image_number = args.image_number |
|
expgroup = 'optimize_lbfgs' |
|
imagetypecode = (dict(val='i', train='n', gan='z', test='t') |
|
.get(args.image_source, args.image_source[0])) |
|
expname = 'opt_%s_%d' % (imagetypecode, image_number) |
|
expdir = os.path.join('results', args.model, expgroup, 'cases', expname) |
|
sumdir = os.path.join('results', args.model, expgroup, |
|
'summary_%s' % imagetypecode) |
|
os.makedirs(expdir, exist_ok=True) |
|
os.makedirs(sumdir, exist_ok=True) |
|
|
|
|
|
|
|
def main(): |
|
pbar.print('Running %s' % expdir) |
|
delete_log() |
|
|
|
|
|
dirname = os.path.join(expdir, 'images') |
|
os.makedirs(dirname, exist_ok=True) |
|
loaded_x, loaded_z = setting.load_test_image(image_number, |
|
args.image_source, model=args.model) |
|
visualize_results((image_number, 'target'), |
|
loaded_x[0], summarize=True) |
|
|
|
|
|
G = setting.load_proggan(args.model) |
|
|
|
|
|
E = nethook.InstrumentedModel( |
|
encoder_net.HybridLayerNormEncoder(halfsize=args.halfsize)) |
|
E.load_state_dict(torch.load(os.path.join('results', args.model, |
|
'invert_hybrid_cse/snapshots/epoch_1000.pth.tar'))['state_dict']) |
|
E.eval() |
|
|
|
G.cuda() |
|
E.cuda() |
|
F = E |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
true_x = loaded_x.cuda() |
|
|
|
init_z = E(true_x) |
|
|
|
if loaded_z is None: |
|
true_z = None |
|
else: |
|
true_z = loaded_z.cuda() |
|
|
|
current_z = init_z.clone() |
|
target_x = loaded_x.clone().cuda() |
|
target_f = F(loaded_x.cuda()) |
|
parameters = [current_z] |
|
show_every = args.snapshot_every |
|
|
|
nethook.set_requires_grad(False, G, E) |
|
nethook.set_requires_grad(True, *parameters) |
|
optimizer = LBFGS.FullBatchLBFGS(parameters) |
|
|
|
def compute_all_loss(): |
|
current_x = G(current_z) |
|
all_loss = {} |
|
all_loss['x'] = l1_loss(target_x, current_x) |
|
all_loss['z'] = 0.0 if not args.lambda_f else ( |
|
mse_loss(target_f, F(current_x)) * args.lambda_f) |
|
return current_x, all_loss |
|
|
|
def closure(): |
|
optimizer.zero_grad() |
|
_, all_loss = compute_all_loss() |
|
return sum(all_loss.values()) |
|
|
|
with torch.enable_grad(): |
|
for step_num in pbar(range(num_steps + 1)): |
|
if step_num == 0: |
|
loss = closure() |
|
loss.backward() |
|
else: |
|
options = {'closure': closure, 'current_loss': loss, |
|
'max_ls': 10} |
|
loss, _, lr, _, _, _, _, _ = optimizer.step(options) |
|
if step_num % show_every == 0: |
|
with torch.no_grad(): |
|
current_x, all_loss = compute_all_loss() |
|
log_progress('%d ' % step_num + ' '.join( |
|
'%s=%.3f' % (k, all_loss[k]) |
|
for k in sorted(all_loss.keys())), phase='a') |
|
visualize_results((image_number, 'a', step_num), current_x, |
|
summarize=(step_num in [0, num_steps])) |
|
checkpoint_dict = OrderedDict(all_loss) |
|
checkpoint_dict['init_z'] = init_z |
|
checkpoint_dict['target_x'] = target_x |
|
checkpoint_dict['current_z'] = target_x |
|
save_checkpoint( |
|
phase='a', |
|
step=step_num, |
|
optimizer=optimizer.state_dict(), |
|
**checkpoint_dict) |
|
|
|
def delete_log(): |
|
try: |
|
os.remove(os.path.join(expdir, 'log.txt')) |
|
except: |
|
pass |
|
|
|
def log_progress(s, phase='a'): |
|
with open(os.path.join(expdir, 'log.txt'), 'a') as f: |
|
f.write(phase + ' ' + s + '\n') |
|
pbar.print(s) |
|
|
|
def save_checkpoint(**kwargs): |
|
dirname = os.path.join(expdir, 'snapshots') |
|
os.makedirs(dirname, exist_ok=True) |
|
filename = 'step_%s_%d.pth.tar' % (kwargs['phase'], kwargs['step']) |
|
torch.save(kwargs, os.path.join(dirname, filename)) |
|
|
|
numeric_data = { |
|
k: v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v |
|
for k, v in kwargs.items() |
|
if isinstance(v, (Number, numpy.ndarray, torch.Tensor))} |
|
filename = 'step_%s_%d.npz' % (kwargs['phase'], kwargs['step']) |
|
numpy.savez(os.path.join(dirname, filename), **numeric_data) |
|
|
|
def visualize_results(step, img, summarize=False): |
|
|
|
if isinstance(step, tuple): |
|
filename = '%s.png' % ('_'.join(str(i) for i in step)) |
|
else: |
|
filename = '%s.png' % str(step) |
|
dirname = os.path.join(expdir, 'images') |
|
os.makedirs(dirname, exist_ok=True) |
|
save_tensor_image(img, os.path.join(dirname, filename)) |
|
lbname = os.path.join(dirname, '+lightbox.html') |
|
if not os.path.exists(lbname): |
|
shutil.copy('seeing/lightbox.html', lbname) |
|
if summarize: |
|
save_tensor_image(img, os.path.join(sumdir, filename)) |
|
lbname = os.path.join(sumdir, '+lightbox.html') |
|
if not os.path.exists(lbname): |
|
shutil.copy('seeing/lightbox.html', lbname) |
|
|
|
def save_tensor_image(img, filename): |
|
if len(img.shape) == 4: |
|
img = img[0] |
|
np_data = ((img.permute(1, 2, 0) / 2 + 0.5) * 255 |
|
).clamp(0, 255).byte().cpu().numpy() |
|
PIL.Image.fromarray(np_data).save(filename) |
|
|
|
def set_requires_grad(requires_grad, *models): |
|
for model in models: |
|
if isinstance(model, torch.nn.Module): |
|
for param in model.parameters(): |
|
param.requires_grad = requires_grad |
|
elif isintance(model, torch.nn.Parameter): |
|
model.requires_grad = requires_grad |
|
else: |
|
assert False, 'unknown type %r' % type(model) |
|
|
|
if __name__ == '__main__': |
|
exit_if_job_done(expdir, redo=args.redo) |
|
main() |
|
mark_job_done(expdir) |
|
|