diff --git a/train.py b/train.py index 5dd5958..b79dc6e 100644 --- a/train.py +++ b/train.py @@ -148,7 +148,7 @@ def worker(rank: int, args: Namespace): model = model.cuda() # criterion, optimizer, scheduler - criterion = nn.L1Loss(size_average=False) + criterion = nn.L1Loss(reduction="sum") if device is not None: criterion = criterion.to(device) elif torch.cuda.is_available():