More working than not
Not sure if validation works, call it a day
This commit is contained in:
parent
4a03211c83
commit
12aabb0d3f
10 changed files with 116 additions and 105 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -1,3 +1,4 @@
|
|||
baseline-experiments/
|
||||
synchronous/
|
||||
npydata/
|
||||
**/__pycache__/**
|
||||
10
_ShanghaiA-train.sh
Normal file
10
_ShanghaiA-train.sh
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
#!/bin/sh
|
||||
#SBATCH -N 1
|
||||
#SBATCH -n 1
|
||||
#SBATCH --partition=Teach-Standard
|
||||
#SBATCH --gres=gpu:6
|
||||
#SBATCH --mem=24000
|
||||
#SBATCH --time=3-00:00:00
|
||||
|
||||
python train.py \
|
||||
--model='stn'
|
||||
|
|
@ -12,13 +12,13 @@ parser.add_argument(
|
|||
|
||||
# Data configuration =========================================================
|
||||
parser.add_argument(
|
||||
"--worker", type=int, default=4, help="Number of data loader processes"
|
||||
"--workers", type=int, default=4, help="Number of data loader processes"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_dataset", type=str, default="ShanghaiA", help="Training dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_dataset", type=str, default="ShanghaiA", help="Evaluation dataset"
|
||||
"--eval_dataset", type=str, default="ShanghaiA", help="Evaluation dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_freq", type=int, default=1,
|
||||
|
|
|
|||
77
dataset.py
77
dataset.py
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import datasets, transforms
|
||||
from image import Image
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import numbers
|
||||
import h5py
|
||||
|
|
@ -15,23 +15,23 @@ import cv2
|
|||
|
||||
def unpack_npy_data(args: Namespace):
|
||||
"""Unpack npy data aas np lists at hard-coded paths wrt cwd."""
|
||||
if args.dataset == "ShanghaiA":
|
||||
if args.train_dataset == "ShanghaiA":
|
||||
train_file = "./npydata/ShanghaiA_train.npy"
|
||||
test_file = "./npydata/ShanghaiA_test.npy"
|
||||
elif args.dataset == "ShanghaiB":
|
||||
elif args.train_dataset == "ShanghaiB":
|
||||
train_file = "./npydata/ShanghaiB_train.npy"
|
||||
test_file = "./npydata/ShanghaiB_test.npy"
|
||||
elif args.dataset == "UCF_QNRF":
|
||||
elif args.train_dataset == "UCF_QNRF":
|
||||
train_file = "./npydata/qnrf_train.npy"
|
||||
test_file = "./npydata/qnrf_test.npy"
|
||||
elif args.dataset == "JHU":
|
||||
elif args.train_dataset == "JHU":
|
||||
train_file = "./npydata/jhu_train.npy"
|
||||
test_file = "./npydata/jhu_test.npy"
|
||||
elif args.dataset == "NWPU":
|
||||
elif args.train_dataset == "NWPU":
|
||||
train_file = "./npydata/nwpu_train.npy"
|
||||
test_file = "./npydata/nwpu_test.npy"
|
||||
|
||||
assert any([fdir is not None for fdir in [train_file, test_file]])
|
||||
assert all([fdir is not None for fdir in [train_file, test_file]])
|
||||
|
||||
with open(train_file, "rb") as fd:
|
||||
train_list = np.load(fd).tolist()
|
||||
|
|
@ -39,7 +39,7 @@ def unpack_npy_data(args: Namespace):
|
|||
test_list = np.load(fd).tolist()
|
||||
|
||||
print("[dataset] Loaded \"{}\": train: {} | test: {}".format(
|
||||
args.dataset, len(train_list), len(test_list)
|
||||
args.train_dataset, len(train_list), len(test_list)
|
||||
))
|
||||
|
||||
return train_list, test_list
|
||||
|
|
@ -55,15 +55,15 @@ def convert_data(train_list, args: Namespace, train: bool):
|
|||
while True:
|
||||
try:
|
||||
gt_file = h5py.File(gt_path)
|
||||
gt_count = np.asarray(gt_file["gt_count"])
|
||||
kpoint = np.asarray(gt_file["kpoint"])
|
||||
break
|
||||
except OSError:
|
||||
print("[dataset] Load error on \'{}\'", img_path)
|
||||
|
||||
img = img.copy()
|
||||
gt_count = gt_count.copy()
|
||||
kpoint = kpoint.copy()
|
||||
|
||||
return img, gt_count
|
||||
return img, kpoint
|
||||
|
||||
|
||||
print("[dataset] Pre-loading dataset...\n{}".format("-" * 50))
|
||||
|
|
@ -71,11 +71,11 @@ def convert_data(train_list, args: Namespace, train: bool):
|
|||
for i in range(len(train_list)):
|
||||
img_path = train_list[i]
|
||||
fname = os.path.basename(img_path)
|
||||
img, gt_count = _load_data(img_path, train, args)
|
||||
img, kpoint = _load_data(img_path, train, args)
|
||||
|
||||
pack = {
|
||||
"img": img,
|
||||
"gt_count": gt_count,
|
||||
"kpoint": kpoint,
|
||||
"fname": fname,
|
||||
}
|
||||
data_keys.append(pack)
|
||||
|
|
@ -119,7 +119,7 @@ class ListDataset(Dataset):
|
|||
|
||||
fname = self.lines[index]["fname"]
|
||||
img = self.lines[index]["img"]
|
||||
gt_count = self.lines[index]["gt_count"]
|
||||
kpoint = self.lines[index]["kpoint"]
|
||||
|
||||
# Data augmentation
|
||||
if self.train:
|
||||
|
|
@ -127,7 +127,7 @@ class ListDataset(Dataset):
|
|||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
# XXX: do random noise?
|
||||
|
||||
gt_count = gt_count.copy()
|
||||
kpoint = kpoint.copy()
|
||||
img = img.copy()
|
||||
|
||||
# Custom transform
|
||||
|
|
@ -135,28 +135,29 @@ class ListDataset(Dataset):
|
|||
img = self.transform(img)
|
||||
|
||||
|
||||
if self.train:
|
||||
return fname, img, gt_count
|
||||
else:
|
||||
device = args.device
|
||||
height, width = img.shape[1], img.shape[2]
|
||||
m = int(width / 384)
|
||||
n = int(height / 384)
|
||||
for i in range(m):
|
||||
for j in range(n):
|
||||
if i == 0 and j == 0:
|
||||
img_ret = img[
|
||||
:, # C
|
||||
j * 384 : 384 * (j + 1), # H
|
||||
i * 384 : 384 * (i + 1), # W
|
||||
].to(device).unsqueeze(0)
|
||||
else:
|
||||
cropped = img[
|
||||
:, # C
|
||||
j * 384 : 384 * (j + 1), # H
|
||||
i * 384 : 384 * (i + 1), # W
|
||||
].to(device).unsqueeze(0)
|
||||
img_ret = torch.cat([img_ret, cropped], 0).to(device)
|
||||
return fname, img_ret, gt_count
|
||||
return fname, img, kpoint
|
||||
|
||||
# if self.train:
|
||||
# return fname, img, gt_count
|
||||
# else:
|
||||
# device = args.device
|
||||
# height, width = img.shape[1], img.shape[2]
|
||||
# m = int(width / 384)
|
||||
# n = int(height / 384)
|
||||
# for i in range(m):
|
||||
# for j in range(n):
|
||||
# if i == 0 and j == 0:
|
||||
# img_ret = img[
|
||||
# :, # C
|
||||
# j * 384 : 384 * (j + 1), # H
|
||||
# i * 384 : 384 * (i + 1), # W
|
||||
# ].to(device).unsqueeze(0)
|
||||
# else:
|
||||
# cropped = img[
|
||||
# :, # C
|
||||
# j * 384 : 384 * (j + 1), # H
|
||||
# i * 384 : 384 * (i + 1), # W
|
||||
# ].to(device).unsqueeze(0)
|
||||
# img_ret = torch.cat([img_ret, cropped], 0).to(device)
|
||||
# return fname, img_ret, gt_count
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ class SquareCropTransformLayer(nn.Module):
|
|||
torch.tensor_split(
|
||||
torch.cat(
|
||||
torch.tensor_split(
|
||||
t_,
|
||||
kpoints_,
|
||||
h_split_count,
|
||||
dim=1
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class STNet(nn.Module):
|
|||
_dummy_size_ = input_size
|
||||
|
||||
# shape checking
|
||||
print("STN: dummy_size {}".format(_dummy_size_))
|
||||
_dummy_x_ = torch.zeros(_dummy_size_)
|
||||
|
||||
# (3.1) Spatial transformer localization-network
|
||||
|
|
@ -81,6 +82,7 @@ class STNet(nn.Module):
|
|||
|
||||
|
||||
def forward(self, x, t):
|
||||
# print("STN: {} | {}".format(x.shape, t.shape))
|
||||
# transform the input, do nothing else
|
||||
return self.stn(x, t)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,16 +25,8 @@ from .stn import STNet
|
|||
from .glue import SquareCropTransformLayer
|
||||
|
||||
class VisionTransformerGAP(VisionTransformer):
|
||||
# [XXX] It might be a bad idea to use vision transformer for small datasets.
|
||||
# ref: ViT paper -- "transformers lack some of the inductive biases inherent
|
||||
# to CNNs, such as translation equivariance and locality".
|
||||
# convolution is specifically equivariant in translation (linear and
|
||||
# shift-equivariant), specifically.
|
||||
# tl;dr: CNNs might perform better for small datasets AND should perform
|
||||
# better for embedded systems.
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, img_size: int, *args, **kwargs):
|
||||
super().__init__(img_size=img_size, *args, **kwargs)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
# That {p_1, p_2, ..., p_N} pos embedding
|
||||
|
|
@ -45,17 +37,17 @@ class VisionTransformerGAP(VisionTransformer):
|
|||
trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
# The "regression head"
|
||||
self.output1 = nn.ModuleDict({
|
||||
"output1.relu0": nn.ReLU(),
|
||||
"output1.linear0": nn.Linear(in_features=6912 * 4, out_features=128),
|
||||
"output1.relu1": nn.ReLU(),
|
||||
"output1.dropout0": nn.Dropout(p=0.5),
|
||||
"output1.linear1": nn.Linear(in_features=128, out_features=1),
|
||||
})
|
||||
self.output1 = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(in_features=6912 * 4, out_features=128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p=0.5),
|
||||
nn.Linear(in_features=128, out_features=1),
|
||||
)
|
||||
self.output1.apply(self._init_weights)
|
||||
|
||||
# Attention map, which we use to train
|
||||
self.attention_map = torch.Tensor(np.zeros((1152, 768))) # (3, 2) resized imgs
|
||||
# glue layer -- since we delay image cropping here
|
||||
self.glue = SquareCropTransformLayer(img_size)
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
|
|
@ -90,25 +82,26 @@ class VisionTransformerGAP(VisionTransformer):
|
|||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, t):
|
||||
with torch.no_grad():
|
||||
x, t = self.glue(x, t)
|
||||
print(f"Glue: {x.shape} | {t.shape}")
|
||||
x = self.forward_features(x) # Compute encoding
|
||||
x = F.adaptive_avg_pool1d(x, (48))
|
||||
x = x.view(x.shape[0], -1) # Move data for regression head
|
||||
# Resized to ???
|
||||
x = self.output1(x) # Regression head
|
||||
return x
|
||||
return x, t
|
||||
|
||||
|
||||
class STNet_VisionTransformerGAP(VisionTransformerGAP):
|
||||
def __init__(self, img_shape: torch.Size, *args, **kwargs):
|
||||
super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs)
|
||||
def __init__(self, img_shape: torch.Size, img_size: int, *args, **kwargs):
|
||||
super(STNet_VisionTransformerGAP, self).__init__(img_size, *args, **kwargs)
|
||||
self.stnet = STNet(img_shape)
|
||||
self.glue = SquareCropTransformLayer(img_size)
|
||||
|
||||
def forward(self, x, t):
|
||||
x, t = self.stnet(x, t)
|
||||
x, t = self.glue(x, t)
|
||||
return super(STNet_VisionTransformerGAP, self).forward(x), t
|
||||
return super(STNet_VisionTransformerGAP, self).forward(x, t)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
@ -131,7 +124,7 @@ def base_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
|
|||
@register_model
|
||||
def stn_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
|
||||
model = STNet_VisionTransformerGAP(
|
||||
img_shape=torch.Size((3, 384, 384)),
|
||||
img_shape=torch.Size((3, 1152, 768)),
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ Alternatively, you could prob. download from here:
|
|||
|
||||
https://www.kaggle.com/datasets/tthien/shanghaitech?resource=download
|
||||
|
||||
It seems the directories are all wrong, though.
|
||||
|
||||
After downloading, execute:
|
||||
|
||||
$ unzip <downloaded-zip-file> -d <repo-dir>/synchronous/dataset/
|
||||
|
|
@ -59,17 +61,24 @@ def pre_dataset_sh():
|
|||
gt_data = mat["image_info"][0][0][0][0][0]
|
||||
|
||||
# Resize to 1152x768
|
||||
is_portrait = False
|
||||
if img_data.shape[1] >= img_data.shape[0]: # landscape
|
||||
rate_x = 1152.0 / img_data.shape[1]
|
||||
rate_y = 768.0 / img_data.shape[0]
|
||||
else: # portrait
|
||||
rate_x = 768.0 / img_data.shape[1]
|
||||
rate_y = 1152.0 / img_data.shape[0]
|
||||
is_portrait = True
|
||||
|
||||
img_data = cv2.resize(img_data, (0, 0), fx=rate_x, fy=rate_y)
|
||||
gt_data[:, 0] = gt_data[:, 0] * rate_x
|
||||
gt_data[:, 1] = gt_data[:, 1] * rate_y
|
||||
|
||||
if is_portrait:
|
||||
print("Portrait img: \'{}\' -- rotating 90 deg clockwise...".format(img_path))
|
||||
img_data = cv2.rotate(img_data, cv2.ROTATE_90_CLOCKWISE)
|
||||
|
||||
|
||||
# Compute 0/1 counts from density map
|
||||
kpoint = np.zeros((img_data.shape[0], img_data.shape[1]))
|
||||
for i in range(len(gt_data)):
|
||||
23
train.py
23
train.py
|
|
@ -29,7 +29,7 @@ def setup_process_group(
|
|||
master_addr: str = "localhost",
|
||||
master_port: Optional[np.ushort] = None
|
||||
):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_ADDR"] = master_addr
|
||||
os.environ["MASTER_PORT"] = (
|
||||
str(random.randint(40000, 65545))
|
||||
if master_port is None
|
||||
|
|
@ -121,7 +121,6 @@ def worker(rank: int, args: Namespace):
|
|||
train_loader = build_train_loader(train_data, args)
|
||||
test_loader = build_test_loader(test_data, args)
|
||||
|
||||
|
||||
# Instantiate model
|
||||
if args.model == "stn":
|
||||
model = stn_patch16_384_gap(args.pth_tar).to(device)
|
||||
|
|
@ -229,11 +228,14 @@ def train_one_epoch(
|
|||
model.train()
|
||||
|
||||
# In one epoch, for each training sample
|
||||
for i, (fname, img, gt_count) in enumerate(train_loader):
|
||||
for i, (fname, img, kpoint) in enumerate(train_loader):
|
||||
kpoint = kpoint.type(torch.FloatTensor)
|
||||
print("Training: img {} | kpoint {}".format(img.shape, kpoint.shape))
|
||||
# fpass
|
||||
img = img.to(device)
|
||||
out = model(img)
|
||||
gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
|
||||
kpoint = kpoint.to(device)
|
||||
out, gt_count = model(img, kpoint)
|
||||
# gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
|
||||
|
||||
# loss
|
||||
loss = criterion(out, gt_count)
|
||||
|
|
@ -288,7 +290,7 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
||||
|
||||
nni.report_intermediate_result(mae)
|
||||
# nni.report_intermediate_result(mae)
|
||||
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
||||
mae=mae, mse=mse
|
||||
))
|
||||
|
|
@ -297,11 +299,10 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tuner_params = nni.get_next_parameter()
|
||||
logger.debug("Generated hyperparameters: {}", tuner_params)
|
||||
combined_params = Namespace(
|
||||
nni.utils.merge_parameter(ret_args, tuner_params)
|
||||
) # Namespaces have better ergonomics, notably a struct-like access syntax.
|
||||
# tuner_params = nni.get_next_parameter()
|
||||
# logger.debug("Generated hyperparameters: {}", tuner_params)
|
||||
# combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
|
||||
combined_params = args
|
||||
logger.debug("Parameters: {}", combined_params)
|
||||
|
||||
if combined_params.use_ddp:
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
# If we cannot get revpersnet running,
|
||||
# we still ought to do some sort of information-preserving perspective transformation
|
||||
# e.g., randomized transformation
|
||||
# and let transcrowd to crunch through these transformed image instead.
|
||||
# After training, we obtain the attention map and put it in our paper.
|
||||
# I just want to get things done...
|
||||
Loading…
Add table
Add a link
Reference in a new issue