You can use reshape and pointwise multiply to reduce the operation to a matmul in terms of two temporary tensors $c_1,c_2$. Consider following transformations:
- pointwise mul to turn $a_i * b_i$ into $c_i$
- outer product to turn $a_i * b_j$ into $c_{ij}$
- reshape to turn $\sum_{ij} a_{ijk}$ into $\sum_{l}c_{lk}$.
With these operations you can define temporary tensors $a,b$ such that final einsum is of the form $\sum_{ijk}a_{ij}b_{jk}$ (matmul).
Basically, out of the indices which are remain uncontracted, you select some of them to be "left" indices and reshape, that's your $i$ index. Then the rest are your "right" indices, ie $k$. Use reshape to turn the contracted indices into $j$. This determines dimensions of your final matmul.
There's more than one way of doing this, and some are more efficient than others. Existing implementations heuristically pick an order for some common einsums to be a good matmul.
There's a Python package (pip install opt_einsum) which can compute an optimal schedule and print it out.
For instance, you can get the optimal sequence of BLAS operations as follows for einsum $ij,kl,iq,kp$ where all dimensions are size 2:
import torch
import numpy as np
import opt_einsum as oe
einsum_string="ij,kl,iq,kp->ljpq"
views = oe.helpers.build_views(einsum_string, {c: 2 for c in einsum_string})
path, path_info = oe.contract_path(einsum_string, *views, optimize='dp')
print(path_info)
You should see something like this
Complete contraction: ij,kl,iq,kp->ljpq
Naive scaling: 6
Optimized scaling: 4
Naive FLOP count: 2.560e+2
Optimized FLOP count: 4.800e+1
Theoretical speedup: 5.333
Largest intermediate: 1.600e+1 elements
--------------------------------------------------------------------------------
scaling BLAS current remaining
--------------------------------------------------------------------------------
3 GEMM iq,ij->qj kl,kp,qj->ljpq
3 GEMM kp,kl->pl qj,pl->ljpq
4 OUTER/EINSUM pl,qj->ljpq ljpq->ljpq
Note that this einsum optimizer finds a schedule which uses two GEMMs followed by an outer product, which is better than what PyTorch/numpy would do by default, which would be to use a single GEMM
See my colab notebook for runnable examples.
For explanations of the theory behind optimal einsum computation, see my write-up on Wolfram Community forums: Tensor networks, Einsum optimizers and the path towards autodiff
einsumfunction in numpy? It operates by computing a sequence of basic tensor operations and optional intermediates required to perform the complex operation specified. You can have numpy tell you the operation sequence usingeinsum_path– helloworld922 Dec 30 '23 at 20:01O(1). From there you just need an efficient dot product function which handles both memory layouts. – helloworld922 Dec 31 '23 at 02:18