142 lines
4 KiB
Python
142 lines
4 KiB
Python
import os
|
|
import random
|
|
from argparse import Namespace
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset
|
|
from torchvision import datasets, transforms
|
|
from PIL import Image
|
|
import numpy as np
|
|
import numbers
|
|
import h5py
|
|
import cv2
|
|
|
|
|
|
def unpack_npy_data(args: Namespace):
|
|
"""Unpack npy data aas np lists at hard-coded paths wrt cwd."""
|
|
if args.train_dataset == "ShanghaiA":
|
|
train_file = "./npydata/ShanghaiA_train.npy"
|
|
test_file = "./npydata/ShanghaiA_test.npy"
|
|
elif args.train_dataset == "ShanghaiB":
|
|
train_file = "./npydata/ShanghaiB_train.npy"
|
|
test_file = "./npydata/ShanghaiB_test.npy"
|
|
elif args.train_dataset == "UCF_QNRF":
|
|
train_file = "./npydata/qnrf_train.npy"
|
|
test_file = "./npydata/qnrf_test.npy"
|
|
elif args.train_dataset == "JHU":
|
|
train_file = "./npydata/jhu_train.npy"
|
|
test_file = "./npydata/jhu_test.npy"
|
|
elif args.train_dataset == "NWPU":
|
|
train_file = "./npydata/nwpu_train.npy"
|
|
test_file = "./npydata/nwpu_test.npy"
|
|
|
|
assert all([fdir is not None for fdir in [train_file, test_file]])
|
|
|
|
with open(train_file, "rb") as fd:
|
|
train_list = np.load(fd).tolist()
|
|
with open(test_file, "rb") as fd:
|
|
test_list = np.load(fd).tolist()
|
|
|
|
print("[dataset] Loaded \"{}\": train: {} | test: {}".format(
|
|
args.train_dataset, len(train_list), len(test_list)
|
|
))
|
|
|
|
return train_list, test_list
|
|
|
|
|
|
def convert_data(train_list, args: Namespace, train: bool):
|
|
def _load_data(img_path: str, train: bool, args: Namespace):
|
|
img = Image.open(img_path).convert("RGB")
|
|
gt_path = (img_path
|
|
.replace(".jpg", ".h5")
|
|
.replace("images", "gt_density_map"))
|
|
|
|
while True:
|
|
try:
|
|
gt_file = h5py.File(gt_path)
|
|
kpoint = np.asarray(gt_file["kpoint"])
|
|
gt_count = np.asarray(gt_file["gt_count"])
|
|
break
|
|
except OSError:
|
|
print("[dataset] Load error on \'{}\'", img_path)
|
|
|
|
img = img.copy()
|
|
kpoint = kpoint.copy()
|
|
gt_count = gt_count.copy()
|
|
|
|
return img, kpoint, gt_count
|
|
|
|
|
|
print("[dataset] Pre-loading dataset...\n{}".format("-" * 50))
|
|
data_keys = []
|
|
for i in range(len(train_list)):
|
|
img_path = train_list[i]
|
|
fname = os.path.basename(img_path)
|
|
img, kpoint, gt_count = _load_data(img_path, train, args)
|
|
|
|
pack = {
|
|
"img": img,
|
|
"kpoint": kpoint,
|
|
"gt_count": gt_count,
|
|
"fname": fname,
|
|
}
|
|
data_keys.append(pack)
|
|
|
|
return data_keys
|
|
|
|
|
|
class ListDataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
root,
|
|
shape = None,
|
|
shuffle: bool = True,
|
|
transform = None,
|
|
train: bool = False,
|
|
seen: int = 0,
|
|
batch_size: int = 1,
|
|
nr_workers: int = 4,
|
|
args: Namespace = None,
|
|
):
|
|
if train:
|
|
random.shuffle(root)
|
|
|
|
self.nr_samples = len(root)
|
|
self.lines = root
|
|
self.transform = transform
|
|
self.train = train
|
|
self.shape = shape
|
|
self.seen = seen
|
|
self.batch_size = batch_size
|
|
self.nr_workers = nr_workers
|
|
self.args = args
|
|
|
|
|
|
def __len__(self):
|
|
return self.nr_samples
|
|
|
|
|
|
def __getitem__(self, index):
|
|
assert index <= len(self), "Index out-of-bounds"
|
|
|
|
fname = self.lines[index]["fname"]
|
|
img = self.lines[index]["img"]
|
|
kpoint = self.lines[index]["kpoint"]
|
|
gt_count = self.lines[index]["gt_count"]
|
|
|
|
# Data augmentation
|
|
if self.train:
|
|
if random.random() > .5:
|
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
|
# XXX: do random noise?
|
|
|
|
kpoint = kpoint.copy()
|
|
img = img.copy()
|
|
gt_count = gt_count.copy()
|
|
|
|
# Custom transform
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
|
|
return fname, img, kpoint, gt_count
|