Added save_checkpoint
This commit is contained in:
parent
ab633da4a5
commit
6456d3878e
1 changed files with 37 additions and 28 deletions
65
train.py
65
train.py
|
|
@ -18,6 +18,7 @@ from arguments import args, ret_args
|
|||
import dataset
|
||||
from dataset import *
|
||||
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
|
||||
from checkpoint import save_checkpoint
|
||||
|
||||
logger = logging.getLogger("train")
|
||||
|
||||
|
|
@ -44,14 +45,14 @@ def setup_process_group(
|
|||
def build_train_loader(data_keys, args):
|
||||
train_dataset = ListDataset(
|
||||
data_keys,
|
||||
shuffle = True,
|
||||
shuffle = True,
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
|
||||
]),
|
||||
train = True,
|
||||
batch_size = args.batch_size,
|
||||
nr_workers = args.workers,
|
||||
]),
|
||||
train = True,
|
||||
batch_size = args.batch_size,
|
||||
nr_workers = args.workers,
|
||||
args = args
|
||||
)
|
||||
if args.use_ddp:
|
||||
|
|
@ -61,9 +62,9 @@ def build_train_loader(data_keys, args):
|
|||
else:
|
||||
train_dist_sampler = None
|
||||
train_loader = DataLoader(
|
||||
dataset=train_dataset,
|
||||
sampler=train_dist_sampler,
|
||||
batch_size=args.batch_size,
|
||||
dataset=train_dataset,
|
||||
sampler=train_dist_sampler,
|
||||
batch_size=args.batch_size,
|
||||
drop_last=False
|
||||
)
|
||||
return train_loader
|
||||
|
|
@ -71,13 +72,13 @@ def build_train_loader(data_keys, args):
|
|||
|
||||
def build_test_loader(data_keys, args):
|
||||
test_dataset = ListDataset(
|
||||
data_keys,
|
||||
shuffle=False,
|
||||
data_keys,
|
||||
shuffle=False,
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
|
||||
]),
|
||||
args=args,
|
||||
]),
|
||||
args=args,
|
||||
train=False
|
||||
)
|
||||
if args.use_ddp:
|
||||
|
|
@ -87,8 +88,8 @@ def build_test_loader(data_keys, args):
|
|||
else:
|
||||
test_dist_sampler = None
|
||||
test_loader = DataLoader(
|
||||
dataset=test_dataset,
|
||||
sampler=test_dist_sampler,
|
||||
dataset=test_dataset,
|
||||
sampler=test_dist_sampler,
|
||||
batch_size=1
|
||||
)
|
||||
return test_loader
|
||||
|
|
@ -114,27 +115,27 @@ def worker(rank: int, args: Namespace):
|
|||
test_data = convert_data(test_list, args, train=False)
|
||||
train_loader = build_train_loader(train_data, args)
|
||||
test_loader = build_test_loader(test_data, args)
|
||||
|
||||
|
||||
|
||||
# Instantiate model
|
||||
if args.model == "stn":
|
||||
model = stn_patch16_384_gap(args.pth_tar).to(device)
|
||||
else:
|
||||
model = base_patch16_384_gap(args.pth_tar).to(device)
|
||||
|
||||
|
||||
if args.use_ddp:
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[rank], output_device=rank,
|
||||
find_unused_parameters=True,
|
||||
model,
|
||||
device_ids=[rank], output_device=rank,
|
||||
find_unused_parameters=True,
|
||||
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
|
||||
)
|
||||
|
||||
# criterion, optimizer, scheduler
|
||||
criterion = nn.L1Loss(size_average=False).to(device)
|
||||
optimizer = torch.optim.Adam(
|
||||
[{"params": model.parameters(), "lr": args.lr}],
|
||||
lr=args.lr,
|
||||
[{"params": model.parameters(), "lr": args.lr}],
|
||||
lr=args.lr,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
|
|
@ -147,7 +148,7 @@ def worker(rank: int, args: Namespace):
|
|||
if not os.path.exists(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
|
||||
if args.progress:
|
||||
if args.progress:
|
||||
if os.path.isfile(args.progress):
|
||||
print("=> Loading checkpoint \'{}\'".format(args.progress))
|
||||
checkpoint = torch.load(args.progress)
|
||||
|
|
@ -161,7 +162,7 @@ def worker(rank: int, args: Namespace):
|
|||
rank, args.start_epoch, args.best_pred
|
||||
))
|
||||
|
||||
# For each epoch:
|
||||
# For each epoch:
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
# Tell sampler which epoch it is
|
||||
if args.use_ddp:
|
||||
|
|
@ -183,11 +184,19 @@ def worker(rank: int, args: Namespace):
|
|||
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
|
||||
|
||||
# Save checkpoint
|
||||
# if not args.use_ddp or torch.distributed.get_rank() == 0:
|
||||
|
||||
if not args.use_ddp or torch.distributed.get_rank() == 0:
|
||||
save_checkpoint({
|
||||
"epoch": epoch + 1,
|
||||
"arch": args.progress,
|
||||
"state_dict": model.state_dict(),
|
||||
"best_prec1": args.best_pred,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}, is_best, args.save_path)
|
||||
|
||||
|
||||
# cleanup
|
||||
torch.distributed.destroy_process_group()
|
||||
if args.use_ddp:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
|
|
@ -197,7 +206,7 @@ def train_one_epoch(
|
|||
optimizer,
|
||||
scheduler,
|
||||
epoch: int,
|
||||
device,
|
||||
device,
|
||||
args: Namespace
|
||||
):
|
||||
# Get learning rate
|
||||
|
|
@ -256,7 +265,7 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
with torch.no_grad():
|
||||
out = model(img)
|
||||
count = torch.sum(out).item()
|
||||
|
||||
|
||||
gt_count = torch.sum(gt_count).item()
|
||||
mae += abs(gt_count - count)
|
||||
mse += abs(gt_count - count) ** 2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue