import os import random from argparse import Namespace import torch import torch.nn.functional as F from torch.utils.data import Dataset from torchvision import datasets, transforms from PIL import Image import numpy as np import numbers import h5py import cv2 def unpack_npy_data(args: Namespace): """Unpack npy data aas np lists at hard-coded paths wrt cwd.""" if args.train_dataset == "ShanghaiA": train_file = "./npydata/ShanghaiA_train.npy" test_file = "./npydata/ShanghaiA_test.npy" elif args.train_dataset == "ShanghaiB": train_file = "./npydata/ShanghaiB_train.npy" test_file = "./npydata/ShanghaiB_test.npy" elif args.train_dataset == "UCF_QNRF": train_file = "./npydata/qnrf_train.npy" test_file = "./npydata/qnrf_test.npy" elif args.train_dataset == "JHU": train_file = "./npydata/jhu_train.npy" test_file = "./npydata/jhu_test.npy" elif args.train_dataset == "NWPU": train_file = "./npydata/nwpu_train.npy" test_file = "./npydata/nwpu_test.npy" assert all([fdir is not None for fdir in [train_file, test_file]]) with open(train_file, "rb") as fd: train_list = np.load(fd).tolist() with open(test_file, "rb") as fd: test_list = np.load(fd).tolist() print("[dataset] Loaded \"{}\": train: {} | test: {}".format( args.train_dataset, len(train_list), len(test_list) )) return train_list, test_list def convert_data(train_list, args: Namespace, train: bool): def _load_data(img_path: str, train: bool, args: Namespace): img = Image.open(img_path).convert("RGB") gt_path = (img_path .replace(".jpg", ".h5") .replace("images", "gt_density_map")) while True: try: gt_file = h5py.File(gt_path) kpoint = np.asarray(gt_file["kpoint"]) break except OSError: print("[dataset] Load error on \'{}\'", img_path) img = img.copy() kpoint = kpoint.copy() return img, kpoint print("[dataset] Pre-loading dataset...\n{}".format("-" * 50)) data_keys = [] for i in range(len(train_list)): img_path = train_list[i] fname = os.path.basename(img_path) img, kpoint = _load_data(img_path, train, args) pack = { "img": img, "kpoint": kpoint, "fname": fname, } data_keys.append(pack) return data_keys class ListDataset(Dataset): def __init__( self, root, shape = None, shuffle: bool = True, transform = None, train: bool = False, seen: int = 0, batch_size: int = 1, nr_workers: int = 4, args: Namespace = None, ): if train: random.shuffle(root) self.nr_samples = len(root) self.lines = root self.transform = transform self.train = train self.shape = shape self.seen = seen self.batch_size = batch_size self.nr_workers = nr_workers self.args = args def __len__(self): return self.nr_samples def __getitem__(self, index): assert index <= len(self), "Index out-of-bounds" fname = self.lines[index]["fname"] img = self.lines[index]["img"] kpoint = self.lines[index]["kpoint"] # Data augmentation if self.train: if random.random() > .5: img = img.transpose(Image.FLIP_LEFT_RIGHT) # XXX: do random noise? kpoint = kpoint.copy() img = img.copy() # Custom transform if self.transform is not None: img = self.transform(img) return fname, img, kpoint # if self.train: # return fname, img, gt_count # else: # device = args.device # height, width = img.shape[1], img.shape[2] # m = int(width / 384) # n = int(height / 384) # for i in range(m): # for j in range(n): # if i == 0 and j == 0: # img_ret = img[ # :, # C # j * 384 : 384 * (j + 1), # H # i * 384 : 384 * (i + 1), # W # ].to(device).unsqueeze(0) # else: # cropped = img[ # :, # C # j * 384 : 384 * (j + 1), # H # i * 384 : 384 * (i + 1), # W # ].to(device).unsqueeze(0) # img_ret = torch.cat([img_ret, cropped], 0).to(device) # return fname, img_ret, gt_count