diff --git a/train.py b/train.py index d556077..e79ff5a 100644 --- a/train.py +++ b/train.py @@ -335,10 +335,9 @@ def valid_one_epoch(test_loader, model, device, epoch, args): mse += diff ** 2 if i % 5 == 0: - if isinstance(model, STNet_VisionTransformerGAP): - with torch.no_grad(): - img_xformed = model.stnet(img).to("cpu") - xformed.append(img_xformed) + # with torch.no_grad(): + # img_xformed = model.stnet(img).to("cpu") + # xformed.append(img_xformed) print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format( fname[0], torch.sum(gt_count_whole).item(),