Call it a day
This commit is contained in:
parent
04f78fbcbc
commit
dcc3f57596
1 changed files with 212 additions and 27 deletions
239
train.py
239
train.py
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Optional
|
||||
from argparse import Namespace
|
||||
|
||||
|
|
@ -14,6 +15,9 @@ import numpy as np
|
|||
|
||||
from model.transcrowd_gap import VisionTransformerGAP
|
||||
from arguments import args, ret_args
|
||||
import dataset
|
||||
from dataset import *
|
||||
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
|
||||
|
||||
logger = logging.getLogger("train")
|
||||
|
||||
|
|
@ -33,26 +37,157 @@ def setup_process_group(
|
|||
|
||||
# join point!
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl", rank=rank, world_size=world_size
|
||||
rank=rank, world_size=world_size
|
||||
)
|
||||
|
||||
# TODO:
|
||||
# The shape for each batch in transcrowd is [3, 384, 384],
|
||||
# this is due to images being cropped before training.
|
||||
# To preserve image semantics wrt the entire layout, we want to apply cropping
|
||||
# i.e., as encoder input during the inference/training pipeline.
|
||||
# This should be okay since our transformations are all deterministic?
|
||||
# not sure...
|
||||
|
||||
def build_train_loader(data_keys, args):
|
||||
train_dataset = ListDataset(
|
||||
data_keys,
|
||||
shuffle = True,
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
|
||||
]),
|
||||
train = True,
|
||||
batch_size = args.batch_size,
|
||||
nr_workers = args.workers,
|
||||
args = args
|
||||
)
|
||||
if args.use_ddp:
|
||||
train_dist_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset=train_dataset
|
||||
)
|
||||
else:
|
||||
train_dist_sampler = None
|
||||
train_loader = DataLoader(
|
||||
dataset=train_dataset,
|
||||
sampler=train_dist_sampler,
|
||||
batch_size=args.batch_size,
|
||||
drop_last=False
|
||||
)
|
||||
return train_loader
|
||||
|
||||
|
||||
def build_test_loader(data_keys, args):
|
||||
test_dataset = ListDataset(
|
||||
data_keys,
|
||||
shuffle=False,
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
|
||||
]),
|
||||
args=args,
|
||||
train=False
|
||||
)
|
||||
if args.use_ddp:
|
||||
test_dist_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset=test_dataset
|
||||
)
|
||||
else:
|
||||
test_dist_sampler = None
|
||||
test_loader = DataLoader(
|
||||
dataset=test_dataset,
|
||||
sampler=test_dist_sampler,
|
||||
batch_size=1
|
||||
)
|
||||
return test_loader
|
||||
|
||||
|
||||
def build_train_loader():
|
||||
pass
|
||||
def worker(rank: int, args: Namespace):
|
||||
world_size = args.ddp_world_size
|
||||
|
||||
# (If DDP) join after process group among processes
|
||||
if args.use_ddp:
|
||||
setup_process_group(rank, world_size)
|
||||
|
||||
def build_valid_loader():
|
||||
pass
|
||||
# Setup device for one proc
|
||||
if args.use_ddp:
|
||||
device = torch.device(rank if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
device = torch.device(args.gpus if torch.cuda.is_available() else "cpu")
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Prepare training data
|
||||
train_list, test_list = unpack_npy_data(args)
|
||||
train_data = convert_data(train_list, args, train=True)
|
||||
test_data = convert_data(test_list, args, train=False)
|
||||
train_loader = build_train_loader(train_data, args)
|
||||
test_loader = build_test_loader(test_data, args)
|
||||
|
||||
|
||||
# Instantiate model
|
||||
if args.model == "stn":
|
||||
model = stn_patch16_384_gap(args.pth_tar).to(device)
|
||||
else:
|
||||
model = base_patch16_384_gap(args.pth_tar).to(device)
|
||||
|
||||
if args.use_ddp:
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[rank], output_device=rank,
|
||||
find_unused_parameters=True,
|
||||
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
|
||||
)
|
||||
|
||||
# criterion, optimizer, scheduler
|
||||
criterion = nn.L1Loss(size_average=False).to(device)
|
||||
optimizer = torch.optim.Adam(
|
||||
[{"params": model.parameters(), "lr": args.lr}],
|
||||
lr=args.lr,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer, milestones=[300], gamma=.1, last_epoch=-1
|
||||
)
|
||||
|
||||
# Checkpointing
|
||||
if (not args.use_ddp) or torch.distributed.get_rank() == 0:
|
||||
print("[worker-0] Saving to \'{}\'".format(args.save_path))
|
||||
if not os.path.exists(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
|
||||
if args.progress:
|
||||
if os.path.isfile(args.progress):
|
||||
print("=> Loading checkpoint \'{}\'".format(args.progress))
|
||||
checkpoint = torch.load(args.progress)
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
args.best_pred = checkpoint['best_prec1']
|
||||
else:
|
||||
print("=> Checkpoint not found!")
|
||||
|
||||
print("[worker-{}] Starting @ epoch {}: pred: {}".format(
|
||||
rank, args.start_epoch, args.best_pred
|
||||
))
|
||||
|
||||
# For each epoch:
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
# Tell sampler which epoch it is
|
||||
if args.use_ddp:
|
||||
train_loader.sampler.set_epoch(epoch)
|
||||
test_loader.sampler.set_epoch(epoch)
|
||||
|
||||
# Train
|
||||
start = time.time()
|
||||
train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args)
|
||||
end_train = time.time()
|
||||
|
||||
# Validate
|
||||
if epoch % 5 == 0:
|
||||
prec1 = valid_one_epoch(test_loader, model, device, args)
|
||||
end_valid = time.time()
|
||||
is_best = prec1 < args.best_pred
|
||||
args.best_pred = min(prec1, args.best_pred)
|
||||
|
||||
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
|
||||
|
||||
# Save checkpoint
|
||||
# if not args.use_ddp or torch.distributed.get_rank() == 0:
|
||||
|
||||
|
||||
# cleanup
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
|
|
@ -62,6 +197,7 @@ def train_one_epoch(
|
|||
optimizer,
|
||||
scheduler,
|
||||
epoch: int,
|
||||
device,
|
||||
args: Namespace
|
||||
):
|
||||
# Get learning rate
|
||||
|
|
@ -70,27 +206,76 @@ def train_one_epoch(
|
|||
(epoch, epoch * len(train_loader.dataset), curr_lr)
|
||||
)
|
||||
|
||||
# Set to train mode (perspective estimator only)
|
||||
revpers_net.train()
|
||||
end = time.time()
|
||||
# Set to train mode
|
||||
model.train()
|
||||
|
||||
# In one epoch, for each training sample
|
||||
for i, (fname, img, gt_count) in enumerate(train_loader):
|
||||
# move stuff to device
|
||||
# fpass (revpers)
|
||||
img = img.cuda()
|
||||
# fpass
|
||||
img = img.to(device)
|
||||
out = model(img)
|
||||
gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
|
||||
|
||||
# loss wrt revpers
|
||||
loss = criterion()
|
||||
# loss
|
||||
loss = criterion(out, gt_count)
|
||||
|
||||
# free grad from mem
|
||||
optimizer.zero_grad()
|
||||
|
||||
# bpass
|
||||
loss.backward()
|
||||
|
||||
# optimizer
|
||||
optimizer.step()
|
||||
|
||||
# periodic message
|
||||
if i % args.print_freq == 0:
|
||||
print("Epoch {}: {}/{}".format(epoch, i, len(train_loader)))
|
||||
|
||||
scheduler.step()
|
||||
|
||||
|
||||
pass
|
||||
def valid_one_epoch(test_loader, model, device, args):
|
||||
print("[valid_one_epoch] Validating...")
|
||||
batch_size = 1
|
||||
model.eval()
|
||||
|
||||
def valid_one_epoch():
|
||||
pass
|
||||
mae = .0
|
||||
mse = .0
|
||||
visi = []
|
||||
index = 0
|
||||
|
||||
for i, (fname, img, gt_count) in enumerate(test_loader):
|
||||
img = img.to(device)
|
||||
# XXX: what do this do
|
||||
if len(img.shape) == 5:
|
||||
img = img.squeeze(0)
|
||||
if len(img.shape) == 3:
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(img)
|
||||
count = torch.sum(out).item()
|
||||
|
||||
gt_count = torch.sum(gt_count).item()
|
||||
mae += abs(gt_count - count)
|
||||
mse += abs(gt_count - count) ** 2
|
||||
|
||||
if i % 15 == 0:
|
||||
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
|
||||
fname[0], gt_count, count
|
||||
))
|
||||
|
||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
||||
|
||||
nni.report_intermediate_result(mae)
|
||||
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
||||
mae=mae, mse=mse
|
||||
))
|
||||
|
||||
return mae
|
||||
|
||||
def main(rank: int, args: Namespace):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
tuner_params = nni.get_next_parameter()
|
||||
|
|
@ -103,10 +288,10 @@ if __name__ == "__main__":
|
|||
if combined_params.use_ddp:
|
||||
# Use DDP, spawn threads
|
||||
torch_mp.spawn(
|
||||
main,
|
||||
worker,
|
||||
args=(combined_params, ), # rank supplied automatically as 1st param
|
||||
nprocs=combined_params.world_size,
|
||||
)
|
||||
else:
|
||||
# No DDP, run in current thread
|
||||
main(None, combined_params)
|
||||
worker(0, combined_params)
|
||||
Loading…
Add table
Add a link
Reference in a new issue