FIX: valid_loader instead of test_loader
This commit is contained in:
parent
62df7464e4
commit
b8f7f922f1
1 changed files with 1 additions and 1 deletions
|
|
@ -171,7 +171,7 @@ if __name__ == "__main__":
|
||||||
def visualize_stn():
|
def visualize_stn():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Get a batch of training data
|
# Get a batch of training data
|
||||||
data = next(iter(test_loader))[0].to(device)
|
data = next(iter(valid_loader))[0].to(device)
|
||||||
|
|
||||||
input_tensor = data.cpu()
|
input_tensor = data.cpu()
|
||||||
transformed_input_tensor = model.stn(data).cpu()
|
transformed_input_tensor = model.stn(data).cpu()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue