Preferences

RossBencina parent
Can you suggest a good reference for understanding which algorithms map well onto the regular grid systolic arrays used by TPUs? The fine article says dese matmul and convolution are good, but is there anything else? Eigendecomposition? SVD? matrix exponential? Solving Ax = b or AX = B? Cholesky?

You can do all of these in terms of matmul to some extent:

Solving AX=B can be done with Newton's method to invert A, which boils down to matmuls.

Matrix exponential is normally done with matmuls- the scale down, Taylor/Pade and square approach.

Why do you need Cholesky? It's typically a means to an end, and when matmul is your primitive, you reach for it much less often.

Eigendecomposition is hard. If we limit ourselves to symmetric, we could use a blocked Jacobi algorithm where we run a non-matmul Jacobi to do 128x128 off-diagonal blocks and then use the matmul unit to apply to the whole matrix- for large enough matrices, still bottlenecked on matmul.

SVD we can get from Polar decomposition, which has purely-matmul iterations, and symmetric eigendecomposition.

One does have to watch out for numerical stability and precision very carefully when doing all these!

cdavid
SVD/eigendecomposition will often boil down to making many matmul (e.g. when using Krylov-based methods, e.g. Arnoldi, Krylov-schur, etc.), so I would expect TPU to work well there. GMRES, one method to solve Ax = b is also based on Arnoldi decomp.
musebox35
I think https://jax-ml.github.io/scaling-book/ is one of the best references to go through. It details how single device and distributed computations map to TPU hardware features. The emphasis is on mapping the transformer computations, both forwards and backwards, so requires some familiarity with how transformer networks are structured.
WithinReason
Anything that you can express as 128x128 (but ideally much larger) dense matrix multiplication and nothing else

This item has no comments currently.