Discovering faster matrix multiplication algorithms with reinforcement learning

TensorGame

TensorGame is played as follows. The start position ({{mathscr{S}}}_{0}) of the game corresponds to the tensor ({mathscr{T}}) representing the bilinear operation of interest, expressed in some basis. In each step t of the game, the player writes down three vectors (u(t), v(t), w(t)), which specify the rank-1 tensor u(t)v(t)w(t), and the state of the game is updated by subtracting the newly written down factor:

$${{mathscr{S}}}_{t}leftarrow {{mathscr{S}}}_{t-1}-{{bf{u}}}^{
(2)

The game ends when the state reaches the zero tensor, ({{mathscr{S}}}_{R}={bf{0}}). This means that the factors written down throughout the game form a factorization of the start tensor ({{mathscr{S}}}_{0}), that is, ({{mathscr{S}}}_{0}={sum }_{t=1}^{R}{{bf{u}}}^{
(3)

Hence, exhibiting a rank-R decomposition of the matrix multiplication tensor ({{mathscr{T}}}_{n}) expressed in any basis proves that the product of two n × n matrices can be computed using R scalar multiplications. Moreover, it is straightforward to convert such a rank-R decomposition into a rank-R decomposition in the canonical basis, thus yielding a practical algorithm of the form shown in Algorithm 1. We leverage this observation by expressing the matrix multiplication tensor ({{mathscr{T}}}_{n}) in a large number of randomly generated bases (typically 100,000) in addition to the canonical basis, and letting AlphaTensor play games in all bases in parallel.

This approach has three appealing properties: (1) it provides a natural exploration mechanism as playing games in different bases automatically injects diversity into the games played by the agent; (2) it exploits properties of the problem as the agent need not succeed in all bases—it is sufficient to find a low-rank decomposition in any of the bases; (3) it enlarges coverage of the algorithm space because a decomposition with entries in a finite set F = {−2, −1, 0, 1, 2} found in a different basis need not have entries in the same set when converted back into the canonical basis.

In full generality, a basis change for a 3D tensor of size S × S × S is specified by three invertible S × S matrices A, B and C. However, in our procedure, we sample bases at random and impose two restrictions: (1) A = B = C, as this performed better in early experiments, and (2) unimodularity ((det {bf{A}}in {-1,+1})), which ensures that after converting an integral factorization into the canonical basis it still contains integer entries only (this is for representational convenience and numerical stability of the resulting algorithm). See Supplementary Information for the exact algorithm.

Signed permutations

In addition to playing (and training on) games in different bases, we also utilize a data augmentation mechanism whenever the neural network is queried in a new MCTS node. At acting time, when the network is queried, we transform the input tensor by applying a change of basis—where the change of basis matrix is set to a random signed permutation. We then query the network on this transformed input tensor, and finally invert the transformation in the network’s policy predictions. Although this data augmentation procedure can be applied with any generic change of basis matrix (that is, it is not restricted to signed permutation matrices), we use signed permutations mainly for computational efficiency. At training time, whenever the neural network is trained on an (input, policy targets, value target) triplet (Fig. 2), we apply a randomly chosen signed permutation to both the input and the policy targets, and train the network on this transformed triplet. In practice, we sample 100 signed permutations at the beginning of an experiment, and use them thereafter.

Action canonicalization

For any λ1, λ2, λ3 {−1, +1} such that λ1λ2λ3 = 1, the actions (λ1u, λ2v, λ3w) and (u, v, w) are equivalent because they lead to the same rank-one tensor (λ1u)  (λ2v)  (λ3w) = uvw. To prevent the network from wasting capacity on predicting multiple equivalent actions, during training we always present targets (u, v, w) for the policy head in a canonical form, defined as having the first non-zero element of u and the first non-zero element of v strictly positive. This is well defined because u or v cannot be all zeros (if they are to be part of a minimal rank decomposition), and for any (u, v, w) there are unique λ1, λ2, λ3 {−1, +1} (with λ1λ2λ3 = 1) that transform it into canonical form. In case the network predicts multiple equivalent actions anyway, we merge them together (summing their empirical policy probabilities) before inserting them into the MCTS tree.

Training regime

We train AlphaTensor on a TPU v3, with a total batch size of 2,048. We use 64 TPU cores, and train for 600,000 iterations. On the actor side, the games are played on standalone TPU v4, and we use 1,600 actors. In practice, the procedure  takes a week to converge.

Neural network

The architecture is composed of a torso, followed by a policy head that predicts a distribution over actions, and a value head that predicts a distribution of the returns from the current state (see Extended Data Fig. 3).

Input

The input to the network contains all the relevant information of the current state and is composed of a list of tensors and a list of scalars. The most important piece of information is the current 3D tensor ({{mathscr{S}}}_{t}) of size S × S × S. (For simplicity, in the description here we assume that all the three dimensions of the tensor are equal in size. The generalization to different sizes is straightforward.) In addition, the model is given access to the last h actions (h being a hyperparameter usually set to 7), represented as h rank-1 tensors that are concatenated to the input. The list of scalars includes the time index t of the current action (where 0 ≤ t < Rlimit).

Torso

The torso of the network is in charge of mapping both scalars and tensors from the input to a representation that is useful to both policy and value heads. Its architecture is based on a modification of transformers23, and its main signature is that it operates over three S × S grids projected from the S × S × S input tensors. Each grid represents two out of the three modes of the tensor. Defining the modes of the tensor as ({mathcal{U}},{mathcal{V}},{mathcal{W}}), the rows and columns of the first grid are associated to ({mathcal{U}}) and ({mathcal{V}}), respectively, the rows and columns of the second grid are associated to ({mathcal{W}}) and ({mathcal{U}}), and the rows and columns of the third grid are associated to ({mathcal{V}}) and ({mathcal{W}}). Each element of each grid is a feature vector, and its initial value is given by the elements of the input tensors along the grid’s missing mode. These feature vectors are enriched by concatenating an S × S × 1 linear projection from the scalars. This is followed by a linear layer projecting these feature vectors into a 512-dimensional space.

The rest of the torso is a sequence of attention-based blocks with the objective of propagating information between the three grids. Each of those blocks has three stages, one for every pair of grids. In each stage, the grids involved are concatenated, and axial attention24 is performed over the columns. It is noted that in each stage we perform in parallel S self-attention operations of 2S elements in each. The representation sent to the policy head corresponds to the 3S512-dimensional feature vectors produced by the last layer of the torso. A detailed description of the structure of the torso is specified in Extended Data Fig. 4 (top) and Appendix A.1.1 in Supplementary Information.

Policy head

The policy head uses the transformer architecture23 to model an autoregressive policy. Factors are decomposed into k tokens of dimensionality d such that k × d = 3S. The transformer conditions on the tokens already generated and cross-attends to the features produced by the torso. At training time, we use teacher-forcing, that is, the ground truth actions are decomposed into tokens and taken as inputs into the causal transformer in such a way that the prediction of a token depends only on the previous tokens. At inference time, K actions are sampled from the head. The feature representation before the last linear layer of the initial step (that is, the only step that is not conditioned on the ground truth) is used as an input to the value head, described below. Details of the architecture are presented in Extended Data Fig. 4 (centre) and Appendix A.1.2 in Supplementary Information.

Value head

The value head is composed of a four-layer multilayer perceptron whose last layer produces q outputs corresponding to the (frac{1}{2q},frac{3}{2q},ldots frac{2q-1}{2q}) quantiles. In this way, the value head predicts the distribution of returns from this state in the form of values predicted for the aforementioned quantiles34. At inference time, we encourage the agent to be risk-seeking by using the average of the predicted values for quantiles over 75%. A detailed description of the value head is presented in Extended Data Fig. 4 (bottom) and Appendix A.1.3 in Supplementary Information.

Related work

The quest for efficient matrix multiplication algorithms started with Strassen’s breakthrough in ref. 2, which showed that one can multiply 2 × 2 matrices using 7 scalar multiplications, leading to an algorithm of complexity ({mathcal{O}}({n}^{2.81})). This led to the development of a very active field of mathematics attracting worldwide interest, which studies the asymptotic complexity of matrix multiplication (see refs. 3,4,5,6). So far, the best known complexity for matrix multiplication is ({mathcal{O}}({n}^{2.37286})) (ref. 12), which improves over ref. 11, and builds on top of fundamental results in the field8,9,10. However, this does not yield practical algorithms, as such approaches become advantageous only for astronomical matrix sizes. Hence, a significant body of work aims at exhibiting explicit factorizations of matrix multiplication tensors, as these factorizations provide practical algorithms. After Strassen’s breakthrough showing that (text{rank},({{mathscr{T}}}_{2})le 7), efficient algorithms for larger matrix sizes were found15,16,18,26,38. Most notably, Laderman showed in ref. 15 that 3 × 3 matrix multiplications can be performed with 23 scalar multiplications. In addition to providing individual low-rank factorizations, an important research direction aims at understanding the space of matrix multiplication algorithms—as opposed to exhibiting individual low-rank factorizations—by studying the symmetry groups and diversity of factorizations (see ref. 5 and references therein). For example, the symmetries of 2 × 2 matrix multiplication were studied in refs. 39,40,41,42, where Strassen’s algorithm was shown to be essentially unique. The case of 3 × 3 was studied in ref. 43, whereas a symmetric factorization for all n is provided in ref. 44.

On the computational front, continuous optimization has been the main workhorse for decomposing tensors17,45,46, and in particular matrix multiplication tensors. Such continuous optimization procedures (for example, alternating least squares), however, yield approximate solutions, which correspond to inexact matrix multiplication algorithms with floating point operations. To circumvent this issue, regularization procedures have been proposed, such as ref. 18, to extract exact decompositions. Unfortunately, such approaches often require  substantial human intervention and expertise to decompose large tensors. A different line of attack was explored in refs. 47,48, based on learning the continuous weights of a two-layer network that mimics the structure of the matrix multiplication operation. This method, which is trained through supervised learning of matrix multiplication examples, finds approximate solutions to 2 × 2 and 3 × 3 matrix multiplications. In ref. 48, a quantization procedure is further used to obtain an exact decomposition for 2 × 2. Unlike continuous optimization-based approaches, AlphaTensor directly produces algorithms from the desired set of valid algorithms, and is flexible in that it allows us to optimize a wide range of (even non-differentiable) objectives. This unlocks tackling broader settings (for example, optimization in finite fields, optimization of runtime), as well as larger problems (for example, ({{mathscr{T}}}_{4}) and ({{mathscr{T}}}_{5})) than those previously considered. Different from continuous optimization, a boolean satisfiability (SAT) based  formulation of the problem of decomposing 3 × 3 matrix multiplication was recently proposed in ref. 20, which adds thousands of new decompositions of rank 23 to the list of known 3 × 3 factorizations. The approach relies on a state-of-the-art SAT solving procedure, where several assumptions and simplifications are made on the factorizations to reduce the search space. As is, this approach is, however, unlikely to scale to larger tensors, as the search space grows very quickly with the size.

On the practical implementation front, ref. 31 proposed several ideas to speed up implementation of fast matrix multiplication algorithms on central processing units (CPUs). Different fast algorithms are then compared and benchmarked, and the potential speed-up of such algorithms is shown against standard multiplication. Other works focused on getting the maximal performance out of a particular fast matrix multiplication algorithm (Strassen’s algorithm with one or two levels of recursion) on a CPU32 or a GPU49. These works show that, despite popular belief, such algorithms are of practical value. We see writing a custom low-level implementation of a given algorithm to be distinct from the focus of this paper—developing new efficient algorithms—and we believe that the algorithms we discovered can further benefit from a more efficient implementation by experts.

Beyond matrix multiplication and bilinear operations, a growing amount of research studies the use of optimization and machine learning to improve the efficiency of computational operations. There are three levels of abstractions at which this can be done: (1) in the hardware design, for example, chip floor planning50, (2) at the hardware–software interface, for example, program super-optimization of a reference implementation for specific hardware51, and (3) on the algorithmic level, for example, program induction52, algorithm selection53 or meta-learning54. Our work focuses on the algorithmic level of abstraction, although AlphaTensor is also flexible to discover efficient algorithms for specific hardware. Different from previous works, we focus on discovering matrix multiplication algorithms that are provably correct, without requiring initial reference implementations. We conclude by relating our work broadly to existing reinforcement learning methods for scientific discovery. Within mathematics, reinforcement learning was applied, for example, to theorem proving55,56,57,58, and to finding counterexamples refuting conjectures in combinatorics and graph theory59. Reinforcement learning was further shown to be useful in many areas in science, such as molecular design60,61 and synthesis62 and optimizing quantum dynamics63.

Leave a Comment