Added save_checkpoint
This commit is contained in:
parent
ab633da4a5
commit
6456d3878e
1 changed files with 37 additions and 28 deletions
65
train.py
65
train.py
|
|
@ -18,6 +18,7 @@ from arguments import args, ret_args
|
||||||
import dataset
|
import dataset
|
||||||
from dataset import *
|
from dataset import *
|
||||||
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
|
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
|
||||||
|
from checkpoint import save_checkpoint
|
||||||
|
|
||||||
logger = logging.getLogger("train")
|
logger = logging.getLogger("train")
|
||||||
|
|
||||||
|
|
@ -44,14 +45,14 @@ def setup_process_group(
|
||||||
def build_train_loader(data_keys, args):
|
def build_train_loader(data_keys, args):
|
||||||
train_dataset = ListDataset(
|
train_dataset = ListDataset(
|
||||||
data_keys,
|
data_keys,
|
||||||
shuffle = True,
|
shuffle = True,
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
|
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
|
||||||
]),
|
]),
|
||||||
train = True,
|
train = True,
|
||||||
batch_size = args.batch_size,
|
batch_size = args.batch_size,
|
||||||
nr_workers = args.workers,
|
nr_workers = args.workers,
|
||||||
args = args
|
args = args
|
||||||
)
|
)
|
||||||
if args.use_ddp:
|
if args.use_ddp:
|
||||||
|
|
@ -61,9 +62,9 @@ def build_train_loader(data_keys, args):
|
||||||
else:
|
else:
|
||||||
train_dist_sampler = None
|
train_dist_sampler = None
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
dataset=train_dataset,
|
dataset=train_dataset,
|
||||||
sampler=train_dist_sampler,
|
sampler=train_dist_sampler,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
drop_last=False
|
drop_last=False
|
||||||
)
|
)
|
||||||
return train_loader
|
return train_loader
|
||||||
|
|
@ -71,13 +72,13 @@ def build_train_loader(data_keys, args):
|
||||||
|
|
||||||
def build_test_loader(data_keys, args):
|
def build_test_loader(data_keys, args):
|
||||||
test_dataset = ListDataset(
|
test_dataset = ListDataset(
|
||||||
data_keys,
|
data_keys,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
|
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
|
||||||
]),
|
]),
|
||||||
args=args,
|
args=args,
|
||||||
train=False
|
train=False
|
||||||
)
|
)
|
||||||
if args.use_ddp:
|
if args.use_ddp:
|
||||||
|
|
@ -87,8 +88,8 @@ def build_test_loader(data_keys, args):
|
||||||
else:
|
else:
|
||||||
test_dist_sampler = None
|
test_dist_sampler = None
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
dataset=test_dataset,
|
dataset=test_dataset,
|
||||||
sampler=test_dist_sampler,
|
sampler=test_dist_sampler,
|
||||||
batch_size=1
|
batch_size=1
|
||||||
)
|
)
|
||||||
return test_loader
|
return test_loader
|
||||||
|
|
@ -114,27 +115,27 @@ def worker(rank: int, args: Namespace):
|
||||||
test_data = convert_data(test_list, args, train=False)
|
test_data = convert_data(test_list, args, train=False)
|
||||||
train_loader = build_train_loader(train_data, args)
|
train_loader = build_train_loader(train_data, args)
|
||||||
test_loader = build_test_loader(test_data, args)
|
test_loader = build_test_loader(test_data, args)
|
||||||
|
|
||||||
|
|
||||||
# Instantiate model
|
# Instantiate model
|
||||||
if args.model == "stn":
|
if args.model == "stn":
|
||||||
model = stn_patch16_384_gap(args.pth_tar).to(device)
|
model = stn_patch16_384_gap(args.pth_tar).to(device)
|
||||||
else:
|
else:
|
||||||
model = base_patch16_384_gap(args.pth_tar).to(device)
|
model = base_patch16_384_gap(args.pth_tar).to(device)
|
||||||
|
|
||||||
if args.use_ddp:
|
if args.use_ddp:
|
||||||
model = nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model,
|
model,
|
||||||
device_ids=[rank], output_device=rank,
|
device_ids=[rank], output_device=rank,
|
||||||
find_unused_parameters=True,
|
find_unused_parameters=True,
|
||||||
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
|
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
|
||||||
)
|
)
|
||||||
|
|
||||||
# criterion, optimizer, scheduler
|
# criterion, optimizer, scheduler
|
||||||
criterion = nn.L1Loss(size_average=False).to(device)
|
criterion = nn.L1Loss(size_average=False).to(device)
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
[{"params": model.parameters(), "lr": args.lr}],
|
[{"params": model.parameters(), "lr": args.lr}],
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
weight_decay=args.weight_decay
|
weight_decay=args.weight_decay
|
||||||
)
|
)
|
||||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||||
|
|
@ -147,7 +148,7 @@ def worker(rank: int, args: Namespace):
|
||||||
if not os.path.exists(args.save_path):
|
if not os.path.exists(args.save_path):
|
||||||
os.makedirs(args.save_path)
|
os.makedirs(args.save_path)
|
||||||
|
|
||||||
if args.progress:
|
if args.progress:
|
||||||
if os.path.isfile(args.progress):
|
if os.path.isfile(args.progress):
|
||||||
print("=> Loading checkpoint \'{}\'".format(args.progress))
|
print("=> Loading checkpoint \'{}\'".format(args.progress))
|
||||||
checkpoint = torch.load(args.progress)
|
checkpoint = torch.load(args.progress)
|
||||||
|
|
@ -161,7 +162,7 @@ def worker(rank: int, args: Namespace):
|
||||||
rank, args.start_epoch, args.best_pred
|
rank, args.start_epoch, args.best_pred
|
||||||
))
|
))
|
||||||
|
|
||||||
# For each epoch:
|
# For each epoch:
|
||||||
for epoch in range(args.start_epoch, args.epochs):
|
for epoch in range(args.start_epoch, args.epochs):
|
||||||
# Tell sampler which epoch it is
|
# Tell sampler which epoch it is
|
||||||
if args.use_ddp:
|
if args.use_ddp:
|
||||||
|
|
@ -183,11 +184,19 @@ def worker(rank: int, args: Namespace):
|
||||||
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
|
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
|
||||||
|
|
||||||
# Save checkpoint
|
# Save checkpoint
|
||||||
# if not args.use_ddp or torch.distributed.get_rank() == 0:
|
if not args.use_ddp or torch.distributed.get_rank() == 0:
|
||||||
|
save_checkpoint({
|
||||||
|
"epoch": epoch + 1,
|
||||||
|
"arch": args.progress,
|
||||||
|
"state_dict": model.state_dict(),
|
||||||
|
"best_prec1": args.best_pred,
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
}, is_best, args.save_path)
|
||||||
|
|
||||||
|
|
||||||
# cleanup
|
# cleanup
|
||||||
torch.distributed.destroy_process_group()
|
if args.use_ddp:
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
|
|
@ -197,7 +206,7 @@ def train_one_epoch(
|
||||||
optimizer,
|
optimizer,
|
||||||
scheduler,
|
scheduler,
|
||||||
epoch: int,
|
epoch: int,
|
||||||
device,
|
device,
|
||||||
args: Namespace
|
args: Namespace
|
||||||
):
|
):
|
||||||
# Get learning rate
|
# Get learning rate
|
||||||
|
|
@ -256,7 +265,7 @@ def valid_one_epoch(test_loader, model, device, args):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out = model(img)
|
out = model(img)
|
||||||
count = torch.sum(out).item()
|
count = torch.sum(out).item()
|
||||||
|
|
||||||
gt_count = torch.sum(gt_count).item()
|
gt_count = torch.sum(gt_count).item()
|
||||||
mae += abs(gt_count - count)
|
mae += abs(gt_count - count)
|
||||||
mse += abs(gt_count - count) ** 2
|
mse += abs(gt_count - count) ** 2
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue