More working than not
Not sure if validation works, call it a day
This commit is contained in:
parent
4a03211c83
commit
12aabb0d3f
10 changed files with 116 additions and 105 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -1,3 +1,4 @@
|
||||||
baseline-experiments/
|
baseline-experiments/
|
||||||
synchronous/
|
synchronous/
|
||||||
npydata/
|
npydata/
|
||||||
|
**/__pycache__/**
|
||||||
10
_ShanghaiA-train.sh
Normal file
10
_ShanghaiA-train.sh
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
#!/bin/sh
|
||||||
|
#SBATCH -N 1
|
||||||
|
#SBATCH -n 1
|
||||||
|
#SBATCH --partition=Teach-Standard
|
||||||
|
#SBATCH --gres=gpu:6
|
||||||
|
#SBATCH --mem=24000
|
||||||
|
#SBATCH --time=3-00:00:00
|
||||||
|
|
||||||
|
python train.py \
|
||||||
|
--model='stn'
|
||||||
|
|
@ -12,13 +12,13 @@ parser.add_argument(
|
||||||
|
|
||||||
# Data configuration =========================================================
|
# Data configuration =========================================================
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--worker", type=int, default=4, help="Number of data loader processes"
|
"--workers", type=int, default=4, help="Number of data loader processes"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_dataset", type=str, default="ShanghaiA", help="Training dataset"
|
"--train_dataset", type=str, default="ShanghaiA", help="Training dataset"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test_dataset", type=str, default="ShanghaiA", help="Evaluation dataset"
|
"--eval_dataset", type=str, default="ShanghaiA", help="Evaluation dataset"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--print_freq", type=int, default=1,
|
"--print_freq", type=int, default=1,
|
||||||
|
|
|
||||||
109
dataset.py
109
dataset.py
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms
|
||||||
from image import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numbers
|
import numbers
|
||||||
import h5py
|
import h5py
|
||||||
|
|
@ -15,23 +15,23 @@ import cv2
|
||||||
|
|
||||||
def unpack_npy_data(args: Namespace):
|
def unpack_npy_data(args: Namespace):
|
||||||
"""Unpack npy data aas np lists at hard-coded paths wrt cwd."""
|
"""Unpack npy data aas np lists at hard-coded paths wrt cwd."""
|
||||||
if args.dataset == "ShanghaiA":
|
if args.train_dataset == "ShanghaiA":
|
||||||
train_file = "./npydata/ShanghaiA_train.npy"
|
train_file = "./npydata/ShanghaiA_train.npy"
|
||||||
test_file = "./npydata/ShanghaiA_test.npy"
|
test_file = "./npydata/ShanghaiA_test.npy"
|
||||||
elif args.dataset == "ShanghaiB":
|
elif args.train_dataset == "ShanghaiB":
|
||||||
train_file = "./npydata/ShanghaiB_train.npy"
|
train_file = "./npydata/ShanghaiB_train.npy"
|
||||||
test_file = "./npydata/ShanghaiB_test.npy"
|
test_file = "./npydata/ShanghaiB_test.npy"
|
||||||
elif args.dataset == "UCF_QNRF":
|
elif args.train_dataset == "UCF_QNRF":
|
||||||
train_file = "./npydata/qnrf_train.npy"
|
train_file = "./npydata/qnrf_train.npy"
|
||||||
test_file = "./npydata/qnrf_test.npy"
|
test_file = "./npydata/qnrf_test.npy"
|
||||||
elif args.dataset == "JHU":
|
elif args.train_dataset == "JHU":
|
||||||
train_file = "./npydata/jhu_train.npy"
|
train_file = "./npydata/jhu_train.npy"
|
||||||
test_file = "./npydata/jhu_test.npy"
|
test_file = "./npydata/jhu_test.npy"
|
||||||
elif args.dataset == "NWPU":
|
elif args.train_dataset == "NWPU":
|
||||||
train_file = "./npydata/nwpu_train.npy"
|
train_file = "./npydata/nwpu_train.npy"
|
||||||
test_file = "./npydata/nwpu_test.npy"
|
test_file = "./npydata/nwpu_test.npy"
|
||||||
|
|
||||||
assert any([fdir is not None for fdir in [train_file, test_file]])
|
assert all([fdir is not None for fdir in [train_file, test_file]])
|
||||||
|
|
||||||
with open(train_file, "rb") as fd:
|
with open(train_file, "rb") as fd:
|
||||||
train_list = np.load(fd).tolist()
|
train_list = np.load(fd).tolist()
|
||||||
|
|
@ -39,9 +39,9 @@ def unpack_npy_data(args: Namespace):
|
||||||
test_list = np.load(fd).tolist()
|
test_list = np.load(fd).tolist()
|
||||||
|
|
||||||
print("[dataset] Loaded \"{}\": train: {} | test: {}".format(
|
print("[dataset] Loaded \"{}\": train: {} | test: {}".format(
|
||||||
args.dataset, len(train_list), len(test_list)
|
args.train_dataset, len(train_list), len(test_list)
|
||||||
))
|
))
|
||||||
|
|
||||||
return train_list, test_list
|
return train_list, test_list
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -55,28 +55,28 @@ def convert_data(train_list, args: Namespace, train: bool):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
gt_file = h5py.File(gt_path)
|
gt_file = h5py.File(gt_path)
|
||||||
gt_count = np.asarray(gt_file["gt_count"])
|
kpoint = np.asarray(gt_file["kpoint"])
|
||||||
break
|
break
|
||||||
except OSError:
|
except OSError:
|
||||||
print("[dataset] Load error on \'{}\'", img_path)
|
print("[dataset] Load error on \'{}\'", img_path)
|
||||||
|
|
||||||
img = img.copy()
|
img = img.copy()
|
||||||
gt_count = gt_count.copy()
|
kpoint = kpoint.copy()
|
||||||
|
|
||||||
|
return img, kpoint
|
||||||
|
|
||||||
return img, gt_count
|
|
||||||
|
|
||||||
|
|
||||||
print("[dataset] Pre-loading dataset...\n{}".format("-" * 50))
|
print("[dataset] Pre-loading dataset...\n{}".format("-" * 50))
|
||||||
data_keys = []
|
data_keys = []
|
||||||
for i in range(len(train_list)):
|
for i in range(len(train_list)):
|
||||||
img_path = train_list[i]
|
img_path = train_list[i]
|
||||||
fname = os.path.basename(img_path)
|
fname = os.path.basename(img_path)
|
||||||
img, gt_count = _load_data(img_path, train, args)
|
img, kpoint = _load_data(img_path, train, args)
|
||||||
|
|
||||||
pack = {
|
pack = {
|
||||||
"img": img,
|
"img": img,
|
||||||
"gt_count": gt_count,
|
"kpoint": kpoint,
|
||||||
"fname": fname,
|
"fname": fname,
|
||||||
}
|
}
|
||||||
data_keys.append(pack)
|
data_keys.append(pack)
|
||||||
|
|
||||||
|
|
@ -85,16 +85,16 @@ def convert_data(train_list, args: Namespace, train: bool):
|
||||||
|
|
||||||
class ListDataset(Dataset):
|
class ListDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root,
|
root,
|
||||||
shape = None,
|
shape = None,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
transform = None,
|
transform = None,
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
seen: int = 0,
|
seen: int = 0,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
nr_workers: int = 4,
|
nr_workers: int = 4,
|
||||||
args: Namespace = None,
|
args: Namespace = None,
|
||||||
):
|
):
|
||||||
if train:
|
if train:
|
||||||
random.shuffle(root)
|
random.shuffle(root)
|
||||||
|
|
@ -109,25 +109,25 @@ class ListDataset(Dataset):
|
||||||
self.nr_workers = nr_workers
|
self.nr_workers = nr_workers
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.nr_samples
|
return self.nr_samples
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
assert index <= len(self), "Index out-of-bounds"
|
assert index <= len(self), "Index out-of-bounds"
|
||||||
|
|
||||||
fname = self.lines[index]["fname"]
|
fname = self.lines[index]["fname"]
|
||||||
img = self.lines[index]["img"]
|
img = self.lines[index]["img"]
|
||||||
gt_count = self.lines[index]["gt_count"]
|
kpoint = self.lines[index]["kpoint"]
|
||||||
|
|
||||||
# Data augmentation
|
# Data augmentation
|
||||||
if self.train:
|
if self.train:
|
||||||
if random.random() > .5:
|
if random.random() > .5:
|
||||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
# XXX: do random noise?
|
# XXX: do random noise?
|
||||||
|
|
||||||
gt_count = gt_count.copy()
|
kpoint = kpoint.copy()
|
||||||
img = img.copy()
|
img = img.copy()
|
||||||
|
|
||||||
# Custom transform
|
# Custom transform
|
||||||
|
|
@ -135,28 +135,29 @@ class ListDataset(Dataset):
|
||||||
img = self.transform(img)
|
img = self.transform(img)
|
||||||
|
|
||||||
|
|
||||||
if self.train:
|
return fname, img, kpoint
|
||||||
return fname, img, gt_count
|
|
||||||
else:
|
# if self.train:
|
||||||
device = args.device
|
# return fname, img, gt_count
|
||||||
height, width = img.shape[1], img.shape[2]
|
# else:
|
||||||
m = int(width / 384)
|
# device = args.device
|
||||||
n = int(height / 384)
|
# height, width = img.shape[1], img.shape[2]
|
||||||
for i in range(m):
|
# m = int(width / 384)
|
||||||
for j in range(n):
|
# n = int(height / 384)
|
||||||
if i == 0 and j == 0:
|
# for i in range(m):
|
||||||
img_ret = img[
|
# for j in range(n):
|
||||||
:, # C
|
# if i == 0 and j == 0:
|
||||||
j * 384 : 384 * (j + 1), # H
|
# img_ret = img[
|
||||||
i * 384 : 384 * (i + 1), # W
|
# :, # C
|
||||||
].to(device).unsqueeze(0)
|
# j * 384 : 384 * (j + 1), # H
|
||||||
else:
|
# i * 384 : 384 * (i + 1), # W
|
||||||
cropped = img[
|
# ].to(device).unsqueeze(0)
|
||||||
:, # C
|
# else:
|
||||||
j * 384 : 384 * (j + 1), # H
|
# cropped = img[
|
||||||
i * 384 : 384 * (i + 1), # W
|
# :, # C
|
||||||
].to(device).unsqueeze(0)
|
# j * 384 : 384 * (j + 1), # H
|
||||||
img_ret = torch.cat([img_ret, cropped], 0).to(device)
|
# i * 384 : 384 * (i + 1), # W
|
||||||
return fname, img_ret, gt_count
|
# ].to(device).unsqueeze(0)
|
||||||
|
# img_ret = torch.cat([img_ret, cropped], 0).to(device)
|
||||||
|
# return fname, img_ret, gt_count
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -46,7 +46,7 @@ class SquareCropTransformLayer(nn.Module):
|
||||||
torch.tensor_split(
|
torch.tensor_split(
|
||||||
torch.cat(
|
torch.cat(
|
||||||
torch.tensor_split(
|
torch.tensor_split(
|
||||||
t_,
|
kpoints_,
|
||||||
h_split_count,
|
h_split_count,
|
||||||
dim=1
|
dim=1
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ class STNet(nn.Module):
|
||||||
_dummy_size_ = input_size
|
_dummy_size_ = input_size
|
||||||
|
|
||||||
# shape checking
|
# shape checking
|
||||||
|
print("STN: dummy_size {}".format(_dummy_size_))
|
||||||
_dummy_x_ = torch.zeros(_dummy_size_)
|
_dummy_x_ = torch.zeros(_dummy_size_)
|
||||||
|
|
||||||
# (3.1) Spatial transformer localization-network
|
# (3.1) Spatial transformer localization-network
|
||||||
|
|
@ -81,6 +82,7 @@ class STNet(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, t):
|
def forward(self, x, t):
|
||||||
|
# print("STN: {} | {}".format(x.shape, t.shape))
|
||||||
# transform the input, do nothing else
|
# transform the input, do nothing else
|
||||||
return self.stn(x, t)
|
return self.stn(x, t)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,16 +25,8 @@ from .stn import STNet
|
||||||
from .glue import SquareCropTransformLayer
|
from .glue import SquareCropTransformLayer
|
||||||
|
|
||||||
class VisionTransformerGAP(VisionTransformer):
|
class VisionTransformerGAP(VisionTransformer):
|
||||||
# [XXX] It might be a bad idea to use vision transformer for small datasets.
|
def __init__(self, img_size: int, *args, **kwargs):
|
||||||
# ref: ViT paper -- "transformers lack some of the inductive biases inherent
|
super().__init__(img_size=img_size, *args, **kwargs)
|
||||||
# to CNNs, such as translation equivariance and locality".
|
|
||||||
# convolution is specifically equivariant in translation (linear and
|
|
||||||
# shift-equivariant), specifically.
|
|
||||||
# tl;dr: CNNs might perform better for small datasets AND should perform
|
|
||||||
# better for embedded systems.
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
num_patches = self.patch_embed.num_patches
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
# That {p_1, p_2, ..., p_N} pos embedding
|
# That {p_1, p_2, ..., p_N} pos embedding
|
||||||
|
|
@ -45,17 +37,17 @@ class VisionTransformerGAP(VisionTransformer):
|
||||||
trunc_normal_(self.pos_embed, std=.02)
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
|
|
||||||
# The "regression head"
|
# The "regression head"
|
||||||
self.output1 = nn.ModuleDict({
|
self.output1 = nn.Sequential(
|
||||||
"output1.relu0": nn.ReLU(),
|
nn.ReLU(),
|
||||||
"output1.linear0": nn.Linear(in_features=6912 * 4, out_features=128),
|
nn.Linear(in_features=6912 * 4, out_features=128),
|
||||||
"output1.relu1": nn.ReLU(),
|
nn.ReLU(),
|
||||||
"output1.dropout0": nn.Dropout(p=0.5),
|
nn.Dropout(p=0.5),
|
||||||
"output1.linear1": nn.Linear(in_features=128, out_features=1),
|
nn.Linear(in_features=128, out_features=1),
|
||||||
})
|
)
|
||||||
self.output1.apply(self._init_weights)
|
self.output1.apply(self._init_weights)
|
||||||
|
|
||||||
# Attention map, which we use to train
|
# glue layer -- since we delay image cropping here
|
||||||
self.attention_map = torch.Tensor(np.zeros((1152, 768))) # (3, 2) resized imgs
|
self.glue = SquareCropTransformLayer(img_size)
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
B = x.shape[0]
|
B = x.shape[0]
|
||||||
|
|
@ -90,25 +82,26 @@ class VisionTransformerGAP(VisionTransformer):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, t):
|
||||||
|
with torch.no_grad():
|
||||||
|
x, t = self.glue(x, t)
|
||||||
|
print(f"Glue: {x.shape} | {t.shape}")
|
||||||
x = self.forward_features(x) # Compute encoding
|
x = self.forward_features(x) # Compute encoding
|
||||||
x = F.adaptive_avg_pool1d(x, (48))
|
x = F.adaptive_avg_pool1d(x, (48))
|
||||||
x = x.view(x.shape[0], -1) # Move data for regression head
|
x = x.view(x.shape[0], -1) # Move data for regression head
|
||||||
# Resized to ???
|
# Resized to ???
|
||||||
x = self.output1(x) # Regression head
|
x = self.output1(x) # Regression head
|
||||||
return x
|
return x, t
|
||||||
|
|
||||||
|
|
||||||
class STNet_VisionTransformerGAP(VisionTransformerGAP):
|
class STNet_VisionTransformerGAP(VisionTransformerGAP):
|
||||||
def __init__(self, img_shape: torch.Size, *args, **kwargs):
|
def __init__(self, img_shape: torch.Size, img_size: int, *args, **kwargs):
|
||||||
super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs)
|
super(STNet_VisionTransformerGAP, self).__init__(img_size, *args, **kwargs)
|
||||||
self.stnet = STNet(img_shape)
|
self.stnet = STNet(img_shape)
|
||||||
self.glue = SquareCropTransformLayer(img_size)
|
|
||||||
|
|
||||||
def forward(self, x, t):
|
def forward(self, x, t):
|
||||||
x, t = self.stnet(x, t)
|
x, t = self.stnet(x, t)
|
||||||
x, t = self.glue(x, t)
|
return super(STNet_VisionTransformerGAP, self).forward(x, t)
|
||||||
return super(STNet_VisionTransformerGAP, self).forward(x), t
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
@ -131,7 +124,7 @@ def base_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
|
||||||
@register_model
|
@register_model
|
||||||
def stn_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
|
def stn_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
|
||||||
model = STNet_VisionTransformerGAP(
|
model = STNet_VisionTransformerGAP(
|
||||||
img_shape=torch.Size((3, 384, 384)),
|
img_shape=torch.Size((3, 1152, 768)),
|
||||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||||||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||||
**kwargs
|
**kwargs
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,19 @@
|
||||||
"""
|
"""
|
||||||
The TransCrowd paper lists ShanghaiTech dataset as from here:
|
The TransCrowd paper lists ShanghaiTech dataset as from here:
|
||||||
|
|
||||||
https://drive.google.com/file/d/1CkYppr_IqR1s6wi53l2gKoGqm7LkJ-Lc/view
|
https://drive.google.com/file/d/1CkYppr_IqR1s6wi53l2gKoGqm7LkJ-Lc/view
|
||||||
|
|
||||||
Alternatively, you could prob. download from here:
|
Alternatively, you could prob. download from here:
|
||||||
|
|
||||||
https://www.kaggle.com/datasets/tthien/shanghaitech?resource=download
|
https://www.kaggle.com/datasets/tthien/shanghaitech?resource=download
|
||||||
|
|
||||||
|
It seems the directories are all wrong, though.
|
||||||
|
|
||||||
After downloading, execute:
|
After downloading, execute:
|
||||||
|
|
||||||
$ unzip <downloaded-zip-file> -d <repo-dir>/synchronous/dataset/
|
$ unzip <downloaded-zip-file> -d <repo-dir>/synchronous/dataset/
|
||||||
|
|
||||||
To unzip the dataset correctly prior to running this script.
|
To unzip the dataset correctly prior to running this script.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -59,17 +61,24 @@ def pre_dataset_sh():
|
||||||
gt_data = mat["image_info"][0][0][0][0][0]
|
gt_data = mat["image_info"][0][0][0][0][0]
|
||||||
|
|
||||||
# Resize to 1152x768
|
# Resize to 1152x768
|
||||||
|
is_portrait = False
|
||||||
if img_data.shape[1] >= img_data.shape[0]: # landscape
|
if img_data.shape[1] >= img_data.shape[0]: # landscape
|
||||||
rate_x = 1152.0 / img_data.shape[1]
|
rate_x = 1152.0 / img_data.shape[1]
|
||||||
rate_y = 768.0 / img_data.shape[0]
|
rate_y = 768.0 / img_data.shape[0]
|
||||||
else: # portrait
|
else: # portrait
|
||||||
rate_x = 768.0 / img_data.shape[1]
|
rate_x = 768.0 / img_data.shape[1]
|
||||||
rate_y = 1152.0 / img_data.shape[0]
|
rate_y = 1152.0 / img_data.shape[0]
|
||||||
|
is_portrait = True
|
||||||
|
|
||||||
img_data = cv2.resize(img_data, (0, 0), fx=rate_x, fy=rate_y)
|
img_data = cv2.resize(img_data, (0, 0), fx=rate_x, fy=rate_y)
|
||||||
gt_data[:, 0] = gt_data[:, 0] * rate_x
|
gt_data[:, 0] = gt_data[:, 0] * rate_x
|
||||||
gt_data[:, 1] = gt_data[:, 1] * rate_y
|
gt_data[:, 1] = gt_data[:, 1] * rate_y
|
||||||
|
|
||||||
|
if is_portrait:
|
||||||
|
print("Portrait img: \'{}\' -- rotating 90 deg clockwise...".format(img_path))
|
||||||
|
img_data = cv2.rotate(img_data, cv2.ROTATE_90_CLOCKWISE)
|
||||||
|
|
||||||
|
|
||||||
# Compute 0/1 counts from density map
|
# Compute 0/1 counts from density map
|
||||||
kpoint = np.zeros((img_data.shape[0], img_data.shape[1]))
|
kpoint = np.zeros((img_data.shape[0], img_data.shape[1]))
|
||||||
for i in range(len(gt_data)):
|
for i in range(len(gt_data)):
|
||||||
23
train.py
23
train.py
|
|
@ -29,7 +29,7 @@ def setup_process_group(
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: Optional[np.ushort] = None
|
master_port: Optional[np.ushort] = None
|
||||||
):
|
):
|
||||||
os.environ["MASTER_ADDR"] = "localhost"
|
os.environ["MASTER_ADDR"] = master_addr
|
||||||
os.environ["MASTER_PORT"] = (
|
os.environ["MASTER_PORT"] = (
|
||||||
str(random.randint(40000, 65545))
|
str(random.randint(40000, 65545))
|
||||||
if master_port is None
|
if master_port is None
|
||||||
|
|
@ -121,7 +121,6 @@ def worker(rank: int, args: Namespace):
|
||||||
train_loader = build_train_loader(train_data, args)
|
train_loader = build_train_loader(train_data, args)
|
||||||
test_loader = build_test_loader(test_data, args)
|
test_loader = build_test_loader(test_data, args)
|
||||||
|
|
||||||
|
|
||||||
# Instantiate model
|
# Instantiate model
|
||||||
if args.model == "stn":
|
if args.model == "stn":
|
||||||
model = stn_patch16_384_gap(args.pth_tar).to(device)
|
model = stn_patch16_384_gap(args.pth_tar).to(device)
|
||||||
|
|
@ -229,11 +228,14 @@ def train_one_epoch(
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
# In one epoch, for each training sample
|
# In one epoch, for each training sample
|
||||||
for i, (fname, img, gt_count) in enumerate(train_loader):
|
for i, (fname, img, kpoint) in enumerate(train_loader):
|
||||||
|
kpoint = kpoint.type(torch.FloatTensor)
|
||||||
|
print("Training: img {} | kpoint {}".format(img.shape, kpoint.shape))
|
||||||
# fpass
|
# fpass
|
||||||
img = img.to(device)
|
img = img.to(device)
|
||||||
out = model(img)
|
kpoint = kpoint.to(device)
|
||||||
gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
|
out, gt_count = model(img, kpoint)
|
||||||
|
# gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
|
||||||
|
|
||||||
# loss
|
# loss
|
||||||
loss = criterion(out, gt_count)
|
loss = criterion(out, gt_count)
|
||||||
|
|
@ -288,7 +290,7 @@ def valid_one_epoch(test_loader, model, device, args):
|
||||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||||
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
||||||
|
|
||||||
nni.report_intermediate_result(mae)
|
# nni.report_intermediate_result(mae)
|
||||||
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
||||||
mae=mae, mse=mse
|
mae=mae, mse=mse
|
||||||
))
|
))
|
||||||
|
|
@ -297,11 +299,10 @@ def valid_one_epoch(test_loader, model, device, args):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tuner_params = nni.get_next_parameter()
|
# tuner_params = nni.get_next_parameter()
|
||||||
logger.debug("Generated hyperparameters: {}", tuner_params)
|
# logger.debug("Generated hyperparameters: {}", tuner_params)
|
||||||
combined_params = Namespace(
|
# combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
|
||||||
nni.utils.merge_parameter(ret_args, tuner_params)
|
combined_params = args
|
||||||
) # Namespaces have better ergonomics, notably a struct-like access syntax.
|
|
||||||
logger.debug("Parameters: {}", combined_params)
|
logger.debug("Parameters: {}", combined_params)
|
||||||
|
|
||||||
if combined_params.use_ddp:
|
if combined_params.use_ddp:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
# If we cannot get revpersnet running,
|
|
||||||
# we still ought to do some sort of information-preserving perspective transformation
|
|
||||||
# e.g., randomized transformation
|
|
||||||
# and let transcrowd to crunch through these transformed image instead.
|
|
||||||
# After training, we obtain the attention map and put it in our paper.
|
|
||||||
# I just want to get things done...
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue