Call it a day
This commit is contained in:
parent
04f78fbcbc
commit
dcc3f57596
1 changed files with 212 additions and 27 deletions
239
train.py
239
train.py
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
|
|
@ -14,6 +15,9 @@ import numpy as np
|
||||||
|
|
||||||
from model.transcrowd_gap import VisionTransformerGAP
|
from model.transcrowd_gap import VisionTransformerGAP
|
||||||
from arguments import args, ret_args
|
from arguments import args, ret_args
|
||||||
|
import dataset
|
||||||
|
from dataset import *
|
||||||
|
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
|
||||||
|
|
||||||
logger = logging.getLogger("train")
|
logger = logging.getLogger("train")
|
||||||
|
|
||||||
|
|
@ -33,26 +37,157 @@ def setup_process_group(
|
||||||
|
|
||||||
# join point!
|
# join point!
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend="nccl", rank=rank, world_size=world_size
|
rank=rank, world_size=world_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO:
|
|
||||||
# The shape for each batch in transcrowd is [3, 384, 384],
|
def build_train_loader(data_keys, args):
|
||||||
# this is due to images being cropped before training.
|
train_dataset = ListDataset(
|
||||||
# To preserve image semantics wrt the entire layout, we want to apply cropping
|
data_keys,
|
||||||
# i.e., as encoder input during the inference/training pipeline.
|
shuffle = True,
|
||||||
# This should be okay since our transformations are all deterministic?
|
transform = transforms.Compose([
|
||||||
# not sure...
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
|
||||||
|
]),
|
||||||
|
train = True,
|
||||||
|
batch_size = args.batch_size,
|
||||||
|
nr_workers = args.workers,
|
||||||
|
args = args
|
||||||
|
)
|
||||||
|
if args.use_ddp:
|
||||||
|
train_dist_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||||
|
dataset=train_dataset
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_dist_sampler = None
|
||||||
|
train_loader = DataLoader(
|
||||||
|
dataset=train_dataset,
|
||||||
|
sampler=train_dist_sampler,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
drop_last=False
|
||||||
|
)
|
||||||
|
return train_loader
|
||||||
|
|
||||||
|
|
||||||
|
def build_test_loader(data_keys, args):
|
||||||
|
test_dataset = ListDataset(
|
||||||
|
data_keys,
|
||||||
|
shuffle=False,
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
|
||||||
|
]),
|
||||||
|
args=args,
|
||||||
|
train=False
|
||||||
|
)
|
||||||
|
if args.use_ddp:
|
||||||
|
test_dist_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||||
|
dataset=test_dataset
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
test_dist_sampler = None
|
||||||
|
test_loader = DataLoader(
|
||||||
|
dataset=test_dataset,
|
||||||
|
sampler=test_dist_sampler,
|
||||||
|
batch_size=1
|
||||||
|
)
|
||||||
|
return test_loader
|
||||||
|
|
||||||
|
|
||||||
def build_train_loader():
|
def worker(rank: int, args: Namespace):
|
||||||
pass
|
world_size = args.ddp_world_size
|
||||||
|
|
||||||
|
# (If DDP) join after process group among processes
|
||||||
|
if args.use_ddp:
|
||||||
|
setup_process_group(rank, world_size)
|
||||||
|
|
||||||
|
# Setup device for one proc
|
||||||
|
if args.use_ddp:
|
||||||
|
device = torch.device(rank if torch.cuda.is_available() else "cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device(args.gpus if torch.cuda.is_available() else "cpu")
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
# Prepare training data
|
||||||
|
train_list, test_list = unpack_npy_data(args)
|
||||||
|
train_data = convert_data(train_list, args, train=True)
|
||||||
|
test_data = convert_data(test_list, args, train=False)
|
||||||
|
train_loader = build_train_loader(train_data, args)
|
||||||
|
test_loader = build_test_loader(test_data, args)
|
||||||
|
|
||||||
|
|
||||||
def build_valid_loader():
|
# Instantiate model
|
||||||
pass
|
if args.model == "stn":
|
||||||
|
model = stn_patch16_384_gap(args.pth_tar).to(device)
|
||||||
|
else:
|
||||||
|
model = base_patch16_384_gap(args.pth_tar).to(device)
|
||||||
|
|
||||||
|
if args.use_ddp:
|
||||||
|
model = nn.parallel.DistributedDataParallel(
|
||||||
|
model,
|
||||||
|
device_ids=[rank], output_device=rank,
|
||||||
|
find_unused_parameters=True,
|
||||||
|
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
|
||||||
|
)
|
||||||
|
|
||||||
|
# criterion, optimizer, scheduler
|
||||||
|
criterion = nn.L1Loss(size_average=False).to(device)
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
[{"params": model.parameters(), "lr": args.lr}],
|
||||||
|
lr=args.lr,
|
||||||
|
weight_decay=args.weight_decay
|
||||||
|
)
|
||||||
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||||
|
optimizer, milestones=[300], gamma=.1, last_epoch=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Checkpointing
|
||||||
|
if (not args.use_ddp) or torch.distributed.get_rank() == 0:
|
||||||
|
print("[worker-0] Saving to \'{}\'".format(args.save_path))
|
||||||
|
if not os.path.exists(args.save_path):
|
||||||
|
os.makedirs(args.save_path)
|
||||||
|
|
||||||
|
if args.progress:
|
||||||
|
if os.path.isfile(args.progress):
|
||||||
|
print("=> Loading checkpoint \'{}\'".format(args.progress))
|
||||||
|
checkpoint = torch.load(args.progress)
|
||||||
|
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||||
|
args.start_epoch = checkpoint['epoch']
|
||||||
|
args.best_pred = checkpoint['best_prec1']
|
||||||
|
else:
|
||||||
|
print("=> Checkpoint not found!")
|
||||||
|
|
||||||
|
print("[worker-{}] Starting @ epoch {}: pred: {}".format(
|
||||||
|
rank, args.start_epoch, args.best_pred
|
||||||
|
))
|
||||||
|
|
||||||
|
# For each epoch:
|
||||||
|
for epoch in range(args.start_epoch, args.epochs):
|
||||||
|
# Tell sampler which epoch it is
|
||||||
|
if args.use_ddp:
|
||||||
|
train_loader.sampler.set_epoch(epoch)
|
||||||
|
test_loader.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
# Train
|
||||||
|
start = time.time()
|
||||||
|
train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args)
|
||||||
|
end_train = time.time()
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
if epoch % 5 == 0:
|
||||||
|
prec1 = valid_one_epoch(test_loader, model, device, args)
|
||||||
|
end_valid = time.time()
|
||||||
|
is_best = prec1 < args.best_pred
|
||||||
|
args.best_pred = min(prec1, args.best_pred)
|
||||||
|
|
||||||
|
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
# if not args.use_ddp or torch.distributed.get_rank() == 0:
|
||||||
|
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
|
|
@ -62,6 +197,7 @@ def train_one_epoch(
|
||||||
optimizer,
|
optimizer,
|
||||||
scheduler,
|
scheduler,
|
||||||
epoch: int,
|
epoch: int,
|
||||||
|
device,
|
||||||
args: Namespace
|
args: Namespace
|
||||||
):
|
):
|
||||||
# Get learning rate
|
# Get learning rate
|
||||||
|
|
@ -70,27 +206,76 @@ def train_one_epoch(
|
||||||
(epoch, epoch * len(train_loader.dataset), curr_lr)
|
(epoch, epoch * len(train_loader.dataset), curr_lr)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set to train mode (perspective estimator only)
|
# Set to train mode
|
||||||
revpers_net.train()
|
model.train()
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
# In one epoch, for each training sample
|
# In one epoch, for each training sample
|
||||||
for i, (fname, img, gt_count) in enumerate(train_loader):
|
for i, (fname, img, gt_count) in enumerate(train_loader):
|
||||||
# move stuff to device
|
# fpass
|
||||||
# fpass (revpers)
|
img = img.to(device)
|
||||||
img = img.cuda()
|
out = model(img)
|
||||||
|
gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
|
||||||
|
|
||||||
# loss wrt revpers
|
# loss
|
||||||
loss = criterion()
|
loss = criterion(out, gt_count)
|
||||||
|
|
||||||
|
# free grad from mem
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# bpass
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# periodic message
|
||||||
|
if i % args.print_freq == 0:
|
||||||
|
print("Epoch {}: {}/{}".format(epoch, i, len(train_loader)))
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
|
||||||
pass
|
def valid_one_epoch(test_loader, model, device, args):
|
||||||
|
print("[valid_one_epoch] Validating...")
|
||||||
|
batch_size = 1
|
||||||
|
model.eval()
|
||||||
|
|
||||||
def valid_one_epoch():
|
mae = .0
|
||||||
pass
|
mse = .0
|
||||||
|
visi = []
|
||||||
|
index = 0
|
||||||
|
|
||||||
|
for i, (fname, img, gt_count) in enumerate(test_loader):
|
||||||
|
img = img.to(device)
|
||||||
|
# XXX: what do this do
|
||||||
|
if len(img.shape) == 5:
|
||||||
|
img = img.squeeze(0)
|
||||||
|
if len(img.shape) == 3:
|
||||||
|
img = img.unsqueeze(0)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out = model(img)
|
||||||
|
count = torch.sum(out).item()
|
||||||
|
|
||||||
|
gt_count = torch.sum(gt_count).item()
|
||||||
|
mae += abs(gt_count - count)
|
||||||
|
mse += abs(gt_count - count) ** 2
|
||||||
|
|
||||||
|
if i % 15 == 0:
|
||||||
|
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
|
||||||
|
fname[0], gt_count, count
|
||||||
|
))
|
||||||
|
|
||||||
|
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||||
|
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
||||||
|
|
||||||
|
nni.report_intermediate_result(mae)
|
||||||
|
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
||||||
|
mae=mae, mse=mse
|
||||||
|
))
|
||||||
|
|
||||||
|
return mae
|
||||||
|
|
||||||
def main(rank: int, args: Namespace):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tuner_params = nni.get_next_parameter()
|
tuner_params = nni.get_next_parameter()
|
||||||
|
|
@ -103,10 +288,10 @@ if __name__ == "__main__":
|
||||||
if combined_params.use_ddp:
|
if combined_params.use_ddp:
|
||||||
# Use DDP, spawn threads
|
# Use DDP, spawn threads
|
||||||
torch_mp.spawn(
|
torch_mp.spawn(
|
||||||
main,
|
worker,
|
||||||
args=(combined_params, ), # rank supplied automatically as 1st param
|
args=(combined_params, ), # rank supplied automatically as 1st param
|
||||||
nprocs=combined_params.world_size,
|
nprocs=combined_params.world_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# No DDP, run in current thread
|
# No DDP, run in current thread
|
||||||
main(None, combined_params)
|
worker(0, combined_params)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue