lemme cook alright?
This commit is contained in:
parent
b6d2460060
commit
62df7464e4
9 changed files with 504 additions and 3 deletions
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
baseline-experiments/
|
||||||
|
synchronous/
|
||||||
86
arguments.py
Normal file
86
arguments.py
Normal file
|
|
@ -0,0 +1,86 @@
|
||||||
|
import argparse
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description = "Reverse-perspective + (TransCrowd | CSRNet)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reproducibility configuration ==============================================
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed", type=int, default=None, help="RNG seed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data configuration =========================================================
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_dataset", type=str, default="ShanghaiA", help="Training dataset"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_dataset", type=str, default="ShanghaiA", help="Evaluation dataset"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--print_freq", type=int, default=1,
|
||||||
|
help="Print evaluation data per <print-freq> training epochs"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--start_epoch", type=int, default=0, help="Epoch to start training from"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path", type=str, default="./save/default/",
|
||||||
|
help="Directory to save checkpoints in"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model configuration ========================================================
|
||||||
|
parser.add_argument(
|
||||||
|
"--load_revnet_from", type=str, default=None,
|
||||||
|
help="Pre-trained reverse perspective model path"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load_csrnet_from", type=str, default=None,
|
||||||
|
help="Pre-trained CSRNet model path"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load_transcrowd_from", type=str, default=None,
|
||||||
|
help="Pre-trained TransCrowd model path"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer configuration ====================================================
|
||||||
|
parser.add_argument(
|
||||||
|
"--weight_decay", type=float, default=5e-4, help="Weight decay"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--momentum", type=float, default=0.95, help="Momentum"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--best_pred", type=float, default=1e5,
|
||||||
|
help="Best prediction (MAE/MSE etc.)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance configuration ==================================================
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=8, help="Number of images per batch"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--epochs", type=int, default=250, help="Number of epochs to train"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpus", type=List[int], default=[0],
|
||||||
|
help="GPU IDs to be made available for training runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Runtime configuration ======================================================
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_ddp", type=bool, default=False,
|
||||||
|
help="Use DistributedDataParallel training"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddp_world_size", type=int, default=1,
|
||||||
|
help="DDP: Number of processes in Pytorch process group"
|
||||||
|
)
|
||||||
|
|
||||||
|
# nni configuration ==========================================================
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr", type=float, default=1e-5, help="Learning rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
ret_args = parser.parse_args()
|
||||||
0
eval-transcrowd.py
Normal file
0
eval-transcrowd.py
Normal file
59
model/csrnet.py
Normal file
59
model/csrnet.py
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
# Stolen from https://github.com/leeyeehoo/CSRNet-pytorch.git
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
from torchvision import models
|
||||||
|
from utils import save_net,load_net
|
||||||
|
|
||||||
|
class CSRNet(nn.Module):
|
||||||
|
def __init__(self, load_weights=False):
|
||||||
|
super(CSRNet, self).__init__()
|
||||||
|
|
||||||
|
# Ref. 2018 paper
|
||||||
|
self.seen = 0
|
||||||
|
self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
|
||||||
|
self.backend_feat = [512, 512, 512, 256, 128, 64] # 4-parallel, 1, 2, 2-then-4, 4 dilation rates
|
||||||
|
|
||||||
|
self.frontend = make_layers(self.frontend_feat)
|
||||||
|
self.backend = make_layers(self.backend_feat,in_channels = 512,dilation = True)
|
||||||
|
self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
|
||||||
|
if not load_weights:
|
||||||
|
mod = models.vgg16(pretrained = True)
|
||||||
|
self._initialize_weights()
|
||||||
|
for i in range(len(self.frontend.state_dict().items())):
|
||||||
|
self.frontend.state_dict().items()[i][1].data[:] = mod.state_dict().items()[i][1].data[:]
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
x = self.frontend(x)
|
||||||
|
x = self.backend(x)
|
||||||
|
x = self.output_layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _initialize_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.normal_(m.weight, std=0.01)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def make_layers(cfg, in_channels = 3, batch_norm=False, dilation=False):
|
||||||
|
if dilation:
|
||||||
|
d_rate = 2
|
||||||
|
else:
|
||||||
|
d_rate = 1
|
||||||
|
layers = []
|
||||||
|
for v in cfg:
|
||||||
|
if v == 'M':
|
||||||
|
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||||
|
else:
|
||||||
|
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate)
|
||||||
|
if batch_norm:
|
||||||
|
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
||||||
|
else:
|
||||||
|
layers += [conv2d, nn.ReLU(inplace=True)]
|
||||||
|
in_channels = v
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
@ -90,7 +90,7 @@ class PerspectiveEstimator(nn.Module):
|
||||||
stride=conv_stride,
|
stride=conv_stride,
|
||||||
dilation=conv_dilation,
|
dilation=conv_dilation,
|
||||||
), # (N, 1, H, W)
|
), # (N, 1, H, W)
|
||||||
'revpers_avg_pooling0': nn.AdaptiveAvgPool2d(
|
'revpers_avg_pool0': nn.AdaptiveAvgPool2d(
|
||||||
output_size=(pool_capacity, 1)
|
output_size=(pool_capacity, 1)
|
||||||
), # (N, 1, K, 1)
|
), # (N, 1, K, 1)
|
||||||
# [?] Do we need to explicitly translate to (N, K) here?
|
# [?] Do we need to explicitly translate to (N, K) here?
|
||||||
|
|
@ -108,7 +108,7 @@ class PerspectiveEstimator(nn.Module):
|
||||||
out = layer.forward(out)
|
out = layer.forward(out)
|
||||||
|
|
||||||
# Normalize in (0, 1]
|
# Normalize in (0, 1]
|
||||||
F.relu_(out) # in-place
|
F.relu(out, inplace=True)
|
||||||
out = torch.exp(-out) + self.epsilon
|
out = torch.exp(-out) + self.epsilon
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
@ -116,4 +116,47 @@ class PerspectiveEstimator(nn.Module):
|
||||||
# def unsupervised_loss(predictions, targets):
|
# def unsupervised_loss(predictions, targets):
|
||||||
|
|
||||||
# [TODO] We need a modified loss -- one that takes advantage of attention instead
|
# [TODO] We need a modified loss -- one that takes advantage of attention instead
|
||||||
# of feature map. I feel like they should work likewise but who knows
|
# of feature map. I feel like they should work likewise but who knows
|
||||||
|
# [XXX] no forget it, we are pre-training rev-perspective as told by the 2020 paper
|
||||||
|
# i.e., via using CSRNet.
|
||||||
|
# Not sure which part is the feature map derived. Maybe after the front-end?
|
||||||
|
# In any case we can always just use the CSR output (inferred density map) as feature map --
|
||||||
|
# through which we compute, for each image:
|
||||||
|
# criterion = Variance([output.sum(axis=W) * effective_pixel_per_row])
|
||||||
|
# In other cases we sum over channels i.e., each feature map i.e., over each filter output
|
||||||
|
# Not sure what channel means in this case...
|
||||||
|
def warped_output_loss(csrnet_pred):
|
||||||
|
N, H, W = csrnet_pred.shape()
|
||||||
|
|
||||||
|
|
||||||
|
def transform_coordinates(
|
||||||
|
img: torch.Tensor, # (C, W, H)
|
||||||
|
factor: float,
|
||||||
|
in_place: bool = True
|
||||||
|
):
|
||||||
|
dev_of_img = img.device
|
||||||
|
|
||||||
|
# Normalize X coords to [0, pi]
|
||||||
|
min_x = torch.Tensor([0., 0., 0.]).to(dev_of_img)
|
||||||
|
max_x = torch.Tensor([0., np.pi, 0.]).to(dev_of_img)
|
||||||
|
min_xdim = torch.min(img, dim=1, keepdim=True)[0]
|
||||||
|
max_xdim = torch.max(img, dim=1, keepdim=True)[0]
|
||||||
|
(img.sub_(min_xdim)
|
||||||
|
.div_(max_xdim - min_xdim)
|
||||||
|
.mul_(max_x - min_x)
|
||||||
|
.add_(min_x))
|
||||||
|
|
||||||
|
# Normalize Y coords to [0, 1]
|
||||||
|
min_y = torch.Tensor([0., 0., 0.]).to(dev_of_img)
|
||||||
|
max_y = torch.Tensor([0., 1., 0.]).to(dev_of_img)
|
||||||
|
min_ydim = torch.min(img, dim=2, keepdim=True)[0]
|
||||||
|
max_ydim = torch.max(img, dim=2, keepdim=True)[0]
|
||||||
|
(img.sub_(min_ydim)
|
||||||
|
.div_(max_ydim - min_ydim)
|
||||||
|
.mul_(max_y - min_y)
|
||||||
|
.add_(min_y))
|
||||||
|
|
||||||
|
# Do elliptical transformation
|
||||||
|
tmp = img.clone().detach()
|
||||||
|
|
||||||
|
pass
|
||||||
0
model/revpers_csrnet.py
Normal file
0
model/revpers_csrnet.py
Normal file
204
model/stn.py
Normal file
204
model/stn.py
Normal file
|
|
@ -0,0 +1,204 @@
|
||||||
|
# stole from pytorch tutorial
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
import torchvision
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class StnNet(nn.Module):
|
||||||
|
def __init__(self, input_size: torch.Size):
|
||||||
|
super(StnNet, self).__init__()
|
||||||
|
|
||||||
|
# Sanity check
|
||||||
|
assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W)
|
||||||
|
if len(input_size) == 3:
|
||||||
|
channels, height, width = input_size
|
||||||
|
else:
|
||||||
|
channels, height, width = input_size[1:]
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||||
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||||
|
self.conv2_drop = nn.Dropout2d()
|
||||||
|
self.fc1 = nn.Linear(320, 50)
|
||||||
|
self.fc2 = nn.Linear(50, 10)
|
||||||
|
|
||||||
|
# Spatial transformer localization-network
|
||||||
|
self.localization_net = nn.Sequential( # (N, C, H, W)
|
||||||
|
nn.Conv2d(channels, 8, kernel_size=7), # (N, 8, H-6, W-6)
|
||||||
|
nn.MaxPool2d(2, stride=2), # (N, 8, (H-6)/2, (W-6)/2)
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Conv2d(8, 10, kernel_size=5), # (N, 10, ((H-6)/2)-4, ((W-6)/2)-4)
|
||||||
|
nn.MaxPool2d(2, stride=2), # (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
|
||||||
|
nn.ReLU(True)
|
||||||
|
) # -> (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
|
||||||
|
self._loc_net_out_shape = (
|
||||||
|
10,
|
||||||
|
(((height - 6) // 2) - 4) // 2,
|
||||||
|
(((width - 6) // 2) - 4) // 2
|
||||||
|
) # TODO: PLEASE let me know if there are better ways of doing this...
|
||||||
|
|
||||||
|
# Regressor for the 3 * 2 affine matrix
|
||||||
|
self.fc_loc = nn.Sequential(
|
||||||
|
nn.Linear(np.prod(self._loc_net_out_shape), 32),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Linear(32, 3 * 2)
|
||||||
|
) # -> (6,)
|
||||||
|
|
||||||
|
# Initialize the weights/bias with identity transformation
|
||||||
|
self.fc_loc[2].weight.data.zero_()
|
||||||
|
self.fc_loc[2].bias.data.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[1, 0, 0,
|
||||||
|
0, 1, 0],
|
||||||
|
dtype=torch.float
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Spatial transformer network forward function
|
||||||
|
def stn(self, x):
|
||||||
|
xs = self.localization_net(x)
|
||||||
|
xs = xs.view(-1, np.prod(self._loc_net_out_shape)) # -> (N, whatever)
|
||||||
|
theta = self.fc_loc(xs)
|
||||||
|
theta = theta.view(-1, 2, 3) # -> (2, 3)
|
||||||
|
|
||||||
|
grid = F.affine_grid(theta, x.size())
|
||||||
|
x = F.grid_sample(x, grid)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# transform the input
|
||||||
|
x = self.stn(x)
|
||||||
|
|
||||||
|
# Perform the usual forward pass
|
||||||
|
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||||
|
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||||
|
x = x.view(-1, 320)
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
x = F.dropout(x, training=self.training)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return F.log_softmax(x, dim=1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset=datasets.MNIST(
|
||||||
|
root="./synchronous/",
|
||||||
|
train=True,
|
||||||
|
download=True,
|
||||||
|
transform=transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((.1307, ), (.3081, ))
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
batch_size=64,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=4,
|
||||||
|
)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset=datasets.MNIST(
|
||||||
|
root = "./synchronous/",
|
||||||
|
train=False,
|
||||||
|
transform=transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((.1307, ), (.3081, ))
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
batch_size=64,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=4,
|
||||||
|
)
|
||||||
|
shape_of_input = next(iter(train_loader))[0].shape
|
||||||
|
model = StnNet(shape_of_input).to(device)
|
||||||
|
optimizer = optim.SGD(model.parameters(), lr=.01)
|
||||||
|
def train(epoch):
|
||||||
|
model.train()
|
||||||
|
for i, (x_, t_) in enumerate(train_loader):
|
||||||
|
# XXX: x_.shape == (N, C, H, W)
|
||||||
|
# Inference
|
||||||
|
x_, t_ = x_.to(device), t_.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
y_ = model(x_)
|
||||||
|
# Backprop
|
||||||
|
l_ = F.nll_loss(y_, t_)
|
||||||
|
l_.backward()
|
||||||
|
optimizer.step()
|
||||||
|
if i % 500 == 0:
|
||||||
|
print("Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
||||||
|
epoch,
|
||||||
|
i * len(x_),
|
||||||
|
len(train_loader.dataset),
|
||||||
|
100. * i / len(train_loader),
|
||||||
|
l_.item()
|
||||||
|
))
|
||||||
|
|
||||||
|
def valid():
|
||||||
|
with torch.no_grad():
|
||||||
|
model.eval()
|
||||||
|
valid_loss = 0
|
||||||
|
correct = 0
|
||||||
|
for x_, t_ in valid_loader:
|
||||||
|
x_, t_ = x_.to(device), t_.to(device)
|
||||||
|
y_ = model(x_)
|
||||||
|
# Sum batch loss
|
||||||
|
valid_loss += F.nll_loss(y_, t_, size_average=False).item()
|
||||||
|
pred = y_.max(1, keepdim=True)[1]
|
||||||
|
correct += pred.eq(t_.view_as(pred)).sum().item()
|
||||||
|
|
||||||
|
valid_loss /= len(valid_loader.dataset)
|
||||||
|
print("\nValid set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n"
|
||||||
|
.format(
|
||||||
|
valid_loss, correct, len(valid_loader.dataset),
|
||||||
|
100. * correct / len(valid_loader.dataset)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def convert_image_np(inp):
|
||||||
|
"""Convert a Tensor to numpy image."""
|
||||||
|
inp = inp.numpy().transpose((1, 2, 0))
|
||||||
|
mean = np.array([0.485, 0.456, 0.406])
|
||||||
|
std = np.array([0.229, 0.224, 0.225])
|
||||||
|
inp = std * inp + mean
|
||||||
|
inp = np.clip(inp, 0, 1)
|
||||||
|
return inp
|
||||||
|
|
||||||
|
# We want to visualize the output of the spatial transformers layer
|
||||||
|
# after the training, we visualize a batch of input images and
|
||||||
|
# the corresponding transformed batch using STN.
|
||||||
|
def visualize_stn():
|
||||||
|
with torch.no_grad():
|
||||||
|
# Get a batch of training data
|
||||||
|
data = next(iter(test_loader))[0].to(device)
|
||||||
|
|
||||||
|
input_tensor = data.cpu()
|
||||||
|
transformed_input_tensor = model.stn(data).cpu()
|
||||||
|
|
||||||
|
in_grid = convert_image_np(
|
||||||
|
torchvision.utils.make_grid(input_tensor))
|
||||||
|
|
||||||
|
out_grid = convert_image_np(
|
||||||
|
torchvision.utils.make_grid(transformed_input_tensor))
|
||||||
|
|
||||||
|
# Plot the results side-by-side
|
||||||
|
f, axarr = plt.subplots(1, 2)
|
||||||
|
axarr[0].imshow(in_grid)
|
||||||
|
axarr[0].set_title('Dataset Images')
|
||||||
|
|
||||||
|
axarr[1].imshow(out_grid)
|
||||||
|
axarr[1].set_title('Transformed Images')
|
||||||
|
|
||||||
|
for epoch in range(1, 20 + 1):
|
||||||
|
train(epoch)
|
||||||
|
valid()
|
||||||
|
|
||||||
|
# Visualize the STN transformation on some input batch
|
||||||
|
visualize_stn()
|
||||||
|
|
||||||
|
plt.ioff()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
101
train-revpers.py
Normal file
101
train-revpers.py
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
|
import timm
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as torch_mp
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import nni
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from model.csrnet import CSRNet
|
||||||
|
from model.reverse_perspective import PerspectiveEstimator
|
||||||
|
from arguments import args, ret_args
|
||||||
|
|
||||||
|
logger = logging.getLogger("train-revpers")
|
||||||
|
|
||||||
|
# We use 2 separate networks as opposed to 1 whole network --
|
||||||
|
# this is more flexible, as we only train one of them...
|
||||||
|
def gen_csrnet(pth_tar: str = None) -> CSRNet:
|
||||||
|
if pth_tar is not None:
|
||||||
|
model = CSRNet(load_weights=True)
|
||||||
|
checkpoint = torch.load(pth_tar)
|
||||||
|
model.load_state_dict(checkpoint["state_dict"], strict=False)
|
||||||
|
else:
|
||||||
|
model = CSRNet(load_weights=False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def gen_revpers(pth_tar: str = None, **kwargs) -> PerspectiveEstimator:
|
||||||
|
model = PerspectiveEstimator(**kwargs)
|
||||||
|
if pth_tar is not None:
|
||||||
|
checkpoint = torch.load(pth_tar)
|
||||||
|
model.load_state_dict(checkpoint["state_dict"], strict=False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def build_train_loader():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def build_valid_loader():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def train_one_epoch(
|
||||||
|
train_loader: DataLoader,
|
||||||
|
revpers_net: PerspectiveEstimator,
|
||||||
|
csr_net: CSRNet,
|
||||||
|
criterion,
|
||||||
|
optimizer,
|
||||||
|
scheduler,
|
||||||
|
epoch: int,
|
||||||
|
args: Namespace
|
||||||
|
):
|
||||||
|
# Get learning rate
|
||||||
|
curr_lr = optimizer.param_groups[0]["lr"]
|
||||||
|
print("Epoch %d, processed %d samples, lr %.10f" %
|
||||||
|
(epoch, epoch * len(train_loader.dataset), curr_lr)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set to train mode (perspective estimator only)
|
||||||
|
revpers_net.train()
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
# In one epoch, for each training sample
|
||||||
|
for i, (fname, img, gt_count) in enumerate(train_loader):
|
||||||
|
# fpass (revpers)
|
||||||
|
img = img.cuda()
|
||||||
|
out_revpers = revpers_net(img)
|
||||||
|
# We need to perform image transformation here...
|
||||||
|
|
||||||
|
img = img.cpu()
|
||||||
|
|
||||||
|
# fpass (csrnet -- do not train)
|
||||||
|
img = img.cuda()
|
||||||
|
out_csrnet = csr_net(img)
|
||||||
|
# loss wrt revpers
|
||||||
|
loss = criterion()
|
||||||
|
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def valid_one_epoch():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def main(rank: int, args: Namespace):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tuner_params = nni.get_next_parameter()
|
||||||
|
logger.debug("Generated hyperparameters: {}", tuner_params)
|
||||||
|
combined_params = Namespace(
|
||||||
|
nni.utils.merge_parameter(ret_args, tuner_params)
|
||||||
|
) # Namespaces have better ergonomics, notably a struct-like access syntax.
|
||||||
|
logger.debug("Parameters: {}", combined_params)
|
||||||
|
|
||||||
|
if combined_params.use_ddp:
|
||||||
|
# Use DDP, spawn threads
|
||||||
|
torch_mp.spawn(
|
||||||
|
main,
|
||||||
|
args=(combined_params, ), # rank supplied automatically as 1st param
|
||||||
|
nprocs=combined_params.world_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No DDP, run in current thread
|
||||||
|
main(None, combined_params)
|
||||||
6
transform_img.py
Normal file
6
transform_img.py
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
# If we cannot get revpersnet running,
|
||||||
|
# we still ought to do some sort of information-preserving perspective transformation
|
||||||
|
# e.g., randomized transformation
|
||||||
|
# and let transcrowd to crunch through these transformed image instead.
|
||||||
|
# After training, we obtain the attention map and put it in our paper.
|
||||||
|
# I just want to get things done...
|
||||||
Loading…
Add table
Add a link
Reference in a new issue