Added sth more

This commit is contained in:
Zhengyi Chen 2024-02-27 21:27:02 +00:00
parent 49a913a328
commit 99266d9c92
2 changed files with 77 additions and 30 deletions

112
train.py Normal file
View file

@ -0,0 +1,112 @@
import os
import random
from typing import Optional
from argparse import Namespace
import timm
import torch
import torch.nn as nn
import torch.multiprocessing as torch_mp
from torch.utils.data import DataLoader
import nni
import logging
import numpy as np
from model.transcrowd_gap import VisionTransformerGAP
from arguments import args, ret_args
logger = logging.getLogger("train")
def setup_process_group(
rank: int,
world_size: int,
master_addr: str = "localhost",
master_port: Optional[np.ushort] = None
):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = (
str(random.randint(40000, 65545))
if master_port is None
else str(master_port)
)
# join point!
torch.distributed.init_process_group(
backend="nccl", rank=rank, world_size=world_size
)
# TODO:
# The shape for each batch in transcrowd is [3, 384, 384],
# this is due to images being cropped before training.
# To preserve image semantics wrt the entire layout, we want to apply cropping
# i.e., as encoder input during the inference/training pipeline.
# This should be okay since our transformations are all deterministic?
# not sure...
def build_train_loader():
pass
def build_valid_loader():
pass
def train_one_epoch(
train_loader: DataLoader,
model: VisionTransformerGAP,
criterion,
optimizer,
scheduler,
epoch: int,
args: Namespace
):
# Get learning rate
curr_lr = optimizer.param_groups[0]["lr"]
print("Epoch %d, processed %d samples, lr %.10f" %
(epoch, epoch * len(train_loader.dataset), curr_lr)
)
# Set to train mode (perspective estimator only)
revpers_net.train()
end = time.time()
# In one epoch, for each training sample
for i, (fname, img, gt_count) in enumerate(train_loader):
# move stuff to device
# fpass (revpers)
img = img.cuda()
# loss wrt revpers
loss = criterion()
pass
def valid_one_epoch():
pass
def main(rank: int, args: Namespace):
pass
if __name__ == "__main__":
tuner_params = nni.get_next_parameter()
logger.debug("Generated hyperparameters: {}", tuner_params)
combined_params = Namespace(
nni.utils.merge_parameter(ret_args, tuner_params)
) # Namespaces have better ergonomics, notably a struct-like access syntax.
logger.debug("Parameters: {}", combined_params)
if combined_params.use_ddp:
# Use DDP, spawn threads
torch_mp.spawn(
main,
args=(combined_params, ), # rank supplied automatically as 1st param
nprocs=combined_params.world_size,
)
else:
# No DDP, run in current thread
main(None, combined_params)