Forgot to commit this

This commit is contained in:
Zhengyi Chen 2024-03-03 02:00:03 +00:00
parent b69727b74f
commit 57510503d0

162
dataset.py Normal file
View file

@ -0,0 +1,162 @@
import os
import random
from argparse import Namespace
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from image import Image
import numpy as np
import numbers
import h5py
import cv2
def unpack_npy_data(args: Namespace):
"""Unpack npy data aas np lists at hard-coded paths wrt cwd."""
if args.dataset == "ShanghaiA":
train_file = "./npydata/ShanghaiA_train.npy"
test_file = "./npydata/ShanghaiA_test.npy"
elif args.dataset == "ShanghaiB":
train_file = "./npydata/ShanghaiB_train.npy"
test_file = "./npydata/ShanghaiB_test.npy"
elif args.dataset == "UCF_QNRF":
train_file = "./npydata/qnrf_train.npy"
test_file = "./npydata/qnrf_test.npy"
elif args.dataset == "JHU":
train_file = "./npydata/jhu_train.npy"
test_file = "./npydata/jhu_test.npy"
elif args.dataset == "NWPU":
train_file = "./npydata/nwpu_train.npy"
test_file = "./npydata/nwpu_test.npy"
assert any([fdir is not None for fdir in [train_file, test_file]])
with open(train_file, "rb") as fd:
train_list = np.load(fd).tolist()
with open(test_file, "rb") as fd:
test_list = np.load(fd).tolist()
print("[dataset] Loaded \"{}\": train: {} | test: {}".format(
args.dataset, len(train_list), len(test_list)
))
return train_list, test_list
def convert_data(train_list, args: Namespace, train: bool):
def _load_data(img_path: str, train: bool, args: Namespace):
img = Image.open(img_path).convert("RGB")
gt_path = (img_path
.replace(".jpg", ".h5")
.replace("images", "gt_density_map"))
while True:
try:
gt_file = h5py.File(gt_path)
gt_count = np.asarray(gt_file["gt_count"])
break
except OSError:
print("[dataset] Load error on \'{}\'", img_path)
img = img.copy()
gt_count = gt_count.copy()
return img, gt_count
print("[dataset] Pre-loading dataset...\n{}".format("-" * 50))
data_keys = []
for i in range(len(train_list)):
img_path = train_list[i]
fname = os.path.basename(img_path)
img, gt_count = _load_data(img_path, train, args)
pack = {
"img": img,
"gt_count": gt_count,
"fname": fname,
}
data_keys.append(pack)
return data_keys
class ListDataset(Dataset):
def __init__(
self,
root,
shape = None,
shuffle: bool = True,
transform = None,
train: bool = False,
seen: int = 0,
batch_size: int = 1,
nr_workers: int = 4,
args: Namespace = None,
):
if train:
random.shuffle(root)
self.nr_samples = len(root)
self.lines = root
self.transform = transform
self.train = train
self.shape = shape
self.seen = seen
self.batch_size = batch_size
self.nr_workers = nr_workers
self.args = args
def __len__(self):
return self.nr_samples
def __getitem__(self, index):
assert index <= len(self), "Index out-of-bounds"
fname = self.lines[index]["fname"]
img = self.lines[index]["img"]
gt_count = self.lines[index]["gt_count"]
# Data augmentation
if self.train:
if random.random() > .5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
# XXX: do random noise?
gt_count = gt_count.copy()
img = img.copy()
# Custom transform
if self.transform is not None:
img = self.transform(img)
if self.train:
return fname, img, gt_count
else:
device = args.device
height, width = img.shape[1], img.shape[2]
m = int(width / 384)
n = int(height / 384)
for i in range(m):
for j in range(n):
if i == 0 and j == 0:
img_ret = img[
:, # C
j * 384 : 384 * (j + 1), # H
i * 384 : 384 * (i + 1), # W
].to(device).unsqueeze(0)
else:
cropped = img[
:, # C
j * 384 : 384 * (j + 1), # H
i * 384 : 384 * (i + 1), # W
].to(device).unsqueeze(0)
img_ret = torch.cat([img_ret, cropped], 0).to(device)
return fname, img_ret, gt_count