import os import random from argparse import Namespace import torch import torch.nn.functional as F from torch.utils.data import Dataset from torchvision import datasets, transforms from PIL import Image import numpy as np import numbers import h5py import cv2 def unpack_npy_data(args: Namespace): """Unpack npy data aas np lists at hard-coded paths wrt cwd.""" if args.train_dataset == "ShanghaiA": train_file = "./npydata/ShanghaiA_train.npy" test_file = "./npydata/ShanghaiA_test.npy" elif args.train_dataset == "ShanghaiB": train_file = "./npydata/ShanghaiB_train.npy" test_file = "./npydata/ShanghaiB_test.npy" elif args.train_dataset == "UCF_QNRF": train_file = "./npydata/qnrf_train.npy" test_file = "./npydata/qnrf_test.npy" elif args.train_dataset == "JHU": train_file = "./npydata/jhu_train.npy" test_file = "./npydata/jhu_test.npy" elif args.train_dataset == "NWPU": train_file = "./npydata/nwpu_train.npy" test_file = "./npydata/nwpu_test.npy" assert all([fdir is not None for fdir in [train_file, test_file]]) with open(train_file, "rb") as fd: train_list = np.load(fd).tolist() with open(test_file, "rb") as fd: test_list = np.load(fd).tolist() print("[dataset] Loaded \"{}\": train: {} | test: {}".format( args.train_dataset, len(train_list), len(test_list) )) return train_list, test_list def convert_data(train_list, args: Namespace, train: bool): def _load_data(img_path: str, train: bool, args: Namespace): img = Image.open(img_path).convert("RGB") gt_path = (img_path .replace(".jpg", ".h5") .replace("images", "gt_density_map")) while True: try: gt_file = h5py.File(gt_path) kpoint = np.asarray(gt_file["kpoint"]) gt_count = np.asarray(gt_file["gt_count"]) break except OSError: print("[dataset] Load error on \'{}\'", img_path) img = img.copy() kpoint = kpoint.copy() gt_count = gt_count.copy() return img, kpoint, gt_count print("[dataset] Pre-loading dataset...\n{}".format("-" * 50)) data_keys = [] for i in range(len(train_list)): img_path = train_list[i] fname = os.path.basename(img_path) img, kpoint, gt_count = _load_data(img_path, train, args) pack = { "img": img, "kpoint": kpoint, "gt_count": gt_count, "fname": fname, } data_keys.append(pack) return data_keys class ListDataset(Dataset): def __init__( self, root, shape = None, shuffle: bool = True, transform = None, train: bool = False, seen: int = 0, batch_size: int = 1, nr_workers: int = 4, args: Namespace = None, ): if train: random.shuffle(root) self.nr_samples = len(root) self.lines = root self.transform = transform self.train = train self.shape = shape self.seen = seen self.batch_size = batch_size self.nr_workers = nr_workers self.args = args def __len__(self): return self.nr_samples def __getitem__(self, index): assert index <= len(self), "Index out-of-bounds" fname = self.lines[index]["fname"] img = self.lines[index]["img"] kpoint = self.lines[index]["kpoint"] gt_count = self.lines[index]["gt_count"] # Data augmentation if self.train: if random.random() > .5: img = img.transpose(Image.FLIP_LEFT_RIGHT) # XXX: do random noise? kpoint = kpoint.copy() img = img.copy() gt_count = gt_count.copy() # Custom transform if self.transform is not None: img = self.transform(img) return fname, img, kpoint, gt_count