|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
class Chart: |
|
def __init__(self): |
|
self.loss_list = [] |
|
|
|
def add_ckpt(self, ckpt_path, line_name): |
|
ckpt = torch.load(ckpt_path, map_location="cpu") |
|
train_step_list = ckpt["train_step_list"] |
|
train_loss_list = ckpt["train_loss_list"] |
|
val_step_list = ckpt["val_step_list"] |
|
val_loss_list = ckpt["val_loss_list"] |
|
val_step_list = [val_step_list[0]] + val_step_list[4::5] |
|
val_loss_list = [val_loss_list[0]] + val_loss_list[4::5] |
|
self.loss_list.append((line_name, train_step_list, train_loss_list, val_step_list, val_loss_list)) |
|
|
|
def draw(self, save_path, plot_val=True): |
|
|
|
plt.rcParams["font.size"] = 14 |
|
plt.rcParams["font.family"] = "serif" |
|
plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans", "Lucida Grande"] |
|
plt.rcParams["font.serif"] = ["Times New Roman", "DejaVu Serif"] |
|
|
|
|
|
plt.figure(figsize=(7.766, 4.8)) |
|
for loss in self.loss_list: |
|
if plot_val: |
|
(line,) = plt.plot(loss[1], loss[2], label=loss[0], linewidth=0.5, alpha=0.5) |
|
line_color = line.get_color() |
|
plt.plot(loss[3], loss[4], linewidth=1.5, color=line_color) |
|
else: |
|
plt.plot(loss[1], loss[2], label=loss[0], linewidth=1) |
|
plt.xlabel("Step") |
|
plt.ylabel("Loss") |
|
legend = plt.legend() |
|
|
|
|
|
|
|
for line in legend.get_lines(): |
|
line.set_linewidth(2) |
|
|
|
plt.savefig(save_path, transparent=True) |
|
plt.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
chart = Chart() |
|
|
|
|
|
chart.add_ckpt("output/syncnet/train-2024_10_24-21:03:11/checkpoints/checkpoint-10000.pt", "Dim 512") |
|
chart.add_ckpt("output/syncnet/train-2024_10_25-18:21:59/checkpoints/checkpoint-10000.pt", "Dim 2048") |
|
chart.add_ckpt("output/syncnet/train-2024_10_24-22:37:04/checkpoints/checkpoint-10000.pt", "Dim 4096") |
|
chart.add_ckpt("output/syncnet/train-2024_10_25-02:30:17/checkpoints/checkpoint-10000.pt", "Dim 6144") |
|
chart.draw("ablation.pdf", plot_val=True) |
|
|