Added save_checkpoint

This commit is contained in:
Zhengyi Chen 2024-02-29 19:04:25 +00:00
parent ab633da4a5
commit 6456d3878e

View file

@ -18,6 +18,7 @@ from arguments import args, ret_args
import dataset
from dataset import *
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
from checkpoint import save_checkpoint
logger = logging.getLogger("train")
@ -44,14 +45,14 @@ def setup_process_group(
def build_train_loader(data_keys, args):
train_dataset = ListDataset(
data_keys,
shuffle = True,
shuffle = True,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
]),
train = True,
batch_size = args.batch_size,
nr_workers = args.workers,
]),
train = True,
batch_size = args.batch_size,
nr_workers = args.workers,
args = args
)
if args.use_ddp:
@ -61,9 +62,9 @@ def build_train_loader(data_keys, args):
else:
train_dist_sampler = None
train_loader = DataLoader(
dataset=train_dataset,
sampler=train_dist_sampler,
batch_size=args.batch_size,
dataset=train_dataset,
sampler=train_dist_sampler,
batch_size=args.batch_size,
drop_last=False
)
return train_loader
@ -71,13 +72,13 @@ def build_train_loader(data_keys, args):
def build_test_loader(data_keys, args):
test_dataset = ListDataset(
data_keys,
shuffle=False,
data_keys,
shuffle=False,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
]),
args=args,
]),
args=args,
train=False
)
if args.use_ddp:
@ -87,8 +88,8 @@ def build_test_loader(data_keys, args):
else:
test_dist_sampler = None
test_loader = DataLoader(
dataset=test_dataset,
sampler=test_dist_sampler,
dataset=test_dataset,
sampler=test_dist_sampler,
batch_size=1
)
return test_loader
@ -114,27 +115,27 @@ def worker(rank: int, args: Namespace):
test_data = convert_data(test_list, args, train=False)
train_loader = build_train_loader(train_data, args)
test_loader = build_test_loader(test_data, args)
# Instantiate model
if args.model == "stn":
model = stn_patch16_384_gap(args.pth_tar).to(device)
else:
model = base_patch16_384_gap(args.pth_tar).to(device)
if args.use_ddp:
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[rank], output_device=rank,
find_unused_parameters=True,
model,
device_ids=[rank], output_device=rank,
find_unused_parameters=True,
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
)
# criterion, optimizer, scheduler
criterion = nn.L1Loss(size_average=False).to(device)
optimizer = torch.optim.Adam(
[{"params": model.parameters(), "lr": args.lr}],
lr=args.lr,
[{"params": model.parameters(), "lr": args.lr}],
lr=args.lr,
weight_decay=args.weight_decay
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
@ -147,7 +148,7 @@ def worker(rank: int, args: Namespace):
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
if args.progress:
if args.progress:
if os.path.isfile(args.progress):
print("=> Loading checkpoint \'{}\'".format(args.progress))
checkpoint = torch.load(args.progress)
@ -161,7 +162,7 @@ def worker(rank: int, args: Namespace):
rank, args.start_epoch, args.best_pred
))
# For each epoch:
# For each epoch:
for epoch in range(args.start_epoch, args.epochs):
# Tell sampler which epoch it is
if args.use_ddp:
@ -183,11 +184,19 @@ def worker(rank: int, args: Namespace):
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
# Save checkpoint
# if not args.use_ddp or torch.distributed.get_rank() == 0:
if not args.use_ddp or torch.distributed.get_rank() == 0:
save_checkpoint({
"epoch": epoch + 1,
"arch": args.progress,
"state_dict": model.state_dict(),
"best_prec1": args.best_pred,
"optimizer": optimizer.state_dict(),
}, is_best, args.save_path)
# cleanup
torch.distributed.destroy_process_group()
if args.use_ddp:
torch.distributed.destroy_process_group()
def train_one_epoch(
@ -197,7 +206,7 @@ def train_one_epoch(
optimizer,
scheduler,
epoch: int,
device,
device,
args: Namespace
):
# Get learning rate
@ -256,7 +265,7 @@ def valid_one_epoch(test_loader, model, device, args):
with torch.no_grad():
out = model(img)
count = torch.sum(out).item()
gt_count = torch.sum(gt_count).item()
mae += abs(gt_count - count)
mse += abs(gt_count - count) ** 2