More working than not

Not sure if validation works, call it a day
This commit is contained in:
Zhengyi Chen 2024-03-03 03:16:54 +00:00
parent 4a03211c83
commit 12aabb0d3f
10 changed files with 116 additions and 105 deletions

View file

@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from image import Image
from PIL import Image
import numpy as np
import numbers
import h5py
@ -15,23 +15,23 @@ import cv2
def unpack_npy_data(args: Namespace):
"""Unpack npy data aas np lists at hard-coded paths wrt cwd."""
if args.dataset == "ShanghaiA":
if args.train_dataset == "ShanghaiA":
train_file = "./npydata/ShanghaiA_train.npy"
test_file = "./npydata/ShanghaiA_test.npy"
elif args.dataset == "ShanghaiB":
elif args.train_dataset == "ShanghaiB":
train_file = "./npydata/ShanghaiB_train.npy"
test_file = "./npydata/ShanghaiB_test.npy"
elif args.dataset == "UCF_QNRF":
elif args.train_dataset == "UCF_QNRF":
train_file = "./npydata/qnrf_train.npy"
test_file = "./npydata/qnrf_test.npy"
elif args.dataset == "JHU":
elif args.train_dataset == "JHU":
train_file = "./npydata/jhu_train.npy"
test_file = "./npydata/jhu_test.npy"
elif args.dataset == "NWPU":
elif args.train_dataset == "NWPU":
train_file = "./npydata/nwpu_train.npy"
test_file = "./npydata/nwpu_test.npy"
assert any([fdir is not None for fdir in [train_file, test_file]])
assert all([fdir is not None for fdir in [train_file, test_file]])
with open(train_file, "rb") as fd:
train_list = np.load(fd).tolist()
@ -39,9 +39,9 @@ def unpack_npy_data(args: Namespace):
test_list = np.load(fd).tolist()
print("[dataset] Loaded \"{}\": train: {} | test: {}".format(
args.dataset, len(train_list), len(test_list)
args.train_dataset, len(train_list), len(test_list)
))
return train_list, test_list
@ -55,28 +55,28 @@ def convert_data(train_list, args: Namespace, train: bool):
while True:
try:
gt_file = h5py.File(gt_path)
gt_count = np.asarray(gt_file["gt_count"])
kpoint = np.asarray(gt_file["kpoint"])
break
except OSError:
print("[dataset] Load error on \'{}\'", img_path)
img = img.copy()
gt_count = gt_count.copy()
kpoint = kpoint.copy()
return img, kpoint
return img, gt_count
print("[dataset] Pre-loading dataset...\n{}".format("-" * 50))
data_keys = []
for i in range(len(train_list)):
img_path = train_list[i]
fname = os.path.basename(img_path)
img, gt_count = _load_data(img_path, train, args)
img, kpoint = _load_data(img_path, train, args)
pack = {
"img": img,
"gt_count": gt_count,
"fname": fname,
"img": img,
"kpoint": kpoint,
"fname": fname,
}
data_keys.append(pack)
@ -85,16 +85,16 @@ def convert_data(train_list, args: Namespace, train: bool):
class ListDataset(Dataset):
def __init__(
self,
root,
self,
root,
shape = None,
shuffle: bool = True,
transform = None,
transform = None,
train: bool = False,
seen: int = 0,
batch_size: int = 1,
nr_workers: int = 4,
args: Namespace = None,
seen: int = 0,
batch_size: int = 1,
nr_workers: int = 4,
args: Namespace = None,
):
if train:
random.shuffle(root)
@ -109,25 +109,25 @@ class ListDataset(Dataset):
self.nr_workers = nr_workers
self.args = args
def __len__(self):
return self.nr_samples
def __getitem__(self, index):
assert index <= len(self), "Index out-of-bounds"
fname = self.lines[index]["fname"]
img = self.lines[index]["img"]
gt_count = self.lines[index]["gt_count"]
kpoint = self.lines[index]["kpoint"]
# Data augmentation
if self.train:
if random.random() > .5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
# XXX: do random noise?
gt_count = gt_count.copy()
kpoint = kpoint.copy()
img = img.copy()
# Custom transform
@ -135,28 +135,29 @@ class ListDataset(Dataset):
img = self.transform(img)
if self.train:
return fname, img, gt_count
else:
device = args.device
height, width = img.shape[1], img.shape[2]
m = int(width / 384)
n = int(height / 384)
for i in range(m):
for j in range(n):
if i == 0 and j == 0:
img_ret = img[
:, # C
j * 384 : 384 * (j + 1), # H
i * 384 : 384 * (i + 1), # W
].to(device).unsqueeze(0)
else:
cropped = img[
:, # C
j * 384 : 384 * (j + 1), # H
i * 384 : 384 * (i + 1), # W
].to(device).unsqueeze(0)
img_ret = torch.cat([img_ret, cropped], 0).to(device)
return fname, img_ret, gt_count
return fname, img, kpoint
# if self.train:
# return fname, img, gt_count
# else:
# device = args.device
# height, width = img.shape[1], img.shape[2]
# m = int(width / 384)
# n = int(height / 384)
# for i in range(m):
# for j in range(n):
# if i == 0 and j == 0:
# img_ret = img[
# :, # C
# j * 384 : 384 * (j + 1), # H
# i * 384 : 384 * (i + 1), # W
# ].to(device).unsqueeze(0)
# else:
# cropped = img[
# :, # C
# j * 384 : 384 * (j + 1), # H
# i * 384 : 384 * (i + 1), # W
# ].to(device).unsqueeze(0)
# img_ret = torch.cat([img_ret, cropped], 0).to(device)
# return fname, img_ret, gt_count