You can retrieve all (S[i], T[j]) pairs using range and itertools.product:
>>> indices = torch.tensor(list(product(range(0, 3), range(3, 6))))
tensor([[0, 3],
[0, 4],
[0, 5],
[1, 3],
[1, 4],
[1, 5],
[2, 3],
[2, 4],
[2, 5]])
# indices.shape = (9, 2)
From there, we construct one-hot-encodings of the indices using torch.nn.functional.one_hot:
>>> mask = one_hot(indices).float()
tensor([[[1., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.]],
[[1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0.]],
[[1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1.]],
[[0., 1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.]],
[[0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0.]],
[[0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1.]],
[[0., 0., 1., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.]],
[[0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 1., 0.]],
[[0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 1.]]])
# mask.shape = (9, 2, 6)
Finally, we compute the matrix multiplication and reshape it to the final form:
>>> (mask@ST).reshape(3, 3, 4, 1)
tensor([[[[0.7792],
[0.0095],
[1.0000],
[1.0000]],
[[0.7792],
[0.0095],
[1.0000],
[1.0000]],
[[0.7792],
[0.0095],
[1.0000],
[1.0000]]],
[[[0.1893],
[0.8159],
[1.0000],
[1.0000]],
[[0.1893],
[0.8159],
[1.0000],
[1.0000]],
[[0.1893],
[0.8159],
[1.0000],
[1.0000]]],
[[[0.0680],
[0.7194],
[1.0000],
[1.0000]],
[[0.0680],
[0.7194],
[1.0000],
[1.0000]],
[[0.0680],
[0.7194],
[1.0000],
[1.0000]]]])