Added sth more
This commit is contained in:
parent
49a913a328
commit
99266d9c92
2 changed files with 77 additions and 30 deletions
112
train.py
Normal file
112
train.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
import os
|
||||
import random
|
||||
from typing import Optional
|
||||
from argparse import Namespace
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as torch_mp
|
||||
from torch.utils.data import DataLoader
|
||||
import nni
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from model.transcrowd_gap import VisionTransformerGAP
|
||||
from arguments import args, ret_args
|
||||
|
||||
logger = logging.getLogger("train")
|
||||
|
||||
|
||||
def setup_process_group(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
master_addr: str = "localhost",
|
||||
master_port: Optional[np.ushort] = None
|
||||
):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = (
|
||||
str(random.randint(40000, 65545))
|
||||
if master_port is None
|
||||
else str(master_port)
|
||||
)
|
||||
|
||||
# join point!
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl", rank=rank, world_size=world_size
|
||||
)
|
||||
|
||||
# TODO:
|
||||
# The shape for each batch in transcrowd is [3, 384, 384],
|
||||
# this is due to images being cropped before training.
|
||||
# To preserve image semantics wrt the entire layout, we want to apply cropping
|
||||
# i.e., as encoder input during the inference/training pipeline.
|
||||
# This should be okay since our transformations are all deterministic?
|
||||
# not sure...
|
||||
|
||||
|
||||
|
||||
|
||||
def build_train_loader():
|
||||
pass
|
||||
|
||||
|
||||
def build_valid_loader():
|
||||
pass
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
train_loader: DataLoader,
|
||||
model: VisionTransformerGAP,
|
||||
criterion,
|
||||
optimizer,
|
||||
scheduler,
|
||||
epoch: int,
|
||||
args: Namespace
|
||||
):
|
||||
# Get learning rate
|
||||
curr_lr = optimizer.param_groups[0]["lr"]
|
||||
print("Epoch %d, processed %d samples, lr %.10f" %
|
||||
(epoch, epoch * len(train_loader.dataset), curr_lr)
|
||||
)
|
||||
|
||||
# Set to train mode (perspective estimator only)
|
||||
revpers_net.train()
|
||||
end = time.time()
|
||||
|
||||
# In one epoch, for each training sample
|
||||
for i, (fname, img, gt_count) in enumerate(train_loader):
|
||||
# move stuff to device
|
||||
# fpass (revpers)
|
||||
img = img.cuda()
|
||||
|
||||
# loss wrt revpers
|
||||
loss = criterion()
|
||||
|
||||
|
||||
pass
|
||||
|
||||
def valid_one_epoch():
|
||||
pass
|
||||
|
||||
def main(rank: int, args: Namespace):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
tuner_params = nni.get_next_parameter()
|
||||
logger.debug("Generated hyperparameters: {}", tuner_params)
|
||||
combined_params = Namespace(
|
||||
nni.utils.merge_parameter(ret_args, tuner_params)
|
||||
) # Namespaces have better ergonomics, notably a struct-like access syntax.
|
||||
logger.debug("Parameters: {}", combined_params)
|
||||
|
||||
if combined_params.use_ddp:
|
||||
# Use DDP, spawn threads
|
||||
torch_mp.spawn(
|
||||
main,
|
||||
args=(combined_params, ), # rank supplied automatically as 1st param
|
||||
nprocs=combined_params.world_size,
|
||||
)
|
||||
else:
|
||||
# No DDP, run in current thread
|
||||
main(None, combined_params)
|
||||
Loading…
Add table
Add a link
Reference in a new issue