diff --git a/model/stn.py b/model/stn.py index e1e7199..0fcf786 100644 --- a/model/stn.py +++ b/model/stn.py @@ -171,7 +171,7 @@ if __name__ == "__main__": def visualize_stn(): with torch.no_grad(): # Get a batch of training data - data = next(iter(test_loader))[0].to(device) + data = next(iter(valid_loader))[0].to(device) input_tensor = data.cpu() transformed_input_tensor = model.stn(data).cpu()