More working than not

Not sure if validation works, call it a day
This commit is contained in:
Zhengyi Chen 2024-03-03 03:16:54 +00:00
parent 4a03211c83
commit 12aabb0d3f
10 changed files with 116 additions and 105 deletions

View file

@ -29,7 +29,7 @@ def setup_process_group(
master_addr: str = "localhost",
master_port: Optional[np.ushort] = None
):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = (
str(random.randint(40000, 65545))
if master_port is None
@ -121,7 +121,6 @@ def worker(rank: int, args: Namespace):
train_loader = build_train_loader(train_data, args)
test_loader = build_test_loader(test_data, args)
# Instantiate model
if args.model == "stn":
model = stn_patch16_384_gap(args.pth_tar).to(device)
@ -229,11 +228,14 @@ def train_one_epoch(
model.train()
# In one epoch, for each training sample
for i, (fname, img, gt_count) in enumerate(train_loader):
for i, (fname, img, kpoint) in enumerate(train_loader):
kpoint = kpoint.type(torch.FloatTensor)
print("Training: img {} | kpoint {}".format(img.shape, kpoint.shape))
# fpass
img = img.to(device)
out = model(img)
gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
kpoint = kpoint.to(device)
out, gt_count = model(img, kpoint)
# gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
# loss
loss = criterion(out, gt_count)
@ -288,7 +290,7 @@ def valid_one_epoch(test_loader, model, device, args):
mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
nni.report_intermediate_result(mae)
# nni.report_intermediate_result(mae)
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
mae=mae, mse=mse
))
@ -297,11 +299,10 @@ def valid_one_epoch(test_loader, model, device, args):
if __name__ == "__main__":
tuner_params = nni.get_next_parameter()
logger.debug("Generated hyperparameters: {}", tuner_params)
combined_params = Namespace(
nni.utils.merge_parameter(ret_args, tuner_params)
) # Namespaces have better ergonomics, notably a struct-like access syntax.
# tuner_params = nni.get_next_parameter()
# logger.debug("Generated hyperparameters: {}", tuner_params)
# combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
combined_params = args
logger.debug("Parameters: {}", combined_params)
if combined_params.use_ddp: