Discovering faster matrix multiplication algorithms with reinforcement learning

TensorGame

TensorGame is played as follows. The start position ({{mathscr{S}}}_{0}) of the game corresponds to the tensor ({mathscr{T}}) representing the bilinear operation of interest, expressed in some basis. In each step t of the game, the player writes down three vectors (u(t), v(t), w(t)), which specify the rank-1 tensor u(t)v(t)w(t), and the state of the game is updated by subtracting the newly written down factor:

{{mathscr{S}}}_{t}leftarrow {{mathscr{S}}}_{t-1}-{{bf{u}}}^{
(2)

The game ends when the state reaches the zero tensor, ({{mathscr{S}}}_{R}={bf{0}}). This means that the factors written down throughout the game form a factorization of the start tensor ({{mathscr{S}}}_{0}), that is, ({{mathscr{S}}}_{0}={sum }_{t=1}^{R}{{bf{u}}}^{
(3)

Hence, exhibiting a rank-R decomposition of the matrix multiplication tensor ({{mathscr{T}}}_{n}) expressed in any basis proves that the product of two n × n matrices can be computed using R scalar multiplications. Moreover, it is straightforward to convert such a rank-R decomposition into a rank-R decomposition in the canonical basis, thus yielding a practical algorithm of the form shown in Algorithm 1. We leverage this observation by expressing the matrix multiplication tensor ({{mathscr{T}}}_{n}) in a large number of randomly generated bases (typically 100,000) in addition to the canonical basis, and letting AlphaTensor play games in all bases in parallel.

This approach has three appealing properties: (1) it provides a natural exploration mechanism as playing games in different bases automatically injects diversity into the games played by the agent; (2) it exploits properties of the problem as the agent need not succeed in all bases—it is sufficient to find a low-rank decomposition in any of the bases; (3) it enlarges coverage of the algorithm space because a decomposition with entries in a finite set F = {−2, −1, 0, 1, 2} found in a different basis need not have entries in the same set when converted back into the canonical basis.

In full generality, a basis change for a 3D tensor of size S × S × S is specified by three invertible S × S matrices A, B and C. However, in our procedure, we sample bases at random and impose two restrictions: (1) A = B = C, as this performed better in early experiments, and (2) unimodularity ((det {bf{A}}in {-1,+1})), which ensures that after converting an integral factorization into the canonical basis it still contains integer entries only (this is for representational convenience and numerical stability of the resulting algorithm). See Supplementary Information for the exact algorithm.

Signed permutations

In addition to playing (and training on) games in different bases, we also utilize a data augmentation mechanism whenever the neural network is queried in a new MCTS node. At acting time, when the network is queried, we transform the input tensor by applying a change of basis—where the change of basis matrix is set to a random signed permutation. We then query the network on this transformed input tensor, and finally invert the transformation in the network’s policy predictions. Although this data augmentation procedure can be applied with any generic change of basis matrix (that is, it is not restricted to signed permutation matrices), we use signed permutations mainly for computational efficiency. At training time, whenever the neural network is trained on an (input, policy targets, value target) triplet (Fig. 2), we apply a randomly chosen signed permutation to both the input and the policy targets, and train the network on this transformed triplet. In practice, we sample 100 signed permutations at the beginning of an experiment, and use them thereafter.

Action canonicalization

For any λ1, λ2, λ3 {−1, +1} such that λ1λ2λ3 = 1, the actions (λ1u, λ2v, λ3w) and (u, v, w) are equivalent because they lead to the same rank-one tensor (λ1u)  (λ2v)  (λ3w) = uvw. To prevent the network from wasting capacity on predicting multiple equivalent actions, during training we always present targets (u, v, w) for the policy head in a canonical form, defined as having the first non-zero element of u and the first non-zero element of v strictly positive. This is well defined because u or v cannot be all zeros (if they are to be part of a minimal rank decomposition), and for any (u, v, w) there are unique λ1, λ2, λ3 {−1, +1} (with λ1λ2λ3 = 1) that transform it into canonical form. In case the network predicts multiple equivalent actions anyway, we merge them together (summing their empirical policy probabilities) before inserting them into the MCTS tree.

Training regime

We train AlphaTensor on a TPU v3, with a total batch size of 2,048. We use 64 TPU cores, and train for 600,000 iterations. On the actor side, the games are played on standalone TPU v4, and we use 1,600 actors. In practice, the procedure  takes a week to converge.

Neural network

The architecture is composed of a torso, followed by a policy head that predicts a distribution over actions, and a value head that predicts a distribution of the returns from the current state (see Extended Data Fig. 3).

Input

The input to the network contains all the relevant information of the current state and is composed of a list of tensors and a list of scalars. The most important piece of information is the current 3D tensor ({{mathscr{S}}}_{t}) of size S × S × S. (For simplicity, in the description here we assume that all the three dimensions of the tensor are equal in size. The generalization to different sizes is straightforward.) In addition, the model is given access to the last h actions (h being a hyperparameter usually set to 7), represented as h rank-1 tensors that are concatenated to the input. The list of scalars includes the time index t of the current action (where 0 ≤ t < Rlimit).

Torso

The torso of the network is in charge of mapping both scalars and tensors from the input to a representation that is useful to both policy and value heads. Its architecture is based on a modification of transformers23, and its main signature is that it operates over three S × S grids projected from the S × S × S input tensors. Each grid represents two out of the three modes of the tensor. Defining the modes of the tensor as ({mathcal{U}},{mathcal{V}},{mathcal{W}}), the rows and columns of the first grid are associated to ({mathcal{U}}) and ({mathcal{V}}), respectively, the rows and columns of the second grid are associated to ({mathcal{W}}) and ({mathcal{U}}), and the rows and columns of the third grid are associated to ({mathcal{V}}) and ({mathcal{W}}). Each element of each grid is a feature vector, and its initial value is given by the elements of the input tensors along the grid’s missing mode. These feature vectors are enriched by concatenating an S × S × 1 linear projection from the scalars. This is followed by a linear layer projecting these feature vectors into a 512-dimensional space.

The rest of the torso is a sequence of attention-based blocks with the objective of propagating information between the three grids. Each of those blocks has three stages, one for every pair of grids. In each stage, the grids involved are concatenated, and axial attention24 is performed over the columns. It is noted that in each stage we perform in parallel S self-attention operations of 2S elements in each. The representation sent to the policy head corresponds to the 3S512-dimensional feature vectors produced by the last layer of the torso. A detailed description of the structure of the torso is specified in Extended Data Fig. 4 (top) and Appendix A.1.1 in Supplementary Information.

The policy head uses the transformer architecture23 to model an autoregressive policy. Factors are decomposed into k tokens of dimensionality d such that k × d = 3S. The transformer conditions on the tokens already generated and cross-attends to the features produced by the torso. At training time, we use teacher-forcing, that is, the ground truth actions are decomposed into tokens and taken as inputs into the causal transformer in such a way that the prediction of a token depends only on the previous tokens. At inference time, K actions are sampled from the head. The feature representation before the last linear layer of the initial step (that is, the only step that is not conditioned on the ground truth) is used as an input to the value head, described below. Details of the architecture are presented in Extended Data Fig. 4 (centre) and Appendix A.1.2 in Supplementary Information.