diff --git a/dataset.py b/dataset.py index d11fee6..ffa78d3 100644 --- a/dataset.py +++ b/dataset.py @@ -56,14 +56,16 @@ def convert_data(train_list, args: Namespace, train: bool): try: gt_file = h5py.File(gt_path) kpoint = np.asarray(gt_file["kpoint"]) + gt_count = np.asarray(gt_file["gt_count"]) break except OSError: print("[dataset] Load error on \'{}\'", img_path) img = img.copy() kpoint = kpoint.copy() + gt_count = gt_count.copy() - return img, kpoint + return img, kpoint, gt_count print("[dataset] Pre-loading dataset...\n{}".format("-" * 50)) @@ -71,11 +73,12 @@ def convert_data(train_list, args: Namespace, train: bool): for i in range(len(train_list)): img_path = train_list[i] fname = os.path.basename(img_path) - img, kpoint = _load_data(img_path, train, args) + img, kpoint, gt_count = _load_data(img_path, train, args) pack = { "img": img, "kpoint": kpoint, + "gt_count": gt_count, "fname": fname, } data_keys.append(pack) @@ -120,6 +123,7 @@ class ListDataset(Dataset): fname = self.lines[index]["fname"] img = self.lines[index]["img"] kpoint = self.lines[index]["kpoint"] + gt_count = self.lines[index]["gt_count"] # Data augmentation if self.train: @@ -129,35 +133,10 @@ class ListDataset(Dataset): kpoint = kpoint.copy() img = img.copy() + gt_count = gt_count.copy() # Custom transform if self.transform is not None: img = self.transform(img) - - return fname, img, kpoint - - # if self.train: - # return fname, img, gt_count - # else: - # device = args.device - # height, width = img.shape[1], img.shape[2] - # m = int(width / 384) - # n = int(height / 384) - # for i in range(m): - # for j in range(n): - # if i == 0 and j == 0: - # img_ret = img[ - # :, # C - # j * 384 : 384 * (j + 1), # H - # i * 384 : 384 * (i + 1), # W - # ].to(device).unsqueeze(0) - # else: - # cropped = img[ - # :, # C - # j * 384 : 384 * (j + 1), # H - # i * 384 : 384 * (i + 1), # W - # ].to(device).unsqueeze(0) - # img_ret = torch.cat([img_ret, cropped], 0).to(device) - # return fname, img_ret, gt_count - + return fname, img, kpoint, gt_count diff --git a/preprocess_data.py b/preprocess_data.py index d1066c8..5bfe122 100644 --- a/preprocess_data.py +++ b/preprocess_data.py @@ -96,6 +96,9 @@ def pre_dataset_sh(): ) # To same shape as image, so i, j flipped wrt. coordinates kpoint = sparse_mat.toarray() + # Sum count as ground truth (we need to train STN, remember?) + gt_count = sparse_mat.nnz + fname = img_path.split("/")[-1] root_path = img_path.split("IMG_")[0].replace("images", "images_crop") @@ -108,6 +111,7 @@ def pre_dataset_sh(): mode='w' ) as hf: hf["kpoint"] = kpoint + hf["gt_count"] = gt_count def make_npydata(): diff --git a/train.py b/train.py index b85ef42..62c1000 100644 --- a/train.py +++ b/train.py @@ -239,16 +239,19 @@ def train_one_epoch( model.train() # In one epoch, for each training sample - for i, (fname, img, kpoint) in enumerate(train_loader): + for i, (fname, img, kpoint, gt_count) in enumerate(train_loader): kpoint = kpoint.type(torch.FloatTensor) + gt_count = gt_count.type(torch.FloatTensor) # fpass if device is not None: img = img.to(device) kpoint = kpoint.to(device) + gt_count = gt_count.to(device) elif torch.cuda.is_available(): img = img.cuda() kpoint = kpoint.cuda() - out, gt_count = model(img, kpoint) + gt_count = gt_count.cuda() + out, _ = model(img, kpoint) # loss loss = criterion(out, gt_count) @@ -282,27 +285,30 @@ def valid_one_epoch(test_loader, model, device, args): visi = [] index = 0 - for i, (fname, img, kpoint) in enumerate(test_loader): + for i, (fname, img, kpoint, gt_count) in enumerate(test_loader): kpoint = kpoint.type(torch.FloatTensor) + gt_count = gt_count.type(torch.FloatTensor) if device is not None: img = img.to(device) kpoint = kpoint.to(device) + gt_count = gt_count.to(device) elif torch.cuda.is_available(): img = img.cuda() kpoint = kpoint.cuda() + gt_count = gt_count.cuda() - # XXX: what do this do + # XXX: do this even happen? if len(img.shape) == 5: img = img.squeeze(0) if len(img.shape) == 3: img = img.unsqueeze(0) with torch.no_grad(): - out, gt_count = model(img, kpoint) + out, _ = model(img, kpoint) pred_count = torch.squeeze(out, 1) - gt_count = torch.squeeze(gt_count, 1) + # gt_count = torch.squeeze(gt_count, 1) - diff = torch.sum(torch.abs(gt_count - pred_count)).item() + diff = torch.abs(gt_count - torch.sum(pred_count)).item() mae += diff mse += diff ** 2 mae = mae * 1.0 / (len(test_loader) * batch_size) @@ -325,12 +331,10 @@ if __name__ == "__main__": tuner_params = nni.get_next_parameter() logger.debug("Generated hyperparameters: {}", tuner_params) combined_params = nni.utils.merge_parameter(ret_args, tuner_params) + if args.debug: os.nice(15) - #combined_params = args - #logger.debug("Parameters: {}", combined_params) - if combined_params.use_ddp: # Use DDP, spawn threads torch_mp.spawn(