8

I have a 2D tensor with some nonzero element in each row like this:

import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)

I want a tensor containing the index of first nonzero element in each row:

indices = tensor([2],
                 [3])

How can I calculate it in Pytorch?

Hichem BOUSSETTA
  • 1,733
  • 1
  • 21
  • 25
Iman Aliabdi
  • 131
  • 1
  • 8

3 Answers3

11

I have simplified Iman's approach to do the following:

idx = torch.arange(tmp.shape[1], 0, -1)
tmp2= tmp * idx
indices = torch.argmax(tmp2, 1, keepdim=True)
Ash
  • 4,353
  • 5
  • 27
  • 37
Shay
  • 111
  • 1
  • 2
5

I could find a tricky answer for my question:

  tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                     [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
  idx = reversed(torch.Tensor(range(1,8)))
  print(idx)

  tmp2= torch.einsum("ab,b->ab", (tmp, idx))

  print(tmp2)

  indices = torch.argmax(tmp2, 1, keepdim=True)
  print(indeces)

The result is:

tensor([7., 6., 5., 4., 3., 2., 1.])
tensor([[0., 0., 5., 0., 3., 0., 0.],
       [0., 0., 0., 4., 3., 0., 0.]])
tensor([[2],
        [3]])
Iman Aliabdi
  • 131
  • 1
  • 8
1

All the nonzero values are equal, so argmax returns the first index.

tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]])
indices = tmp.argmax(1)
Seppo Enarvi
  • 2,860
  • 3
  • 28
  • 25