import os import random import time from typing import Optional from argparse import Namespace import timm import torch import torch.nn as nn import torch.multiprocessing as torch_mp from torch.utils.data import DataLoader import nni import logging import numpy as np from model.transcrowd_gap import VisionTransformerGAP from arguments import args, ret_args import dataset from dataset import * from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap from checkpoint import save_checkpoint logger = logging.getLogger("train") def setup_process_group( rank: int, world_size: int, master_addr: str = "localhost", master_port: Optional[np.ushort] = None ): os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = ( str(random.randint(40000, 65545)) if master_port is None else str(master_port) ) # join point! torch.distributed.init_process_group( rank=rank, world_size=world_size ) def build_train_loader(data_keys, args): train_dataset = ListDataset( data_keys, shuffle = True, transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225]) ]), train = True, batch_size = args.batch_size, nr_workers = args.workers, args = args ) if args.use_ddp: train_dist_sampler = torch.utils.data.distributed.DistributedSampler( dataset=train_dataset ) else: train_dist_sampler = None train_loader = DataLoader( dataset=train_dataset, sampler=train_dist_sampler, batch_size=args.batch_size, drop_last=False ) return train_loader def build_test_loader(data_keys, args): test_dataset = ListDataset( data_keys, shuffle=False, transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225]) ]), args=args, train=False ) if args.use_ddp: test_dist_sampler = torch.utils.data.distributed.DistributedSampler( dataset=test_dataset ) else: test_dist_sampler = None test_loader = DataLoader( dataset=test_dataset, sampler=test_dist_sampler, batch_size=1 ) return test_loader def worker(rank: int, args: Namespace): world_size = args.ddp_world_size # (If DDP) join after process group among processes if args.use_ddp: setup_process_group(rank, world_size) # Setup device for one proc if args.use_ddp and torch.cuda.is_available(): device = torch.device(rank) elif torch.cuda.is_available(): os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus args.gpus = [int(gpu_id) for gpu_id in args.gpus.split(",")] device = None elif torch.backends.mps.is_available(): device = torch.device("mps") else: print("[!!!] Using CPU for inference. This will be slow...") device = torch.device("cpu") if device is not None: torch.set_default_device(device) # Prepare training data train_list, test_list = unpack_npy_data(args) train_data = convert_data(train_list, args, train=True) test_data = convert_data(test_list, args, train=False) train_loader = build_train_loader(train_data, args) test_loader = build_test_loader(test_data, args) # Instantiate model if args.model == "stn": model = stn_patch16_384_gap(args.pth_tar) else: model = base_patch16_384_gap(args.pth_tar) if args.use_ddp: model = nn.parallel.DistributedDataParallel( model, device_ids=[rank], output_device=rank, find_unused_parameters=True, gradient_as_bucket_view=True # XXX: vital, otherwise OOM ) else: model = nn.DataParallel( model, device_ids=args.gpus ) if device is not None: model = model.to(device) elif torch.cuda.is_available(): model = model.cuda() # criterion, optimizer, scheduler criterion = nn.L1Loss(reduction="sum") if device is not None: criterion = criterion.to(device) elif torch.cuda.is_available(): criterion = criterion.cuda() optimizer = torch.optim.Adam( [{"params": model.parameters(), "lr": args.lr}], lr=args.lr, weight_decay=args.weight_decay ) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[300], gamma=.1, last_epoch=-1 ) # Checkpointing if (not args.use_ddp) or torch.distributed.get_rank() == 0: print("[worker-0] Saving to \'{}\'".format(args.save_path)) if not os.path.exists(args.save_path): os.makedirs(args.save_path) if args.resume: if os.path.isfile(args.resume): print("=> Loading checkpoint \'{}\'".format(args.resume)) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict'], strict=False) args.start_epoch = checkpoint['epoch'] args.best_pred = checkpoint['best_prec1'] # sic. else: print("=> Checkpoint not found!") print("[worker-{}] Starting @ epoch {}: pred: {}".format( rank, args.start_epoch, args.best_pred )) # For each epoch: for epoch in range(args.start_epoch, args.epochs): # Tell sampler which epoch it is if args.use_ddp: train_loader.sampler.set_epoch(epoch) test_loader.sampler.set_epoch(epoch) # Train start = time.time() train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args) end_train = time.time() # Validate if epoch % 5 == 0 or args.debug: prec1 = valid_one_epoch(test_loader, model, device, args) end_valid = time.time() is_best = prec1 < args.best_pred args.best_pred = min(prec1, args.best_pred) print("* best MAE {mae:.3f} *".format(mae=args.best_pred)) # Save checkpoint if not args.use_ddp or torch.distributed.get_rank() == 0: save_checkpoint({ "epoch": epoch + 1, "arch": args.resume, "state_dict": model.state_dict(), "best_prec1": args.best_pred, "optimizer": optimizer.state_dict(), }, is_best, args.save_path) # cleanup if args.use_ddp: torch.distributed.destroy_process_group() def train_one_epoch( train_loader: DataLoader, model: VisionTransformerGAP, criterion, optimizer, scheduler, epoch: int, device, args: Namespace ): # Get learning rate curr_lr = optimizer.param_groups[0]["lr"] print("Epoch %d, processed %d samples, lr %.10f" % (epoch, epoch * len(train_loader.dataset), curr_lr) ) # Set to train mode model.train() # In one epoch, for each training sample for i, (fname, img, kpoint) in enumerate(train_loader): kpoint = kpoint.type(torch.FloatTensor) # fpass if device is not None: img = img.to(device) kpoint = kpoint.to(device) elif torch.cuda.is_available(): img = img.cuda() kpoint = kpoint.cuda() out, gt_count = model(img, kpoint) # loss loss = criterion(out, gt_count) # free grad from mem optimizer.zero_grad() # bpass loss.backward() # optimizer optimizer.step() # periodic message if i % args.print_freq == 0: print("Epoch {}: {}/{}".format(epoch, i, len(train_loader))) if args.debug: break scheduler.step() def valid_one_epoch(test_loader, model, device, args): print("[valid_one_epoch] Validating...") batch_size = 1 model.eval() mae = .0 mse = .0 visi = [] index = 0 for i, (fname, img, kpoint) in enumerate(test_loader): kpoint = kpoint.type(torch.FloatTensor) if device is not None: img = img.to(device) kpoint = kpoint.to(device) elif torch.cuda.is_available(): img = img.cuda() kpoint = kpoint.cuda() # XXX: what do this do if len(img.shape) == 5: img = img.squeeze(0) if len(img.shape) == 3: img = img.unsqueeze(0) with torch.no_grad(): out, gt_count = model(img, kpoint) pred_count = torch.squeeze(out, 1) gt_count = torch.squeeze(gt_count, 1) diff = torch.sum(torch.abs(gt_count - pred_count)).item() mae += diff mse += diff ** 2 mae = mae * 1.0 / (len(test_loader) * batch_size) mse = np.sqrt(mse / (len(test_loader)) * batch_size) if i % 5 == 0: print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format( fname[0], torch.sum(gt_count).item(), torch.sum(pred_count).item(), mae, mse )) nni.report_intermediate_result() print("* MAE {mae:.3f} | MSE {mse:.3f} *".format( mae=mae, mse=mse )) return mae if __name__ == "__main__": tuner_params = nni.get_next_parameter() logger.debug("Generated hyperparameters: {}", tuner_params) combined_params = nni.utils.merge_parameter(ret_args, tuner_params) if args.debug: os.nice(15) #combined_params = args #logger.debug("Parameters: {}", combined_params) if combined_params.use_ddp: # Use DDP, spawn threads torch_mp.spawn( worker, args=(combined_params, ), # rank supplied at callee as 1st param # also above *has* to be 1-tuple else runtime expands Namespace. nprocs=combined_params.world_size, ) else: # No DDP, run in current thread worker(0, combined_params)