import os import random import time from typing import Optional from argparse import Namespace import timm import torch import torch.nn as nn import torch.multiprocessing as torch_mp from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter import torchvision import nni import logging import numpy as np from model.transcrowd_gap import VisionTransformerGAP from arguments import args, ret_args import dataset from dataset import * from model.transcrowd_gap import * from checkpoint import save_checkpoint logger = logging.getLogger("train") writer = SummaryWriter(args.save_path + "/tensorboard-run") def setup_process_group( rank: int, world_size: int, master_addr: str = "localhost", master_port: Optional[np.ushort] = None ): os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = ( str(random.randint(40000, 65545)) if master_port is None else str(master_port) ) # join point! torch.distributed.init_process_group( rank=rank, world_size=world_size ) def build_train_loader(data_keys, args): train_dataset = ListDataset( data_keys, shuffle = True, transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225]) ]), train = True, batch_size = args.batch_size, nr_workers = args.workers, args = args ) if args.use_ddp: train_dist_sampler = torch.utils.data.distributed.DistributedSampler( dataset=train_dataset ) else: train_dist_sampler = None train_loader = DataLoader( dataset=train_dataset, sampler=train_dist_sampler, batch_size=args.batch_size, drop_last=False ) return train_loader def build_test_loader(data_keys, args): test_dataset = ListDataset( data_keys, shuffle=False, transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225]) ]), args=args, train=False ) if args.use_ddp: test_dist_sampler = torch.utils.data.distributed.DistributedSampler( dataset=test_dataset ) else: test_dist_sampler = None test_loader = DataLoader( dataset=test_dataset, sampler=test_dist_sampler, batch_size=1 ) return test_loader def worker(rank: int, args: Namespace): world_size = args.ddp_world_size # (If DDP) join after process group among processes if args.use_ddp: setup_process_group(rank, world_size) # Setup device for one proc if args.use_ddp and torch.cuda.is_available(): device = torch.device(rank) elif torch.cuda.is_available(): os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus args.gpus = [int(gpu_id) for gpu_id in args.gpus.split(",")] device = None elif torch.backends.mps.is_available(): device = torch.device("mps") else: print("[!!!] Using CPU for inference. This will be slow...") device = torch.device("cpu") if device is not None: torch.set_default_device(device) # Prepare training data train_list, test_list = unpack_npy_data(args) train_data = convert_data(train_list, args, train=True) test_data = convert_data(test_list, args, train=False) train_loader = build_train_loader(train_data, args) test_loader = build_test_loader(test_data, args) # Instantiate model if args.model == "stn": model = stn_patch16_384_gap(args.pth_tar) else: model = base_patch16_384_gap(args.pth_tar) if args.use_ddp: model = nn.parallel.DistributedDataParallel( model, device_ids=[rank], output_device=rank, find_unused_parameters=True, gradient_as_bucket_view=True # XXX: vital, otherwise OOM ) else: model = nn.DataParallel( model, device_ids=args.gpus ) if device is not None: model = model.to(device) elif torch.cuda.is_available(): model = model.cuda() # criterion, optimizer, scheduler criterion = nn.L1Loss(reduction="sum") if device is not None: criterion = criterion.to(device) elif torch.cuda.is_available(): criterion = criterion.cuda() optimizer = torch.optim.Adam( [{"params": model.parameters(), "lr": args.lr}], lr=args.lr, weight_decay=args.weight_decay ) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[300], gamma=.1, last_epoch=-1 ) # Checkpointing if (not args.use_ddp) or torch.distributed.get_rank() == 0: print("[worker-0] Saving to \'{}\'".format(args.save_path)) if not os.path.exists(args.save_path): os.makedirs(args.save_path) if args.resume: if os.path.isfile(args.resume): print("=> Loading checkpoint \'{}\'".format(args.resume)) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict'], strict=False) args.start_epoch = checkpoint['epoch'] args.best_pred = checkpoint['best_prec1'] # sic. else: print("=> Checkpoint not found!") print("[worker-{}] Starting @ epoch {}: pred: {}".format( rank, args.start_epoch, args.best_pred )) # For each epoch: for epoch in range(args.start_epoch, args.epochs): # Tell sampler which epoch it is if args.use_ddp: train_loader.sampler.set_epoch(epoch) test_loader.sampler.set_epoch(epoch) # Train start = time.time() train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args) end_train = time.time() # Validate if epoch % 5 == 0 or args.debug: prec1 = valid_one_epoch(test_loader, model, device, epoch, args) end_valid = time.time() is_best = prec1 < args.best_pred args.best_pred = min(prec1, args.best_pred) print("* best MAE {mae:.3f} *".format(mae=args.best_pred)) # Save checkpoint if not args.use_ddp or torch.distributed.get_rank() == 0: save_checkpoint({ "epoch": epoch + 1, "arch": args.resume, "state_dict": model.state_dict(), "best_prec1": args.best_pred, "optimizer": optimizer.state_dict(), }, is_best, args.save_path) # cleanup if args.use_ddp: torch.distributed.destroy_process_group() def train_one_epoch( train_loader: DataLoader, model: VisionTransformerGAP, criterion, optimizer, scheduler, epoch: int, device, args: Namespace ): # Get learning rate curr_lr = optimizer.param_groups[0]["lr"] print("Epoch %d, processed %d samples, lr %.10f" % (epoch, epoch * len(train_loader.dataset), curr_lr) ) # Set to train mode model.train() # In one epoch, for each training sample for i, (fname, img, kpoint, gt_count_whole) in enumerate(train_loader): kpoint = kpoint.type(torch.FloatTensor) gt_count_whole = gt_count_whole.type(torch.FloatTensor).unsqueeze(1) batch_size = img.size(0) # send to device if device is not None: img = img.to(device) kpoint = kpoint.to(device) gt_count_whole = gt_count_whole.to(device) device_type = device.type elif torch.cuda.is_available(): img = img.cuda() kpoint = kpoint.cuda() gt_count_whole = gt_count_whole.cuda() device_type = "cuda" # Desperate measure to reduce mem footprint... with torch.autocast(device_type): # fpass out, gt_count = model(img, kpoint) # loss loss = criterion(out, gt_count) # wrt. transformer writer.add_scalar("L1-loss wrt. xformer (train)", loss, epoch * i) loss += ( F.mse_loss( # stn: info retainment gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), gt_count_whole) + F.threshold( # stn: perspective correction gt_count.view(batch_size, -1).var(dim=1).mean(), threshold=loss.item(), value=loss.item() ) ) writer.add_scalar("Composite loss (train)", loss, epoch * i) # free grad from mem optimizer.zero_grad(set_to_none=True) # bpass loss.backward() # optimizer optimizer.step() if args.debug: break # Flush writer writer.flush() scheduler.step() def valid_one_epoch(test_loader, model, device, epoch, args): print("[valid_one_epoch] Validating...") batch_size = 1 model.eval() mae = .0 mse = .0 visi = [] index = 0 xformed = [] for i, (fname, img, kpoint, gt_count_whole) in enumerate(test_loader): kpoint = kpoint.type(torch.FloatTensor) gt_count_whole = gt_count_whole.type(torch.FloatTensor) if device is not None: img = img.to(device) kpoint = kpoint.to(device) gt_count_whole = gt_count_whole.to(device) elif torch.cuda.is_available(): img = img.cuda() kpoint = kpoint.cuda() gt_count_whole = gt_count_whole.cuda() # XXX: do this even happen? if len(img.shape) == 5: img = img.squeeze(0) if len(img.shape) == 3: img = img.unsqueeze(0) with torch.no_grad(): out, gt_count = model(img, kpoint) pred_count = torch.squeeze(out, 1) gt_count = torch.squeeze(gt_count, 1) diff = torch.abs(gt_count_whole - torch.sum(pred_count)).item() mae += diff mse += diff ** 2 if i % 5 == 0: if isinstance(model, STNet_VisionTransformerGAP): with torch.no_grad(): img_xformed = model.stnet(img).to("cpu") xformed.append(img_xformed) print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format( fname[0], torch.sum(gt_count_whole).item(), torch.sum(pred_count).item() )) mae = mae * 1.0 / (len(test_loader) * batch_size) mse = np.sqrt(mse / (len(test_loader)) * batch_size) writer.add_scalar("MAE (valid)", mae, epoch) writer.add_scalar("MSE (valid)", mse, epoch) if len(xformed) != 0: img_grid = torchvision.utils.make_grid(xformed) writer.add_image("STN: transformed image", img_grid, epoch) nni.report_intermediate_result(mae) print("* MAE {mae:.3f} | MSE {mse:.3f} *".format( mae=mae, mse=mse )) writer.flush() return mae if __name__ == "__main__": tuner_params = nni.get_next_parameter() logger.debug("Generated hyperparameters: {}", tuner_params) combined_params = nni.utils.merge_parameter(ret_args, tuner_params) if args.debug: os.nice(15) if combined_params.use_ddp: # Use DDP, spawn threads torch_mp.spawn( worker, args=(combined_params, ), # rank supplied at callee as 1st param # also above *has* to be 1-tuple else runtime expands Namespace. nprocs=combined_params.ddp_world_size, ) else: # No DDP, run in current thread worker(0, combined_params)