mlp-project/checkpoint.py

8 lines
254 B
Python

import torch
import shutil
def save_checkpoint(state, is_best: bool, task_id, fname="checkpoint.pth.tar"):
fdir = "./"+str(task_id)+"/"
torch.save(state, fdir + fname)
if is_best:
shutil.copyfile(fdir + fname, fdir + "best.pth.tar")