I recently came across this Exception when training a model with PyTorchand calculating the loss using NLLLoss function (NLLLoss).
RuntimeError: multi-target not supported at /pytorch/aten/src/THNN/generic/ClassNLLCriterion.c:20
This exception is because I provided the wrong dimensions for the labels passed to the Criteria (NLLLoss) function. from the documentation I find that if the input to NLLLoss function is N x C (where N is the batch size, and C is the number of classifications), then the target (second argument) should be of size N (one class for each batch element).
RuntimeError: multi-target not supported at /pytorch/aten/src/THNN/generic/ClassNLLCriterion.c:20
This exception is because I provided the wrong dimensions for the labels passed to the Criteria (NLLLoss) function. from the documentation I find that if the input to NLLLoss function is N x C (where N is the batch size, and C is the number of classifications), then the target (second argument) should be of size N (one class for each batch element).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Valid code | |
m = nn.LogSoftmax(dim=1) | |
loss = nn.NLLLoss() | |
# input is of size N x C = 3 x 5 | |
input = torch.randn(3, 5, requires_grad=True) | |
# each element in target has to have 0 <= value < C | |
# Correct dimension - one value for each row (N). | |
target = torch.tensor([0, 0, 4]) | |
output = loss(m(input), target) | |
output.backward() | |
#Error code | |
m = nn.LogSoftmax(dim=1) | |
loss = nn.NLLLoss() | |
# input is of size N x C = 3 x 5 | |
input = torch.randn(3, 5, requires_grad=True) | |
# each element in target has to have 0 <= value < C | |
# WRONG - target is a matrix of same size (usually one hot encoded) | |
target = torch.randn(3, 5).long() | |
output = loss(m(input), target) | |
print(m(input).shape) | |
print(target.shape) | |
print(output.shape) | |
output.backward() | |
#Error thrown | |
RuntimeError: multi-target not supported at /pytorch/aten/src/THNN/generic/ClassNLLCriterion.c:20 |
No comments:
Post a Comment