TEST: train on gt_count instead of kpoint
This commit is contained in:
parent
ee50e84946
commit
83fcc43f0b
3 changed files with 26 additions and 39 deletions
37
dataset.py
37
dataset.py
|
|
@ -56,14 +56,16 @@ def convert_data(train_list, args: Namespace, train: bool):
|
||||||
try:
|
try:
|
||||||
gt_file = h5py.File(gt_path)
|
gt_file = h5py.File(gt_path)
|
||||||
kpoint = np.asarray(gt_file["kpoint"])
|
kpoint = np.asarray(gt_file["kpoint"])
|
||||||
|
gt_count = np.asarray(gt_file["gt_count"])
|
||||||
break
|
break
|
||||||
except OSError:
|
except OSError:
|
||||||
print("[dataset] Load error on \'{}\'", img_path)
|
print("[dataset] Load error on \'{}\'", img_path)
|
||||||
|
|
||||||
img = img.copy()
|
img = img.copy()
|
||||||
kpoint = kpoint.copy()
|
kpoint = kpoint.copy()
|
||||||
|
gt_count = gt_count.copy()
|
||||||
|
|
||||||
return img, kpoint
|
return img, kpoint, gt_count
|
||||||
|
|
||||||
|
|
||||||
print("[dataset] Pre-loading dataset...\n{}".format("-" * 50))
|
print("[dataset] Pre-loading dataset...\n{}".format("-" * 50))
|
||||||
|
|
@ -71,11 +73,12 @@ def convert_data(train_list, args: Namespace, train: bool):
|
||||||
for i in range(len(train_list)):
|
for i in range(len(train_list)):
|
||||||
img_path = train_list[i]
|
img_path = train_list[i]
|
||||||
fname = os.path.basename(img_path)
|
fname = os.path.basename(img_path)
|
||||||
img, kpoint = _load_data(img_path, train, args)
|
img, kpoint, gt_count = _load_data(img_path, train, args)
|
||||||
|
|
||||||
pack = {
|
pack = {
|
||||||
"img": img,
|
"img": img,
|
||||||
"kpoint": kpoint,
|
"kpoint": kpoint,
|
||||||
|
"gt_count": gt_count,
|
||||||
"fname": fname,
|
"fname": fname,
|
||||||
}
|
}
|
||||||
data_keys.append(pack)
|
data_keys.append(pack)
|
||||||
|
|
@ -120,6 +123,7 @@ class ListDataset(Dataset):
|
||||||
fname = self.lines[index]["fname"]
|
fname = self.lines[index]["fname"]
|
||||||
img = self.lines[index]["img"]
|
img = self.lines[index]["img"]
|
||||||
kpoint = self.lines[index]["kpoint"]
|
kpoint = self.lines[index]["kpoint"]
|
||||||
|
gt_count = self.lines[index]["gt_count"]
|
||||||
|
|
||||||
# Data augmentation
|
# Data augmentation
|
||||||
if self.train:
|
if self.train:
|
||||||
|
|
@ -129,35 +133,10 @@ class ListDataset(Dataset):
|
||||||
|
|
||||||
kpoint = kpoint.copy()
|
kpoint = kpoint.copy()
|
||||||
img = img.copy()
|
img = img.copy()
|
||||||
|
gt_count = gt_count.copy()
|
||||||
|
|
||||||
# Custom transform
|
# Custom transform
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
img = self.transform(img)
|
img = self.transform(img)
|
||||||
|
|
||||||
|
return fname, img, kpoint, gt_count
|
||||||
return fname, img, kpoint
|
|
||||||
|
|
||||||
# if self.train:
|
|
||||||
# return fname, img, gt_count
|
|
||||||
# else:
|
|
||||||
# device = args.device
|
|
||||||
# height, width = img.shape[1], img.shape[2]
|
|
||||||
# m = int(width / 384)
|
|
||||||
# n = int(height / 384)
|
|
||||||
# for i in range(m):
|
|
||||||
# for j in range(n):
|
|
||||||
# if i == 0 and j == 0:
|
|
||||||
# img_ret = img[
|
|
||||||
# :, # C
|
|
||||||
# j * 384 : 384 * (j + 1), # H
|
|
||||||
# i * 384 : 384 * (i + 1), # W
|
|
||||||
# ].to(device).unsqueeze(0)
|
|
||||||
# else:
|
|
||||||
# cropped = img[
|
|
||||||
# :, # C
|
|
||||||
# j * 384 : 384 * (j + 1), # H
|
|
||||||
# i * 384 : 384 * (i + 1), # W
|
|
||||||
# ].to(device).unsqueeze(0)
|
|
||||||
# img_ret = torch.cat([img_ret, cropped], 0).to(device)
|
|
||||||
# return fname, img_ret, gt_count
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -96,6 +96,9 @@ def pre_dataset_sh():
|
||||||
) # To same shape as image, so i, j flipped wrt. coordinates
|
) # To same shape as image, so i, j flipped wrt. coordinates
|
||||||
kpoint = sparse_mat.toarray()
|
kpoint = sparse_mat.toarray()
|
||||||
|
|
||||||
|
# Sum count as ground truth (we need to train STN, remember?)
|
||||||
|
gt_count = sparse_mat.nnz
|
||||||
|
|
||||||
fname = img_path.split("/")[-1]
|
fname = img_path.split("/")[-1]
|
||||||
root_path = img_path.split("IMG_")[0].replace("images", "images_crop")
|
root_path = img_path.split("IMG_")[0].replace("images", "images_crop")
|
||||||
|
|
||||||
|
|
@ -108,6 +111,7 @@ def pre_dataset_sh():
|
||||||
mode='w'
|
mode='w'
|
||||||
) as hf:
|
) as hf:
|
||||||
hf["kpoint"] = kpoint
|
hf["kpoint"] = kpoint
|
||||||
|
hf["gt_count"] = gt_count
|
||||||
|
|
||||||
|
|
||||||
def make_npydata():
|
def make_npydata():
|
||||||
|
|
|
||||||
24
train.py
24
train.py
|
|
@ -239,16 +239,19 @@ def train_one_epoch(
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
# In one epoch, for each training sample
|
# In one epoch, for each training sample
|
||||||
for i, (fname, img, kpoint) in enumerate(train_loader):
|
for i, (fname, img, kpoint, gt_count) in enumerate(train_loader):
|
||||||
kpoint = kpoint.type(torch.FloatTensor)
|
kpoint = kpoint.type(torch.FloatTensor)
|
||||||
|
gt_count = gt_count.type(torch.FloatTensor)
|
||||||
# fpass
|
# fpass
|
||||||
if device is not None:
|
if device is not None:
|
||||||
img = img.to(device)
|
img = img.to(device)
|
||||||
kpoint = kpoint.to(device)
|
kpoint = kpoint.to(device)
|
||||||
|
gt_count = gt_count.to(device)
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
img = img.cuda()
|
img = img.cuda()
|
||||||
kpoint = kpoint.cuda()
|
kpoint = kpoint.cuda()
|
||||||
out, gt_count = model(img, kpoint)
|
gt_count = gt_count.cuda()
|
||||||
|
out, _ = model(img, kpoint)
|
||||||
|
|
||||||
# loss
|
# loss
|
||||||
loss = criterion(out, gt_count)
|
loss = criterion(out, gt_count)
|
||||||
|
|
@ -282,27 +285,30 @@ def valid_one_epoch(test_loader, model, device, args):
|
||||||
visi = []
|
visi = []
|
||||||
index = 0
|
index = 0
|
||||||
|
|
||||||
for i, (fname, img, kpoint) in enumerate(test_loader):
|
for i, (fname, img, kpoint, gt_count) in enumerate(test_loader):
|
||||||
kpoint = kpoint.type(torch.FloatTensor)
|
kpoint = kpoint.type(torch.FloatTensor)
|
||||||
|
gt_count = gt_count.type(torch.FloatTensor)
|
||||||
if device is not None:
|
if device is not None:
|
||||||
img = img.to(device)
|
img = img.to(device)
|
||||||
kpoint = kpoint.to(device)
|
kpoint = kpoint.to(device)
|
||||||
|
gt_count = gt_count.to(device)
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
img = img.cuda()
|
img = img.cuda()
|
||||||
kpoint = kpoint.cuda()
|
kpoint = kpoint.cuda()
|
||||||
|
gt_count = gt_count.cuda()
|
||||||
|
|
||||||
# XXX: what do this do
|
# XXX: do this even happen?
|
||||||
if len(img.shape) == 5:
|
if len(img.shape) == 5:
|
||||||
img = img.squeeze(0)
|
img = img.squeeze(0)
|
||||||
if len(img.shape) == 3:
|
if len(img.shape) == 3:
|
||||||
img = img.unsqueeze(0)
|
img = img.unsqueeze(0)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out, gt_count = model(img, kpoint)
|
out, _ = model(img, kpoint)
|
||||||
pred_count = torch.squeeze(out, 1)
|
pred_count = torch.squeeze(out, 1)
|
||||||
gt_count = torch.squeeze(gt_count, 1)
|
# gt_count = torch.squeeze(gt_count, 1)
|
||||||
|
|
||||||
diff = torch.sum(torch.abs(gt_count - pred_count)).item()
|
diff = torch.abs(gt_count - torch.sum(pred_count)).item()
|
||||||
mae += diff
|
mae += diff
|
||||||
mse += diff ** 2
|
mse += diff ** 2
|
||||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||||
|
|
@ -325,12 +331,10 @@ if __name__ == "__main__":
|
||||||
tuner_params = nni.get_next_parameter()
|
tuner_params = nni.get_next_parameter()
|
||||||
logger.debug("Generated hyperparameters: {}", tuner_params)
|
logger.debug("Generated hyperparameters: {}", tuner_params)
|
||||||
combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
|
combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
|
||||||
|
|
||||||
if args.debug:
|
if args.debug:
|
||||||
os.nice(15)
|
os.nice(15)
|
||||||
|
|
||||||
#combined_params = args
|
|
||||||
#logger.debug("Parameters: {}", combined_params)
|
|
||||||
|
|
||||||
if combined_params.use_ddp:
|
if combined_params.use_ddp:
|
||||||
# Use DDP, spawn threads
|
# Use DDP, spawn threads
|
||||||
torch_mp.spawn(
|
torch_mp.spawn(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue