101 lines
No EOL
2.8 KiB
Python
101 lines
No EOL
2.8 KiB
Python
from argparse import Namespace
|
|
|
|
import timm
|
|
import torch
|
|
import torch.multiprocessing as torch_mp
|
|
from torch.utils.data import DataLoader
|
|
import nni
|
|
import logging
|
|
|
|
from model.csrnet import CSRNet
|
|
from model.reverse_perspective import PerspectiveEstimator
|
|
from arguments import args, ret_args
|
|
|
|
logger = logging.getLogger("train-revpers")
|
|
|
|
# We use 2 separate networks as opposed to 1 whole network --
|
|
# this is more flexible, as we only train one of them...
|
|
def gen_csrnet(pth_tar: str = None) -> CSRNet:
|
|
if pth_tar is not None:
|
|
model = CSRNet(load_weights=True)
|
|
checkpoint = torch.load(pth_tar)
|
|
model.load_state_dict(checkpoint["state_dict"], strict=False)
|
|
else:
|
|
model = CSRNet(load_weights=False)
|
|
return model
|
|
|
|
def gen_revpers(pth_tar: str = None, **kwargs) -> PerspectiveEstimator:
|
|
model = PerspectiveEstimator(**kwargs)
|
|
if pth_tar is not None:
|
|
checkpoint = torch.load(pth_tar)
|
|
model.load_state_dict(checkpoint["state_dict"], strict=False)
|
|
return model
|
|
|
|
def build_train_loader():
|
|
pass
|
|
|
|
def build_valid_loader():
|
|
pass
|
|
|
|
def train_one_epoch(
|
|
train_loader: DataLoader,
|
|
revpers_net: PerspectiveEstimator,
|
|
csr_net: CSRNet,
|
|
criterion,
|
|
optimizer,
|
|
scheduler,
|
|
epoch: int,
|
|
args: Namespace
|
|
):
|
|
# Get learning rate
|
|
curr_lr = optimizer.param_groups[0]["lr"]
|
|
print("Epoch %d, processed %d samples, lr %.10f" %
|
|
(epoch, epoch * len(train_loader.dataset), curr_lr)
|
|
)
|
|
|
|
# Set to train mode (perspective estimator only)
|
|
revpers_net.train()
|
|
end = time.time()
|
|
|
|
# In one epoch, for each training sample
|
|
for i, (fname, img, gt_count) in enumerate(train_loader):
|
|
# fpass (revpers)
|
|
img = img.cuda()
|
|
out_revpers = revpers_net(img)
|
|
# We need to perform image transformation here...
|
|
|
|
img = img.cpu()
|
|
|
|
# fpass (csrnet -- do not train)
|
|
img = img.cuda()
|
|
out_csrnet = csr_net(img)
|
|
# loss wrt revpers
|
|
loss = criterion()
|
|
|
|
|
|
pass
|
|
|
|
def valid_one_epoch():
|
|
pass
|
|
|
|
def main(rank: int, args: Namespace):
|
|
pass
|
|
|
|
if __name__ == "__main__":
|
|
tuner_params = nni.get_next_parameter()
|
|
logger.debug("Generated hyperparameters: {}", tuner_params)
|
|
combined_params = Namespace(
|
|
nni.utils.merge_parameter(ret_args, tuner_params)
|
|
) # Namespaces have better ergonomics, notably a struct-like access syntax.
|
|
logger.debug("Parameters: {}", combined_params)
|
|
|
|
if combined_params.use_ddp:
|
|
# Use DDP, spawn threads
|
|
torch_mp.spawn(
|
|
main,
|
|
args=(combined_params, ), # rank supplied automatically as 1st param
|
|
nprocs=combined_params.world_size,
|
|
)
|
|
else:
|
|
# No DDP, run in current thread
|
|
main(None, combined_params) |