mlp-project/make_dataset.py

134 lines
No EOL
4.9 KiB
Python

import os
import glob
import random
import numpy as np
import cv2
import scipy.io as io
import h5py
CWD = os.getcwd()
def pre_dataset_sh():
root = CWD + "/synchronous/dataset/" + dataset_name + "/"
part_A_train = os.path.join(root, "part_A_final/train_data", "images")
part_A_test = os.path.join(root, "part_A_final/test_data", "images")
part_B_train = os.path.join(root, "part_B_final/train_data", "images")
part_B_test = os.path.join(root, "part_B_final/test_data", "images")
# Create cropped (to 1152x768) dataset directories
for base_path in part_A_train, part_A_test, part_B_train, part_B_test:
for replacement in "images_crop", "gt_density_map_crop":
if not os.path.exists(base_path.replace("images", replacement)):
os.makedirs(base_path.replace("images", replacement))
# Gather all jpg paths in part A & B, train & test
img_paths = []
for path in part_A_train, part_A_test, part_B_train, part_B_test:
for img_path in glob.glob(os.path.join(path, "*.jpg")):
img_paths.append(img_path)
# np.random.seed(0)
# random.seed(0)
for img_path in img_paths:
img_data = cv2.imread(img_path)
mat = io.loadmat(
img_path
.replace(".jpg", ".mat")
.replace("images", "ground_truth")
.replace("IMG_", "GT_IMG_")
)
gt_data = mat["image_info"][0][0][0][0][0]
# Resize to 1152x768
if img_data.shape[1] >= img_data.shape[0]: # landscape
rate_x = 1152.0 / img_data.shape[1]
rate_y = 768.0 / img_data.shape[0]
else: # portrait
rate_x = 768.0 / img_data.shape[1]
rate_y = 1152.0 / img_data.shape[0]
img_data = cv2.resize(img_data, (0, 0), fx=rate_x, fy=rate_y)
gt_data[:, 0] = gt_data[:, 0] * rate_x
gt_data[:, 1] = gt_data[:, 1] * rate_y
# Compute gt_count from density map (gt_data)
# XXX: what does it do exactly?
kpoint = np.zeros((img_data.shape[0], img_data.shape[1]))
for i in range(len(gt_data)):
if ( int(gt_data[i][1]) < img_data.shape[0]
and int(gt_data[i][0]) < img_data.shape[1]):
kpoint[int(gt_data[i][1]), int(gt_data[i][0])] = 1
fname = img_path.split("/")[-1]
root_path = img_path.split("IMG_")[0].replace("images", "images_crop")
# Likewise, we do not crop to patched sequences here...
# Skip directly to saving fixed-size data & gt_count.
img_path = img_path.replace("images", "images_crop")
cv2.imwrite(img_path, img_data)
gt_count = np.sum(kpoint)
with h5py.File(
img_path.replace('.jpg', '.h5').replace('images', 'gt_density_map'),
'w'
) as hf:
hf["gt_count"] = gt_count
def make_npydata():
if not os.path.exists("./npydata"):
os.makedirs("./npydata")
try:
sh_A_train_path = CWD + '/synchronous/dataset/ShanghaiTech/part_A_final/train_data/images_crop/'
sh_A_test_path = CWD + '/synchronous/dataset/ShanghaiTech/part_A_final/test_data/images_crop/'
train_fpaths = []
for fname in os.listdir(sh_A_train_path):
if fname.split(".")[1] == "jpg":
train_fpaths.append(sh_A_train_path + fname)
train_fpaths.sort()
np.save("./npydata/ShanghaiA_train.npy", train_fpaths)
test_fpaths = []
for fname in os.listdir(sh_A_test_path):
if fname.split(".")[1] == "jpg":
test_fpaths.append(sh_A_test_path + fname)
test_fpaths.sort()
np.save("./npydata/ShanghaiA_test.npy", test_fpaths)
print("Saved ShanghaiA image list (test: {} | train: {})".format(
len(test_fpaths), len(train_fpaths)
))
except:
print("The ShanghaiA dataset path is wrong.")
try:
sh_B_train_path = CWD + '/synchronous/dataset/ShanghaiTech/part_B_final/train_data/images_crop/'
sh_B_test_path = CWD + '/synchronous/dataset/ShanghaiTech/part_B_final/test_data/images_crop/'
train_fpaths = []
for fname in os.listdir(sh_B_train_path):
if fname.split(".")[1] == "jpg":
train_fpaths.append(sh_B_train_path + fname)
train_fpaths.sort()
np.save("./npydata/ShanghaiB_train.npy", train_fpaths)
test_fpaths = []
for fname in os.listdir(sh_B_test_path):
if fname.split(".")[1] == "jpg":
test_fpaths.append(sh_B_test_path + fname)
test_fpaths.sort()
np.save("./npydata/ShanghaiB_test.npy", test_fpaths)
print("Saved ShanghaiB image list (test: {} | train: {})".format(
len(test_fpaths), len(train_fpaths)
))
except:
print("The ShanghaiB dataset path is wrong.")
if __name__ == "__main__":
# Download manually...
pre_dataset_sh() # XXX: preliminary
make_npydata()