from argparse import Namespace import timm import torch import torch.multiprocessing as torch_mp from torch.utils.data import DataLoader import nni import logging from model.csrnet import CSRNet from model.reverse_perspective import PerspectiveEstimator from arguments import args, ret_args logger = logging.getLogger("train-revpers") # We use 2 separate networks as opposed to 1 whole network -- # this is more flexible, as we only train one of them... def gen_csrnet(pth_tar: str = None) -> CSRNet: if pth_tar is not None: model = CSRNet(load_weights=True) checkpoint = torch.load(pth_tar) model.load_state_dict(checkpoint["state_dict"], strict=False) else: model = CSRNet(load_weights=False) return model def gen_revpers(pth_tar: str = None, **kwargs) -> PerspectiveEstimator: model = PerspectiveEstimator(**kwargs) if pth_tar is not None: checkpoint = torch.load(pth_tar) model.load_state_dict(checkpoint["state_dict"], strict=False) return model def build_train_loader(): pass def build_valid_loader(): pass def train_one_epoch( train_loader: DataLoader, revpers_net: PerspectiveEstimator, csr_net: CSRNet, criterion, optimizer, scheduler, epoch: int, args: Namespace ): # Get learning rate curr_lr = optimizer.param_groups[0]["lr"] print("Epoch %d, processed %d samples, lr %.10f" % (epoch, epoch * len(train_loader.dataset), curr_lr) ) # Set to train mode (perspective estimator only) revpers_net.train() end = time.time() # In one epoch, for each training sample for i, (fname, img, gt_count) in enumerate(train_loader): # fpass (revpers) img = img.cuda() out_revpers = revpers_net(img) # We need to perform image transformation here... img = img.cpu() # fpass (csrnet -- do not train) img = img.cuda() out_csrnet = csr_net(img) # loss wrt revpers loss = criterion() pass def valid_one_epoch(): pass def main(rank: int, args: Namespace): pass if __name__ == "__main__": tuner_params = nni.get_next_parameter() logger.debug("Generated hyperparameters: {}", tuner_params) combined_params = Namespace( nni.utils.merge_parameter(ret_args, tuner_params) ) # Namespaces have better ergonomics, notably a struct-like access syntax. logger.debug("Parameters: {}", combined_params) if combined_params.use_ddp: # Use DDP, spawn threads torch_mp.spawn( main, args=(combined_params, ), # rank supplied automatically as 1st param nprocs=combined_params.world_size, ) else: # No DDP, run in current thread main(None, combined_params)