FIX: valid_loader instead of test_loader

This commit is contained in:
Zhengyi Chen 2024-02-25 21:15:50 +00:00
parent 62df7464e4
commit b8f7f922f1

View file

@ -171,7 +171,7 @@ if __name__ == "__main__":
def visualize_stn(): def visualize_stn():
with torch.no_grad(): with torch.no_grad():
# Get a batch of training data # Get a batch of training data
data = next(iter(test_loader))[0].to(device) data = next(iter(valid_loader))[0].to(device)
input_tensor = data.cpu() input_tensor = data.cpu()
transformed_input_tensor = model.stn(data).cpu() transformed_input_tensor = model.stn(data).cpu()