From b8f7f922f10ca844839101a0636f06e8ef4dca78 Mon Sep 17 00:00:00 2001 From: rubberhead Date: Sun, 25 Feb 2024 21:15:50 +0000 Subject: [PATCH] FIX: valid_loader instead of test_loader --- model/stn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/stn.py b/model/stn.py index e1e7199..0fcf786 100644 --- a/model/stn.py +++ b/model/stn.py @@ -171,7 +171,7 @@ if __name__ == "__main__": def visualize_stn(): with torch.no_grad(): # Get a batch of training data - data = next(iter(test_loader))[0].to(device) + data = next(iter(valid_loader))[0].to(device) input_tensor = data.cpu() transformed_input_tensor = model.stn(data).cpu()