import os import random from typing import Optional from argparse import Namespace import timm import torch import torch.nn as nn import torch.multiprocessing as torch_mp from torch.utils.data import DataLoader import nni import logging import numpy as np from model.transcrowd_gap import VisionTransformerGAP from arguments import args, ret_args logger = logging.getLogger("train") def setup_process_group( rank: int, world_size: int, master_addr: str = "localhost", master_port: Optional[np.ushort] = None ): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = ( str(random.randint(40000, 65545)) if master_port is None else str(master_port) ) # join point! torch.distributed.init_process_group( backend="nccl", rank=rank, world_size=world_size ) # TODO: # The shape for each batch in transcrowd is [3, 384, 384], # this is due to images being cropped before training. # To preserve image semantics wrt the entire layout, we want to apply cropping # i.e., as encoder input during the inference/training pipeline. # This should be okay since our transformations are all deterministic? # not sure... def build_train_loader(): pass def build_valid_loader(): pass def train_one_epoch( train_loader: DataLoader, model: VisionTransformerGAP, criterion, optimizer, scheduler, epoch: int, args: Namespace ): # Get learning rate curr_lr = optimizer.param_groups[0]["lr"] print("Epoch %d, processed %d samples, lr %.10f" % (epoch, epoch * len(train_loader.dataset), curr_lr) ) # Set to train mode (perspective estimator only) revpers_net.train() end = time.time() # In one epoch, for each training sample for i, (fname, img, gt_count) in enumerate(train_loader): # move stuff to device # fpass (revpers) img = img.cuda() # loss wrt revpers loss = criterion() pass def valid_one_epoch(): pass def main(rank: int, args: Namespace): pass if __name__ == "__main__": tuner_params = nni.get_next_parameter() logger.debug("Generated hyperparameters: {}", tuner_params) combined_params = Namespace( nni.utils.merge_parameter(ret_args, tuner_params) ) # Namespaces have better ergonomics, notably a struct-like access syntax. logger.debug("Parameters: {}", combined_params) if combined_params.use_ddp: # Use DDP, spawn threads torch_mp.spawn( main, args=(combined_params, ), # rank supplied automatically as 1st param nprocs=combined_params.world_size, ) else: # No DDP, run in current thread main(None, combined_params)