mlp-project/dataset.py
rubberhead 12aabb0d3f More working than not
Not sure if validation works, call it a day
2024-03-03 03:16:54 +00:00

163 lines
4.8 KiB
Python

import os
import random
from argparse import Namespace
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from PIL import Image
import numpy as np
import numbers
import h5py
import cv2
def unpack_npy_data(args: Namespace):
"""Unpack npy data aas np lists at hard-coded paths wrt cwd."""
if args.train_dataset == "ShanghaiA":
train_file = "./npydata/ShanghaiA_train.npy"
test_file = "./npydata/ShanghaiA_test.npy"
elif args.train_dataset == "ShanghaiB":
train_file = "./npydata/ShanghaiB_train.npy"
test_file = "./npydata/ShanghaiB_test.npy"
elif args.train_dataset == "UCF_QNRF":
train_file = "./npydata/qnrf_train.npy"
test_file = "./npydata/qnrf_test.npy"
elif args.train_dataset == "JHU":
train_file = "./npydata/jhu_train.npy"
test_file = "./npydata/jhu_test.npy"
elif args.train_dataset == "NWPU":
train_file = "./npydata/nwpu_train.npy"
test_file = "./npydata/nwpu_test.npy"
assert all([fdir is not None for fdir in [train_file, test_file]])
with open(train_file, "rb") as fd:
train_list = np.load(fd).tolist()
with open(test_file, "rb") as fd:
test_list = np.load(fd).tolist()
print("[dataset] Loaded \"{}\": train: {} | test: {}".format(
args.train_dataset, len(train_list), len(test_list)
))
return train_list, test_list
def convert_data(train_list, args: Namespace, train: bool):
def _load_data(img_path: str, train: bool, args: Namespace):
img = Image.open(img_path).convert("RGB")
gt_path = (img_path
.replace(".jpg", ".h5")
.replace("images", "gt_density_map"))
while True:
try:
gt_file = h5py.File(gt_path)
kpoint = np.asarray(gt_file["kpoint"])
break
except OSError:
print("[dataset] Load error on \'{}\'", img_path)
img = img.copy()
kpoint = kpoint.copy()
return img, kpoint
print("[dataset] Pre-loading dataset...\n{}".format("-" * 50))
data_keys = []
for i in range(len(train_list)):
img_path = train_list[i]
fname = os.path.basename(img_path)
img, kpoint = _load_data(img_path, train, args)
pack = {
"img": img,
"kpoint": kpoint,
"fname": fname,
}
data_keys.append(pack)
return data_keys
class ListDataset(Dataset):
def __init__(
self,
root,
shape = None,
shuffle: bool = True,
transform = None,
train: bool = False,
seen: int = 0,
batch_size: int = 1,
nr_workers: int = 4,
args: Namespace = None,
):
if train:
random.shuffle(root)
self.nr_samples = len(root)
self.lines = root
self.transform = transform
self.train = train
self.shape = shape
self.seen = seen
self.batch_size = batch_size
self.nr_workers = nr_workers
self.args = args
def __len__(self):
return self.nr_samples
def __getitem__(self, index):
assert index <= len(self), "Index out-of-bounds"
fname = self.lines[index]["fname"]
img = self.lines[index]["img"]
kpoint = self.lines[index]["kpoint"]
# Data augmentation
if self.train:
if random.random() > .5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
# XXX: do random noise?
kpoint = kpoint.copy()
img = img.copy()
# Custom transform
if self.transform is not None:
img = self.transform(img)
return fname, img, kpoint
# if self.train:
# return fname, img, gt_count
# else:
# device = args.device
# height, width = img.shape[1], img.shape[2]
# m = int(width / 384)
# n = int(height / 384)
# for i in range(m):
# for j in range(n):
# if i == 0 and j == 0:
# img_ret = img[
# :, # C
# j * 384 : 384 * (j + 1), # H
# i * 384 : 384 * (i + 1), # W
# ].to(device).unsqueeze(0)
# else:
# cropped = img[
# :, # C
# j * 384 : 384 * (j + 1), # H
# i * 384 : 384 * (i + 1), # W
# ].to(device).unsqueeze(0)
# img_ret = torch.cat([img_ret, cropped], 0).to(device)
# return fname, img_ret, gt_count