|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
class Chart:
|
|
def __init__(self):
|
|
self.loss_list = []
|
|
|
|
def add_ckpt(self, ckpt_path, line_name):
|
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
|
train_step_list = ckpt["train_step_list"]
|
|
train_loss_list = ckpt["train_loss_list"]
|
|
val_step_list = ckpt["val_step_list"]
|
|
val_loss_list = ckpt["val_loss_list"]
|
|
val_step_list = [val_step_list[0]] + val_step_list[4::5]
|
|
val_loss_list = [val_loss_list[0]] + val_loss_list[4::5]
|
|
self.loss_list.append((line_name, train_step_list, train_loss_list, val_step_list, val_loss_list))
|
|
|
|
def draw(self, save_path, plot_val=True):
|
|
|
|
plt.rcParams["font.size"] = 14
|
|
plt.rcParams["font.family"] = "serif"
|
|
plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans", "Lucida Grande"]
|
|
plt.rcParams["font.serif"] = ["Times New Roman", "DejaVu Serif"]
|
|
|
|
|
|
plt.figure(figsize=(7.766, 4.8))
|
|
for loss in self.loss_list:
|
|
if plot_val:
|
|
(line,) = plt.plot(loss[1], loss[2], label=loss[0], linewidth=0.5, alpha=0.5)
|
|
line_color = line.get_color()
|
|
plt.plot(loss[3], loss[4], linewidth=1.5, color=line_color)
|
|
else:
|
|
plt.plot(loss[1], loss[2], label=loss[0], linewidth=1)
|
|
plt.xlabel("Step")
|
|
plt.ylabel("Loss")
|
|
legend = plt.legend()
|
|
|
|
|
|
|
|
for line in legend.get_lines():
|
|
line.set_linewidth(2)
|
|
|
|
plt.savefig(save_path, transparent=True)
|
|
plt.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
chart = Chart()
|
|
|
|
|
|
chart.add_ckpt("output/syncnet/train-2024_10_24-21:03:11/checkpoints/checkpoint-10000.pt", "Dim 512")
|
|
chart.add_ckpt("output/syncnet/train-2024_10_25-18:21:59/checkpoints/checkpoint-10000.pt", "Dim 2048")
|
|
chart.add_ckpt("output/syncnet/train-2024_10_24-22:37:04/checkpoints/checkpoint-10000.pt", "Dim 4096")
|
|
chart.add_ckpt("output/syncnet/train-2024_10_25-02:30:17/checkpoints/checkpoint-10000.pt", "Dim 6144")
|
|
chart.draw("ablation.pdf", plot_val=True)
|
|
|