Brought to you by:
Paper The following article is Free article

Self-consistent dynamical field theory of kernel evolution in wide neural networks*

and

Published 15 November 2023 © 2023 IOP Publishing Ltd and SISSA Medialab srl
, , Machine Learning 2023 Citation Blake Bordelon and Cengiz Pehlevan J. Stat. Mech. (2023) 114009 DOI 10.1088/1742-5468/ad01b0

1742-5468/2023/11/114009

Abstract

We analyze feature learning in infinite-width neural networks trained with gradient flow through a self-consistent dynamical field theory. We construct a collection of deterministic dynamical order parameters which are inner-product kernels for hidden unit activations and gradients in each layer at pairs of time points, providing a reduced description of network activity through training. These kernel order parameters collectively define the hidden layer activation distribution, the evolution of the neural tangent kernel (NTK), and consequently, output predictions. We show that the field theory derivation recovers the recursive stochastic process of infinite-width feature learning networks obtained by Yang and Hu with tensor programs. For deep linear networks, these kernels satisfy a set of algebraic matrix equations. For nonlinear networks, we provide an alternating sampling procedure to self-consistently solve for the kernel order parameters. We provide comparisons of the self-consistent solution to various approximation schemes including the static NTK approximation, gradient independence assumption, and leading order perturbation theory, showing that each of these approximations can break down in regimes where general self-consistent solutions still provide an accurate description. Lastly, we provide experiments in more realistic settings which demonstrate that the loss and kernel dynamics of convolutional neural networks at fixed feature learning strength are preserved across different widths on a image classification task.

Export citation and abstract BibTeX RIS

1. Introduction

Deep learning has emerged as a successful paradigm for solving challenging machine learning and computational problems across a variety of domains [1, 2]. However, theoretical understanding of the training and generalization of modern deep learning methods lags behind current practice. Ideally, a theory of deep learning would be analytically tractable, efficiently computable, capable of predicting network performance and internal features that the network learns, and interpretable through a reduced description involving desirably initialization-independent quantities.

Several recent theoretical advances have fruitfully considered the idealization of wide neural networks, where the number of hidden units in each layer is taken to be large. Under certain parameterization, Bayesian neural networks and gradient descent (GD) trained networks converge to gaussian processes (NNGPs) [35] and neural tangent kernel (NTK) machines [68] in their respective infinite-width limits. These limits provide both analytic tractability as well as detailed training and generalization analysis [916]. However, in this limit, with these parameterizations, data representations are fixed and do not adapt to data, termed the lazy regime of NN training, to contrast it from the rich regime where NNs significantly alter their internal features while fitting the data [17, 18]. The fact that the representation of data is fixed renders these kernel-based theories incapable of explaining feature learning, an ingredient which is crucial to the success of deep learning in practice [19, 20]. Thus, alternative theories capable of modeling feature learning dynamics are needed.

Recently developed alternative parameterizations such as the mean field [21] and the $\mu P$ [22] parameterizations allow feature learning in infinite-width NNs trained with GD. Using the tensor programs (TPs) framework, Yang and Hu identified a stochastic process that describes the evolution of preactivation features in infinite-width $\mu P$ NNs [22]. In this work, we study an equivalent parameterization to $\mu P$ with self-consistent dynamical mean field theory (DMFT) and recover the stochastic process description of infinite NNs using this alternative technique. In the same large width scaling, we include a scalar parameter γ0 that allows smooth interpolation between lazy and rich behavior [17]. We provide a new computational procedure to sample this stochastic process and demonstrate its predictive power for wide NNs.

Our novel contributions in this paper are the following:

  • (i)  
    We develop a path integral formulation of gradient flow dynamics in infinite-width networks in the feature learning regime. Our parameterization includes a scalar parameter γ0 to allow interpolation between rich and lazy regimes and comparison to perturbative methods.
  • (ii)  
    Using a stationary action argument, we identify a set of saddle point equations that the kernels satisfy at infinite-width, relating the stochastic processes that define hidden activation evolution to the kernels and vice versa. We show that our saddle point equations recover at $\gamma_0 = 1$, from an alternative method, the same stochastic process obtained previously with TPs [22].
  • (iii)  
    We develop a polynomial-time numerical procedure to solve the saddle point equations for deep networks. In numerical experiments, we demonstrate that solutions to these self-consistency equations are predictive of network training at a variety of feature learning strengths, widths and depths. We provide comparisons of our theory to various approximate methods, such as perturbation theory.

Code to reproduce our experiments can be found on our Github.

1.1. Related works

A natural extension to the lazy NTK/NNGP limit that allows the study of feature learning is to calculate finite width corrections to the infinite-width limit. Finite width corrections to Bayesian inference in wide networks have been obtained with various perturbative [2329] and self-consistent techniques [3033]. In the GD based setting, leading order corrections to the NTK dynamics have been analyzed to study finite width effects [27, 3436]. These methods give approximate corrections which are accurate provided the strength of feature learning is small. In very rich feature learning regimes, however, the leading order corrections can give incorrect predictions [37, 38].

Another approach to studying feature learning is to alter NN parameterization in gradient-based learning to allow significant feature evolution even at infinite-width, the mean field limit [21, 39]. Works on mean field NNs have yielded formal loss convergence results [40, 41] and shown equivalences of gradient flow dynamics to a partial differential equation (PDE) [4244].

Our results are most closely related to a set of recent works which studied infinite-width NNs trained with GD using the TPs framework [22]. We show that our discrete time field theory at unit feature learning strength $\gamma_0 = 1$ recovers the stochastic process which was derived from TP. The stochastic process derived from TP has provided insights into practical issues in NN training such as hyper-parameter search [45]. Computing the exact infinite-width limit of GD has exponential time requirements [22], which we show can be circumvented with an alternating sampling procedure. A projected variant of GD training has provided an infinite-width theory that could be scaled to realistic datasets like CIFAR-10 [46]. Inspired by Chizat and Bach's work on mechanisms of lazy and rich training [17], our theory interpolates between lazy and rich behavior in the mean field limit for varying γ0 and allows comparison of DMFT to perturbative analysis near small γ0. Further, our derivation of a DMFT action allows the possibility of pursuing finite width effects.

Our theory is inspired by self-consistent DMFT from statistical physics [4753]. This framework has been utilized in the theory of random recurrent networks [5459], tensor PCA [60, 61], phase retrieval [62], and high-dimensional linear classifiers [6366], but has yet to be developed for deep feature learning. By developing a self-consistent DMFT of deep NNs, we gain insight into how features evolve in the rich regime of network training, while retaining many pleasant analytic properties of the infinite-width limit.

2. Problem setup and definitions

Our theory applies to infinite-width networks, both fully-connected and convolutional. For notational ease we will relegate convolutional results to later sections. For input $\boldsymbol{x}_{\mu} \in \mathbb{R}^D$, we define the hidden pre-activation vectors $\boldsymbol{h}^{\ell} \in \mathbb{R}^{N}$ for layers $\ell \in \{1,\ldots ,L\}$ as

Equation (1)

where $\boldsymbol{\theta} = \text{Vec}\{\boldsymbol{W}^0,\ldots ,\boldsymbol{w}^L\}$ are the trainable parameters of the network and φ is a twice differentiable activation function. Inspired by previous works on the mechanisms of lazy gradient based training, the parameter γ will control the laziness or richness of the training dynamics [17, 18, 22, 42]. Each of the trainable parameters are initialized as Gaussian random variables with unit variance $W^{\ell}_{ij} \sim \mathcal{N}(0,1)$. They evolve under gradient flow $\frac{d}{\mathrm{d}t} \boldsymbol{\theta} = - \gamma^2 \nabla_{\boldsymbol{\theta}} \mathcal{L}$. The choice of learning rate γ2 causes $\frac{d}{\mathrm{d}t} \mathcal{L}|_{t = 0}$ to be independent of γ. To characterize the evolution of weights, we introduce back-propagation variables

Equation (2)

where $\boldsymbol{z}^{\ell}_{\mu}$ is the pre-gradient signal.

The relevant dynamical objects to characterize feature learning are feature and gradient kernels for each hidden layer $\ell \in \{1,\ldots ,L\}$, defined as

Equation (3)

From the kernels $\{\Phi^{\ell},G^{\ell} \}_{\ell = 1}^L$, we can compute the NTK $K^{\mathrm{NTK}}_{\mu\alpha}(t,s) = \nabla_{\theta} f_{\mu}(t)\cdot\nabla_{\theta} f_{\alpha}(s) = \sum_{\ell = 0}^{L} G^{\ell+1}_{\mu\alpha}(t,s) \Phi^{\ell}_{\mu\alpha}(t,s),$ [6] and the dynamics of the network function fµ

Equation (4)

where we define base cases $G_{\mu\alpha}^{L+1}(t,s) = 1, \Phi^0_{\mu\alpha}(t,s) = K^x_{\mu\alpha} = \frac 1D \boldsymbol{x}_{\mu}\cdot \boldsymbol{x}_{\alpha}$. In prior work, $\Phi^{\ell},G^{\ell}$ were termed forward and backward kernels and were theoretically computed at initialization and empirically measured through training [67]. Our DMFT will provide exact formulas for these kernels throughout the full dynamics of feature learning. We note that the above formula holds for any data point µ which may or may not be in the set of P training examples. The above expressions demonstrate that knowledge of the temporal trajectory of the NTK on the t = s diagonal gives the temporal trajectory of the network predictions $f_{\mu}(t)$.

Following prior works on infinite-width networks [18, 21, 22, 40], we study the mean field limit

Equation (5)

As we demonstrate in the appendices D and N, this is the only N-scaling which allows feature learning as $N \to \infty$. The $\gamma_0 = 0$ limit recovers the static NTK limit [6]. We discuss other scalings and parameterizations in appendix N, relating our work to the $\mu P$-parameterization and TP analysis of [22], showing they have identical feature dynamics in the infinite-width limit. We also analyze the effect of different hidden layer widths and initialization variances in the appendix D.8. We focus on equal widths and NTK parameterization (as in equation (1)) in the main text to reduce complexity.

3. Self-consistent DMFT

Next, we derive our self-consistent DMFT in a limit where $t, P = \mathcal{O}_N(1)$. Our goal is to build a description of training dynamics purely based on representations, and independent of weights. Studying feature learning at infinite-width enjoys several analytical properties:

  • The kernel order parameters $\Phi^{\ell},G^{\ell}$ concentrate over random initializations but are dynamical, allowing flexible adaptation of features to the task structure.
  • In each layer $\ell$, each neuron's preactivation $h_i^{\ell}$ and pregradient $z^{\ell}_i$ become i.i.d. draws from a distribution characterized by a set of order parameters $\{\Phi^{\ell},G^{\ell},A^{\ell},B^{\ell}\}$.
  • The kernels are defined as self-consistent averages (denoted by $\left\langle \right\rangle$) over this distribution of neurons in each layer $\Phi^{\ell}_{\mu\alpha}(t,s) = \left\langle \phi(h_{\mu}^{\ell}(t)) \phi(h_{\alpha}^{\ell}(s)) \right\rangle$ and $G_{\mu\alpha}^{\ell}(t,s) = \left\langle g_{\mu}^{\ell}(t) g_{\alpha}^{\ell}(s) \right\rangle$.

The next section derives these facts from a path-integral formulation of gradient flow dynamics.

3.1. Path integral construction

Gradient flow after a random initialization of weights defines a high dimensional stochastic process over initalizations for variables $\{\boldsymbol{h},\boldsymbol{g}\}$. Therefore, we will utilize DMFT formalism to obtain a reduced description of network activity during training. For a simplified derivation of the DMFT for the two-layer (L = 1) case, see appendix D.2. Generally, we separate the contribution on each forward/backward pass between the initial condition and gradient updates to weight matrix $\boldsymbol{W}^{\ell}$, defining new stochastic variables $\boldsymbol{\chi}^{\ell} ,\boldsymbol{\xi}^{\ell} \in \mathbb{R}^{N}$ as

Equation (6)

We let Z represent the moment generating functional (MGF) for these stochastic fields

which requires, by construction the normalization condition $Z[\{\boldsymbol{0},\boldsymbol{0} \}] = 1$. We enforce our definition of $\boldsymbol{\chi},\boldsymbol{\xi}$ using an integral representation of the delta-function. Thus for each sample $\mu \in [P]$ and each time $t \in \mathbb{R}_+$, we multiply Z by

Equation (7)

for χ and the respective expression for ξ . After making such substitutions, we perform integration over initial Gaussian weight matrices to arrive at an integral expression for Z, which we derive in the appendix D.4. We show that Z can be described by set of order-parameters $\{\Phi^{\ell} ,\hat\Phi^{\ell} , G^{\ell}, \hat G^{\ell}, A^{\ell}, B^{\ell}\}$

Equation (8)

Equation (9)

where S is the DMFT action and $\mathcal Z$ is a single-site MGF, which defines the distribution of fields $\{\chi^{\ell},\xi^{\ell}\}$ over the neural population in each layer. The order parameters A and B are related to the correlations between feedforward and feedback signals. We provide a detailed formula for $\mathcal Z$ in appendix D.4 and show that it factorizes over different layers $\mathcal Z = \prod_{\ell = 1}^L \mathcal{Z}_{\ell}$. Each of the single site MGFs has the form

Equation (10)

where $\mathcal{H}_{\ell}$ is a single-site Hamiltonian that depends on the order parameters and defines the probability density over fields $\{\chi^{\ell}, \xi^{\ell} ,\hat\chi^{\ell}, \hat\xi^{\ell} \}$. We introduce the single site average $\left\langle O \right\rangle$ of observable O

Equation (11)

In the next section, we express the DMFT saddle-point equations defining $\{\Phi^{\ell}, G^{\ell} \}$ in terms of such single site averages.

3.2. Deriving the DMFT equations from the path integral saddle point

As $N \to \infty$, the moment-generating function Z is exponentially dominated by the saddle point of S. The equations that define this saddle point also define our DMFT. We thus identify the kernels that render S locally stationary ($\delta S = 0$). The most important equations are those which define $\{\Phi^{\ell},G^{\ell}\}$

Equation (12)

where $\left\langle \right\rangle$ denotes an average over the stochastic process induced by $\mathcal Z$, which is defined below

Equation (13)

where we define base cases $\Phi^0_{\mu\alpha}(t,s) = K^x_{\mu\alpha}$ and $G^{L+1}_{\mu\alpha}(t,s) = 1$, $A^0 = B^L = 0$. We see that the fields $\{h^{\ell},z^{\ell} \}$, which represent the single site preactivations and pre-gradients, are implicit functionals of the mean-zero Gaussian processes $\{u^{\ell},r^{\ell}\}$ which have covariances $\left\langle u^{\ell}_{\mu}(t) u^{\ell}_{\alpha}(s) \right\rangle = \Phi^{\ell-1}_{\mu\alpha}(t,s) , \left\langle r^{\ell}_{\mu}(t) r^{\ell}_{\alpha}(s) \right\rangle = G^{\ell+1}_{\mu\alpha}(t,s)$. The other saddle point equations give the linear response functions

Equation (14)

which arise due to dependence between the feedforward and feedback signals. We note that, in the lazy limit $\gamma_0 \to 0$, the fields approach Gaussian processes $h^{\ell} \to u^{\ell}$, $z^{\ell} \to r^{\ell}$. Lastly, the final saddle point equations $\frac{\delta S}{\delta \Phi^{\ell}} = 0 ,\frac{\delta S}{\delta G^{\ell}} = 0$ imply that $\hat\Phi^{\ell} = \hat G^{\ell} = 0$. The full set of equations that define the DMFT are given in appendix D.7.

This theory is easily extended to more general architectures such as networks with varying widths by layer (appendix D.8), trainable bias parameter (appendix H), multiple (but $\mathcal{O}_N(1)$) output channels (appendix I), convolutional architectures (appendix G), networks trained with weight decay (appendix J), Langevin sampling (appendix K) and momentum (appendix L), discrete time training (appendix M). In appendix N, we discuss parameterizations which give equivalent feature and predictor dynamics and show our derived stochastic process is equivalent to the $\mu P$ scheme of Yang and Hu [22].

4. Solving the self-consistent DMFT

The saddle point equations obtained from the field theory discussed in the previous section must be solved self-consistently. By this we mean that, given knowledge of the kernels, we can characterize the distribution of $\{h^{\ell}, z^{\ell}\}$, and given the distribution of $\{h^{\ell},z^{\ell}\}$, we can compute the kernels [64, 68]. In appendix B, we provide algorithm 1, a numerical procedure based on this idea to efficiently solve for the kernels with an alternating Monte–Carlo strategy. The output of the algorithm are the dynamical kernels $\Phi^{\ell}_{\mu\alpha}(t,s), G^{\ell}_{\mu\alpha}(t,s), A^{\ell}_{\mu\alpha}(t,s), B^{\ell}_{\mu\alpha}(t,s)$, from which any network observable can be computed as we discuss in appendix D. We provide an example of the solution to the saddle point equations compared to training a finite NN in figure 1. We plot $\Phi^{\ell}, G^{\ell}$ at the end of training and the sample-trace of these kernels through time. Additionally, we compare the kernels of finite width N network to the DMFT predicted kernels using a cosine-similarity alignment metric $A(\boldsymbol{\Phi}^\mathrm{DMFT},\boldsymbol{\Phi}^\mathrm{NN}) = \frac{\text{Tr} \ \boldsymbol{\Phi}^\mathrm{DMFT} \boldsymbol{\Phi}^\mathrm{NN}}{|\boldsymbol{\Phi}^\mathrm{DMFT}||\boldsymbol{\Phi}^\mathrm{NN}|}$. Additional examples are shown in appendix figures A1 and A2.

Figure 1.

Figure 1. Neural network feature learning dynamics is captured by self-consistent dynamical mean field theory (DMFT). (a) Training loss curves on a subsample of P = 10 CIFAR-10 training points in a depth 4 (L = 3, N = 2500) tanh network ($\phi(h) = \tanh(h)$) trained with MSE. Increasing γ0 accelerates training. (b), (c) The distribution of preactivations at the beginning and end of training matches predictions of the DMFT. (d) The final $\Phi^{\ell}$ (at t = 100) kernel order parameters match the finite width network. (e) The temporal dynamics of the sample-traced kernels $\sum_{\mu} \Phi_{\mu\mu}^{\ell}(t,s)$ matches experiment and reveals rich dynamics across layers. (f) The alignment $A(\boldsymbol{\Phi}^{\ell}_\mathrm{DMFT}, \boldsymbol{\Phi}^{\ell}_\mathrm{NN})$, defined as cosine similarity, of the kernel $\Phi^{\ell}_{\mu\alpha}(t,s)$ predicted by theory (DMFT) and width N networks for different N but fixed $\gamma_0 = \gamma/\sqrt{N}$. Errorbars show standard deviation computed over 10 repeats. Around $N \sim 500$ DMFT begins to show near perfect agreement with the NN. (g)–(i) The same plots but for the gradient kernel $\boldsymbol{G}^{\ell}$. Whereas finite width effects for $\boldsymbol{\Phi}^{\ell}$ are larger at later layers $\ell$ since variance accumulates on the forward pass, fluctuations in $\boldsymbol{G}^{\ell}$ are large in early layers.

Standard image High-resolution image
Algorithm 1. Alternating Monte–Carlo solution to saddle point equations.
     Data: $\boldsymbol{K}^x, \boldsymbol{y}$, Initial Guesses $\{\boldsymbol{\Phi}^{\ell} ,\boldsymbol{G}^{\ell} \}_{\ell = 1}^L$, $\{\boldsymbol{A}^{\ell},\boldsymbol{B}^{\ell}\}_{\ell = 1}^{L-1}$, Sample count $\mathcal S$, Update Speed β
   Result: Final Kernels $\{\boldsymbol{\Phi}^{\ell} ,\boldsymbol{G}^{\ell} \}_{\ell = 1}^L$, $\{\boldsymbol{A}^{\ell},\boldsymbol{B}^{\ell}\}_{\ell = 1}^{L-1}$, Network predictions through training $f_{\mu}(t)$
1  $\boldsymbol{\Phi}^0 = \boldsymbol{K}^x \otimes \boldsymbol{1} \boldsymbol{1}^{\top}$, $\boldsymbol{G}^{L+1} = \boldsymbol{1} \boldsymbol{1}^{\top}$   
2  while Kernels Not Converged do
3    From $\{\boldsymbol{\Phi}^{\ell}, \boldsymbol{G}^{\ell}\}$ compute $\boldsymbol{K}^{\mathrm{NTK}}(t,t)$ and solve $\frac{d}{\mathrm{d}t} f_{\mu}(t) = \sum_{\alpha} \Delta_{\alpha}(t) K^{\mathrm{NTK}}_{\mu\alpha}(t,t)$   4   $\ell = 1$  
5    while $\ell \lt L+1$ do
6      Draw $\mathcal S$ samples $\{u^{\ell}_{\mu,n}(t) \}_{n = 1}^{\mathcal S} \sim \mathcal{GP}(0,\boldsymbol{\Phi}^{\ell-1})$, $\{r^{\ell}_{\mu,n}(t) \}_{n = 1}^{\mathcal S} \sim \mathcal{GP}(0,\boldsymbol{G}^{\ell+1})$  
7      Solve equation (13) for each sample to get $\{h^{\ell}_{\mu,n}(t),z^{\ell}_{\mu,n}(t)\}_{n = 1}^{\mathcal S}$  
8      Compute new $\boldsymbol{\Phi}^{\ell},\boldsymbol{G}^{\ell}$ estimates:
9      $\tilde\Phi_{\mu\alpha}^{\ell}(t,s) = \frac{1}{\mathcal S} \sum_{n \in [\mathcal S]} \phi(h_{\mu,n}^{\ell}(t) ) \phi(h_{\alpha,n}^{\ell}(s))$, $\tilde{G}_{\mu\alpha}^{\ell}(t,s) = \frac{1}{\mathcal S} \sum_{n \in [\mathcal S]} g^{\ell}_{\mu,n}(t) g^{\ell}_{\alpha,n}(s) $   
10      Solve for Jacobians on each sample $\frac{\partial \phi(\boldsymbol{h}_{n}^{\ell})}{\partial \boldsymbol{r}^{\ell \top}_n}, \frac{\partial \boldsymbol{g}_n^{\ell}}{\partial \boldsymbol{u}^{\ell\top}_n}$   
11      Compute new $\boldsymbol{A}^{\ell},\boldsymbol{B}^{\ell-1}$ estimates:
12      $\tilde{\boldsymbol{A}}^{\ell} = \frac{1}{\mathcal S} \sum_{n \in [\mathcal S]} \frac{\partial \phi(\boldsymbol{h}_{n}^{\ell})}{\partial \boldsymbol{r}^{\ell \top}_n} \ , \tilde{\boldsymbol{B}}^{\ell-1} = \frac{1}{\mathcal S} \sum_{n \in [\mathcal S]} \frac{\partial \boldsymbol{g}_n^{\ell}}{\partial \boldsymbol{u}^{\ell\top}_n}$   
13      $\ell \gets \ell+1$  
14    end
15    $\ell = 1$  
16    while $\ell \lt L+1$ do
17      Update feature kernels: $\boldsymbol{\Phi}^{\ell} \gets (1-\beta) \boldsymbol{\Phi}^{\ell} + \beta \tilde{\boldsymbol{\Phi}}^{\ell}$, $\boldsymbol{G}^{\ell} \gets (1-\beta) \boldsymbol{G}^{\ell} + \beta \tilde{\boldsymbol{G}}^{\ell}$   
18      if $\ell \lt L$ then
19        Update $\boldsymbol{A}^{\ell} \gets (1-\beta) \boldsymbol{A}^{\ell} + \beta \tilde{\boldsymbol{A}}^{\ell}, \boldsymbol{B}^{\ell} \gets (1-\beta)\boldsymbol{B}^{\ell} + \beta \tilde{\boldsymbol{B}}^{\ell}$
20      end
21      $\ell \gets \ell+1$
22    end
23  end
24  return $\{\boldsymbol{\Phi}^{\ell}, \boldsymbol{G}^{\ell} \}_{\ell = 1}^L, \{\boldsymbol{A}^{\ell},\boldsymbol{B}^{\ell}\}_{\ell = 1}^{L-1}, \{f_{\mu}(t)\}_{\mu = 1}^P$

4.1. Deep linear networks: closed form self-consistent equations

Deep linear networks ($\phi(h) = h$) are of theoretical interest since they are simpler to analyze than nonlinear networks but preserve nontrivial training dynamics and feature learning [23, 25, 32, 6973]. In a deep linear network, we can simplify our saddle point equations to algebraic formulas that close in terms of the kernels $H^{\ell}_{\mu\alpha}(t,s) = \left\langle h_{\mu}^{\ell}(t) h_{\alpha}^{\ell}(s) \right\rangle$, $G^{\ell}(t,s) = \left\langle g^{\ell}(t) g^{\ell}(s) \right\rangle$ [22]. This is a significant simplification since it allows the solution of the saddle point equations without a sampling procedure.

To describe the result, we first introduce a vectorization notation $\boldsymbol{h}^{\ell} = \text{Vec}\{h_{\mu}^{\ell}(t) \}_{\mu\in [P], t \in \mathbb{R}_+}$. Likewise we convert kernels $\boldsymbol{H}^{\ell} = \text{Mat}\{H^{\ell}_{\mu\alpha}(t,s) \}_{\mu,\alpha \in [P], t,s \in \mathbb{R}_+}$ into matrices. The inner product under this vectorization is defined as $\boldsymbol{a} \cdot \boldsymbol{b} = \int_0^{\infty} \mathrm{d}t \sum_{\mu = 1}^P a_{\mu}(t) b_{\mu}(t)$. In a practical computational implementation, the theory would be evaluated on a grid of T time points with discrete time GD, so these kernels $\boldsymbol{H}^{\ell} \in \mathbb{R}^{PT \times PT}$ would indeed be matrices of the appropriate size. The fields $\boldsymbol{h}^{\ell},\boldsymbol{g}^{\ell}$ are linear functionals of independent Gaussian processes $\boldsymbol{u}^{\ell},\boldsymbol{r}^{\ell}$, giving $(\mathbf{I} - \gamma_0^2 \boldsymbol{C}^{\ell} \boldsymbol{D}^{\ell}) \boldsymbol{h}^{\ell} = \boldsymbol{u}^{\ell} + \gamma_0 \boldsymbol{C}^{\ell} \boldsymbol{r}^{\ell} \ , \ (\mathbf{I} - \gamma_0^2 \boldsymbol{D}^{\ell} \boldsymbol{C}^{\ell}) \boldsymbol{g}^{\ell} = \boldsymbol{r}^{\ell} + \gamma_0 \boldsymbol{D}^{\ell} \boldsymbol{u}^{\ell}$. The matrices $\boldsymbol{C}^{\ell}$ and $\boldsymbol{D}^{\ell}$ are causal integral operators which depend on $\{\boldsymbol{A}^{\ell-1}, \boldsymbol{H}^{\ell-1}\}$ and $\{\boldsymbol{B}^{\ell}, \boldsymbol{G}^{\ell+1}\}$ respectively which we define in appendix F. The saddle point equations which define the kernels are

Equation (15)

Examples of the predictions obtained by solving these systems of equations are provided in figure 2. We see that these DMFT equations describe kernel evolution for networks of a variety of depths and that the change in each layer's kernel increases with the depth of the network.

Figure 2.

Figure 2. Deep linear network with the full DMFT. (a) The train loss for NNs of varying L. (b) For a $L = 5, N = 1000$ NN, the kernels $H^{\ell}$ at the end of training compared to DMFT theory on P = 20 datapoints. (c) Average displacement of feature kernels for different depth networks at same γ0 value. For equal values of γ0, deeper networks exhibit larger changes to their features, manifested in lower alignment with their initial t = 0 kernels H . (d) The solution to the temporal components of the $G^{\ell}(t,s)$ and $\sum_{\mu}H^{\ell}_{\mu\mu}(t,s)$ kernels obtained from the self-consistent equations.

Standard image High-resolution image

Unlike many prior results [6972], our DMFT does not require any restrictions on the structure of the input data but holds for any $\boldsymbol{K}^x, \boldsymbol{y}$. However, for whitened data $\boldsymbol{K}^x = \mathbf{I}$ we show in appendix F.1.1, appendix F.2 that our DMFT learning curves interpolate between NTK dynamics and the sigmoidal trajectories of prior works [69, 70] as γ0 is increased. For example, in the two layer (L = 1) linear network with $\boldsymbol{K}^x = \mathbf{I}$, the dynamics of the error norm $\Delta(t) = ||\boldsymbol{\Delta}(t)||$ takes the form $\frac{\partial}{\partial t} \Delta(t) = - 2 \sqrt{1 + \gamma_0^2 (y - \Delta(t))^2}\Delta(t)$ where $y = ||\boldsymbol{y}||$. These dynamics give the linear convergence rate of the NTK if $\gamma_0 \to 0$ but approaches logistic dynamics of [70] as $\gamma_0 \to \infty$. Further, $\boldsymbol{H}(t) = \left\langle \boldsymbol{h}^1(t) \boldsymbol{h}^1(t)^{\top} \right\rangle \in \mathbb{R}^{P\times P}$ only grows in the $\boldsymbol{y} \boldsymbol{y}^{\top}$ direction with $H_y(t) = \frac{1}{y^2} \boldsymbol{y}^{\top} \boldsymbol{H}(t) \boldsymbol{y} = \sqrt{1+ \gamma_0^2 (y-\Delta(t))^2 }$. At the end of training $\boldsymbol{H}(t) \to \mathbf{I} + \frac{1}{y^2}[\sqrt{1+\gamma_0^2 y^2}-1] \boldsymbol{y}\boldsymbol{y}^{\top}$, recovering the rank one spike which was recently obtained in the small initialization limit [74]. We show this one dimensional system in figure A3.

4.2. Feature learning with L2 regularization

As we show in appendix J, the DMFT can be extended to networks trained with weight decay $\frac{\mathrm{d}\boldsymbol{\theta}}{\mathrm{d}t} = - \gamma^2 \nabla_{\boldsymbol{\theta}} \mathcal{L} - \lambda \boldsymbol{\theta}$. If neural network is homogenous in its parameters so that $f(c\boldsymbol{\theta}) = c^\kappa f(\boldsymbol{\theta})$ (examples include networks with linear, ReLU, quadratic activations), then the final network predictor is a kernel regressor with the final NTK $\lim_{t\to\infty} f(\boldsymbol{x},t) = \boldsymbol{k}(\boldsymbol{x})^{\top} [ \boldsymbol{K} + \lambda \kappa \mathbf{I}]^{-1} \boldsymbol{y}$ where $K(\boldsymbol{x},\boldsymbol{x}^{^{\prime}})$ is the final-NTK, $[\boldsymbol{k}(\boldsymbol{x})]_{\mu} = K(\boldsymbol{x},\boldsymbol{x}_{\mu})$ and $[\boldsymbol{K}]_{\mu\alpha} = K(\boldsymbol{x}_{\mu},\boldsymbol{x}_{\alpha})$. We note that the effective regularization λκ increases with depth L. In NTK parameterization, weight decay in infinite width homogenous networks gives a trivial fixed point $K(\boldsymbol{x},\boldsymbol{x}^{^{\prime}}) \to 0$ and consequently a zero predictor f → 0 [75]. However, as we show in figure 3, increasing feature learning γ0 can prevent convergence to the trivial fixed point, allowing a non-zero fixed point for $K,f$ even at infinite width. The kernel and function dynamics can be predicted with DMFT. The fixed point is a nontrivial function of the hyperparameters $\lambda, \kappa, L, \gamma_0$.

Figure 3.

Figure 3. Width N = 1000 ReLU networks trained with L2 regularization have nontrivial fixed point in DMFT limit ($\gamma_0 \gt 0$). (a) Training loss dynamics for a L = 1 ReLU network with λ = 1. In $\gamma_0 \to 0$ limit the fixed point is trivial $f = K = 0$. The final loss is a decreasing function of γ0. (b) The final kernel is more aligned with target with increasing γ0. Networks with homogenous activations enjoy a representer theorem at infinite-width as we show in appendix J.

Standard image High-resolution image

5. Approximation schemes

We now compare our exact DMFT with approximations of prior work, providing an explanation of when these approximations give accurate predictions and when they break down.

5.1. Gradient independence ansatz

We can study the accuracy of the ansatz $ \boldsymbol{A}^{\ell} = \boldsymbol{B}^{\ell} = 0$, which is equivalent to treating the weight matrices $\boldsymbol{W}^{\ell}(0)$ and $\boldsymbol{W}^{\ell}(0)^{\top}$ which appear in forward and backward passes respectively as independent Gaussian matrices. This assumption was utilized in prior works on signal propagation in deep networks in the lazy regime [7680]. A consequence of this approximation is the Gaussianity and statistical independence of $\chi^{\ell}$ and $\xi^{\ell}$ (conditional on $\{\boldsymbol{\Phi}^{\ell},\boldsymbol{G}^{\ell}\}$) in each layer as we show in appendix O. This ansatz works very well near $\gamma_0 \approx 0$ (the static kernel regime) since $\frac{\mathrm{d}\boldsymbol{h}}{\mathrm{d}\boldsymbol{r}}, \frac{\mathrm{d}\boldsymbol{z}}{\mathrm{d}\boldsymbol{u}} \sim \mathcal{O}(\gamma_0)$ or around initialization t ≈ 0 but begins to fail at larger values of $\gamma_0, t$ (figures 4 and A4).

Figure 4.

Figure 4. Comparison of DMFT to various approximation schemes in a L = 5 hidden layer, width N = 1000 linear network with $\gamma_0 = 1.0$ and P = 100. (a) The loss for the various approximations do not track the true trajectory induced by gradient descent in the large γ0 regime. (b), (c) The feature kernels $H^{\ell}_{\mu\alpha}(t,s)$ across each of the L = 5 hidden layers for each of the theories is compared to a width 1000 neural network. Again, we plot the sample-traced dynamics $\sum_{\mu\mu} H^{\ell}_{\mu\mu}(t,s)$. (d) The alignment of $\boldsymbol{H}^{\ell}$ compared to the finite NN $A(\boldsymbol{H}^{\ell}, \boldsymbol{H}^{\ell}_\mathrm{NN})$ averaged across $\ell \in \{1,\ldots ,5\}$ for varying γ. The predictions of all of these theories coincide in the $\gamma_0 = 0$ limit but begin to deviate in the feature learning regime. Only the nonperturbative DMFT is accurate over a wide range of γ0.

Standard image High-resolution image

5.2. Small-feature learning perturbation theory at infinite-width

In the $\gamma_0 \to 0$ limit, we recover static kernels, giving linear dynamics identical to the NTK limit [6]. Corrections to this lazy limit can be extracted at small but finite γ0. This is conceptually similar to recent works which consider perturbation series for the NTK in powers of $1/N$ [27, 28, 35] (though not identical, see [81] for finite N effects in mean-field parameterization). We expand all observables $q(\gamma_0)$ in a power series in γ0, giving $q(\gamma_0) = q^{(0)} + \gamma_0 q^{(1)} + \gamma_0^2 q^{(2)} + \ldots $ and compute corrections up to $\mathcal O(\gamma_0^2)$. We show that the $\mathcal{O}(\gamma_0)$ and $\mathcal{O}(\gamma_0^3)$ corrections to kernels vanish, giving leading order expansions of the form $\boldsymbol{\Phi} = \boldsymbol{\Phi}^0 + \gamma_0^2 \boldsymbol{\Phi}^2 + \mathcal{O}(\gamma_0^4)$ and $\boldsymbol{G} = \boldsymbol{G}^0 + \gamma_0^2 \boldsymbol{G}^2 + \mathcal{O}(\gamma_0^4)$ (see appendix P.2). Further, we show that the NTK has relative change at leading order which scales linearly with depth $|\Delta K^{\mathrm{NTK}}|/|K^\mathrm{NTK,0}| \sim\mathcal{O}_{\gamma_0, L}(L \gamma_0^2) = \mathcal{O}_{N,\gamma,L}(\frac{\gamma^2 L}{N})$, which is consistent with finite width effective field theory at $\gamma = \mathcal{O}_N(1)$ [2628] (appendix P.6). Further, at the leading order correction, all temporal dependencies are controlled by $P(P+1)$ functions $v_{\alpha}(t) = \int_0^t \mathrm{d}s \Delta^0_{\alpha}(s)$ and $v_{\alpha\beta}(t) = \int_0^t \mathrm{d}s \Delta^0_{\alpha}(s) \int_0^s \mathrm{d}s^{^{\prime}} \Delta^0_{\beta}(s^{^{\prime}})$, which is consistent with those derived for finite width NNs using a truncation of the neural tangent hierarchy [27, 34, 35]. To lighten notation, we focus our main text comparison of our non-perturbative DMFT to perturbation theory in the deep linear case. Full perturbation theory is in appendix P.2.

Using the timescales derived in the previous section, we find that the leading order correction to the kernels in infinite-width deep linear network have the form

Equation (16)

We see that the relative change in the NTK $|\boldsymbol{K}^{\mathrm{NTK}} - \boldsymbol{K}^{\mathrm{NTK}}(0)|/|\boldsymbol{K}^{\mathrm{NTK}}(0)| \sim \mathcal{O}( \gamma_0^2 L) = \mathcal{O}( \gamma^2 L /N)$, so that large depth L networks exhibit more significant kernel evolution, which agrees with other perturbative studies [25, 27, 35] as well as the nonperturbative results in figure 2. However at large γ0 and large L, this theory begins to break down as we show in figure 4.

6. Feature learning dynamics is preserved across widths

Our DMFT suggests that for networks sufficiently wide for their kernels to concentrate, the dynamics of loss and kernels should be invariant under the rescaling $N \to R N, \gamma \to \gamma / \sqrt{R}$, which keeps γ0 fixed. To evaluate how well this idea holds in a realistic deep learning problem, we trained convolutional neural networks (CNNs) of varying channel counts N on two-class CIFAR classification [82]. We tracked the dynamics of the loss and the last layer $\Phi^L$ kernel. The results are provided in figure 5. We see that dynamics are largely independent of rescaling as predicted. Further, as expected, larger γ0 leads to larger changes in kernel norm and faster alignment to the target function y, as was also found in [83]. Consequently, the higher γ0 networks train more rapidly. The trend is consistent for width N = 250 and N = 500. More details about the experiment can be found in appendix C.2 and figure A5.

Figure 5.

Figure 5. The dynamics of a depth 5 (L = 4 hidden) CNNs trained on first two classes of CIFAR (boat vs plane) exhibit consistency for different channel counts $N \in \{250,500\}$ for fixed $\gamma_0 = \gamma / \sqrt{N}$. (a) We plot the test loss (MSE) and (b) test classification error. Networks with higher γ0 train more rapidly. Time is measured in every 100 update steps. (c) The dynamics of the last layer feature kernel $\Phi^L$, shown as alignment to the target function. As predicted by the DMFT, higher γ0 corresponds to more active kernel evolution, evidenced by larger change in the alignment.

Standard image High-resolution image

7. Discussion

We provided a unifying DMFT derivation of feature dynamics in infinite networks trained with gradient based optimization. Our theory interpolates between lazy infinite-width behavior of a static NTK in $\gamma_0 \to 0$ and rich feature learning. At $\gamma_0 = 1$, our DMFT construction agrees with the stochastic process derived previously with the TPs framework [22]. Our saddle point equations give self-consistency conditions which relate the stochastic fields to the kernels. These equations are exactly solveable in deep linear networks and can be efficiently solved with a numerical method in the nonlinear case. Comparisons with other approximation schemes show that DMFT can be accurate at a much wider range of γ0. We believe our framework could be a useful perspective for future theoretical analyses of feature learning and generalization in wide networks.

Though our DMFT is quite general in regards to the data and architecture, the technique is not entirely rigorous and relies on heuristic physics techniques. Our theory holds in the $T,P = \mathcal{O}_N(1)$ and may break down otherwise; other asymptotic regimes (such as $P/N, T/\log(N) = \mathcal{O}_N(1)$, etc) may exhibit phenomena relevant to deep learning practice [32, 84]. Indeed, many experiements find that finite width effects appear to grow dynamically during learning (with T and P) and hinder the performance of models [45, 81, 85, 86]. The computational requirements of our method, while smaller than the exponential time complexity for exact solution [22], are still significant for large $P T$. In table 1, we compare the time taken for various theories to compute the feature kernels throughout T steps of GD. For a width N network, computation of each forward pass on all P data points takes $\mathcal{O}(P N^2)$ computations. The static NTK requires computation of $\mathcal{O}(P^2)$ entries in the kernel which do not need to be recomputed. However, the DMFT requires matrix multiplications on PT×PT matrices giving a $\mathcal{O}(P^3 T^3)$ time scaling. Future work could aim to improve the computational overhead of the algorithm, by considering data averaged theories [64] or one pass SGD [22]. Alternative projected versions of GD have also enabled much better computational scaling in the evaluation of the theoretical predictions [46], allowing evaluation on full CIFAR-10.

Table 1. Computational requirements to compute kernel dynamics and trained network predictions on P points in a depth N neural network on a grid of T time points trained with P data points for various theories. DMFT is faster and less memory intensive than a width N network only if $N \gg PT$. It is more computationally efficient to compute full DMFT kernels than leading order perturbation theory when $T \ll \sqrt{P}$. The expensive scaling with both samples and time are the cost of a full-batch non-perturbative theory of gradient based feature learning dynamics.

RequirementsWidth-N NNStatic NTKPerturbativeFull DMFT
Memory for kernels $\mathcal O(N^2)$ $\mathcal O(P^2 )$ $\mathcal O(P^4 T)$ $\mathcal O(P^2 T^2)$
Time for kernels $\mathcal O(P N^2 T)$ $\mathcal O(P^2)$ $\mathcal O(P^4 T)$ $\mathcal O(P^3 T^3)$
Time for final outputs $\mathcal{O}(P N^2 T)$ $\mathcal{O}(P^3)$ $\mathcal{O}(P^4)$ $\mathcal{O}(P^3 T^3)$

Since the first appearance of our work in conference proceedings [87], we have extended our DMFT technique beyond GD-based training on a loss function to study the dynamics of other, more biologically-plausible learning rules such as feedback alignment and Hebbian learning [88]. Such rules follow updates with pseudo-gradient fields $\tilde{\boldsymbol{g}}^{\ell}_{\mu}(t)$ which provide a bioplausible approximation to the true backprogagation signals. In this case, the key order parameters to consider are the feature kernels $\Phi^{\ell}_{\mu\nu}(t,s)$ and the gradient-pseudogradient correlators $\tilde{G}^{\ell}_{\mu\nu}(t,s) = \frac{1}{N} \boldsymbol{g}^{\ell}_{\mu}(t) \cdot \tilde{\boldsymbol{g}}^{\ell}_\nu(s)$. Successful feature learning enhances the gradient-pseudogradient alignment measured with $\tilde{G}$. As in the present work, the kernels $\{\Phi^{\ell}, \tilde{G}^{\ell} \}$ and the distribution of preactivations and pregradients are related self-consistently at infinite width.

It remains an open question how much deep learning phenomena can be captured by this infinite width feature learning limit of network dynamics. A recent empirical study analyzed the loss dynamics, individual network logits, and internal feature kernels and preactivation distributions of networks trained at different widths, finding that for simple tasks like CIFAR-10, networks across widths exhibit consistency across these observables in the mean field/µ parameterization [86]. However, for harder tasks such as ImageNet or token prediction on the C4 dataset, wider networks exhibit distinct dynamics, often training faster and updating features more rapidly. The differences across widths in performance and learned representations motivates the development of theoretical methods beyond the mean-field analysis presented here, which can characterize finite size effects on learning dynamics in the feature learning regime [28, 29, 81].

Acknowledgments

This work was supported by NSF Grant DMS-2134157 and an award from the Harvard Data Science Initiative Competitive Research Fund. B B acknowledges additional support from the NSF-Simons Center for Mathematical and Statistical Analysis of Biology at Harvard (Award #1764269) and the Harvard Q-Bio Initiative.

We thank Jacob Zavatone-Veth, Alex Atanasov, Abdulkadir Canatar, and Ben Ruben for comments on this manuscript as well as Greg Yang, Boris Hanin, Yasaman Bahri, and Jascha Sohl-Dickstein for useful discussions.

Appendix A: Additional figures

Figure A1.

Figure A1. Self-consistent DMFT reproduces two layer (L = 1 hidden layer, width N = 2000) ReLU NN's preactivation density, loss dynamics and learned kernel. (a) The loss is obtained by taking saddle point results for $\Phi,G$ and calculating the NTK's dynamics. The $\gamma_0 \to 0$ limit is governed by a static NTK, while the $\gamma_0 \gt 0$ network exhibits kernel evolution and accelerated training. (b) We plot the preactivation h distribution for neurons in the hidden layer of the trained NN against the theoretical densities defined by $\mathcal Z[\Phi,G]$. For small γ0, the final distribution is approximately Gaussian, but becomes non-Gaussian, asymmetric, and heavy tailed for large γ0. The DMFT estimate of the distribution is noisy due to the finite sampling error. (c) The pre-gradient distribution p(z) in the trained network has larger final variance for large γ0. (d), (e) The final $\Phi,G$ are accurately predicted by the field theory and exhibit a block structure that increases with γ0 due to feature learning.

Standard image High-resolution image
Figure A2.

Figure A2. Self-consistent DFT reproduces loss dynamics, and kernels through time in a L = 3 tanh network. (a) The loss when training on synthetic data is obtained by taking saddle point results for $\Phi,G$ and calculating the NTK's dynamics. The $\gamma_0 \to 0$ limit is governed by a static NTK, while the $\gamma_0 \gt 0$ network exhibits kernel evolution and accelerated training. Solid lines are a N = 2000 NN and dashed lines are from solving DMFT equations. (b), (c) The final learned kernels Φ (b) and G (c) are accurately predicted by the field theory and exhibits block structure due to clustering by class identity. (d) The temporal components of $\Phi,G$ reveals nontrivial dynamical structure.

Standard image High-resolution image
Figure A3.

Figure A3. Error and kernel dynamics obtained by solving a one-dimensional ODE system for a depth-2 linear network. (a) $\Delta(t)$ error dynamics from appendix F.1.1 allows one to solve for $\boldsymbol{H}(t)$ by solving a one-dimensional ODE at each value of γ0. The learning curves interpolate between exponential convergence at small γ0 and logistic sigmoidal trajectories at large γ0. (b) The projection of the kernel $\boldsymbol{H}(t)$ along the task relevant subspace $\boldsymbol{y} \in \mathbb{R}^P$.

Standard image High-resolution image
Figure A4.

Figure A4. Gradient independence fails to characterize feature learning dynamics in networks with L > 1 and large γ0. (a) Loss curves for deep linear networks predicted under gradient independence ansatz for $\gamma_0 = 1.5$. (b) The predicted and experimental feature kernels $\boldsymbol{H}^{\ell}$ for the L = 5 hidden layer network demonstrate that gradient independence underestimates the size of kernel adaptation.

Standard image High-resolution image
Figure A5.

Figure A5. Repeating the experiment of figure 5 with depth 7 (L = 6 hidden layer) CNN trained on two class CIFAR over a wide range of γ0 with $N \in \{250,500\}$. We find consistent agreement of loss and prediction dynamics across widths but finite size effects become more significant when computing feature kernels of deeper layers. We note that, while higher γ0 is associated with faster convergence, the final test accuracy for this model is roughly insensitive to choice of γ0.

Standard image High-resolution image

Appendix B: Algorithmic implementation

The alternating sample-and-solve procedure we develope and describe below for nonlinear networks is based on numerical recipes used in the dynamical mean field simulations in computational physics [68]. The basic principle is to leverage the fact that, conditional on kernels, we can easily draw samples $\{u_{\mu}^{\ell}(t), r_{\mu}^{\ell}(t)\}$ from their appropriate GPs. From these sampled fields, we can identify the kernel order parameters by simple estimation of the appropriate moments.

The parameter β controls the recency weighting of the samples obtained at each iteration. If β = 1, then the rank of the kernel estimates is limited to the number of samples $\mathcal S$ used in a single iteration, but with β < 1 smaller sample sizes $\mathcal S$ can be used to still obtain accurate results. We used β = 0.6 in our deep network experiments. Convergence is usually achieved in around ∼15 steps for a depth 4 (L = 3 hidden layer) network such as the one in figures 1 and A2.

Appendix C: Experimental details

All NN training is performed with a Jax GD optimizer [89] with a fixed learning rate.

C.1. MLP experiments

For the MLP experiments, we perform full batch GD. Networks are initialized with Gaussian weights with unit standard deviation $W^{\ell}_{ij} \sim \mathcal{N}(0,1)$. The learning rate is chosen as $\eta_0 \gamma^2 = \eta_0 \gamma_0^2 N$ for a network of width N. The hidden features $\boldsymbol{h}^{\ell}_{\mu}(t) \in \mathbb{R}^N$ are stored throughout training and used to compute the kernels $\Phi^{\ell}_{\mu\alpha}(t,s) = \frac{1}{N} \phi(\boldsymbol{h}_{\mu}^{\ell}(t)) \cdot \phi(\boldsymbol{h}_{\alpha}^{\ell}(s))$. These experiments can be reproduced with the provided jupyter notebooks on our Github.

C.2. CNN experiments on CIFAR-10

We define a depth-L CNN model with ReLU activations and stride 1, which is implemented as a pytree of parameters in JAX [89]. We apply global average pooling in the final layer before a dense readout layer. The code to initialize and evaluate the model is provided on our Github in the file titled scratch_cnn_expt.ipynb.

After constructing a CNN model, we train using MSE loss with the base learning rate $\eta_0 = 2.0 \times 10^{-4}$, batch size 250. The learning rate passed to the optimizer is thus $\eta = \eta_0 \gamma^2 = \eta_0 \gamma_0^2 N$. We optimize the loss function which is scaled appropriately as $\ell( \gamma_0^{-1} f, y)$. Throughout training, we compute the last layer's embedding $\phi(\boldsymbol{h}^L)$ on the test set to calculate the alignment $A(\boldsymbol{\Phi}^L, \boldsymbol{y}\boldsymbol{y}^{\top})$. Training is performed on 4 NVIDIA GPUs. Training a L = 3 network of width 500 takes roughly 1 h.

Appendix D: Derivation of self-consistent dynamical field theory

In this section, we introduce the dynamical field theory setup and saddle point equations. The path integral theory we develop is based on the Martin–Siggia–Rose–De Dominicis–Janssen (MSRDJ) framework [47], of which a useful review for random recurrent networks can be found here [54]. Similar computations can be found in recent works which consider typical behavior in high-dimensional classification on random data [63, 64].

D.1. Deep network field definitions and scaling

As discussed in the main text, we consider the following wide network architecture parameterized by trainable weights $\boldsymbol{\theta} = \text{Vec}\{\boldsymbol{W}^0 , \boldsymbol{W}^1 , \ldots \boldsymbol{w}^L\}$, giving network output fµ defined as

Equation (D.1)

Using gradient flow with learning rate η on cost $\mathcal L = \sum_{\mu} \ell(f_{\mu},y_{\mu})$ for loss function, we introduce functions $\Delta_{\mu} = - \frac{\partial \mathcal L}{\partial f_{\mu}}$ and η for learning rate, and gradient flow induces the following dynamics

Equation (D.2)

Since $K_\mathrm{NTK}$ is $O_{\gamma}(1)$ at initialization, it is clear that to have $O_{\gamma}(1)$ evolution of the network output at initialization we need $\eta = \gamma^2$. With this scaling, we have the following

Equation (D.3)

Now, to build a valid field theory, we want to express everything in terms of features $\boldsymbol{h}_{\mu}^{\ell}$ rather than parameters θ and we will define the following gradient features $\boldsymbol{g}^{\ell}_{\mu} = \sqrt{N} \frac{\partial h^{L+1}_{\mu} }{\partial \boldsymbol{h}^{\ell}_{\mu}}$ which admit the recursion and base case

Equation (D.4)

We define the pre-gradient field $\boldsymbol{z}^{\ell}_{\mu} = \frac{1}{\sqrt N} \boldsymbol{W}^{\ell \top} \boldsymbol{g}^{\ell+1}_{\mu}$ so that $\boldsymbol{g}^{\ell}_{\mu} = \dot\phi(\boldsymbol{h}^{\ell}_{\mu}) \odot \boldsymbol{z}^{\ell}_{\mu}(t)$. From these quantities, we can derive the gradients with respect to parameters

Equation (D.5)

which allows us to compute the NTK in terms of these features

Equation (D.6)

where $K^x_{\mu\alpha} = \frac{1}{D} \boldsymbol{x}_{\mu} \cdot \boldsymbol{x}_{\alpha}$ is the input Gram matrix. We see that the NTK can be built out of the following primitive kernels

Equation (D.7)

We utilize the parameter space dynamics to express $\boldsymbol{W}^{\ell}$ in terms of the $\{\boldsymbol{g},\boldsymbol{h}\}$ fields

Equation (D.8)

Using the field recurrences $\boldsymbol{h}^{\ell+1}_{\mu}(t) = \frac{1}{\sqrt N} \boldsymbol{W}^{\ell}(t) \phi(\boldsymbol{h}^{\ell}_{\mu}(t))$ we can derive the following recursive dynamics for the features

Equation (D.9)

In the above, we implicitly utilize the base cases for the feature kernels $\Phi^0_{\mu\nu}(t,s) = K^x_{\mu\nu}$ and $G^{L+1}_{\mu\nu}(t,s) = 1$. We also introduced the following random fields $\boldsymbol{\chi}_{\mu}^{\ell}(t), \boldsymbol{\xi}_{\mu}^{\ell}(t)$ which involve the random initial conditions

Equation (D.10)

We observe that the dynamics of the hidden features is controlled by the factor $\frac{\gamma}{\sqrt N}$. If $\gamma = O_N(1)$ then we recover static NTK in the limit as $N\to\infty$. However, if $\gamma = O_N(\sqrt N)$ then we obtain $O_N(1)$ evolution of our features and we reach a new rich regime. We choose the scaling $\gamma = \gamma_0 \sqrt N$ for our field theory so that $\gamma_0 \gt 0$ will give a feature learning network.

D.2. Warmup: DMFT for one hidden layer NN

In this section, we provide a warmup problem of a L = 1 hidden layer network which allows us to illustrate the mechanics of the MSRDJ formalism. A more detailed computation can be found in the next section. Though many of the interesting dynamical aspects of the deep network case are missing in the two layer case, our aim is to show a simple application of the ideas. The fields of interest are $\boldsymbol{\chi}_{\mu} = \frac{1}{\sqrt D} \boldsymbol{W}^0(0) \boldsymbol{x}_{\mu}$ and $\boldsymbol{\xi} = \boldsymbol{w}^1(0)$. Unlike the deeper $L \unicode{x2A7E} 2$ case, both of these fields are time invariant since $\boldsymbol{x}_{\mu}$ does not vary in time. These random fields provide initial conditions for the preactivation and pre-gradient fields $\boldsymbol{h}_{\mu}(t), \boldsymbol{z}(t) \in \mathbb{R}^N$, which evolve according to

Equation (D.11)

where the network predictions evolve as $\frac{\partial}{\partial t} f_{\mu}(t) = \sum_{\alpha} [\Phi_{\mu\alpha}(t,t) + G_{\mu\alpha}(t,t) K^x_{\mu\alpha} ] \Delta_{\alpha}(t)$ for kernels $\Phi_{\mu\alpha}(t,t) = \frac{1}{N} \phi(\boldsymbol{h}_{\mu}(t)) \cdot \phi(\boldsymbol{h}_{\alpha}(t))$ and $G_{\mu\alpha}(t,t) = \frac{1}{N} \boldsymbol{g}_{\mu}(t) \cdot \boldsymbol{g}_{\alpha}(t)$. At finite N, the kernels $\Phi, G$ will depend on the random initial conditions $\boldsymbol{\chi}, \boldsymbol{\xi}$, leading to a predictor fµ which varies over initializations. If we can establish that the kernels $\Phi,G$ concentrate at infinite-width $N \to \infty$, then $\Delta_{\mu}$ are deterministic. We now study the moment generating function for the fields

Equation (D.12)

To perform the average over $\boldsymbol{\theta}_0 = \{\boldsymbol{W}^0(0),\boldsymbol{w}^1(0)\}$, we enforce the definition of $\boldsymbol{\chi}_{\mu},\boldsymbol{\xi}$ with delta functions

Equation (D.13)

Though this step may seem redundant in this example, it will be very helpful in the deep network case, so we pursue it for illustration. After mulitplying by these factors of unity and performing the Gaussian integrals, we obtain

Equation (D.14)

We now aim to enforce the definitions of the kernel order parameters with delta functions

Equation (D.15)

where the fields $\boldsymbol{h}_{\mu}(t), \boldsymbol{g}_{\mu}(t)$ are regarded as functions of $\{\boldsymbol{\chi}_{\mu}\}_{\mu},\boldsymbol{\xi}$ (see equation (D.11)) and the $\hat{\Phi}, \hat{G}$ integrals run over the imaginary axis $(-i \infty, i \infty)$. After this step, we can write

Equation (D.16)

where the DMFT action $S[\Phi,\hat{\Phi},G,\hat{G}]$ is $\mathcal{O}_N(1)$ and has the form

Equation (D.17)

The single site moment generating function $\mathcal Z[j,v]$ arises from the factorization of the integrals over N different fields in the hidden layer and takes the form

Equation (D.18)

where, again, we must regard $h_{\mu}(t), g_{\mu}(t)$ as functions of $\chi,\xi$. The variables in the above are no longer vectors in $\mathbb{R}^N$ but rather are scalars. We can write $\mathcal Z[j,v] = \int \prod_{\mu} \mathrm{d}\chi_{\mu} \mathrm{d}\hat{\chi}_{\mu} \mathrm{d}\xi \mathrm{d}\hat{\xi} \exp\left( - \mathcal{H}[\{\chi_{\mu},\hat\chi_{\mu}\},\xi,\hat\xi, j, v] \right)$ where $\mathcal H$ is the logarithm of the integrand above. Since the full MGF takes the form $Z \propto \int \mathrm{d}\Phi \mathrm{d}\hat\Phi \mathrm{d}G\mathrm{d}\hat{G} \exp\left( N S[\Phi,\hat\Phi,G,\hat G] \right)$, characterization of the $N \to \infty$ limit requires one to identify the saddle point of S, where $\delta S = 0$ for any variation of these four order parameters.

Equation (D.19)

where the ith single site average $\left\langle \right\rangle _i$ of an observable $O(\chi,\hat\chi,\xi,\hat\xi)$ is defined as

Equation (D.20)

Since $\hat{\Phi} = \hat{G} = 0$ the single site MGF reveals that the initial fields are independent Gaussians $\{\chi_{\mu}\} \sim \mathcal{N}(0,\boldsymbol{K}^x)$ and $\xi \sim\mathcal{N}(0,1)$. At zero source $\boldsymbol{j}, \boldsymbol{v} \to 0$, all single site averages $\left\langle \right\rangle _i$ are equivalent and we may merely write $\Phi_{\mu\alpha}(t,s) = \left\langle \phi(h_{\mu}(t))\phi(h_{\alpha}(s)) \right\rangle \ , \ G_{\mu\alpha}(t,s) = \left\langle g_{\mu}(t) g_{\alpha}(s) \right\rangle$, where $\left\langle \right\rangle$ is the average over the single site distributions for $\boldsymbol{j},\boldsymbol{v} \to 0$.

D.2.1. Final L = 1 DMFT equations.

Putting all of the saddle point equations together, we arrive at the following DMFT

Equation (D.21)

We see that for L = 1 networks, it suffices to solve for the kernels on the time-time diagonal. Further in this two layer case $\chi, \xi$ are independent and do not vary in time. These facts will not hold in general for $L \unicode{x2A7E} 2$ networks, which requires a more intricate analysis as we show in the next section.

D.3. Path integral formulation for deep networks

As discussed in the main text, we study the distribution over fields by computing the moment generating functional for the stochastic processes $\{\boldsymbol{\chi}^{\ell}, \boldsymbol{\xi}^{\ell} \}_{\ell = 1}^L$

Equation (D.22)

Moments of these stochastic fields can be computed through differentiation of Z near zero-source

Equation (D.23)

To perform the average over the initial parameters, we enforce the definition of the fields $\boldsymbol{\chi}^{\ell+1}(t) = \frac{1}{\sqrt N} \boldsymbol{W}^{\ell}(0) \phi(\boldsymbol{h}^{\ell}_{\mu}(t))$, $\boldsymbol{\xi}^{\ell}_{\mu}(t) = \frac{1}{\sqrt N} \boldsymbol{W}^{\ell}(0)^{\top} \boldsymbol{g}^{\ell+1}_{\mu}(t)$, by inserting the following terms in the definition of $Z[\{\boldsymbol{j},\boldsymbol{v}\}]$ so we may more easily perform the average over weights $\boldsymbol{\theta}_0$. We enforce these definitions with an integral representation of the Dirac-Delta function $1 = \int_{\mathbb{R}} \mathrm{d}x \ \delta(x) = \frac{1}{2\pi} \int_{\mathbb{R}} \mathrm{d}x \int_{\mathbb{R}} \mathrm{d}\hat{x} \exp\left( i x \hat{x} \right)$. We note that we are implicitly working in the Ito scheme, where factors of Jacobian determinants are equal to one [54, 90, 91] (we note that $\boldsymbol{h}^{\ell}_{\mu}(t)$ does not causally depend on $\boldsymbol{\chi}^{\ell+1}_{\mu}(t)$ and $\boldsymbol{g}^{\ell}_{\mu}(t)$ does not causally depend on $\boldsymbol{\xi}^{\ell}(t)$). Applying this to fields $\boldsymbol{\chi},\boldsymbol{\xi}$, we have

Equation (D.24)

where $\{h^{\ell}, g^{\ell} \}$ are understood to be stochastic processes which are causally determined by the $\{\chi^{\ell},\xi^{\ell}\}$ fields, in the sense that $h^{\ell}(t)$ only depends on $\chi^{\ell}(s)$ for s < t. We thus have an expression of the form

Equation (D.25)

Since $\boldsymbol{W}^{\ell}(0)$ are all Gaussian random variables, these averages can be performed quite easily, yielding

Equation (D.26)

D.4. Order parameters and action definition

We define the following order parameters which we will show concentrate in the $N \to \infty$ limit

Equation (D.27)

The NTK only depends on $\{\Phi^{\ell},G^{\ell}\}$ so from these order parameters, we can compute the function evolution. The parameter $\boldsymbol{A}^{\ell}$ arises from the coupling of the fields across a single layer's initial weight matrix $\boldsymbol{W}^{\ell}(0)$. We can again enforce these definitions with integral representations of the Dirac-delta function. For each pair of samples $\mu,\alpha$ and each pair of times $t,s$, we multiply by

Equation (D.28)

for all $\ell \in \{1,\ldots ,L\}$ and analogously

Equation (D.29)

for $\ell \in \{1,\ldots ,L-1\}$. After introducing these order parameters into the definition of the partition function, we have a factorization of the integrals over each of the N sites in each hidden layer. This gives the following partition function

Equation (D.30)

Equation (D.31)

We thus see that the action S consists of inner-products between order parameters $\{\Phi,G,A\}$ and their duals $\{\hat\Phi,\hat G, B\}$ as well as a single site MGF $\mathcal{Z}[\{\Phi, \hat\Phi, G,\hat G, A, B , j, v\}]$, which is defined as

Equation (D.32)

D.5. Saddle point equations

Since the integrand in the moment generating function Z takes the form $e^{N S[\{\Phi,\hat\Phi,G,\hat G, A, B\}]}$, the $N\to \infty$ limit can be obtained from saddle point integration, also known as the method of steepest descent [92]. This consists of finding order parameters $\{\Phi,\hat\Phi,G,\hat G, A, B\}$ which render the action S locally stationary. Concretely, this leads to the following saddle point equations.

Equation (D.33)

We use the notation $\langle \rangle$ to denote an average over the self-consistent distribution on fields induced by the single-site moment generating function $\mathcal{Z}$ at the saddle point. Concretely if $\mathcal Z = \int \mathrm{d}\chi \mathrm{d}\xi \mathrm{d}\hat\chi \mathrm{d}\hat\xi \exp ( - \mathcal H[\chi, \xi ,\hat\chi, \hat\xi] )$ then the single-site self-consistent average of observable $O([\chi, \xi ,\hat\chi, \hat\xi])$ is defined as

Equation (D.34)

To calculate the averages of the dual variables such as $\left\langle \hat\chi^{\ell+1} \hat\chi^{\ell+1} \right\rangle$, it will be convenient to work with vector and matrix notation. We let $\boldsymbol{\chi}^{\ell} = \text{Vec}\{\chi_{\mu}^{\ell}(t) \}_{\mu\in[P], t\in \mathbb{R}_+}$ represent the vectorization of the stochastic process over different samples and times and define the dot product between two of these vectors as $\boldsymbol{a} \cdot \boldsymbol{b} = \sum_{\mu = 1}^P \int_0^{\infty} \mathrm{d}t \ a_{\mu}(t) b_{\mu}(t)$. We also apply this procedure on the kernels so that $\boldsymbol{\Phi} = \text{Mat}\{\Phi_{\mu\alpha}(t,s)\}_{\mu\alpha \in [P], t,s \in \mathbb{R}_+}$. Matrix vector products take the form $[\boldsymbol{A} \boldsymbol{b}]_{\mu,t} = \int_0^{\infty} \mathrm{d}s \sum_{\alpha} A_{\mu\alpha}(t,s) b_{\alpha}(s)$. We can obtain the behavior of $\langle \boldsymbol{\hat \chi}^{\ell+1}_{\mu} \boldsymbol{\hat \chi}^{\ell+1 \top}_{\mu} \rangle$ in terms of primal fields $\{\chi,\xi, h, z\}$ by insertion of a dummy source u into the effective partition function.

Equation (D.35)

Similarly, we can obtain the equation for $\left\langle \boldsymbol{\hat\xi}^{\ell} \boldsymbol{\hat\xi}^{\ell \top} \right\rangle$ by inserting a dummy source r and differentiating near zero source

Equation (D.36)

As we will demonstrate in the next subsection, these correlators must vanish. Lastly, we can calculate the remaining correlators in terms of primal variables

Equation (D.37)

D.6. Single site stochastic process: Hubbard trick

To get a better sense of this distribution, we can now simplify the quadratic forms appearing in $\mathcal Z$ using the Hubbard trick [93], which merely relates a Gaussian function to its Fourier transform.

Equation (D.38)

Applying this to the quadratic forms in the single-site MGF $\mathcal{Z}$, we get

Equation (D.39)

Next, we integrate over all $\hat\chi^{\ell}, \hat\xi^{\ell}$ variables which yield Dirac-delta functions

Equation (D.40)

To remedy the notational asymmetry, we redefine $\boldsymbol{B}^{\ell}$ as its transpose $\boldsymbol{B}^{\ell} \to \boldsymbol{B}^{\ell \top}$. The presence of these delta-functions in the MGF $\mathcal Z$ indicate the constraints $\boldsymbol{u}^{\ell} = \boldsymbol{\chi}^{\ell} - \boldsymbol{A}^{\ell-1} \boldsymbol{g}^{\ell}$ and $\boldsymbol{r}^{\ell} = \boldsymbol{\xi}^{\ell} - \boldsymbol{B}^{\ell} \phi(\boldsymbol{h}^{\ell})$. We can thus return to the $\hat\Phi$ and $\hat G$ saddle point equations and verify that these order parameters vanish

Equation (D.41)

since $\left\langle \boldsymbol{u}^{\ell+1} \boldsymbol{u}^{\ell+1 \top} \right\rangle = \boldsymbol{\Phi}^{\ell}$. Following an identical argument, $\hat{\boldsymbol{G}}^{\ell} = 0$. After this simplification, the single site MGF takes the form

Equation (D.42)

The interpretation is thus that $\boldsymbol{u}^{\ell}, \boldsymbol{r}^{\ell}$ are sampled independently from their respective Gaussian processes and the fields $\boldsymbol{\chi}^{\ell}$ and $\boldsymbol{\xi}^{\ell}$ are determined in terms of $\boldsymbol{u}^{\ell}, \boldsymbol{r}^{\ell}, \boldsymbol{h}^{\ell}, \boldsymbol{g}^{\ell}$. This means that we can apply Stein's Lemma (integration by parts) [94] to simplify the last two saddle point equations

Equation (D.43)

D.7. Final DMFT equations

We can now close this stochastic process in terms of preactivations $h^{\ell}$ and pre-gradients $z^{\ell}$. To match the formulas provided in the main text, we rescale $A^{\ell} \to A^{\ell} /\gamma_0 = \mathcal{O}_{\gamma_0}(1)$ and $B^{\ell} \to B^{\ell} /\gamma_0 = \mathcal{O}_{\gamma_0}(1)$, which makes it clear that the non-Gaussian corrections to the $h_{\mu}^{\ell}(t), z_{\mu}^{\ell}(t)$ fields are $\mathcal{O}(\gamma_0)$. After this rescaling, we have the following complete DMFT equations.


The base cases in the above equations are that $A^{0} = B^L = 0$ and $\Phi^0_{\mu\alpha}(t,s) = K^x_{\mu\alpha}$ and $G^{L+1}_{\mu\alpha}(t,s) = 1$. From the above self-consistent equations, one obtains the NTK dynamics and consequently the output predictions of the network with $\frac{\partial f_{\mu}}{\partial t} = \sum_{\alpha} \Delta_{\alpha}(t) \left[ \sum_{\ell} G^{\ell+1}_{\mu\alpha}(t,t) \Phi^{\ell}_{\mu\alpha}(t,t) \right]$.

D.8. Varying network widths and initialization scales

In this section, we relax the assumption of network widths being equal while taking all widths to infinity at a fixed ratio. This will allow us to analyze the influence of bottlenecks on the dynamics. We let $N^{\ell} = a_{\ell} N$ represent the width of layer $\ell$. Without loss of generality, we can choose that $N^L = N$ and proceed by defining order parameters in the usual way

Equation (D.45)

Since $N^L = N$, the variable $\boldsymbol{g}^L = \sqrt{N^L} \frac{\partial h^{L+1}}{\partial \boldsymbol{h}^L} = \boldsymbol{w}^L \odot \dot\phi(\boldsymbol{h}^L) = \mathcal{O}_{N,\gamma}(1)$ as desired. We extend this definition to each layer as before $\boldsymbol{g}^{\ell} = \sqrt{N^{\ell}} \frac{\partial h^{L+1}}{\partial \boldsymbol{h}^{\ell}}$ which again satisfies the recursion

Equation (D.46)

Now, we need to calculate the dynamics on weights $\boldsymbol{W}^{\ell}$

Equation (D.47)

Using our definition of the kernels and the $\boldsymbol{h},\boldsymbol{z}$ fields

Equation (D.48)

We also find the usual formula for the NTK

Equation (D.49)

Now, as before, we need to consider the distribution of $\boldsymbol{\chi},\boldsymbol{\xi}$ fields. We assume $W^{\ell}_{ij}(0) \sim \mathcal{N}(0,\sigma^2_{\ell})$. This requires computing integrals like

Equation (D.50)

where $\boldsymbol{A}^{\ell}_{\mu\alpha}(t,s) = -\frac{i}{N^{\ell}} \boldsymbol{\Phi}(\boldsymbol{h}^{\ell}_{\mu}(t)) \cdot \boldsymbol{\hat\xi}^{\ell}_{\alpha}(s)$. The action thus takes the form

Equation (D.51)

where the zero-source MGF for layer $\ell$ has the form

Equation (D.52)

The saddle point equations give

Equation (D.53)

where $\boldsymbol{u}^{\ell} \sim \mathcal{GP}(0,\sigma^2_{\ell-1} \boldsymbol{\Phi}^{\ell-1}) , \boldsymbol{r}^{\ell} \sim \mathcal{GP}(0,\sigma^2_{\ell} \boldsymbol{G}^{\ell+1})$. We redefine $\boldsymbol{B}^{\ell} \to \frac{1}{\sigma^2_{\ell}} \sqrt{\frac{a_{\ell}}{a_{\ell+1}}}\boldsymbol{B}^{\ell}$. To take the $N\to\infty$ limit of the field dynamics, again use $\gamma_{0} = \gamma / \sqrt{N} = O_{N}(1)$. The field equations take the form

Equation (D.54)

We thus find that the evolution of the scalar fields in a given layer is set by the parameter $\gamma_{0}/\sqrt{a_{\ell}}$, indicating that relatively wider layers evolve less and contribute less of a change to the overall NTK. This definition for $\boldsymbol{A}^{\ell},\boldsymbol{B}^{\ell}$ is non-ideal to extract intuition about bottlenecks since $\boldsymbol{A}^{\ell-1} \sim \mathcal{O}\left( \frac{\gamma_0}{\sqrt{a_{\ell-1}}} \right)$ and $\boldsymbol{B}^{\ell} \sim \mathcal{O}\left( \frac{\gamma_0}{\sqrt{a_{\ell+1}}} \right)$. To remedy this, we redefine $\tilde{\boldsymbol{A}}^{\ell} = \frac{\sqrt{a_{\ell}}}{\gamma_0} \boldsymbol{A}^{\ell}, \tilde{\boldsymbol{B}}^{\ell} = \frac{\sqrt{a_{\ell+1}}}{\gamma_0} \boldsymbol{B}^{\ell}$. With this choice, we have

Equation (D.55)

where $\tilde{A}^{\ell-1}, \tilde{B}^{\ell}$ do not have a leading order scaling with $a_{\ell-1}$ or $a_{\ell+1}$ respectively. Under this change of variables, it is now apparent that a very wide layer $\ell$, where $\frac{\gamma_0}{\sqrt{a_{\ell}}} \ll 1$ is small, the fields $h^{\ell}, z^{\ell}$ become well approximated by the Gaussian processes $u^{\ell}, r^{\ell}$, albeit with evolving covariances $\boldsymbol{\Phi}^{\ell-1}, \boldsymbol{G}^{\ell+1}$ respectively. In a realistic CNN architecture where the number of channels increases across layers, this result would predict that more feature learning and deviations from Gaussianity to occur in the early layers and the later layers to be well approximated as Gaussian fields $u^{\ell}, r^{\ell}$ with temporally evolving covariances for $\ell \sim L$. We leave evaluation of this prediction to future work.

Appendix E: Two-layer networks

In a two-layer network, there are no A or B order parameters, so the fields χ1 and ξ1 are always independent. Further, χ1 and ξ1 are both constant throughout training dynamics. Thus we can obtain differential rather than integral equations for the stochastic fields $h^1, z^1$ which are

Equation (E.1)

where the average is taken over the random initial conditions $\boldsymbol{h}^1(0) \sim \mathcal{N}(0,\boldsymbol{K}^x)$ and $\boldsymbol{z}^1(0) \sim \mathcal{N}(0,\boldsymbol{1} \boldsymbol{1}^{\top})$. An example of the two-layer theory for a ReLU network can be found in appendix figure A1. In this two-layer setting, a drift PDE can be obtained for the joint density of preactivations and feedback fields $p(\boldsymbol{h},z;t)$

Equation (E.2)

which is a zero-diffusion feature space version of the PDE derived in the original two-layer mean field limit of neural networks [21, 42, 43].

Appendix F: Deep linear networks

In the deep linear case, the $g^{\ell}_{\mu}(t)$ fields are independent of sample index µ. We introduce the kernel $H^{\ell}_{\mu\alpha}(t,s) = \left\langle h^{\ell}_{\mu}(t) h^{\ell}_{\alpha}(s) \right\rangle$. The field equations are

Equation (F.1)

Or in vector notation $\boldsymbol{h}^{\ell} = \boldsymbol{u}^{\ell} + \gamma_0 \boldsymbol{C}^{\ell}\boldsymbol{g}^{\ell}$ and $\boldsymbol{g}^{\ell} = \boldsymbol{r}^{\ell} + \gamma_0 \boldsymbol{D}^{\ell} \boldsymbol{h}^{\ell}$ where

Equation (F.2)

Using the formulas which define the fields, we have

Equation (F.3)

The saddle point equations can thus be written as

Equation (F.4)

We solve these equations by repeatedly updating $\boldsymbol{H}^{\ell}, \boldsymbol{G}^{\ell}$, using equation (F.4) and the current estimate of $\boldsymbol{C}^{\ell}, \boldsymbol{D}^{\ell}$. We then use the new $\boldsymbol{H}^{\ell}, \boldsymbol{G}^{\ell}$ to recompute $\boldsymbol{K}^{\mathrm{NTK}}$ and $\boldsymbol{\Delta}(t)$, calculating $\boldsymbol{C}^{\ell}, \boldsymbol{D}^{\ell}$ and then recomputing $\boldsymbol{H}^{\ell}, \boldsymbol{G}^{\ell}$. This procedure usually converges in approximately five to ten steps.

F.1. Two-layer linear network

As we saw in appendix E, the field dynamics simplify considerably in the two-layer case, allowing the description of all fields in terms of differential equations. In a two-layer linear network, we let $\boldsymbol{h}(t) \in \mathbb{R}^P$ represent the hidden activation field and $g(t) \in \mathbb{R}$ represent the gradient

Equation (F.5)

The kernels $\boldsymbol{H}(t) = \left\langle \boldsymbol{h}(t) \boldsymbol{h}(t)^{\top} \right\rangle$ and $G(t) = \left\langle g(t)^2 \right\rangle$ thus evolve as

Equation (F.6)

It is easy to verify that the network predictions on the P training points are $\boldsymbol{f}(t) = \boldsymbol{y} - \boldsymbol{\Delta}(t) = \frac{1}{\gamma_0} \left\langle g(t) \boldsymbol{h}(t) \right\rangle \in \mathbb{R}^P$. Thus the dynamics of $\boldsymbol{H}(t), G(t)$ and $\boldsymbol{\Delta}(t)$ close

Equation (F.7)

where the initial conditions are $\boldsymbol{H}(0) = \boldsymbol{I}$, $G(0) = 1$ and $\boldsymbol{\Delta}(0) = \boldsymbol{y}$. These equations hold for any choice of data $\boldsymbol{K}^x, \boldsymbol{y}$.

F.1.1. Whitened data in two-layer linear network.

For input data which is whitened where $\boldsymbol{K}^x = \mathbf{I}$, then the dynamics can be simplified even further, recovering the sigmoidal curves very similar to those obtained under a special initialization [69, 70, 72, 74]. In this case we note that the error signal always evolves in the y direction, $\boldsymbol{\Delta}(t) = \Delta(t) \frac{\boldsymbol{y}}{|\boldsymbol{y}|}$, and that H only evolves in a rank one direction $\boldsymbol{y} \boldsymbol{y}^{\top}$ direction as well. Let $\frac{1}{|\boldsymbol{y}|^2} \boldsymbol{y}^{\top} \boldsymbol{H}(t) \boldsymbol{y} = H_y(t)$. Let $y = |\boldsymbol{y}|$ represent the norm of the target vector, then the relevant scalar dynamics are

Equation (F.8)

Now note that, at initialization $H_y(0) = G(0) = 1$ and that $ \frac{\partial }{\partial t} H_y(t) = \frac{\partial}{\partial t} G(t)$. Thus, we have an automatic balancing condition $H_y(t) = G(t)$ for all $t \in \mathbb{R}_+$ and the dynamics reduce to two variables

Equation (F.9)

We note that this system obeys a conservation law which constrains $(H_y, y-\Delta)$ to a hyperbola

Equation (F.10)

This conservation law implies that $H_y(0)^2 = 1 = \lim_{t\to\infty} H_y(t)^2 - \gamma_0^2 y^2$ or that the final kernel has the form $\lim_{t\to\infty} \boldsymbol{H}(t) = \frac{1}{y^2}[\sqrt{1 + \gamma_0^2 y^2 } -1 ] \boldsymbol{y} \boldsymbol{y}^{\top} + \mathbf{I}$. The result that the final kernel becomes a rank one spike in the direction of the target function was also obtained in finite width networks in the limit of small initialization [74] and also from a normative toy model of feature learning [83]. We can use the conservation law above $1 = H_y(t)^2 - \gamma_0^2 (\Delta(t) - y)^2$ to simplify the dynamics to a one-dimensional system

Equation (F.11)

where $f = y-\Delta$. We see that increasing γ0 provides strict acceleration in the learning dynamics, illustrating the training benefits of feature evolution. Since this system is separable, we can solve for the time it takes for the network output norm to reach output level f

Equation (F.12)

The NTK limit can be obtained by taking $\gamma_0 \to 0$ which gives

Equation (F.13)

which recovers the usual convergence rate of a linear model. The right hand side of equation (F.12) has a perturbation series in γ0 2 which converges in the disk $\gamma_0 \lt \frac{1}{y}$. The other limit of interest is the $\gamma_0 \to \infty$ limit where

Equation (F.14)

which recovers the logistic growth observed in the initialization scheme of prior works [69, 70]. The timescale τ required to learn is only $\tau \sim \frac{1}{\gamma_0} \ll 1$, which is much smaller than the $O_{\gamma_0}(1)$ time to learn predicted from the small γ0 expansion. We note that the above leading order asymptotic behavior at large γ0 considers the DMFT initial condition $\Delta(0) = y$ as an unstable fixed point. For realistic learning curves, one would need to stipulate some alternative initial condition such as $\Delta = y-\epsilon$ for some small epsilon > 0 in order to have nontrivial leading order dynamics.

F.2. Deep linear whitened data

In this section, we examine the role of depth when linear networks are trained on whitened data. As in the two-layer case, all hidden kernels $\boldsymbol{H}^{\ell}(t,s)$ need only be tracked in the one-dimensional task relevant subspace along the vector y . We let $\Delta(t) = \frac{1}{y}\boldsymbol{y} \cdot \boldsymbol{\Delta}(t)$ and let $h_y(t) = \frac{1}{y} \boldsymbol{h}^{\ell}(t) \cdot \boldsymbol{y}$. We have

Equation (F.15)

Lastly, we have the simple evolution equation for the scalar error $\Delta(t)$

Equation (F.16)

Vectorizing we find the following equations for the time × time matrix order parameters $\boldsymbol{h}^{\ell} = \boldsymbol{u}^{\ell} + \gamma_0 \boldsymbol{C}^{\ell} \boldsymbol{g}^{\ell} \ , \ \boldsymbol{g}^{\ell} = \boldsymbol{r}^{\ell} + \gamma_0 \boldsymbol{D}^{\ell} \boldsymbol{h}^{\ell}$, we can solve for the response functions $\boldsymbol{A}^{\ell} = \left( \mathbf{I} - \gamma_0^2 \boldsymbol{C}^{\ell} \boldsymbol{D}^{\ell} \right)^{-1} \boldsymbol{C}^{\ell}$ and $\boldsymbol{B}^{\ell} = \left( \mathbf{I} - \gamma_0^2 \boldsymbol{D}^{\ell} \boldsymbol{C}^{\ell} \right)^{-1} \boldsymbol{D}^{\ell}$. This formulation has the advantage that it no longer has any sample-size dependence: arbitrary sample sizes can be considered with no computational cost.

Appendix G: Convolutional networks with infinite channels

The DMFT described in this work can be extended to CNNs with infinitely many channels, much in the same way that infinite CNNs have a well defined kernel limit [95, 96]. We let $W^{\ell}_{ij,\mathfrak{a}}$ represent the value of the filter at spatial displacement $\mathfrak{a}$ from the center of the filter, which maps relates activity at channel j of layer $\ell$ to channel i of layer $\ell+1$. The fields $h_{\mu,i,\mathfrak{a}}^{\ell}$ are defined recursively as

Equation (G.1)

where $\mathcal{S}^{\ell}$ is the spatial receptive field at layer $\ell$. For example, a $(2k+1) \times (2k+1)$ convolution will have $\mathcal{S}^{\ell} = \{(i,j) \in \mathbb{Z}^2 : -k\unicode{x2A7D} i \unicode{x2A7D} k, -k \unicode{x2A7D} j \unicode{x2A7D} k \}$. The output function is obtained from the last layer is defined as $f_{\mu} = \frac{1}{\gamma_0 N} \sum_{i = 1}^N w_{i,\mathfrak{a}}^{L} \phi(h^L_{\mu,i,\mathfrak{a}})$. The gradient fields have the same definition as before $\boldsymbol{g}^{\ell}_{\mu,\mathfrak{a}} = \gamma_0 N \frac{\partial f_{\mu}}{\partial \boldsymbol{h}^{\ell}_{\mu,\mathfrak{a}}}$, which as before enjoy the following recursion from the chain rule

Equation (G.2)

The dynamics of each set of filters $\{\boldsymbol{W}^{\ell}_{\mathfrak{b}} \}$ can therefore be written in terms of the features $\boldsymbol{h}^{\ell}_{\mathfrak{a}}, \boldsymbol{g}^{\ell}_{\mathfrak{a}}$

Equation (G.3)

The feature space description of the forward and backward pass relations is

Equation (G.4)

where $\boldsymbol{\chi}^{\ell+1}_{\mu,\mathfrak{a}}(t) = \frac{1}{\sqrt N} \boldsymbol{W}^{\ell}(0) \phi(\boldsymbol{h}^{\ell}_{\mu \mathfrak{a}}(t))$. The order parameters for this network architecture are

Equation (G.5)

These two order parameters per layer collectively define the NTK. Following the computation in appendix D, we obtain the following field theory in the $N \to \infty$ limit:

Equation (G.6)

We see that this field theory essentially multiplies the number of sample indices by the number of spatial indices $P \to P |\mathcal S|$. Thus the time complexity of evaluation of this theory scales very poorly as $\mathcal{O}(P^3 |\mathcal S|^3 T^3)$, rendering DMFT solutions very computationally intensive.

Appendix H: Trainable bias parameter

If we include a bias $\boldsymbol{b}^{\ell}(t) \in \mathbb{R}^{N}$ in our trainable model, so that

Equation (H.1)

then the dynamics on $\boldsymbol{b}^{\ell}(t)$ induced by gradient flow is

Equation (H.2)

Assuming that $b_i^{\ell}(0) \sim \mathcal{N}(0,1)$, the dynamics of the DMFT becomes

Equation (H.3)

Appendix I: Multiple output channels

We now consider network outputs on $C = \mathcal{O}_N(1)$ classes. The prediction for a data point $\mu \in [P] $ at time $t \in \mathbb{R}_+$ is $\boldsymbol{f}_{\mu}(t) \in \mathbb{R}^C$. As before, we define the error signal as $\boldsymbol{\Delta}_{\mu} = - \frac{\partial }{\partial \boldsymbol{f}_{\mu} } \ell(\boldsymbol{f}_{\mu},\boldsymbol{y}_{\mu}) \in \mathbb{R}^C$. For any pair of data points $\mu,\alpha$ the NTK is a C×C matrix $\boldsymbol{K}^{\mathrm{NTK}}_{\mu\alpha} \in \mathbb{R}^{C \times C}$ with entries $K_{\mu\alpha,cc^{^{\prime}}}^{\mathrm{NTK}} = \frac{\partial f_c(\boldsymbol{x}_{\mu})}{\partial \boldsymbol{\theta}} \cdot \frac{\partial f_{c^{^{\prime}}}(\boldsymbol{x}_{\alpha})}{\partial \boldsymbol{\theta}}$. From these matrices, we can compute the evolution of the predictions in the network.

Equation (I.1)

In this case, we have matrices for the backprop features $\boldsymbol{g}^{\ell} = \gamma \sqrt{N} \frac{\partial \boldsymbol{f}^{\top}}{\partial \boldsymbol{h}^{\ell}} \in \mathbb{R}^{N \times C}$. These satisfy the usual recursion

Equation (I.2)

We can now compute the NTK for samples $\mu,\alpha$

Equation (I.3)

where $\boldsymbol{G}_{\mu\alpha}^{\ell} = \frac{1}{N} \boldsymbol{g}_{\mu}^{\ell \top} \boldsymbol{g}_{\alpha}^{\ell} \in \mathbb{R}^{C \times C}$ and $\Phi_{\mu\alpha}^{\ell} = \frac{1}{N} \phi(\boldsymbol{h}^{\ell}_{\mu}) \cdot \phi(\boldsymbol{h}^{\ell}_{\alpha}) \in \mathbb{R}$. Next we introduce kernels $\boldsymbol{A}^{\ell}_{\mu\alpha}(t,s) \in\mathbb{R}^C$ and $\boldsymbol{B}^{\ell}_{\mu\alpha}(t,s) \in \mathbb{R}^C$ which are defined in the usual way. The corresponding field theory has the form

Equation (I.4)

From these fields, the saddle point equations define the kernels as

Equation (I.5)

This allows us to study the multi-class structure of learned representations.

Appendix J: Weight decay in deep homogenous networks

If we train with weight decay, $\frac{d}{\mathrm{d}t}\boldsymbol{\theta} = - \gamma^2 \nabla_{\boldsymbol{\theta}} \mathcal{L} - \lambda \boldsymbol{\theta}$, in a κ-degree homogenous network ($f(c\boldsymbol{\theta}) = c^\kappa f(\boldsymbol{\theta})$), then the prediction dynamics satisfy

This holds by the following identity $\frac{\partial}{\partial c} f(c \boldsymbol{\theta}) = \frac{\partial }{\partial c} c^\kappa f(\boldsymbol{\theta})$, which when evaluated at c = 1 gives $\frac{\partial}{\partial \boldsymbol{\theta}} f(\boldsymbol{\theta}) \cdot \boldsymbol{\theta} = \kappa f(\boldsymbol{\theta})$. This identity was utilized in a prior work which studied L2 regularization in the lazy regime [75]. For a L-hidden layer ReLU network $\phi(h) = \max(0,h)$, the degree is $\kappa = L+1$, while rectified power law nonlinearities $\phi(h) = \max(0,h)^q$ give degrees $\kappa = \frac{q^{L+1}-1}{q-1}$. We note that the fixed point of the function dynamics above gives a representer theorem with the final NTK

Equation (J.1)

where $[\boldsymbol{k}(x)]_{\mu} = \lim_{t\to\infty }K(\boldsymbol{x},\boldsymbol{x}_{\mu}, t)$ and $K_{\mu\alpha} = \lim_{t \to \infty} K(\boldsymbol{x}_{\mu},\boldsymbol{x}_{\alpha},t)$. The prior work of Lewkowycz and Gur-Ar [75] considered NTK parameterization $\gamma_0 = 0$. In this limit, the kernel (and consequently output function) decay to zero at large time, but if $\gamma_0 \gt 0$, then the network converges to a nontrivial fixed point as $t \to \infty$. In the DMFT limit we can determine the final kernel by solving the following field dynamics

Equation (J.2)

We see that the contribution from initial conditions is exponentially suppressed at large time t while the second term contributes most when the system has equilibrated. We provide an example of the weight decay DMFT showing its validity in a two layer ReLU network in figure 3.

Appendix K: Bayesian/Langevin trained mean field networks

Rather than studying exact gradient flow, many works have considered Langevin dynamics (gradient flow with white noise process on the weights) of neural network training [25, 3032, 97]. This setting is of special theoretical interest since the distribution of parameters converges at long times to a Gibbs equilibrium distribution which has a Bayesian interpretation [3, 4, 97]. The relevant Langevin equation for our mean field gradient flow is

Equation (K.1)

where λ is a ridge penalty which controls the scale of parameters, and $\mathrm{d}\boldsymbol{\epsilon}(t)$ is a Brownian motion term which has covariance structure $\left\langle \mathrm{d}\boldsymbol{\epsilon}(t) \mathrm{d}\boldsymbol{\epsilon}(t^{^{\prime}})^{\top} \right\rangle = \delta(t-t^{^{\prime}}) \mathbf{I}$. The parameter β, known as the inverse temperature controls the scale of the random Gaussian noise injected into this stochastic process. The dynamical early-time treatment of the $\beta \to \infty$ limit will coincide with our usual DMFT while the $\beta \ll \infty$ will exhibit a nontrivial balance between the usual DMFT feature updates and the random Langevin noise. At late times, such a system will equilibrate to its Gibbs distribution.

K.1. Dynamical analysis

In this section we analyze the DMFT for these Langevin dynamics. First we note that the effect of regularization can be handled with a simple integrating factor

Equation (K.2)

where $\mathrm{d}\boldsymbol{\epsilon}(t) \in \mathbb{R}^{N \times N}$ is the Gaussian noise for layer $\ell$ at time t. It is straightforward to verify by Ito's lemma that, under mean field parameterization, the fluctuations in $f^{^{\prime}}s$ dynamics due to Brownian motion are $\frac{\partial f}{\partial \boldsymbol{\theta}} \cdot \mathrm{d}\boldsymbol{\epsilon}(t) \sim \mathcal{O}(N^{-1/2})$ and are thus negligible in the $N \to \infty$ limit. Thus the evolution of the network function takes the form

We can express both of these parameter contractions in feature space provided we introduce the new features $r_{i,\mu}^{\ell}(t) = \frac{\partial g_{i,\mu}^{\ell}}{\partial h_{i,\mu}^{\ell}}$ which are necessary to compute Hessian terms like $\frac{\partial^2 f}{\partial W_{ij}^{\ell} \partial W^{\ell}_{ij}} = N^{-3/2} \frac{\partial }{\partial W^{\ell}_{ij}} [ g_i^{\ell+1} \phi(h_j^{\ell})] = N^{-2} \ r_i \ \phi(h_j^{\ell})^2$ in each layer. This gives the following evolution

Equation (K.3)

As before, we compute the next layer field $\boldsymbol{h}^{\ell+1}$ in terms of $\boldsymbol{\chi}^{\ell+1}$ and $\boldsymbol{z}^{\ell}$ in terms of $\boldsymbol{\xi}^{\ell}$

Equation (K.4)

The dependence on the initial condition through $\boldsymbol{\chi},\boldsymbol{\xi}$ is suppressed at long times due the regularization factor $e^{-\frac{\lambda}{\beta} t}$, while the Brownian motion and gradient updates will survive in the $t\to\infty$ limit. In addition to the usual $\{\boldsymbol{\chi}^{\ell},\boldsymbol{\xi}^{\ell}\}$ fields which arise from the initial condition, we see that $\boldsymbol{h}^{\ell}(t), \boldsymbol{z}^{\ell}(t)$ also depend on the following fields which arise from the integrated Brownian motion

Equation (K.5)

Our aim is now to compute the moment generating function for the $\{\boldsymbol{\chi},\boldsymbol{\xi},\boldsymbol{\chi}^{\epsilon},\boldsymbol{\xi}^{\epsilon}\}$ fields which causally determine $\{\boldsymbol{h},\boldsymbol{z}\}$. This MGF has the form

Equation (K.6)

We insert Dirac-delta functions in the usual way to enforce the definitions of $\boldsymbol{\chi},\boldsymbol{\xi},\boldsymbol{\chi}^{\epsilon},\boldsymbol{\xi}^{\epsilon}$ and then average over $\boldsymbol{\theta}_0 , \boldsymbol{\epsilon}(t)$. These averages can be performed separately with the $\boldsymbol{\theta}_0$ average giving the identical terms as derived in previous sections. We focus on the average over Brownian disorder

Equation (K.7)

where we introduced the order parameter $i A^{\epsilon,\ell}_{\mu\alpha}(t,t^{^{\prime}}) = \frac{1}{N} \phi(\boldsymbol{h}^{\ell}_{\mu}(t)) \cdot \boldsymbol{\hat\xi}^{\epsilon,\ell}_{\alpha}(s)$. We will use the shorthand for the temporal prefactor in the above $C_{\lambda,\beta}(t,t^{^{\prime}}) = \frac{1}{\lambda} \exp\left( - \frac{\lambda}{\beta}(t+t^{^{\prime}}) \right) \left[ e^{2 \frac{\lambda}{\beta} \min\{t,t^{^{\prime}}\}} - 1 \right] \sim_{t,t^{^{\prime}} \to \infty} \frac{1}{\lambda} \exp\left( - \frac{\lambda}{\beta} |t-t^{^{\prime}}| \right)$. We insert a Lagrange multiplier $B^{\epsilon,\ell}$ to enforce the definition of $A^{\epsilon,\ell}$. After

Equation (K.8)

The order parameters can be determined by the saddle point equations. These equations for $\Phi,\hat\Phi,G,\hat{G},A,B$ are the same as before. The new equations are

Equation (K.9)

Using the fact that $\boldsymbol{\Phi}^{\ell},G^{\ell}$ concentrate, we can use the Hubbard trick to linearize the quadratic terms in $\hat{\chi}^{\epsilon}$ and $\hat{\xi}^{\epsilon}$.

Equation (K.10)

Equation (K.11)

Using the vectorization notation, we find the interpretation that $\boldsymbol{\chi}^{\epsilon,\ell} $ and $\boldsymbol{\xi}^{\epsilon,\ell}$ decouple as

Equation (K.12)

Equation (K.13)

As before, we make the substitutions $\boldsymbol{B} \to \gamma_0^{-1} {\boldsymbol{B}}^{\top}$ and $\boldsymbol{A} \to \gamma_0^{-1} \boldsymbol{A}$ and arrive at the final DMFT equations

Equation (K.14)

where the kernels are defined in the usual way. As expected, the contributions from the initial conditions $\chi^{\ell}, \xi^{\ell}$ are exponentially suppressed at late time whereas the contributions from the Brownian disorder $\chi^{\epsilon,\ell},\xi^{\epsilon,\ell}$ persist at late time.

K.2. Weak feature learning, long time limit

In the weak feature learning $\gamma_0 \to 0$ and long time $t \to \infty$ limit, the preactivation fields equilibrate to Gaussian processes $h^{\ell}_{\mu}(t) \sim u^{\epsilon,\ell}_{\mu}(t), z^{\ell}_{\mu}(t) \sim r^{\epsilon,\ell}_{\alpha}(t)$, which have respective covariances $H^{\ell}_{\mu\alpha}(t,s) = \left\langle h^{\ell}_{\mu}(t) h^{\ell}_{\alpha}(s) \right\rangle = C_{\lambda,\beta}(t,s) \Phi^{\ell-1}_{\mu\alpha}(t,s), Z^{\ell}_{\mu\alpha}(t,s) = \left\langle z^{\ell}_{\mu}(t) z^{\ell}_{\alpha}(s) \right\rangle = C_{\lambda,\beta}(t,s) G^{\ell+1}_{\mu\alpha}(t,s)$. In this long time limit, the feature kernels will be time translation invariant, e.g. $\Phi^{\ell}_{\mu\alpha}(t,s) = \Phi^{\ell}_{\mu\alpha}(|t-s|)$. Letting $\tau = |t-s|$ and $C_{\lambda,\beta}(\tau ) = \frac{1}{\lambda} \exp\left(-\frac{\lambda}{\beta} \tau \right)$, we have the following recurrence for $H^{\ell},\Phi^{\ell}$

Equation (K.15)

Similarly, we can obtain $Z^{\ell}$ and $G^{\ell}$ in a backward pass recursion

Equation (K.16)

On the temporal diagonal τ = 0, these equations give the usual recursions used to compute the NNGP kernels at initialization [4], though with initialization variance $C_{\lambda,\beta}(0) = \lambda^{-1}$, set by the weight decay term in the Langevin dynamics. This indicates that the long time Langevin dynamics at $\gamma_0 \to 0$ simply rescales the Gaussian weight variance based on λ. It would be interesting to explore fluctuation dissipation relationships at finite γ0 within this framework which we leave to future work.

K.3. Equilibrium analysis

The Langevin dynamics at finite N converges (possibly in a time extensive in N) to an equilibrium distribution with several interesting properties, as was recently studied by Yang et al [97] and implicitly by Seroussi and Ringel [31] in a large sample size limit. This setting differs from the previous section where first $N \to \infty$ limit is taken, followed by a $t \to \infty$ limit in the DMFT. This section, on the other hand, studies for any N, the $t \to \infty$ limiting equilibrium distribution. This equilibrated distribution is then analyzed in the $N \to \infty$ limit. The relationship between these two orders of limits remains an open problem. The equilibrium distribution over parameters $p(\boldsymbol{\theta}|\mathcal D) \propto \exp\left( - \beta \gamma^2 L(\boldsymbol{\theta}) - \frac{\lambda}{2} |\boldsymbol{\theta}|^2 \right)$ can be viewed as a Bayes posterior with log-likelihood $- \beta \gamma^2 L(\boldsymbol{\theta})$ and a Gaussian prior with scale $\lambda^{-1/2}$. In the mean field limit with $\gamma = \sqrt{N} \gamma_0$, we can express the density over pre-activations $\boldsymbol{h}^{\ell}$ and the output predictions f. This gives

Equation (K.17)

We see that $p(\boldsymbol{f}|\mathcal D) \propto \int \mathrm{d}\Phi \mathrm{d}\hat{\Phi} \exp\left( N S[\Phi,\hat{\Phi}] \right)$ where

Equation (K.18)

Thus the predictions fµ become nonrandom in this $N \to \infty$ limit and can be determined from the saddle point equations as in [97]. Again, letting $\Delta_{\mu} = - \frac{\partial}{\partial f_{\mu}} \ell(f_{\mu},y_{\mu})$, we find

Equation (K.19)

which implies that fµ at the fixed point satisfies the following equations

Equation (K.20)

The last layer's dual kernel has the form $\hat{\Phi}^L_{\mu\alpha} = - \frac{\gamma_0^2 \beta^2}{2\lambda} \Delta_{\mu} \Delta_{\alpha}$, which we see vanishes as feature learning strength is taken to zero $\gamma_0 \to 0$, while for non-negligible γ0, we see that the last layer features are non-Gaussian. We thus see that the moment generating function for the last layer field has the form

Equation (K.21)

In the $\gamma_0 \to 0$ limit, the non-Gaussian component of this density vanishes. Now that we have this form, we can compute $\Phi^L$ conditional on $\Phi^{L-1}$. Next, we calculate $\hat{\Phi}^{L-1}_{\mu\alpha} = \left\langle \hat{h}^L_{\mu} \hat{h}^L_{\alpha} \right\rangle$, giving

Equation (K.22)

Again, we note that in the $\gamma_0 \to 0$ limit, since $\left\langle \boldsymbol{h}^L \boldsymbol{h}^L \right\rangle \sim \lambda^{-1} \boldsymbol{\Phi}^{L-1}$, so that $\hat{\boldsymbol{\Phi}}^{L-1} = 0$, implying that the $h^{L-1}$ fields are also Gaussian in this $\gamma_0 \to 0$ limit. For arbitrary γ0, this recursive argument can be completed going backwards using

Equation (K.23)

For deep linear networks, the distributions are all Gaussian, allowing one to close algebraically, the saddle point equations for $\Phi,\hat{\Phi}$ [97].

Appendix L: Momentum dynamics

Standard GD often converges slowly and requires careful tuning of learning rate. Momentum, in contrast can, be stable under a wider range of learning rates and can benefit from acceleration on certain problems [98101]. In this section we show that our field theory is still valid when training with momentum; simply altering the field definitions appropriately gives the infinite-width feature learning behavior.

Momentum uses a low-pass filtered version of the gradients to update the weights. A continuous limit of momentum dynamics on the trainable parameters $\{\boldsymbol{W}^{\ell} \}$ would give the following differential equations.

Equation (L.1)

We write the expression this way so that the small time constant τ → 0 limit corresponds to classic GD. Integrating out the $\boldsymbol{Q}^{\ell}(t)$ variable, this gives the following weight dynamics

Equation (L.2)

which implies the following field evolution

Equation (L.3)

We see that in the τ → 0 limit, the tʹʹ integral is dominated by the contribution at $t^{^{\prime\prime}} \sim t^{^{\prime}}$ recovering usual GD dynamics. For $\tau \gg 0$, we see that the integral accumulates additional contributions from the past values of fields and kernels.

Appendix M: Discrete time

Our model can also be accommodated in discrete time, though we lose the NTK as a key player in the theory (note that $\frac{d}{\mathrm{d}t} f_{\mu} = \frac{\mathrm{d}f_{\mu}}{\mathrm{d}\theta} \cdot \frac{\mathrm{d}\theta}{\mathrm{d}t} = \sum_{\alpha} \Delta_{\alpha} K_{\mu\alpha}^{\mathrm{NTK}}$ requires a continuous time limit of the GD dynamics). For a discrete time analysis we let $t \in \mathbb{N}$ and define our network function as

Equation (M.1)

We treat $f_{\mu}(t)$ as a potentially random variable and insert

Equation (M.2)

Noting that $\boldsymbol{w}^L(0)$ is involved in the definition of both $f_{\mu}(t)$ and $\boldsymbol{\xi}_{\mu}^L(t)$, we see that the average over $\boldsymbol{w}^L(0)$ now takes the form

Equation (M.3)

We extend our definition as before $i A^L_{\mu\alpha}(t,s) = \frac{1}{N \gamma_0} \phi(\boldsymbol{h}^L_{\mu}(t)) \cdot \boldsymbol{\xi}^L_{\alpha}(s)$. Proceeding with the calculation as usual, we find that

Equation (M.4)

The saddle point equations can now be analyzed. In addition to the usual order parameters, we note that $f,\hat{f}$ also generate saddle point equations

Equation (M.5)

We also obtain saddle point equations for the new $A^L, B^L$ order parameters.

Equation (M.6)

Equation (M.7)

which implies $B^L_{\mu\alpha}(t,s) = 0$ and $A^L = \gamma_0^{-1} \left\langle \frac{\phi(h^L_{\mu}(t))}{\partial r^L_{\alpha}(s)} \right\rangle$. This gives the following DMFT

Equation (M.8)

We leave it to future work to verify that a continuous time limit of the above DMFT recovers function evolution governed by the NTK.

Appendix N: Equivalent parameterizations

In this section, we show the equivalence of our parameterization scheme with many alternatives including the $\mu P$ parameterization of Yang and Hu [22]. We also compare the derived stochastic processes obtained with DMFT and TPs in appendix N.6. Following Yang we use a modified variant of abc parameterization. We will assume the following parameterization and initialization

Equation (N.1)

and we consider training with gradient flow dynamics

Equation (N.2)

The learning rate is scaled as $\eta = \eta_0 \gamma^2 N^{-c}$ with $\eta_0 = \mathcal{O}(1)$. The factor of γ2 in the learning rate η ensures that $\frac{d}{\mathrm{d}t} f|_{t = 0}$ does not depend on γ. Lastly, we will scale the Chizat and Bach feature learning parameter as $\gamma = \gamma_0 N^d$. We will ultimately find that only $d = \frac{1}{2}$ will allow stable feature learning in the infinite width $N \to \infty$ limit.

We will now derive constraints on $(a, b, c, d)$ which give desired large width behavior. We will identify a one-dimensional family of parameterizations which satisfy three desiderata of network training 1. finite preactivations, 2. learning in finite time, 3. feature learning.

N.1. Fields are ${\mathcal{O}}$ N (1)

In this section, we identify conditions under which $\boldsymbol{h}^{\ell}$ have $\mathcal{O}_N(1)$ entries which ensures that the kernels $\Phi^{\ell}$ are also $\mathcal{O}_N(1)$. The base case for h 1 gives us the following covariance of entries at initialization

Equation (N.3)

Assuming that Kx does not scale with N as $N \to \infty$, we find the constraint that $2 a_0 + b_0 = 0$. Now that we have a condition for h1 to be $\mathcal{O}_N(1)$ in its entries giving $\Phi^1 \sim \mathcal{O}(1)$, we proceed with the induction step. We assume that $\Phi^{\ell} \sim \mathcal{O}_N(1)$ and we then find conditions which guarantee $\boldsymbol{h}^{\ell+1}$ has $\mathcal{O}_N(1)$ entries. The covariance at layer $\ell+1$ at initialization is

Equation (N.4)

Since we are assuming under the inductive hypothesis that $\Phi^{\ell} = \mathcal{O}_N(1)$, we identify the constraint $2a_{\ell} + b_{\ell} = 1$. Again we see that $(a_{\ell} = \frac{1}{2}, b_{\ell} = 0)$ works, but this is not the only possible scaling. Alternatively standard parameterization $(a_{\ell} = 0, b_{\ell} = 1)$ will also preserve the $\mathcal{O}_N(1)$ scale of the features. To characterize prediction and feature dynamics, we next need to analyze the scale of the feature gradients $\frac{\partial h^{L+1}}{\partial \boldsymbol{h}^{\ell}}$. We start with the last layer and define

which has $\mathcal{O}_N(1)$ entries by construction. We similarly extend this definition to earlier layers $\boldsymbol{g}^{\ell} = N^{a_L + b_L/2} \frac{\partial h^{L+1}}{\partial \boldsymbol{h}^{\ell}}$ to see whether $\boldsymbol{g}^{\ell}$ remains $\mathcal{O}_N(1)$ under its backward-pass recursion

Equation (N.5)

Now, letting $\boldsymbol{z}^{\ell} = N^{-a_{\ell}} \boldsymbol{W}^{\ell}(0)^{\top} \boldsymbol{g}^{\ell+1}$ as in the main text, we have

Equation (N.6)

Under the inductive hypothesis that $G^{\ell+1} \sim \mathcal{O}_N(1)$ and the previous constraint $2 a_{\ell} + b_{\ell} = 1$, the z variables have $\mathcal{O}(1)$ variance. Overall, we can thus ensure that $\Phi^{\ell}, G^{\ell} \sim \mathcal{O}_N(1)$ if $2 a_{\ell} + b_{\ell} = 1$ for $\ell \in \{1,\ldots ,L\}$ and $2 a_0 + b_0 = 0$.

N.2. Predictions evolve in $\mathcal{O}$ N (1) time

As before we define the NTK be the matrix which characterizes network prediction dynamics $\partial_t f_{\mu} = \eta_0 \sum_{\alpha} K^{\mathrm{NTK}}_{\mu\alpha} \Delta_{\alpha}$. We demand that this matrix be $K^{\mathrm{NTK}} \sim \mathcal O_N(1)$ so that the network prediction evolution $\partial_t f_{\mu} \sim \mathcal{O}_N(1)$

Equation (N.7)

where we used the usual definition of the kernels $\Phi^{\ell} = \frac{1}{N} \phi(\boldsymbol{h}^{\ell}) \cdot \phi(\boldsymbol{h}^{\ell})$ and $G^{\ell} = \frac{1}{N} \boldsymbol{g}^{\ell} \cdot \boldsymbol{g}^{\ell}$ which are $\mathcal{O}_N(1)$ under the assumptions of the previous section. We thus find the following constraints

Equation (N.8)

Again this recovers the parameterization in the main text provided c = 0 and $a_0 = 0$ and $a_{\ell} = \frac{1}{2}$. We see that for nonzero c, we need nonzero a0.

N.3.  $\mathcal{O}$ N (1) feature evolution

Now, we desire that the fields $h_i, z_i$ all evolve by an $\mathcal{O}_N(1)$ amount during network training, so that feature learning is stable. Under the assumption that $2 a_L + b_L = 1$ (see previous sections), the update equation for $\boldsymbol{W}^{\ell}$ and $\boldsymbol{h}^{\ell}$ give

Now, noting that $\boldsymbol{h}^{\ell+1}(t) = N^{-a_{\ell}} \boldsymbol{W}^{\ell}(t) \phi(\boldsymbol{h}^{\ell}(t))$ and $\Phi^{\ell}_{\mu\alpha} = \frac{1}{N} \phi(\boldsymbol{h}_{\mu}) \cdot \phi(\boldsymbol{h}_{\alpha})$, we have

Equation (N.9)

where we used $\gamma = \gamma_0 N^d$. The above equation implies that $d - c -2a_{\ell} + \frac{1}{2} = 0$ is necessary and sufficient for $\mathcal{O}_N(1)$ feature evolution. An identical argument for the pregradient fields $\boldsymbol{z}^{\ell}_{\mu}(t)$ gives the same constraint.

N.4. Putting constraints together

The set of parameterizations which yield $\mathcal{O}(1)$ feature evolution are those for which

  • (i)  
    Features $h,z$ are $\mathcal{O}_N(1) \implies 2a_{\ell} + b_{\ell} = 1$ for $\ell \in \{1,\ldots ,L\}$ and $2a_0 + b_0 = 0$.
  • (ii)  
    Outputs predictions evolve in $\mathcal{O}_N(1)$ time $\implies 2a_{\ell} + c = 1$, $2 a_0 + c = 0$
  • (iii)  
    Features $h,z $ have $\mathcal{O}_N(1)$ evolution $\implies d = \frac{1}{2}$.

The parameterization discussed in appendix D satisfies these with $d = \frac{1}{2}, a_{\ell} = \frac{1}{2}, b_{\ell} = 0, c = 0$. The quite general requirement for feature learning that $d = \frac{1}{2}$ indicates is that $\gamma = \gamma_0 \sqrt{N}$ for any choice of $a_{\ell}, b_{\ell}, c$ as we use in the main text. This indicates that neural network prediction logits at initialization scale as $f_{\mu} \sim \mathcal{O}(N^{-1/2})$ in the feature learning infinite width limit. The set of parameterizations which meet these three requirements is one dimensional with $d = \frac{1}{2}$, and $(a ,b ,c) \in \{(a,1-2a, 1-2a) : a \in \mathbb{R} \}$ for all layers except the first layer which has $(a_0 = a - \frac{1}{2}, b_0 = 1-2a)$. Our parameterization corresponds to $a = \frac{1}{2}$. However, in the next section, we show that if one demands $\mathcal{O}_N(1)$ raw learning rate $\eta = \eta_0 \gamma^2 N^{-c}$, then the parameterization is unique and is the $\mu P$ parameterization of Yang and Hu [22].

N.5.  $\mathcal{O}$ N (1) raw learning rate

We are also interested in a parameterization for which we can have learning rate $\eta \sim \mathcal O(1)$ which are those for which $\eta = \eta_0 \gamma^2 N^{-c} = \mathcal{O}_N(N^{2d-c}) \ \dot = \ \mathcal{O}_N(1) \implies c = 2d = 1$. Under this constraint, $a_{\ell} = 0$ and $b_{\ell} = 1$ for $\ell \in \{1,\ldots ,L\}$ and $a_0 = - \frac{1}{2}$ and $b_0 = 1$, which corresponds to a modification of standard parameterization, with first and last layer altered with width. In a computational algorithm, the learning rate would be $\eta = \eta_0 \gamma^2 N^{-c} = \eta_0 \gamma_0^2 = \mathcal{O}_N(1)$. This is equivalent to the $\mu P$ parameterization stated in the main text of Yang and Hu [22].

N.6. Equivalence of DMFT at $\gamma_0 = 1$ and TP-derived stochastic process

Now that we have established that the parameterization we consider here (modified NTK parameterization) is equivalent to $\mu P$, (modified standard parameterization), we will now demonstrate that the stochastic process which we obtained through a stationary action principle applied to our DMFT action S is equivalent to the stochastic process derived from the TP framework of Yang [22, 96]. Using the notation from appendix H of Yang and Hu [22], they give the following evolution equations for the preactivations in a hidden layer in one pass SGD

Equation (N.10)

where $\hat{Z}^{W x_t}$ is a mean zero Gaussian variable with covariance $\mathbb{E}[Z^{x_t} Z^{x_s}]$ and $\hat{Z}^{W^{\top} \mathrm{d}h_t}$ is a mean zero Gaussian with covariance $\mathbb{E}[Z^{\mathrm{d}h_t} Z^{\mathrm{d}h_s}]$. We can switch to the notation of this work by making the substitutions $Z^{h_t} \to h(t)$, $\hat{Z}^{W x_t} \to u(t)$, $\chi_s \to - \Delta(s)$, $\dot{Z}^{Wx} \to \sum_{s} \Delta(s) A(t,s)$ and $\mathbb{E}[ Z^{x_s} Z^{x_t} ] \to \Phi(t,s)$, and so on. A summary of the full set of notational substitutions between this work and TP are summarized in table N1.

Table N1. Dictionary relating the notation of the tensor programs (TP) framework [22] and this work.

DMFT h(t) $\chi(t)$ g(t) $\xi(t)$ $\Phi^{\ell}(t,s)$ $G^{\ell}(t,s)$ $A^{\ell}(t,s), B^{\ell}(t,s)$ $\Delta(t)$
TP $Z^{h_t}$ $Z^{Wx_t}$ $Z^{\mathrm{d}x_t}$ $Z^{W^{\top} \mathrm{d}h_t}$ $\mathbb{E}[Z^{x_t} Z^{x_s}]$ $\mathbb{E}[Z^{\mathrm{d}h_t} Z^{\mathrm{d}h_s}]$ θts χt

After these substitutions are made, we see that the equations above match the one-pass SGD version of the DMFT equations in appendix M. A similar identification can be made for the backward pass. This shows that both TPs and DMFT, though alternative derivations, give identical descriptions of the stochastic processes induced by random initializations + GD in infinite neural networks.

Appendix O: Gradient independence

The gradient independence approximation treats the random initial weight matrix $\boldsymbol{W}^{\ell}(0)$ as a independently sampled Gaussian matrix when used in the backward pass. We let this second matrix be $\tilde{\boldsymbol{W}}^{\ell}(0)$. As before, we have $\boldsymbol{\chi}^{\ell+1} = \frac{1}{\sqrt N} \boldsymbol{W}^{\ell}(0) \boldsymbol{\Phi}(\boldsymbol{h}^{\ell})$, however we now define $\boldsymbol{\xi}^{\ell} = \frac{1}{\sqrt N} \tilde{\boldsymbol{W}}^{\ell}(0)^{\top} \boldsymbol{g}^{\ell+1}$. Now, when computing the moment generating function Z, the integrals over $\boldsymbol{W}^{\ell}(0)$ and $\tilde{\boldsymbol{W}}^{\ell}(0)$ factorize

Equation (O.1)

We see that in this field theory, the fields $\chi, \xi$ are all independent Gaussian processes $\{\chi^{\ell+1}_{\mu}(t) \} \sim \mathcal{GP}(0, \boldsymbol{\Phi}^{\ell})$ and $\{\xi^{\ell}_{\mu}(t) \} \sim \mathcal{GP}(0,\boldsymbol{G}^{\ell+1})$. This corresponds to making the assumption that $\boldsymbol{A}^{\ell} = \boldsymbol{B}^{\ell} = 0$ so that $\chi = u$ and $\xi = r$ within the full DMFT.

Appendix P: Perturbation theory

P.1. Small γ0 expansion

In this section we analyze the leading corrections in a small γ0 expansion of our DMFT theory. All fields at each time t are expanded in power series in γ0.

Equation (P.1)

Our goal is to calculate all corrections to the kernels up to $\mathcal{O}(\gamma_0^3)$ to show that the leading correction is $\mathcal{O}(\gamma_0^2)$ and the subleading correction is $\mathcal{O}(\gamma_0^4)$. It will again be convenient to utilize the vector notation defined in appendix D.

We note that unlike other works on perturbation theory in wide networks, we do not attempt to characterize fluctuation effects in the kernels due to finite width, but rather operate in a regime where the kernels are concentrating and their variance is negligible. For a more thorough discussion of perturbative field theory in finite width networks, see [27, 28, 35].

P.1.1. Linear network.

The kernels in deep linear networks can be expanded in powers of γ0 2 giving a leading order correction of size $\mathcal O(\gamma_0^2)$ and can be computed explicitly from the closed saddle point equations. We use the symmetrizer $\{\boldsymbol{X},\boldsymbol{Y} \}_{sym} = \boldsymbol{X} \boldsymbol{Y} + \boldsymbol{Y}^{\top} \boldsymbol{X}^{\top}$ as shorthand. The leading order behavior of $\boldsymbol{C}^{\ell} \sim \boldsymbol{C}^{(0)} + \mathcal O(\gamma_0^2) \ , \ \boldsymbol{D}^{\ell} \sim \boldsymbol{D}^{(0)} + \mathcal O(\gamma_0^2), \boldsymbol{H}^{\ell,0} = \boldsymbol{H}^{(0)} = \boldsymbol{K}^x \otimes \boldsymbol{1} \boldsymbol{1}^{\top}, \boldsymbol{G}^{\ell,(0)} = \boldsymbol{G}^{(0)} = \boldsymbol{1}\boldsymbol{1}^{\top}$ is independent of layer index so we find the following leading order corrections

Equation (P.2)

Note that $[\boldsymbol{C}^0 \boldsymbol{g}]_{\mu t} = \int_0^t \mathrm{d}t^{^{\prime}} \sum_{\beta} H^0_{\mu\beta}(t,t^{^{\prime}}) \Delta_{\beta}(t^{^{\prime}}) g(t^{^{\prime}}) = \sum_{\beta} K^x_{\mu\beta} \int_0^t \mathrm{d}t^{^{\prime}} \Delta_{\beta}(t^{^{\prime}}) g(t^{^{\prime}})$ and note that $[\boldsymbol{D} \boldsymbol{h}]_{t} = \int_0^t \mathrm{d}t^{^{\prime}} G^0(t,t^{^{\prime}}) \sum_{\alpha} \Delta_{\alpha}(t^{^{\prime}}) h_{\alpha}(t^{^{\prime}}) = \sum_{\alpha} \int_0^t \mathrm{d}t^{^{\prime}} \Delta_{\alpha}(t^{^{\prime}}) h_{\alpha}(t^{^{\prime}})$.

Equation (P.3)

We can simplify the notation by introducing functions $v_{\alpha}(t) = \int_0^t \Delta_{\alpha}(t^{^{\prime}})$ and $v_{\alpha\beta}(t) = \int_0^t \mathrm{d}t^{^{\prime}} \Delta_{\alpha}(t^{^{\prime}}) \int_0^{t^{^{\prime}}} \mathrm{d}t^{^{\prime\prime}} \Delta_{\beta}(t^{^{\prime\prime}})$.

Equation (P.4)

Using the fact that

Equation (P.5)

and utilizing the identity $\sum_{\ell = 1}^L \ell = \frac{1}{2} L(L+1)$, we recover the result provided in the main text.

P.2. Nonlinear perturbation theory

In this section, we explore perturbation theory in nonlinear networks. We start with the formula which implicitly defines $\boldsymbol{h}^{\ell},\boldsymbol{z}^{\ell}$ treated as vectors over samples and time

Equation (P.6)

We proceed under the assumption of a power series in γ0

Equation (P.7)

As before, the leading terms for $\boldsymbol{C}^{\ell,0}, \boldsymbol{D}^{\ell,1}$ only depend on time through the functions $\{v_{\alpha}(t) \}$ and $\{v_{\alpha\beta}(t) \}$. Expanding both sides of the implicit equation for $\boldsymbol{z}^{\ell}$ we have

Equation (P.8)

Performing a similar exercise for $\boldsymbol{h}^{\ell}$, we get the following first three leading terms for $\boldsymbol{z}^{\ell}, \boldsymbol{h}^{\ell}$, and we find

Equation (P.9)

As will become apparent soon, it is crucially important to identify the dependence of each of these terms on r . We note that $z^{\ell,1}$ does not depend on r and $h^{\ell,1}$ is linear in r. In the next section, we use this fact to show that $\Phi^{\ell,1} = 0$ and $G^{\ell,1} = 0$. These conditions imply that $C^{\ell,0}$ and $D^{\ell,1} = 0$. As a consequence, $z^{\ell,2}$ is linear in r and $h^{\ell,2}$ only contains even powers of r. Lastly, this implies that $z^{\ell,3}$ only contains even powers of r and $h^{\ell,3}$ contains only odd powers of r.

P.2.1. Leading corrections to Φ1 kernel is $\mathcal{O}(\gamma_0^2)$.

We start in the first layer where $\boldsymbol{u}^1 \sim \mathcal{GP}(0,\boldsymbol{K}^x \otimes \boldsymbol{1} \boldsymbol{1}^{\top})$ (note that this is $\mathcal{O}_{\gamma_0}(1)$) and compute the expansion of Φ1 in γ0

Equation (P.10)

where powers and multiplications of vectors are taken elementwise. Now, note that, as promised, the terms linear in γ0 vanish since $\boldsymbol{h}^{1,1}$ is linear the Gaussian random variable r 1, which is a mean zero and independent of u 1 so an average like $\left\langle \boldsymbol{r}^{1} F(\boldsymbol{u}^{1}) \right\rangle = \left\langle \boldsymbol{r}^{1,0} \right\rangle \left\langle F(\boldsymbol{u}^{1}) \right\rangle = 0$ must vanish for any function F. Thus we see that $\boldsymbol{\Phi}^{\ell}$'s leading correction is $\mathcal{O}(\gamma_0^2)$.

We also obtain, by a similar argument, that the cubic $\mathcal{O}(\gamma_0^3)$ term vanishes. To see this, note that $\boldsymbol{h}^{1,3}$ only contains odd powers of r 1. Next, $\boldsymbol{h}^{1,1} \boldsymbol{h}^{1,2}$ contains only odd powers of r , and $(\boldsymbol{h}^{1,1})^3$ is cubic in r . Since all odd moments of a mean-zero Gaussian vanish, all averages of these terms over r annihilate, causing the γ0 3 terms to vanish. Thus $\boldsymbol{\Phi}^{1} = \boldsymbol{\Phi}^{1,0} + \gamma_0^2 \boldsymbol{\Phi}^{1,2} + \mathcal{O}(\gamma_0^4)$.

P.3. Forward pass induction for $\Phi^{\ell}$

We now assume the inductive hypothesis that for some $\ell \in \{1,\ldots ,L-1\}$ that

Equation (P.11)

and we will show that this will imply that the next layer must have a similar expansion $\boldsymbol{\Phi}^{\ell+1} = \boldsymbol{\Phi}^{\ell+1,0} + \gamma_0^2 \boldsymbol{\Phi}^{\ell+1,2} + \mathcal{O}(\gamma_0^4)$. First, we note that $\boldsymbol{u}^{\ell+1} \sim \mathcal{GP}(0, \boldsymbol{\Phi}^{\ell,0} + \gamma_0^2 \boldsymbol{\Phi}^{\ell,2} + \ldots )$. As before, we compute the leading terms in the expansion of $\boldsymbol{\Phi}^{\ell+1}$

Equation (P.12)

where, as before, the γ0 and γ0 3 terms vanish by the fact that odd moments of $\boldsymbol{r}^{\ell+1}$ vanish. Now, note that all averages are performed over $\boldsymbol{u}^{\ell+1} \sim \mathcal{GP}(0,\boldsymbol{\Phi}^{\ell,0} + \gamma_0^2 \boldsymbol{\Phi}^{\ell,2} + \ldots )$, which depends on the perturbed kernel of the previous layer. How can we calculate the contribution of the correction which is due to the previous layer's kernel movement? This can be obtained easily from the following identity. Let $F(\boldsymbol{u},\boldsymbol{r})$ be an arbitrary observable which depends on Gaussian fields u and r which have covariances $\boldsymbol{\Phi}^{\ell,0} + \gamma_0^2 \boldsymbol{\Phi}^{\ell,2} + \mathcal{O}(\gamma_0^4)$ and $\boldsymbol{G}^{\ell+2,0} + \gamma_0^2 \boldsymbol{G}^{\ell+2,2} + \mathcal{O}(\gamma_0^3)$ (note this only requires that the linear in γ0 terms of G vanish which is easy to verify). Then

Equation (P.13)

Equation (P.14)

where $\boldsymbol{u}_0 \sim \mathcal{GP}(0,\boldsymbol{\Phi}^{\ell,0}) , \boldsymbol{r}_0 \sim \mathcal{N}(0,\boldsymbol{G}^{\ell+2,0})$. Thus, the leading order behavior of $\boldsymbol{\Phi}^{\ell+1}$ can easily be obtained in terms of averages over the original unperturbed covariances

Equation (P.15)

where the trace is taken against the Hessian indices and the indices on $\boldsymbol{\Phi}^{\ell,2}$. This gives us the desired result by induction that for all $\ell \in \{1,\ldots ,L\}$, we have $\boldsymbol{\Phi}^{\ell} = \boldsymbol{\Phi}^{\ell,0} + \gamma_0^2 \boldsymbol{\Phi}^{\ell,2} + \mathcal{O}(\gamma_0^4)$. We see that $\Phi^{\ell}$ accumulates corrections from the previous layers' corrections through the forward pass recursion.

P.4. Leading corrections to GL kernel is $\mathcal{O}(\gamma_0^2)$

The analogous argument for G L now can be provided. First note that r L is independent of u L and of γ0. Thus we can find that G L has no linear-in-γ0 term in its expansion since

Equation (P.16)

each term contains only odd powers of r L and odd moments of Gaussian variables vanish. After much more work, one can verify that $\boldsymbol{G}^{L,3}$ also must vanish since all terms contain odd powers of r .

Equation (P.17)

First, note that $\boldsymbol{g}^{L,0}$ is linear in r . Next, note that $\boldsymbol{g}^{L,1}$ only depends on even powers of r since $\boldsymbol{g}^{L,1} = \dot\phi(\boldsymbol{u}) \boldsymbol{z}^{L,1} + \ddot\phi(\boldsymbol{u}) \boldsymbol{h}^{L,1} \boldsymbol{r}$. Next, we have

Equation (P.18)

which only depends on odd powers of r . Lastly, we have $\boldsymbol{g}^{L,3}$

Equation (P.19)

which we see only contains even powers of r . Thus $\boldsymbol{g}^{L,3} \boldsymbol{g}^{L,0}$ will be odd in r . Looking at the expansion for $\boldsymbol{G}^{L,3}$, we see that all terms are odd in r and so the averages vanish under the Gaussian integrals.

P.5. Backward pass recursion for $G^{\ell}$

We can derive a similar recursion on the backward pass for $\boldsymbol{G}^{\ell}$'s leading order corrections. Using the same idea from the previous section, we find the following expressions

This time, we see that $\boldsymbol{G}^{\ell}$ accumulates corrections from succeeding layers through the backward pass recursion.

P.6. Form of the leading corrections

We can expand the $\boldsymbol{h}^{\ell}$ and $\boldsymbol{z}^{\ell}$ fields around $\boldsymbol{u}^{\ell,0},\boldsymbol{r}^{\ell,0}$ to find the leading order corrections to each feature kernel

Equation (P.20)

The first term requires additional expansion to extract the corrections in γ0 2

Equation (P.21)

where we used the fact that $\boldsymbol{C}^{\ell,1} = 0$ which follows from the fact that $\Phi^{\ell-1,1} = 0$, and $\Delta^{\ell,1} = 0$. Now, expanding out term by term

Equation (P.22)

We see that the corrections for the $\Phi^{\ell}$ kernels accumulate on the forward pass through the final term so $\Phi^{\ell,2} \sim \mathcal{O}(\ell)$. Now we will perform the same analysis for $\boldsymbol{G}^{\ell}$.

Equation (P.23)

We see that, through the second term, the $\boldsymbol{G}^{\ell}$ kernels accumulate on the backward pass so that $\boldsymbol{G}^{\ell,2} \sim \mathcal{O}(L+1-\ell)$. As before the difficult term is the first expression which requires a full expansion of $\boldsymbol{g}^{\ell}$ to second order

Equation (P.24)

From these terms we find

Equation (P.25)

Now the correction to the NTK has the form

Equation (P.26)

Since each $\Phi^{\ell,2}, G^{L+1-\ell,2} \sim \mathcal{O}(\ell)$, each of the two sums from $\ell \in \{1,\ldots ,L-1\}$ gives a depth scaling of the form $\sim \sum_{\ell = 1}^{L-1} \ell = \frac{L(L-1)}{2}$. Since the original NTK has scale $\boldsymbol{K}^{\mathrm{NTK},0} \sim \mathcal{O}(L)$, the relative change in the kernel is $\frac{|\boldsymbol{K}^2|}{|\boldsymbol{K}^0|} = \mathcal{O}(\gamma_0^2 L)$. In a finite width N, network, our definition $\gamma = \gamma_0 \sqrt{N}$ would indicate that a width N network would have corrections of scale $\gamma_0^2 L = \frac{\gamma^2 L }{N }$ in the NTK regime where $\gamma = \mathcal O_N(1)$ provided the network is sufficiently wide to disregard initialization dependent fluctuations in the kernels.

Please wait… references are loading.