TEST: train on gt_count instead of kpoint
This commit is contained in:
parent
ee50e84946
commit
83fcc43f0b
3 changed files with 26 additions and 39 deletions
37
dataset.py
37
dataset.py
|
|
@ -56,14 +56,16 @@ def convert_data(train_list, args: Namespace, train: bool):
|
|||
try:
|
||||
gt_file = h5py.File(gt_path)
|
||||
kpoint = np.asarray(gt_file["kpoint"])
|
||||
gt_count = np.asarray(gt_file["gt_count"])
|
||||
break
|
||||
except OSError:
|
||||
print("[dataset] Load error on \'{}\'", img_path)
|
||||
|
||||
img = img.copy()
|
||||
kpoint = kpoint.copy()
|
||||
gt_count = gt_count.copy()
|
||||
|
||||
return img, kpoint
|
||||
return img, kpoint, gt_count
|
||||
|
||||
|
||||
print("[dataset] Pre-loading dataset...\n{}".format("-" * 50))
|
||||
|
|
@ -71,11 +73,12 @@ def convert_data(train_list, args: Namespace, train: bool):
|
|||
for i in range(len(train_list)):
|
||||
img_path = train_list[i]
|
||||
fname = os.path.basename(img_path)
|
||||
img, kpoint = _load_data(img_path, train, args)
|
||||
img, kpoint, gt_count = _load_data(img_path, train, args)
|
||||
|
||||
pack = {
|
||||
"img": img,
|
||||
"kpoint": kpoint,
|
||||
"gt_count": gt_count,
|
||||
"fname": fname,
|
||||
}
|
||||
data_keys.append(pack)
|
||||
|
|
@ -120,6 +123,7 @@ class ListDataset(Dataset):
|
|||
fname = self.lines[index]["fname"]
|
||||
img = self.lines[index]["img"]
|
||||
kpoint = self.lines[index]["kpoint"]
|
||||
gt_count = self.lines[index]["gt_count"]
|
||||
|
||||
# Data augmentation
|
||||
if self.train:
|
||||
|
|
@ -129,35 +133,10 @@ class ListDataset(Dataset):
|
|||
|
||||
kpoint = kpoint.copy()
|
||||
img = img.copy()
|
||||
gt_count = gt_count.copy()
|
||||
|
||||
# Custom transform
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
|
||||
return fname, img, kpoint
|
||||
|
||||
# if self.train:
|
||||
# return fname, img, gt_count
|
||||
# else:
|
||||
# device = args.device
|
||||
# height, width = img.shape[1], img.shape[2]
|
||||
# m = int(width / 384)
|
||||
# n = int(height / 384)
|
||||
# for i in range(m):
|
||||
# for j in range(n):
|
||||
# if i == 0 and j == 0:
|
||||
# img_ret = img[
|
||||
# :, # C
|
||||
# j * 384 : 384 * (j + 1), # H
|
||||
# i * 384 : 384 * (i + 1), # W
|
||||
# ].to(device).unsqueeze(0)
|
||||
# else:
|
||||
# cropped = img[
|
||||
# :, # C
|
||||
# j * 384 : 384 * (j + 1), # H
|
||||
# i * 384 : 384 * (i + 1), # W
|
||||
# ].to(device).unsqueeze(0)
|
||||
# img_ret = torch.cat([img_ret, cropped], 0).to(device)
|
||||
# return fname, img_ret, gt_count
|
||||
|
||||
return fname, img, kpoint, gt_count
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue