TEST: train on gt_count instead of kpoint
This commit is contained in:
parent
ee50e84946
commit
83fcc43f0b
3 changed files with 26 additions and 39 deletions
24
train.py
24
train.py
|
|
@ -239,16 +239,19 @@ def train_one_epoch(
|
|||
model.train()
|
||||
|
||||
# In one epoch, for each training sample
|
||||
for i, (fname, img, kpoint) in enumerate(train_loader):
|
||||
for i, (fname, img, kpoint, gt_count) in enumerate(train_loader):
|
||||
kpoint = kpoint.type(torch.FloatTensor)
|
||||
gt_count = gt_count.type(torch.FloatTensor)
|
||||
# fpass
|
||||
if device is not None:
|
||||
img = img.to(device)
|
||||
kpoint = kpoint.to(device)
|
||||
gt_count = gt_count.to(device)
|
||||
elif torch.cuda.is_available():
|
||||
img = img.cuda()
|
||||
kpoint = kpoint.cuda()
|
||||
out, gt_count = model(img, kpoint)
|
||||
gt_count = gt_count.cuda()
|
||||
out, _ = model(img, kpoint)
|
||||
|
||||
# loss
|
||||
loss = criterion(out, gt_count)
|
||||
|
|
@ -282,27 +285,30 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
visi = []
|
||||
index = 0
|
||||
|
||||
for i, (fname, img, kpoint) in enumerate(test_loader):
|
||||
for i, (fname, img, kpoint, gt_count) in enumerate(test_loader):
|
||||
kpoint = kpoint.type(torch.FloatTensor)
|
||||
gt_count = gt_count.type(torch.FloatTensor)
|
||||
if device is not None:
|
||||
img = img.to(device)
|
||||
kpoint = kpoint.to(device)
|
||||
gt_count = gt_count.to(device)
|
||||
elif torch.cuda.is_available():
|
||||
img = img.cuda()
|
||||
kpoint = kpoint.cuda()
|
||||
gt_count = gt_count.cuda()
|
||||
|
||||
# XXX: what do this do
|
||||
# XXX: do this even happen?
|
||||
if len(img.shape) == 5:
|
||||
img = img.squeeze(0)
|
||||
if len(img.shape) == 3:
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
out, gt_count = model(img, kpoint)
|
||||
out, _ = model(img, kpoint)
|
||||
pred_count = torch.squeeze(out, 1)
|
||||
gt_count = torch.squeeze(gt_count, 1)
|
||||
# gt_count = torch.squeeze(gt_count, 1)
|
||||
|
||||
diff = torch.sum(torch.abs(gt_count - pred_count)).item()
|
||||
diff = torch.abs(gt_count - torch.sum(pred_count)).item()
|
||||
mae += diff
|
||||
mse += diff ** 2
|
||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||
|
|
@ -325,12 +331,10 @@ if __name__ == "__main__":
|
|||
tuner_params = nni.get_next_parameter()
|
||||
logger.debug("Generated hyperparameters: {}", tuner_params)
|
||||
combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
|
||||
|
||||
if args.debug:
|
||||
os.nice(15)
|
||||
|
||||
#combined_params = args
|
||||
#logger.debug("Parameters: {}", combined_params)
|
||||
|
||||
if combined_params.use_ddp:
|
||||
# Use DDP, spawn threads
|
||||
torch_mp.spawn(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue