FIX: valid_loader instead of test_loader
This commit is contained in:
parent
62df7464e4
commit
b8f7f922f1
1 changed files with 1 additions and 1 deletions
|
|
@ -171,7 +171,7 @@ if __name__ == "__main__":
|
|||
def visualize_stn():
|
||||
with torch.no_grad():
|
||||
# Get a batch of training data
|
||||
data = next(iter(test_loader))[0].to(device)
|
||||
data = next(iter(valid_loader))[0].to(device)
|
||||
|
||||
input_tensor = data.cpu()
|
||||
transformed_input_tensor = model.stn(data).cpu()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue