Debug
This commit is contained in:
parent
a9dd8dee04
commit
ab15419d2f
5 changed files with 63 additions and 19 deletions
48
train.py
48
train.py
|
|
@ -106,13 +106,16 @@ def worker(rank: int, args: Namespace):
|
|||
if args.use_ddp and torch.cuda.is_available():
|
||||
device = torch.device(rank)
|
||||
elif torch.cuda.is_available():
|
||||
device = torch.device(args.gpus)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
args.gpus = [int(gpu_id) for gpu_id in args.gpus.split(",")]
|
||||
device = None
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
print("[!!!] Using CPU for inference. This will be slow...")
|
||||
device = torch.device("cpu")
|
||||
torch.set_default_device(device)
|
||||
if device is not None:
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Prepare training data
|
||||
train_list, test_list = unpack_npy_data(args)
|
||||
|
|
@ -123,9 +126,9 @@ def worker(rank: int, args: Namespace):
|
|||
|
||||
# Instantiate model
|
||||
if args.model == "stn":
|
||||
model = stn_patch16_384_gap(args.pth_tar).to(device)
|
||||
model = stn_patch16_384_gap(args.pth_tar)
|
||||
else:
|
||||
model = base_patch16_384_gap(args.pth_tar).to(device)
|
||||
model = base_patch16_384_gap(args.pth_tar)
|
||||
|
||||
if args.use_ddp:
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
|
|
@ -140,8 +143,17 @@ def worker(rank: int, args: Namespace):
|
|||
device_ids=args.gpus
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
model = model.to(device)
|
||||
elif torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
|
||||
# criterion, optimizer, scheduler
|
||||
criterion = nn.L1Loss(size_average=False).to(device)
|
||||
criterion = nn.L1Loss(size_average=False)
|
||||
if device is not None:
|
||||
criterion = criterion.to(device)
|
||||
elif torch.cuda.is_available():
|
||||
criterion = criterion.cuda()
|
||||
optimizer = torch.optim.Adam(
|
||||
[{"params": model.parameters(), "lr": args.lr}],
|
||||
lr=args.lr,
|
||||
|
|
@ -184,7 +196,7 @@ def worker(rank: int, args: Namespace):
|
|||
end_train = time.time()
|
||||
|
||||
# Validate
|
||||
if epoch % 5 == 0:
|
||||
if epoch % 5 == 0 or args.debug:
|
||||
prec1 = valid_one_epoch(test_loader, model, device, args)
|
||||
end_valid = time.time()
|
||||
is_best = prec1 < args.best_pred
|
||||
|
|
@ -232,8 +244,12 @@ def train_one_epoch(
|
|||
kpoint = kpoint.type(torch.FloatTensor)
|
||||
print("Training: img {} | kpoint {}".format(img.shape, kpoint.shape))
|
||||
# fpass
|
||||
img = img.to(device)
|
||||
kpoint = kpoint.to(device)
|
||||
if device is not None:
|
||||
img = img.to(device)
|
||||
kpoint = kpoint.to(device)
|
||||
elif torch.cuda.is_available():
|
||||
img = img.cuda()
|
||||
kpoint = kpoint.cuda()
|
||||
out, gt_count = model(img, kpoint)
|
||||
# gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
|
||||
|
||||
|
|
@ -266,8 +282,14 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
visi = []
|
||||
index = 0
|
||||
|
||||
for i, (fname, img, gt_count) in enumerate(test_loader):
|
||||
img = img.to(device)
|
||||
for i, (fname, img, kpoint) in enumerate(test_loader):
|
||||
if device is not None:
|
||||
img = img.to(device)
|
||||
kpoint = kpoint.to(device)
|
||||
elif torch.cuda.is_available():
|
||||
img = img.cuda()
|
||||
kpoint = kpoint.cuda()
|
||||
|
||||
# XXX: what do this do
|
||||
if len(img.shape) == 5:
|
||||
img = img.squeeze(0)
|
||||
|
|
@ -275,12 +297,12 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
img = img.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(img)
|
||||
out, gt_count = model(img, kpoint)
|
||||
count = torch.sum(out).item()
|
||||
|
||||
gt_count = torch.sum(gt_count).item()
|
||||
mae += abs(gt_count - count)
|
||||
mse += abs(gt_count - count) ** 2
|
||||
mae += abs(kpoint - count)
|
||||
mse += abs(kpoint - count) ** 2
|
||||
|
||||
if i % 15 == 0:
|
||||
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue