Added save_checkpoint

This commit is contained in:
Zhengyi Chen 2024-02-29 19:04:25 +00:00
parent ab633da4a5
commit 6456d3878e

View file

@ -18,6 +18,7 @@ from arguments import args, ret_args
import dataset import dataset
from dataset import * from dataset import *
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
from checkpoint import save_checkpoint
logger = logging.getLogger("train") logger = logging.getLogger("train")
@ -44,14 +45,14 @@ def setup_process_group(
def build_train_loader(data_keys, args): def build_train_loader(data_keys, args):
train_dataset = ListDataset( train_dataset = ListDataset(
data_keys, data_keys,
shuffle = True, shuffle = True,
transform = transforms.Compose([ transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225]) transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
]), ]),
train = True, train = True,
batch_size = args.batch_size, batch_size = args.batch_size,
nr_workers = args.workers, nr_workers = args.workers,
args = args args = args
) )
if args.use_ddp: if args.use_ddp:
@ -61,9 +62,9 @@ def build_train_loader(data_keys, args):
else: else:
train_dist_sampler = None train_dist_sampler = None
train_loader = DataLoader( train_loader = DataLoader(
dataset=train_dataset, dataset=train_dataset,
sampler=train_dist_sampler, sampler=train_dist_sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
drop_last=False drop_last=False
) )
return train_loader return train_loader
@ -71,13 +72,13 @@ def build_train_loader(data_keys, args):
def build_test_loader(data_keys, args): def build_test_loader(data_keys, args):
test_dataset = ListDataset( test_dataset = ListDataset(
data_keys, data_keys,
shuffle=False, shuffle=False,
transform = transforms.Compose([ transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225]) transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
]), ]),
args=args, args=args,
train=False train=False
) )
if args.use_ddp: if args.use_ddp:
@ -87,8 +88,8 @@ def build_test_loader(data_keys, args):
else: else:
test_dist_sampler = None test_dist_sampler = None
test_loader = DataLoader( test_loader = DataLoader(
dataset=test_dataset, dataset=test_dataset,
sampler=test_dist_sampler, sampler=test_dist_sampler,
batch_size=1 batch_size=1
) )
return test_loader return test_loader
@ -114,27 +115,27 @@ def worker(rank: int, args: Namespace):
test_data = convert_data(test_list, args, train=False) test_data = convert_data(test_list, args, train=False)
train_loader = build_train_loader(train_data, args) train_loader = build_train_loader(train_data, args)
test_loader = build_test_loader(test_data, args) test_loader = build_test_loader(test_data, args)
# Instantiate model # Instantiate model
if args.model == "stn": if args.model == "stn":
model = stn_patch16_384_gap(args.pth_tar).to(device) model = stn_patch16_384_gap(args.pth_tar).to(device)
else: else:
model = base_patch16_384_gap(args.pth_tar).to(device) model = base_patch16_384_gap(args.pth_tar).to(device)
if args.use_ddp: if args.use_ddp:
model = nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, model,
device_ids=[rank], output_device=rank, device_ids=[rank], output_device=rank,
find_unused_parameters=True, find_unused_parameters=True,
gradient_as_bucket_view=True # XXX: vital, otherwise OOM gradient_as_bucket_view=True # XXX: vital, otherwise OOM
) )
# criterion, optimizer, scheduler # criterion, optimizer, scheduler
criterion = nn.L1Loss(size_average=False).to(device) criterion = nn.L1Loss(size_average=False).to(device)
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
[{"params": model.parameters(), "lr": args.lr}], [{"params": model.parameters(), "lr": args.lr}],
lr=args.lr, lr=args.lr,
weight_decay=args.weight_decay weight_decay=args.weight_decay
) )
scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.MultiStepLR(
@ -147,7 +148,7 @@ def worker(rank: int, args: Namespace):
if not os.path.exists(args.save_path): if not os.path.exists(args.save_path):
os.makedirs(args.save_path) os.makedirs(args.save_path)
if args.progress: if args.progress:
if os.path.isfile(args.progress): if os.path.isfile(args.progress):
print("=> Loading checkpoint \'{}\'".format(args.progress)) print("=> Loading checkpoint \'{}\'".format(args.progress))
checkpoint = torch.load(args.progress) checkpoint = torch.load(args.progress)
@ -161,7 +162,7 @@ def worker(rank: int, args: Namespace):
rank, args.start_epoch, args.best_pred rank, args.start_epoch, args.best_pred
)) ))
# For each epoch: # For each epoch:
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
# Tell sampler which epoch it is # Tell sampler which epoch it is
if args.use_ddp: if args.use_ddp:
@ -183,11 +184,19 @@ def worker(rank: int, args: Namespace):
print("* best MAE {mae:.3f} *".format(mae=args.best_pred)) print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
# Save checkpoint # Save checkpoint
# if not args.use_ddp or torch.distributed.get_rank() == 0: if not args.use_ddp or torch.distributed.get_rank() == 0:
save_checkpoint({
"epoch": epoch + 1,
"arch": args.progress,
"state_dict": model.state_dict(),
"best_prec1": args.best_pred,
"optimizer": optimizer.state_dict(),
}, is_best, args.save_path)
# cleanup # cleanup
torch.distributed.destroy_process_group() if args.use_ddp:
torch.distributed.destroy_process_group()
def train_one_epoch( def train_one_epoch(
@ -197,7 +206,7 @@ def train_one_epoch(
optimizer, optimizer,
scheduler, scheduler,
epoch: int, epoch: int,
device, device,
args: Namespace args: Namespace
): ):
# Get learning rate # Get learning rate
@ -256,7 +265,7 @@ def valid_one_epoch(test_loader, model, device, args):
with torch.no_grad(): with torch.no_grad():
out = model(img) out = model(img)
count = torch.sum(out).item() count = torch.sum(out).item()
gt_count = torch.sum(gt_count).item() gt_count = torch.sum(gt_count).item()
mae += abs(gt_count - count) mae += abs(gt_count - count)
mse += abs(gt_count - count) ** 2 mse += abs(gt_count - count) ** 2