380 lines
12 KiB
Python
380 lines
12 KiB
Python
import os
|
|
import random
|
|
import time
|
|
from typing import Optional
|
|
from argparse import Namespace
|
|
|
|
import timm
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.multiprocessing as torch_mp
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import torchvision
|
|
import nni
|
|
import logging
|
|
import numpy as np
|
|
|
|
from model.transcrowd_gap import VisionTransformerGAP
|
|
from arguments import args, ret_args
|
|
import dataset
|
|
from dataset import *
|
|
from model.transcrowd_gap import *
|
|
from checkpoint import save_checkpoint
|
|
|
|
logger = logging.getLogger("train")
|
|
writer = SummaryWriter(args.save_path + "/tensorboard-run")
|
|
|
|
def setup_process_group(
|
|
rank: int,
|
|
world_size: int,
|
|
master_addr: str = "localhost",
|
|
master_port: Optional[np.ushort] = None
|
|
):
|
|
os.environ["MASTER_ADDR"] = master_addr
|
|
os.environ["MASTER_PORT"] = (
|
|
str(random.randint(40000, 65545))
|
|
if master_port is None
|
|
else str(master_port)
|
|
)
|
|
|
|
# join point!
|
|
torch.distributed.init_process_group(
|
|
rank=rank, world_size=world_size
|
|
)
|
|
|
|
|
|
def build_train_loader(data_keys, args):
|
|
train_dataset = ListDataset(
|
|
data_keys,
|
|
shuffle = True,
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
|
|
]),
|
|
train = True,
|
|
batch_size = args.batch_size,
|
|
nr_workers = args.workers,
|
|
args = args
|
|
)
|
|
if args.use_ddp:
|
|
train_dist_sampler = torch.utils.data.distributed.DistributedSampler(
|
|
dataset=train_dataset
|
|
)
|
|
else:
|
|
train_dist_sampler = None
|
|
train_loader = DataLoader(
|
|
dataset=train_dataset,
|
|
sampler=train_dist_sampler,
|
|
batch_size=args.batch_size,
|
|
drop_last=False
|
|
)
|
|
return train_loader
|
|
|
|
|
|
def build_test_loader(data_keys, args):
|
|
test_dataset = ListDataset(
|
|
data_keys,
|
|
shuffle=False,
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
|
|
]),
|
|
args=args,
|
|
train=False
|
|
)
|
|
if args.use_ddp:
|
|
test_dist_sampler = torch.utils.data.distributed.DistributedSampler(
|
|
dataset=test_dataset
|
|
)
|
|
else:
|
|
test_dist_sampler = None
|
|
test_loader = DataLoader(
|
|
dataset=test_dataset,
|
|
sampler=test_dist_sampler,
|
|
batch_size=1
|
|
)
|
|
return test_loader
|
|
|
|
|
|
def worker(rank: int, args: Namespace):
|
|
world_size = args.ddp_world_size
|
|
|
|
# (If DDP) join after process group among processes
|
|
if args.use_ddp:
|
|
setup_process_group(rank, world_size)
|
|
|
|
# Setup device for one proc
|
|
if args.use_ddp and torch.cuda.is_available():
|
|
device = torch.device(rank)
|
|
elif torch.cuda.is_available():
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
|
args.gpus = [int(gpu_id) for gpu_id in args.gpus.split(",")]
|
|
device = None
|
|
elif torch.backends.mps.is_available():
|
|
device = torch.device("mps")
|
|
else:
|
|
print("[!!!] Using CPU for inference. This will be slow...")
|
|
device = torch.device("cpu")
|
|
if device is not None:
|
|
torch.set_default_device(device)
|
|
|
|
# Prepare training data
|
|
train_list, test_list = unpack_npy_data(args)
|
|
train_data = convert_data(train_list, args, train=True)
|
|
test_data = convert_data(test_list, args, train=False)
|
|
train_loader = build_train_loader(train_data, args)
|
|
test_loader = build_test_loader(test_data, args)
|
|
|
|
# Instantiate model
|
|
if args.model == "stn":
|
|
model = stn_patch16_384_gap(args.pth_tar)
|
|
else:
|
|
model = base_patch16_384_gap(args.pth_tar)
|
|
|
|
if args.use_ddp:
|
|
model = nn.parallel.DistributedDataParallel(
|
|
model,
|
|
device_ids=[rank], output_device=rank,
|
|
find_unused_parameters=True,
|
|
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
|
|
)
|
|
else:
|
|
model = nn.DataParallel(
|
|
model,
|
|
device_ids=args.gpus
|
|
)
|
|
|
|
if device is not None:
|
|
model = model.to(device)
|
|
elif torch.cuda.is_available():
|
|
model = model.cuda()
|
|
|
|
# criterion, optimizer, scheduler
|
|
criterion = nn.L1Loss(reduction="sum")
|
|
if device is not None:
|
|
criterion = criterion.to(device)
|
|
elif torch.cuda.is_available():
|
|
criterion = criterion.cuda()
|
|
optimizer = torch.optim.Adam(
|
|
[{"params": model.parameters(), "lr": args.lr}],
|
|
lr=args.lr,
|
|
weight_decay=args.weight_decay
|
|
)
|
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
optimizer, milestones=[300], gamma=.1, last_epoch=-1
|
|
)
|
|
|
|
# Checkpointing
|
|
if (not args.use_ddp) or torch.distributed.get_rank() == 0:
|
|
print("[worker-0] Saving to \'{}\'".format(args.save_path))
|
|
if not os.path.exists(args.save_path):
|
|
os.makedirs(args.save_path)
|
|
|
|
if args.resume:
|
|
if os.path.isfile(args.resume):
|
|
print("=> Loading checkpoint \'{}\'".format(args.resume))
|
|
checkpoint = torch.load(args.resume)
|
|
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
|
args.start_epoch = checkpoint['epoch']
|
|
args.best_pred = checkpoint['best_prec1'] # sic.
|
|
else:
|
|
print("=> Checkpoint not found!")
|
|
|
|
print("[worker-{}] Starting @ epoch {}: pred: {}".format(
|
|
rank, args.start_epoch, args.best_pred
|
|
))
|
|
|
|
# For each epoch:
|
|
for epoch in range(args.start_epoch, args.epochs):
|
|
# Tell sampler which epoch it is
|
|
if args.use_ddp:
|
|
train_loader.sampler.set_epoch(epoch)
|
|
test_loader.sampler.set_epoch(epoch)
|
|
|
|
# Train
|
|
start = time.time()
|
|
train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args)
|
|
end_train = time.time()
|
|
|
|
# Validate
|
|
if epoch % 5 == 0 or args.debug:
|
|
prec1 = valid_one_epoch(test_loader, model, device, epoch, args)
|
|
end_valid = time.time()
|
|
is_best = prec1 < args.best_pred
|
|
args.best_pred = min(prec1, args.best_pred)
|
|
|
|
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
|
|
|
|
# Save checkpoint
|
|
if not args.use_ddp or torch.distributed.get_rank() == 0:
|
|
save_checkpoint({
|
|
"epoch": epoch + 1,
|
|
"arch": args.resume,
|
|
"state_dict": model.state_dict(),
|
|
"best_prec1": args.best_pred,
|
|
"optimizer": optimizer.state_dict(),
|
|
}, is_best, args.save_path)
|
|
|
|
|
|
# cleanup
|
|
if args.use_ddp:
|
|
torch.distributed.destroy_process_group()
|
|
|
|
|
|
def train_one_epoch(
|
|
train_loader: DataLoader,
|
|
model: VisionTransformerGAP,
|
|
criterion,
|
|
optimizer,
|
|
scheduler,
|
|
epoch: int,
|
|
device,
|
|
args: Namespace
|
|
):
|
|
# Get learning rate
|
|
curr_lr = optimizer.param_groups[0]["lr"]
|
|
print("Epoch %d, processed %d samples, lr %.10f" %
|
|
(epoch, epoch * len(train_loader.dataset), curr_lr)
|
|
)
|
|
|
|
# Set to train mode
|
|
model.train()
|
|
|
|
# In one epoch, for each training sample
|
|
for i, (fname, img, kpoint, gt_count_whole) in enumerate(train_loader):
|
|
kpoint = kpoint.type(torch.FloatTensor)
|
|
gt_count_whole = gt_count_whole.type(torch.FloatTensor).unsqueeze(1)
|
|
batch_size = img.size(0)
|
|
# send to device
|
|
if device is not None:
|
|
img = img.to(device)
|
|
kpoint = kpoint.to(device)
|
|
gt_count_whole = gt_count_whole.to(device)
|
|
device_type = device.type
|
|
elif torch.cuda.is_available():
|
|
img = img.cuda()
|
|
kpoint = kpoint.cuda()
|
|
gt_count_whole = gt_count_whole.cuda()
|
|
device_type = "cuda"
|
|
|
|
# Desperate measure to reduce mem footprint...
|
|
with torch.autocast(device_type):
|
|
# fpass
|
|
out, gt_count = model(img, kpoint)
|
|
# loss
|
|
loss = criterion(out, gt_count) # wrt. transformer
|
|
writer.add_scalar("L1-loss wrt. xformer (train)", loss, epoch * i)
|
|
|
|
loss += (
|
|
F.mse_loss( # stn: info retainment
|
|
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
|
gt_count_whole)
|
|
+ F.threshold( # stn: perspective correction
|
|
gt_count.view(batch_size, -1).var(dim=1).mean(),
|
|
threshold=loss.item(),
|
|
value=loss.item()
|
|
)
|
|
)
|
|
writer.add_scalar("Composite loss (train)", loss, epoch * i)
|
|
|
|
# free grad from mem
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
# bpass
|
|
loss.backward()
|
|
|
|
# optimizer
|
|
optimizer.step()
|
|
|
|
if args.debug:
|
|
break
|
|
|
|
# Flush writer
|
|
writer.flush()
|
|
|
|
scheduler.step()
|
|
|
|
|
|
def valid_one_epoch(test_loader, model, device, epoch, args):
|
|
print("[valid_one_epoch] Validating...")
|
|
batch_size = 1
|
|
model.eval()
|
|
|
|
mae = .0
|
|
mse = .0
|
|
visi = []
|
|
index = 0
|
|
xformed = []
|
|
|
|
for i, (fname, img, kpoint, gt_count_whole) in enumerate(test_loader):
|
|
kpoint = kpoint.type(torch.FloatTensor)
|
|
gt_count_whole = gt_count_whole.type(torch.FloatTensor)
|
|
if device is not None:
|
|
img = img.to(device)
|
|
kpoint = kpoint.to(device)
|
|
gt_count_whole = gt_count_whole.to(device)
|
|
elif torch.cuda.is_available():
|
|
img = img.cuda()
|
|
kpoint = kpoint.cuda()
|
|
gt_count_whole = gt_count_whole.cuda()
|
|
|
|
# XXX: do this even happen?
|
|
if len(img.shape) == 5:
|
|
img = img.squeeze(0)
|
|
if len(img.shape) == 3:
|
|
img = img.unsqueeze(0)
|
|
|
|
with torch.no_grad():
|
|
out, gt_count = model(img, kpoint)
|
|
pred_count = torch.squeeze(out, 1)
|
|
gt_count = torch.squeeze(gt_count, 1)
|
|
|
|
diff = torch.abs(gt_count_whole - torch.sum(pred_count)).item()
|
|
mae += diff
|
|
mse += diff ** 2
|
|
|
|
if i % 5 == 0:
|
|
# with torch.no_grad():
|
|
# img_xformed = model.stnet(img).to("cpu")
|
|
# xformed.append(img_xformed)
|
|
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format(
|
|
fname[0],
|
|
torch.sum(gt_count_whole).item(),
|
|
torch.sum(pred_count).item()
|
|
))
|
|
|
|
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
|
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
|
writer.add_scalar("MAE (valid)", mae, epoch)
|
|
writer.add_scalar("MSE (valid)", mse, epoch)
|
|
if len(xformed) != 0:
|
|
img_grid = torchvision.utils.make_grid(xformed)
|
|
writer.add_image("STN: transformed image", img_grid, epoch)
|
|
nni.report_intermediate_result(mae)
|
|
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
|
mae=mae, mse=mse
|
|
))
|
|
writer.flush()
|
|
return mae
|
|
|
|
|
|
if __name__ == "__main__":
|
|
tuner_params = nni.get_next_parameter()
|
|
logger.debug("Generated hyperparameters: {}", tuner_params)
|
|
combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
|
|
|
|
if args.debug:
|
|
os.nice(15)
|
|
|
|
if combined_params.use_ddp:
|
|
# Use DDP, spawn threads
|
|
torch_mp.spawn(
|
|
worker,
|
|
args=(combined_params, ), # rank supplied at callee as 1st param
|
|
# also above *has* to be 1-tuple else runtime expands Namespace.
|
|
nprocs=combined_params.ddp_world_size,
|
|
)
|
|
else:
|
|
# No DDP, run in current thread
|
|
worker(0, combined_params)
|