Training invariances in deep nets and its consequences II: Linear network

In this second post of the series, we go through the one-line proof of a certain training invariance in linear neural networks and its consequences, in term of alignment, weight matrix rank, and margin.

The training invariance that I will cover has appeared in Arora et al. (2018), Du et al. (2019) and Ji and Telgarsky (2019). I will mainly use the approach in Ji and Telgarsky (2019) since they are rather intuitive.

Recall that from the last series, we have set up the binary classification problem with dataset \(\lbrace (x_i \in \mathcal{X}, y_i \in \lbrace -1, 1 \rbrace ) \rbrace_{i \in [n]}\) of \(n\) training examples. Fixing a nice loss function \(\ell\), we set up the empirical risk minimization (ERM) problem \(\) min{w} mathcal{R}(w) := frac{1}{n} sum { i = 1 } ^ n ell( yi f( xi ; w)) \(\) for some function class \(\lbrace f(\cdot ; w) : \mathbb{R}^d \to \mathbb{R} \mid w \in R \rbrace\) parameterized by \(w\) in some Euclidean vector space.

To simplify the math, we use gradient flow (GF) to solve the optimization problem:

\[ \frac{ d w(t) }{ d t } = - \nabla_w \mathcal{ R }(w(t)). \]

Unfamiliar readers can think of GF as the limit of the more famous gradient descent algorithm when we let step size goes to \(0\) for intuition. With discrete steps, all results in this post still holds, albeit with extra slack terms.

Linear neural network: a simple (but not that simple) function class

The state of learning theory is such that a lot is known about linear models and classifiers but few nonlinear models come with provable guarantees. This is more or less true for deep learning theory where even advance analysis for nonlinear models depend heavily on the linear counterpart (say the recently-famous neural tangent kernel analysis). As such, much effort has been made to understand linear neural networks as we now define. A \(L\)-layer linear network with \(n_i\) neurons in the \(i\)-th layer for all \(i \in [L+1]\) is parameterized by \(w = \lbrace W_k \in \mathbb{R}^{ n_{k+1} \times n_k } \rbrace_{k \in [L]}\) as

\[ f(x;w) = W_L W_{L-1} \ldots W_2 W_1 x. \]

Isn't this just a linear classifier? Yes and no. As a function class, linear neural networks are indeed the same function class as linear classifiers: they contain exactly the same function. However, what we are interested in is the trajectory/behavior analysis of gradient based algorithms. As optimization objectives, ERM under linear nets is anything but similar to ERM under linear classifier. Indeed, a linear neural net, despite its name, is not a linear function in its parameters \(w\) but a polynomial of degree \(L\). Solving ERM for linear net under square loss, for example, is an instance of polynomial optimization - an NP-complete problem; whereas solving ERM for linear classifiers under squared loss is a version of the ordinary least square (OLS) problem (had we written our loss as \(\ell(|f(x_i;w) - y_i|)\) then it would be exactly OLS).

Matrix-wise training invariance: a three-line proof for linear neural network

A mentor once told me: “Never underestimate the power of simply writing down the derivatives”, so we can start with that: for all \(k \in [ L ]\), for all time \(t\) in the training process,

\[ \begin{equation} \frac{ d W _ k } { d t } = - \nabla_{ W_k } \mathcal{ R }( w ) = - \frac { 1 } { n } \sum_{ i = 1 } ^ n y _ i \ell ’ ( y_i f ( x _ i ; w ) ) \nabla _ { W_k } f ( x _ i ; w ) = W_{ k+1 } ^ \top \ldots W _ L ^ \top A W _ 1 ^ \top \ldots W_{ k - 1 } ^ \top \end{equation} \]

where \(A :=- \nabla_{\left( W_L \ldots W_1 \right) } \mathcal{R}( f(w)) = -\frac{1}{n} \sum _ { i = 1 } ^ n y_ix_i \ell’ (y_i f(x_i))\). In the first step, we use definition of GF. In the second step, we use the smooth chain rule to differentiate through the loss function (and thus explicitly assumes that the loss is differentiable, although this also works for nondifferentiable losses). Finally, in the last step, we use the explicit formula of \(f\) to differentiate \(W_k\) away.

Here we used lots of tricks in multivariate calculus in order to be precise and formal. Unfamiliar readers can imagine all the weight matrices to be scalar and \(f(x) = w_L \ldots w_1 x\). Then differentiating \(f\) with respect to \(w_k\) for some \(k\) will just remove \(w_k\) from the product and therefore we get the form of the last expression. When the weights are matrix, we have to be careful of the ordering of the weights since matrix multiplication is not commutative.

The main observation now is that the final right hand side expression contains all the weight matrices multiplied together, except for \(W_k^\top\). One may wonder what happens if we right-multiply \(W_k^\top\) back into the gradient expression. This is exactly the second line: for all time \(t\) in the training process,

\[ \begin{equation} \frac{ d W _ k } { d t } W_k ^ \top = W_ { k+1 } ^ \top \ldots W _ L ^ \top A W _ 1 ^ \top \ldots W_{ k - 1 } ^ \top W_k ^\top = W_ { k + 1 } ^ \top\frac{ d W _ {k + 1} } { d t } \end{equation} \]

Note that to get the second equality, we just write down the first derivation (eqn (1)) for \(W_{k + 1}\) and notice the missing \(W_{k+1}^\top\) term that we can left-multiply back into the gradient expression.

Now I will admit that the third and last step requires more justifications than I am giving here if one is not familiar with multivariate calculus, but as usually, if this stumps you, imagine the weights matrices are just scalar and take the leap of faith. From equation (2), integrate both sides (via, for example, Fundamental Theorem of Calculus in each matrix entries) from \(0\) to some time \(t > 0\) and rearrange to get

\[ \begin{equation} W_k(t) W_k^\top (t) - W_{k+1}^\top(t) W_{k+1}(t) = W_k(0) W_k^\top (0) - W_{k+1}^\top(0) W_{k+1}(0) \end{equation} \]

and that is the first training invariance: the difference between self-adjoints of consecutive layers is set at initialization and remains invariant throughout training.

Notice that we proved this invariance without any assumption on the initialization of GF: in fact the only extra assumption we made was that the loss is differentiable, which is not a problem for lots of popular loss functions like the logistic loss or the square loss.

Why is this important? The remaining of the post will give some insights but here is a hint: the self-adjoint of a matrix controls its singular values, which in turn control the rank and deficiencies of the matrix.

Consequence of matrix-wise invariance in linear nets: low rank and max-margin