Skip to content

RuntimeError: one_hot is only applicable to index tensor. #454

Answered by MKaczkow
talhaanwarch asked this question in Q&A
Discussion options

You must be logged in to vote

For anyone looking for answer: gt input tensor (y_pred) must have type torch.LongTensor (or torch.cuda.LongTensor). If not, this line:

y_true = F.one_hot(y_true, num_classes)  # N,H*W -> N,H*W, C

from smp.losses module fails. This is also described on SO. Though it's unclear for me why LongTensor is specifically required in torch.nn.functional.one_hot (any integer type would do, I guess?), the docs state it is and .long() solves problem.

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
3 replies
@talhaanwarch
Comment options

@talhaanwarch
Comment options

@ljb-1
Comment options

Comment options

You must be logged in to vote
1 reply
@MKaczkow
Comment options

Answer selected by qubvel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants