This commit is contained in:
Zhengyi Chen 2024-03-03 19:40:22 +00:00
parent a9dd8dee04
commit ab15419d2f
5 changed files with 63 additions and 19 deletions

View file

@ -106,13 +106,16 @@ def worker(rank: int, args: Namespace):
if args.use_ddp and torch.cuda.is_available():
device = torch.device(rank)
elif torch.cuda.is_available():
device = torch.device(args.gpus)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
args.gpus = [int(gpu_id) for gpu_id in args.gpus.split(",")]
device = None
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
print("[!!!] Using CPU for inference. This will be slow...")
device = torch.device("cpu")
torch.set_default_device(device)
if device is not None:
torch.set_default_device(device)
# Prepare training data
train_list, test_list = unpack_npy_data(args)
@ -123,9 +126,9 @@ def worker(rank: int, args: Namespace):
# Instantiate model
if args.model == "stn":
model = stn_patch16_384_gap(args.pth_tar).to(device)
model = stn_patch16_384_gap(args.pth_tar)
else:
model = base_patch16_384_gap(args.pth_tar).to(device)
model = base_patch16_384_gap(args.pth_tar)
if args.use_ddp:
model = nn.parallel.DistributedDataParallel(
@ -140,8 +143,17 @@ def worker(rank: int, args: Namespace):
device_ids=args.gpus
)
if device is not None:
model = model.to(device)
elif torch.cuda.is_available():
model = model.cuda()
# criterion, optimizer, scheduler
criterion = nn.L1Loss(size_average=False).to(device)
criterion = nn.L1Loss(size_average=False)
if device is not None:
criterion = criterion.to(device)
elif torch.cuda.is_available():
criterion = criterion.cuda()
optimizer = torch.optim.Adam(
[{"params": model.parameters(), "lr": args.lr}],
lr=args.lr,
@ -184,7 +196,7 @@ def worker(rank: int, args: Namespace):
end_train = time.time()
# Validate
if epoch % 5 == 0:
if epoch % 5 == 0 or args.debug:
prec1 = valid_one_epoch(test_loader, model, device, args)
end_valid = time.time()
is_best = prec1 < args.best_pred
@ -232,8 +244,12 @@ def train_one_epoch(
kpoint = kpoint.type(torch.FloatTensor)
print("Training: img {} | kpoint {}".format(img.shape, kpoint.shape))
# fpass
img = img.to(device)
kpoint = kpoint.to(device)
if device is not None:
img = img.to(device)
kpoint = kpoint.to(device)
elif torch.cuda.is_available():
img = img.cuda()
kpoint = kpoint.cuda()
out, gt_count = model(img, kpoint)
# gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
@ -266,8 +282,14 @@ def valid_one_epoch(test_loader, model, device, args):
visi = []
index = 0
for i, (fname, img, gt_count) in enumerate(test_loader):
img = img.to(device)
for i, (fname, img, kpoint) in enumerate(test_loader):
if device is not None:
img = img.to(device)
kpoint = kpoint.to(device)
elif torch.cuda.is_available():
img = img.cuda()
kpoint = kpoint.cuda()
# XXX: what do this do
if len(img.shape) == 5:
img = img.squeeze(0)
@ -275,12 +297,12 @@ def valid_one_epoch(test_loader, model, device, args):
img = img.unsqueeze(0)
with torch.no_grad():
out = model(img)
out, gt_count = model(img, kpoint)
count = torch.sum(out).item()
gt_count = torch.sum(gt_count).item()
mae += abs(gt_count - count)
mse += abs(gt_count - count) ** 2
mae += abs(kpoint - count)
mse += abs(kpoint - count) ** 2
if i % 15 == 0:
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(