Added save_checkpoint
This commit is contained in:
parent
ab633da4a5
commit
6456d3878e
1 changed files with 37 additions and 28 deletions
11
train.py
11
train.py
|
|
@ -18,6 +18,7 @@ from arguments import args, ret_args
|
|||
import dataset
|
||||
from dataset import *
|
||||
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
|
||||
from checkpoint import save_checkpoint
|
||||
|
||||
logger = logging.getLogger("train")
|
||||
|
||||
|
|
@ -183,10 +184,18 @@ def worker(rank: int, args: Namespace):
|
|||
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
|
||||
|
||||
# Save checkpoint
|
||||
# if not args.use_ddp or torch.distributed.get_rank() == 0:
|
||||
if not args.use_ddp or torch.distributed.get_rank() == 0:
|
||||
save_checkpoint({
|
||||
"epoch": epoch + 1,
|
||||
"arch": args.progress,
|
||||
"state_dict": model.state_dict(),
|
||||
"best_prec1": args.best_pred,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}, is_best, args.save_path)
|
||||
|
||||
|
||||
# cleanup
|
||||
if args.use_ddp:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue