Added save_checkpoint
This commit is contained in:
parent
ab633da4a5
commit
6456d3878e
1 changed files with 37 additions and 28 deletions
13
train.py
13
train.py
|
|
@ -18,6 +18,7 @@ from arguments import args, ret_args
|
||||||
import dataset
|
import dataset
|
||||||
from dataset import *
|
from dataset import *
|
||||||
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
|
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
|
||||||
|
from checkpoint import save_checkpoint
|
||||||
|
|
||||||
logger = logging.getLogger("train")
|
logger = logging.getLogger("train")
|
||||||
|
|
||||||
|
|
@ -183,11 +184,19 @@ def worker(rank: int, args: Namespace):
|
||||||
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
|
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
|
||||||
|
|
||||||
# Save checkpoint
|
# Save checkpoint
|
||||||
# if not args.use_ddp or torch.distributed.get_rank() == 0:
|
if not args.use_ddp or torch.distributed.get_rank() == 0:
|
||||||
|
save_checkpoint({
|
||||||
|
"epoch": epoch + 1,
|
||||||
|
"arch": args.progress,
|
||||||
|
"state_dict": model.state_dict(),
|
||||||
|
"best_prec1": args.best_pred,
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
}, is_best, args.save_path)
|
||||||
|
|
||||||
|
|
||||||
# cleanup
|
# cleanup
|
||||||
torch.distributed.destroy_process_group()
|
if args.use_ddp:
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue