diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..e966480 --- /dev/null +++ b/dataset.py @@ -0,0 +1,162 @@ +import os +import random +from argparse import Namespace + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset +from torchvision import datasets, transforms +from image import Image +import numpy as np +import numbers +import h5py +import cv2 + + +def unpack_npy_data(args: Namespace): + """Unpack npy data aas np lists at hard-coded paths wrt cwd.""" + if args.dataset == "ShanghaiA": + train_file = "./npydata/ShanghaiA_train.npy" + test_file = "./npydata/ShanghaiA_test.npy" + elif args.dataset == "ShanghaiB": + train_file = "./npydata/ShanghaiB_train.npy" + test_file = "./npydata/ShanghaiB_test.npy" + elif args.dataset == "UCF_QNRF": + train_file = "./npydata/qnrf_train.npy" + test_file = "./npydata/qnrf_test.npy" + elif args.dataset == "JHU": + train_file = "./npydata/jhu_train.npy" + test_file = "./npydata/jhu_test.npy" + elif args.dataset == "NWPU": + train_file = "./npydata/nwpu_train.npy" + test_file = "./npydata/nwpu_test.npy" + + assert any([fdir is not None for fdir in [train_file, test_file]]) + + with open(train_file, "rb") as fd: + train_list = np.load(fd).tolist() + with open(test_file, "rb") as fd: + test_list = np.load(fd).tolist() + + print("[dataset] Loaded \"{}\": train: {} | test: {}".format( + args.dataset, len(train_list), len(test_list) + )) + + return train_list, test_list + + +def convert_data(train_list, args: Namespace, train: bool): + def _load_data(img_path: str, train: bool, args: Namespace): + img = Image.open(img_path).convert("RGB") + gt_path = (img_path + .replace(".jpg", ".h5") + .replace("images", "gt_density_map")) + + while True: + try: + gt_file = h5py.File(gt_path) + gt_count = np.asarray(gt_file["gt_count"]) + break + except OSError: + print("[dataset] Load error on \'{}\'", img_path) + + img = img.copy() + gt_count = gt_count.copy() + + return img, gt_count + + + print("[dataset] Pre-loading dataset...\n{}".format("-" * 50)) + data_keys = [] + for i in range(len(train_list)): + img_path = train_list[i] + fname = os.path.basename(img_path) + img, gt_count = _load_data(img_path, train, args) + + pack = { + "img": img, + "gt_count": gt_count, + "fname": fname, + } + data_keys.append(pack) + + return data_keys + + +class ListDataset(Dataset): + def __init__( + self, + root, + shape = None, + shuffle: bool = True, + transform = None, + train: bool = False, + seen: int = 0, + batch_size: int = 1, + nr_workers: int = 4, + args: Namespace = None, + ): + if train: + random.shuffle(root) + + self.nr_samples = len(root) + self.lines = root + self.transform = transform + self.train = train + self.shape = shape + self.seen = seen + self.batch_size = batch_size + self.nr_workers = nr_workers + self.args = args + + + def __len__(self): + return self.nr_samples + + + def __getitem__(self, index): + assert index <= len(self), "Index out-of-bounds" + + fname = self.lines[index]["fname"] + img = self.lines[index]["img"] + gt_count = self.lines[index]["gt_count"] + + # Data augmentation + if self.train: + if random.random() > .5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + # XXX: do random noise? + + gt_count = gt_count.copy() + img = img.copy() + + # Custom transform + if self.transform is not None: + img = self.transform(img) + + + if self.train: + return fname, img, gt_count + else: + device = args.device + height, width = img.shape[1], img.shape[2] + m = int(width / 384) + n = int(height / 384) + for i in range(m): + for j in range(n): + if i == 0 and j == 0: + img_ret = img[ + :, # C + j * 384 : 384 * (j + 1), # H + i * 384 : 384 * (i + 1), # W + ].to(device).unsqueeze(0) + else: + cropped = img[ + :, # C + j * 384 : 384 * (j + 1), # H + i * 384 : 384 * (i + 1), # W + ].to(device).unsqueeze(0) + img_ret = torch.cat([img_ret, cropped], 0).to(device) + return fname, img_ret, gt_count + + \ No newline at end of file