TEST: train on gt_count instead of kpoint

This commit is contained in:
Zhengyi Chen 2024-03-04 01:49:09 +00:00
parent ee50e84946
commit 83fcc43f0b
3 changed files with 26 additions and 39 deletions

View file

@ -96,6 +96,9 @@ def pre_dataset_sh():
) # To same shape as image, so i, j flipped wrt. coordinates
kpoint = sparse_mat.toarray()
# Sum count as ground truth (we need to train STN, remember?)
gt_count = sparse_mat.nnz
fname = img_path.split("/")[-1]
root_path = img_path.split("IMG_")[0].replace("images", "images_crop")
@ -108,6 +111,7 @@ def pre_dataset_sh():
mode='w'
) as hf:
hf["kpoint"] = kpoint
hf["gt_count"] = gt_count
def make_npydata():