mlp-project/train.py
2024-02-27 21:27:02 +00:00

112 lines
No EOL
2.7 KiB
Python

import os
import random
from typing import Optional
from argparse import Namespace
import timm
import torch
import torch.nn as nn
import torch.multiprocessing as torch_mp
from torch.utils.data import DataLoader
import nni
import logging
import numpy as np
from model.transcrowd_gap import VisionTransformerGAP
from arguments import args, ret_args
logger = logging.getLogger("train")
def setup_process_group(
rank: int,
world_size: int,
master_addr: str = "localhost",
master_port: Optional[np.ushort] = None
):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = (
str(random.randint(40000, 65545))
if master_port is None
else str(master_port)
)
# join point!
torch.distributed.init_process_group(
backend="nccl", rank=rank, world_size=world_size
)
# TODO:
# The shape for each batch in transcrowd is [3, 384, 384],
# this is due to images being cropped before training.
# To preserve image semantics wrt the entire layout, we want to apply cropping
# i.e., as encoder input during the inference/training pipeline.
# This should be okay since our transformations are all deterministic?
# not sure...
def build_train_loader():
pass
def build_valid_loader():
pass
def train_one_epoch(
train_loader: DataLoader,
model: VisionTransformerGAP,
criterion,
optimizer,
scheduler,
epoch: int,
args: Namespace
):
# Get learning rate
curr_lr = optimizer.param_groups[0]["lr"]
print("Epoch %d, processed %d samples, lr %.10f" %
(epoch, epoch * len(train_loader.dataset), curr_lr)
)
# Set to train mode (perspective estimator only)
revpers_net.train()
end = time.time()
# In one epoch, for each training sample
for i, (fname, img, gt_count) in enumerate(train_loader):
# move stuff to device
# fpass (revpers)
img = img.cuda()
# loss wrt revpers
loss = criterion()
pass
def valid_one_epoch():
pass
def main(rank: int, args: Namespace):
pass
if __name__ == "__main__":
tuner_params = nni.get_next_parameter()
logger.debug("Generated hyperparameters: {}", tuner_params)
combined_params = Namespace(
nni.utils.merge_parameter(ret_args, tuner_params)
) # Namespaces have better ergonomics, notably a struct-like access syntax.
logger.debug("Parameters: {}", combined_params)
if combined_params.use_ddp:
# Use DDP, spawn threads
torch_mp.spawn(
main,
args=(combined_params, ), # rank supplied automatically as 1st param
nprocs=combined_params.world_size,
)
else:
# No DDP, run in current thread
main(None, combined_params)