More working than not
Not sure if validation works, call it a day
This commit is contained in:
parent
4a03211c83
commit
12aabb0d3f
10 changed files with 116 additions and 105 deletions
23
train.py
23
train.py
|
|
@ -29,7 +29,7 @@ def setup_process_group(
|
|||
master_addr: str = "localhost",
|
||||
master_port: Optional[np.ushort] = None
|
||||
):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_ADDR"] = master_addr
|
||||
os.environ["MASTER_PORT"] = (
|
||||
str(random.randint(40000, 65545))
|
||||
if master_port is None
|
||||
|
|
@ -121,7 +121,6 @@ def worker(rank: int, args: Namespace):
|
|||
train_loader = build_train_loader(train_data, args)
|
||||
test_loader = build_test_loader(test_data, args)
|
||||
|
||||
|
||||
# Instantiate model
|
||||
if args.model == "stn":
|
||||
model = stn_patch16_384_gap(args.pth_tar).to(device)
|
||||
|
|
@ -229,11 +228,14 @@ def train_one_epoch(
|
|||
model.train()
|
||||
|
||||
# In one epoch, for each training sample
|
||||
for i, (fname, img, gt_count) in enumerate(train_loader):
|
||||
for i, (fname, img, kpoint) in enumerate(train_loader):
|
||||
kpoint = kpoint.type(torch.FloatTensor)
|
||||
print("Training: img {} | kpoint {}".format(img.shape, kpoint.shape))
|
||||
# fpass
|
||||
img = img.to(device)
|
||||
out = model(img)
|
||||
gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
|
||||
kpoint = kpoint.to(device)
|
||||
out, gt_count = model(img, kpoint)
|
||||
# gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
|
||||
|
||||
# loss
|
||||
loss = criterion(out, gt_count)
|
||||
|
|
@ -288,7 +290,7 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
||||
|
||||
nni.report_intermediate_result(mae)
|
||||
# nni.report_intermediate_result(mae)
|
||||
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
||||
mae=mae, mse=mse
|
||||
))
|
||||
|
|
@ -297,11 +299,10 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tuner_params = nni.get_next_parameter()
|
||||
logger.debug("Generated hyperparameters: {}", tuner_params)
|
||||
combined_params = Namespace(
|
||||
nni.utils.merge_parameter(ret_args, tuner_params)
|
||||
) # Namespaces have better ergonomics, notably a struct-like access syntax.
|
||||
# tuner_params = nni.get_next_parameter()
|
||||
# logger.debug("Generated hyperparameters: {}", tuner_params)
|
||||
# combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
|
||||
combined_params = args
|
||||
logger.debug("Parameters: {}", combined_params)
|
||||
|
||||
if combined_params.use_ddp:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue