Forgot to commit this
This commit is contained in:
parent
b69727b74f
commit
57510503d0
1 changed files with 162 additions and 0 deletions
162
dataset.py
Normal file
162
dataset.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
import os
|
||||
import random
|
||||
from argparse import Namespace
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import datasets, transforms
|
||||
from image import Image
|
||||
import numpy as np
|
||||
import numbers
|
||||
import h5py
|
||||
import cv2
|
||||
|
||||
|
||||
def unpack_npy_data(args: Namespace):
|
||||
"""Unpack npy data aas np lists at hard-coded paths wrt cwd."""
|
||||
if args.dataset == "ShanghaiA":
|
||||
train_file = "./npydata/ShanghaiA_train.npy"
|
||||
test_file = "./npydata/ShanghaiA_test.npy"
|
||||
elif args.dataset == "ShanghaiB":
|
||||
train_file = "./npydata/ShanghaiB_train.npy"
|
||||
test_file = "./npydata/ShanghaiB_test.npy"
|
||||
elif args.dataset == "UCF_QNRF":
|
||||
train_file = "./npydata/qnrf_train.npy"
|
||||
test_file = "./npydata/qnrf_test.npy"
|
||||
elif args.dataset == "JHU":
|
||||
train_file = "./npydata/jhu_train.npy"
|
||||
test_file = "./npydata/jhu_test.npy"
|
||||
elif args.dataset == "NWPU":
|
||||
train_file = "./npydata/nwpu_train.npy"
|
||||
test_file = "./npydata/nwpu_test.npy"
|
||||
|
||||
assert any([fdir is not None for fdir in [train_file, test_file]])
|
||||
|
||||
with open(train_file, "rb") as fd:
|
||||
train_list = np.load(fd).tolist()
|
||||
with open(test_file, "rb") as fd:
|
||||
test_list = np.load(fd).tolist()
|
||||
|
||||
print("[dataset] Loaded \"{}\": train: {} | test: {}".format(
|
||||
args.dataset, len(train_list), len(test_list)
|
||||
))
|
||||
|
||||
return train_list, test_list
|
||||
|
||||
|
||||
def convert_data(train_list, args: Namespace, train: bool):
|
||||
def _load_data(img_path: str, train: bool, args: Namespace):
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
gt_path = (img_path
|
||||
.replace(".jpg", ".h5")
|
||||
.replace("images", "gt_density_map"))
|
||||
|
||||
while True:
|
||||
try:
|
||||
gt_file = h5py.File(gt_path)
|
||||
gt_count = np.asarray(gt_file["gt_count"])
|
||||
break
|
||||
except OSError:
|
||||
print("[dataset] Load error on \'{}\'", img_path)
|
||||
|
||||
img = img.copy()
|
||||
gt_count = gt_count.copy()
|
||||
|
||||
return img, gt_count
|
||||
|
||||
|
||||
print("[dataset] Pre-loading dataset...\n{}".format("-" * 50))
|
||||
data_keys = []
|
||||
for i in range(len(train_list)):
|
||||
img_path = train_list[i]
|
||||
fname = os.path.basename(img_path)
|
||||
img, gt_count = _load_data(img_path, train, args)
|
||||
|
||||
pack = {
|
||||
"img": img,
|
||||
"gt_count": gt_count,
|
||||
"fname": fname,
|
||||
}
|
||||
data_keys.append(pack)
|
||||
|
||||
return data_keys
|
||||
|
||||
|
||||
class ListDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
shape = None,
|
||||
shuffle: bool = True,
|
||||
transform = None,
|
||||
train: bool = False,
|
||||
seen: int = 0,
|
||||
batch_size: int = 1,
|
||||
nr_workers: int = 4,
|
||||
args: Namespace = None,
|
||||
):
|
||||
if train:
|
||||
random.shuffle(root)
|
||||
|
||||
self.nr_samples = len(root)
|
||||
self.lines = root
|
||||
self.transform = transform
|
||||
self.train = train
|
||||
self.shape = shape
|
||||
self.seen = seen
|
||||
self.batch_size = batch_size
|
||||
self.nr_workers = nr_workers
|
||||
self.args = args
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.nr_samples
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
assert index <= len(self), "Index out-of-bounds"
|
||||
|
||||
fname = self.lines[index]["fname"]
|
||||
img = self.lines[index]["img"]
|
||||
gt_count = self.lines[index]["gt_count"]
|
||||
|
||||
# Data augmentation
|
||||
if self.train:
|
||||
if random.random() > .5:
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
# XXX: do random noise?
|
||||
|
||||
gt_count = gt_count.copy()
|
||||
img = img.copy()
|
||||
|
||||
# Custom transform
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
|
||||
if self.train:
|
||||
return fname, img, gt_count
|
||||
else:
|
||||
device = args.device
|
||||
height, width = img.shape[1], img.shape[2]
|
||||
m = int(width / 384)
|
||||
n = int(height / 384)
|
||||
for i in range(m):
|
||||
for j in range(n):
|
||||
if i == 0 and j == 0:
|
||||
img_ret = img[
|
||||
:, # C
|
||||
j * 384 : 384 * (j + 1), # H
|
||||
i * 384 : 384 * (i + 1), # W
|
||||
].to(device).unsqueeze(0)
|
||||
else:
|
||||
cropped = img[
|
||||
:, # C
|
||||
j * 384 : 384 * (j + 1), # H
|
||||
i * 384 : 384 * (i + 1), # W
|
||||
].to(device).unsqueeze(0)
|
||||
img_ret = torch.cat([img_ret, cropped], 0).to(device)
|
||||
return fname, img_ret, gt_count
|
||||
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue