10

I noticed that libraries like numpy and pytorch are able to perform arbitrary tensor contractions at speeds similar to comparably sized matrix multiplications. This leads me to believe that underneath they somehow express these operations in terms of BLAS routines.

However, I can't think of a general algorithm to do so, especially for arbitrary tensor contractions. This is further complicated if one wanted to avoid tensor transpositions. I'm curious how this is done.

Anton Menshov
  • 8,672
  • 7
  • 38
  • 94
ilya
  • 111
  • 4
  • 5
    Are you referring to the einsum function in numpy? It operates by computing a sequence of basic tensor operations and optional intermediates required to perform the complex operation specified. You can have numpy tell you the operation sequence using einsum_path – helloworld922 Dec 30 '23 at 20:01
  • i'm referring to numpy.tensordot function... it has transpose in its source, but it doesn't look like transposing into the new shape adds ANY computational overhead compared to a matmul, since the numpy.tensordot takes abt as long as a similarly sized matmul, so i'm wondering what's the trick – ilya Dec 30 '23 at 21:22
  • 3
    Numpy stores transposed arrays as a flag, so transpose is O(1). From there you just need an efficient dot product function which handles both memory layouts. – helloworld922 Dec 31 '23 at 02:18
  • could you recommend any resources/write-ups where i could read more about how to do smth like that? – ilya Dec 31 '23 at 11:11

1 Answers1

4

You can use reshape and pointwise multiply to reduce the operation to a matmul in terms of two temporary tensors $c_1,c_2$. Consider following transformations:

  1. pointwise mul to turn $a_i * b_i$ into $c_i$
  2. outer product to turn $a_i * b_j$ into $c_{ij}$
  3. reshape to turn $\sum_{ij} a_{ijk}$ into $\sum_{l}c_{lk}$.

With these operations you can define temporary tensors $a,b$ such that final einsum is of the form $\sum_{ijk}a_{ij}b_{jk}$ (matmul).

Basically, out of the indices which are remain uncontracted, you select some of them to be "left" indices and reshape, that's your $i$ index. Then the rest are your "right" indices, ie $k$. Use reshape to turn the contracted indices into $j$. This determines dimensions of your final matmul.

There's more than one way of doing this, and some are more efficient than others. Existing implementations heuristically pick an order for some common einsums to be a good matmul.

There's a Python package (pip install opt_einsum) which can compute an optimal schedule and print it out.

For instance, you can get the optimal sequence of BLAS operations as follows for einsum $ij,kl,iq,kp$ where all dimensions are size 2:

                                                                                            import torch
import numpy as np
import opt_einsum as oe

einsum_string="ij,kl,iq,kp->ljpq" views = oe.helpers.build_views(einsum_string, {c: 2 for c in einsum_string}) path, path_info = oe.contract_path(einsum_string, *views, optimize='dp') print(path_info)

You should see something like this

Complete contraction:  ij,kl,iq,kp->ljpq
         Naive scaling:  6
     Optimized scaling:  4
      Naive FLOP count:  2.560e+2
  Optimized FLOP count:  4.800e+1
   Theoretical speedup:  5.333
  Largest intermediate:  1.600e+1 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   3           GEMM              iq,ij->qj                        kl,kp,qj->ljpq
   3           GEMM              kp,kl->pl                           qj,pl->ljpq
   4   OUTER/EINSUM            pl,qj->ljpq                            ljpq->ljpq

Note that this einsum optimizer finds a schedule which uses two GEMMs followed by an outer product, which is better than what PyTorch/numpy would do by default, which would be to use a single GEMM

See my colab notebook for runnable examples.

For explanations of the theory behind optimal einsum computation, see my write-up on Wolfram Community forums: Tensor networks, Einsum optimizers and the path towards autodiff

Yaroslav Bulatov
  • 2,655
  • 11
  • 23
  • For tensor contractions like 'ijk,jl->ilk' it still schedules one tensordot call. Do you know how numpy.tensordot works internally in this case? – Vladimir Lysikov Jan 02 '24 at 18:00
  • @VladimirLysikov that's interesting, it writes "TDOT" under "BLAS" column, but it's not a blas routine is it? – Yaroslav Bulatov Jan 02 '24 at 18:27
  • No, it's not a BLAS routine. I assume it stays for tensordot, when I search "TDOT" in the documentation of opt_einsum, it does not give anything useful. – Vladimir Lysikov Jan 02 '24 at 18:30
  • OK, I'm assuming tensordot is small variation of matmul like batch-matmul. Note that you could transform this contraction to ijk,jlk->ilk by first adding k dimension to second argument (np.broadcast_to(b.reshape((2,2,1)),a.shape). This new contraction is equivalent to doing k GEMM calls in parallel – Yaroslav Bulatov Jan 02 '24 at 18:34
  • Yes. It's also possible to first make a transpose ijk->kij and then use GEMM. It is interesting what Numpy actually does in this case. The internals are too confusing for me. – Vladimir Lysikov Jan 02 '24 at 18:38
  • @VladimirLysikov OK, from this change, it appears that's exactly the use-case for TDOT -- for operations that are one transpose away from matmul – Yaroslav Bulatov Jan 02 '24 at 18:52
  • thank you for the reply :) i’m still a little confused wrt how you’d do those steps for a tensor contraction where the contracted dimensions aren’t at the “inner” ends of tensors (like (2,3,4) and (3,4,5)). For example, if i wanted to contract tensors w/ dimensions (2,3,5) and (2,7,11), or maybe (2,3,5) and (7,2,11), I don’t understand how to reshape/flatten them to do it in a single gemm, without transposing . – ilya Jan 02 '24 at 20:44
  • @ilya you are allowed to transpose , optional transpose is part of GEMM https://oneapi-src.github.io/oneMKL/domains/blas/gemm.html#onemkl-blas-gemm . Also, reshaping can merge arbitrary indices, not just adjacent ones (tensors are stored as linear arrays, so shape just determines the strides you take) – Yaroslav Bulatov Jan 02 '24 at 20:59
  • @YaroslavBulatov but iiuc for numpy/pt transposing isn’t actually moving any data around, it’s just a flag that they store so it’s a O(1) operation, so they must be doing smth special inside the GEMM step… do you by chance know what that is? – ilya Jan 02 '24 at 23:35
  • @ilya are you asking how GEMM implements transposed matmul? I don't know the details, but AB and AB^T should have similar (atomic) implementations inside GEMM, either operation can be done as a nested loop, just the indexing is different. – Yaroslav Bulatov Jan 02 '24 at 23:56
  • @YaroslavBulatov ig overall my question is rather about the best way to keep track of transpositions i'd have to remember abt later on (during gemm or afterwards when i reshape the resulting matrix into the correct tensor), bc when you merge/flatten non-neighbouring dimensions, you have to later transpose them back, bc otherwise it would be a different tensor... – ilya Jan 03 '24 at 01:13