Added save_checkpoint

This commit is contained in:
Zhengyi Chen 2024-02-29 19:04:25 +00:00
parent ab633da4a5
commit 6456d3878e

View file

@ -18,6 +18,7 @@ from arguments import args, ret_args
import dataset
from dataset import *
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
from checkpoint import save_checkpoint
logger = logging.getLogger("train")
@ -183,10 +184,18 @@ def worker(rank: int, args: Namespace):
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
# Save checkpoint
# if not args.use_ddp or torch.distributed.get_rank() == 0:
if not args.use_ddp or torch.distributed.get_rank() == 0:
save_checkpoint({
"epoch": epoch + 1,
"arch": args.progress,
"state_dict": model.state_dict(),
"best_prec1": args.best_pred,
"optimizer": optimizer.state_dict(),
}, is_best, args.save_path)
# cleanup
if args.use_ddp:
torch.distributed.destroy_process_group()