Biased stochastic approximation#
We explore the convergence of continuous-time ordinary differential equations and their discrete-time analogs, such as stochastic approximation and gradient descent, through the lens of Lyapunov theory [B+98] [LR15]. From this perspective, we will study biased stochastic approximation [KMMW19] where the expectation of the stochastic updates conditioned on the past (which we call the conditional expectation) is not the same as the expectation of the stochastic updates under the stationary distribution (which we call the total expectation).
This post is a continuation from our series on spiking networks, path integrals and motivic information.
Continuous-time dynamics#
Suppose we have a continuous-time dynamical system with state space \(\mathcal{H} \subset \mathbb{R}^d\) and update rule
and suppose we have a function \(V: \mathcal{H} \rightarrow \mathbb{R}\) that has a finite lower bound. We are primarily interested in the conditions that will lead to convergence of the system to two kinds of points, namely
To a fixed point, i.e. \(h(\eta) = 0,\)
To a critical point, i.e. \(\nabla V(\eta) = 0.\)
In the special case where \(h(\eta) = \nabla V(\eta),\) the two kinds of points are the same.
If the value of \(V(\eta(t))\) along any path \(\eta : \mathbb{R} \rightarrow \mathcal{H}\) in the dynamical system decreases strictly with time \(t\), then we say that \(V\) is a Lyapunov function for the dynamical system. Here, the system converges to a point that is both a fixed point of the system and a critical point of \(V\).
Suppose \(V\) is \(C^1\)-smooth and near each point \(\eta\) it has the expansion
In continuous-time, the updates \(\eta'-\eta\) are infinitesimal, so the change in \(V\) is dominated by the first order term. Therefore, if \(0 < \langle \nabla V(\eta), h(\eta) \rangle\) for all \(\eta : \mathcal{H}\) away from the zeros of \(h,\) then \(V\) is a Lyapunov function for the dynamical system.
In the special case where \(h(\eta) = \nabla V(\eta),\) we have
which will be strictly positive away from the zeros of \(h.\)
Discrete-time dynamics#
For discrete-time dynamical systems such as gradient descent or stochastic approximation, the situation is more delicate. The value of the function \(V: \mathcal{H} \rightarrow \mathbb{R}\) may fluctuate up and down with each discrete time step, but the total expectation of \(\Vert h(\eta) \Vert\) or of \(\nabla V(\eta)\) decreases over time so the system converges to a fixed point or to a critical point.
Specifically, suppose that we have discrete-time stochastic process
where \(\gamma_{n+1}\) is the time-dependent learning rate and \(H_{\eta_n}\) is the update at time \((n+1)\) that depends on the previous system state \(\eta_n.\) The update could be deterministic as in the case of gradient descent, or more generally it could be stochastic as in the case of stochastic approximation.
We will assume that any stochastic update \(H_{\eta_n}(X_{n+1})\) is a function of some parameter-controlled stochastic process \(\{X_n\}\), i.e. the distribution of \(X_{n+1}\) conditioned on its past \(X_n, X_{n-1}, \ldots\) is controlled by the parameter \(\eta_n.\) We will also assume that for each \(\eta \in \mathcal{H}\), the \(\eta\)-controlled stochastic process \(\{X_n\}\) has a unique stationary distribution \(\pi_\eta\). Let \(h(\eta)\) be the total expectation or mean field of the \(\eta\)-controlled stochastic process, i.e. the expectation of the stochastic updates \(H_{\eta}(x)\) with respect to \(x \sim \pi_\eta.\)
As before, let \(V: \mathcal{H} \rightarrow \mathbb{R}\) be a function with a finite lower bound. We will again be interested in conditions that guarantee the convergence of the discrete-time dynamical system to two kinds of points, namely
To a fixed point, i.e. \(h(\eta) = 0,\)
To a critical point, i.e. \(\nabla V(\eta) = 0.\)
The general strategy for proving the convergence of the system is to show that the total expectation of \(V(\eta_n)\) is eventually a decreasing function. To show that this decrease happens, it is often sufficient to check that for all \(\eta', \eta,\)
for some constant \(\ell>0.\) When the domain \(\mathcal{H}\) is convex, this condition is equivalent to the \(\ell\)-smoothness of \(V,\) i.e. for all \(\eta', \eta,\)
which is to say that \(\nabla V\) is Lipschitz continuous.
Going further, we may want to prove some results about the speed of convergence. First, let us represent the update
as the sum of the mean field and a correction term. Substituting this representation into the \(\ell\)-smoothness condition, we get
To manage the speed of convergence, one effective way is to bound the terms involving \(\langle \nabla V(\eta), h(\eta_n) \rangle,\) \(\langle \nabla V(\eta), E_{\eta_n}(X_{n+1}) \rangle\) and \(\Vert E_{\eta_n}(X_{n+1})\Vert^2\) with some scalar multiple of \(\Vert h(\eta_n) \Vert^2.\)
Starting with \(\langle \nabla V(\eta), h(\eta_n) \rangle,\) we could require that for all \(\eta,\)
This condition is automatically satisfied for gradient dynamical systems where \(h(\eta) = \nabla V(\eta)\), so the inequality becomes an equality with \(C=1.\)
As for \(\Vert E_{\eta_n}(X_{n+1})\Vert^2,\) we could require that this correction term be uniformly bounded, or that its conditional expectation (i.e. expected value when conditioned on the past \(X_n, X_{n-1}, \ldots\)) be bounded by a scalar multiple of \(\Vert h(\eta_n) \Vert^2.\)
Lastly, for \(\langle \nabla V(\eta), E_{\eta_n}(X_{n+1}) \rangle,\) recall that
so the total expectation of the correction \(E_{\eta_n}(X_{n+1})\) is zero by definition. If we further require that the conditional expectation of \(H_{\eta_n}(X_{n+1})\)be equal to the mean field \(h(\eta_n),\) then the conditional expectation of \(E_{\eta_n}(X_{n+1})\) will be equal to zero. Here, the expectation of \(\langle \nabla V(\eta), E_{\eta_n}(X_{n+1}) \rangle\) vanishes when we condition on the past and we get a good handle on the speed of convergence. We refer to this scenario as unbiased stochastic approximation.
However, the conditional expectation of \(E_{\eta_n}(X_{n+1})\) is often non-zero for many important applications. We refer to this scenario as biased stochastic approximation. Here, the object of interest is the discrete-time stochastic integral
If the conditional expectation of \(E_{\eta_n}(X_{n+1})\) is zero, then the above stochastic integral is a martingale. To tackle the scenario where the conditional expectation of \(E_{\eta_n}(X_{n+1})\) is non-zero, we will need to find another suitable martingale by solving a Poisson equation, so as to control the behavior of the stochastic integral. We introduce this strategy in the next section.
Martingales, stochastic integrals and the Poisson equation#
Given a (discrete-time or continuous-time) stochastic process \(\{H_t\}_{0 \leq t},\) let \(\mathcal{F}_s\) denote the filtration (i.e. sequence of sigma algebras with \(\mathcal{F}_s \subseteq \mathcal{F}_t\) for all \(s \leq t\)) generated by the random variations \(\{H_t\}_{0\leq t \leq s}.\) Recall that \(H_t\) is a martingale if for all \(s \leq t,\)
Note that this condition implies that for discrete time stochastic processes, the expectation of the martingale difference \(E_{n+1} = H_{n+1} - H_n\) conditioned on \(\mathcal{F}_n\) is zero.
Martingales play an important role in stochastic integration [CChafaiG12]. For example, let \(\{X_t\}_{0\leq t}\) be a continuous-time stochastic process driven by Brownian motion, e.g.
where \(X_t\) is an \(n\)-dimensional vector, \(B_t\) is \(m\)-dimensional Brownian motion, \(b(X_t)\) is an \(n\)-dimensional vector and \(\sigma(X_t)\) is an \(n \times m\) matrix. Let \(L\) be the infinitesimal generator of this process, i.e. \(L\) acts on the space of measurable functions such that for any measurable function \(g\) and state \(x,\)
where \(P^t\) is the transition operator of the stochastic process \(\{X_t\}.\)
Then, one can show that
thanks to the Itô formula, e.g. see Lemma 7.3.2 of [Oks13].
The last term
is a martingale, so its expected value conditioned on \(X_0\) is the initial value \(M_0\) which is zero. Therefore,
after taking expectations of the Itô formula.
Interestingly, this formula gives us a strategy for computing expectations of stochastic integrals of the form
Indeed, if we are able to solve the Poisson equation
with some solution \(g\) that is sufficiently regular, then we can use the above formula to compute the desired answer.
Under some regularity conditions on \(\{X_t\}\) and assuming that the total expectation of \(g(x)\) is zero, a candidate solution to the Poisson equation is
if the integral is well-defined [CChafaiG12]. Indeed, if \(Lg = f,\) then
Taking limits as \(t \rightarrow \infty\) and observing that under strong ergodicity \(P^t g(x)\) goes to the total expectation of \(g(x)\) which is zero, the candidate solution follows.
The strategy for integrals of functions of discrete-time stochastic processes works in a similar way, where solutions to the Poisson equation provide martingales whose expected values vanish. We will use this strategy for analyzing biased stochastic approximation algorithms.
In a future post, we will explore martingales, stochastic integrals and the Poisson equation through the lens of regularity structures [Hai14].
Biased updates#
For the rest of this post, we assume that our discrete-time parameter-controlled stochastic process \(\{X_n\}_{0 \leq n}\) is Markov. Explicitly, let \((\mathcal{X}, \Sigma)\) be a measurable space. A function \(P:\mathcal{X} \times \Sigma \rightarrow [0,1]\) is a Markov kernel if \(P(x, \cdot): \Sigma \rightarrow [0,1]\) is a distribution for all \(x \in \mathcal{X}\), and \(P(\cdot, A): \mathcal{X} \rightarrow [0,1]\) is measurable for all \(A \in \Sigma.\) A Markov kernel generalizes the notion of a transition matrix beyond finite-state Markov chains.
Suppose we have a Markov kernel \(P_\eta\) for each \(\eta \in \mathcal{H}\). Let \(L_\eta\) be the generator of the process with Markov kernel \(P_\eta\), i.e. \(L_\eta\) acts on the space of measurable functions such that for any function \(g\) and state \(x,\)
We also assume that for all \(\eta \in \mathcal{H}\), the Markov kernel \(P_\eta\) has a unique stationary distribution \(\pi_\eta\), i.e. for all \(A \in \Sigma,\)
Let \(X_0, X_1, \ldots\) be random variables on this space such that for all bounded measurable functions \(\varphi\) and integer \(n \geq 0\), we have
where the parameter updates are given by
as before. We will primarily be interested in the stochastic integral
For simplicity, we first fix the parameters \(\eta_n = \eta\) for all \(n\) and drop the subscripts \(\eta\) in notations such as \(E_\eta,\) \(P_\eta\) and \(L_\eta\). Later, we will generalize the approach to the case with parameter updates [KMMW19].
Suppose that we have a solution \(\hat{H}\) to the Poisson equation \(L \hat{H} = E.\) Then
The last sum is a martingale, because each of the summands \(P\hat{H}(X_{k}) - \hat{H}(X_{k+1})\) is a martingale difference, i.e. conditioned on \(X_s\) for some \(s \leq k,\)
This martingale term vanishes under both conditional and total expectation.
Now, we tackle the general case where the parameter \(\eta_n\) is updated. We are primarily interested in the stochastic integral
Suppose that for all \(\eta\), we have a solution \(\hat{H}_\eta\) to the Poisson equation \(L_\eta \hat{H}_\eta = E_\eta.\) We may then decompose the above stochastic integral as the sum of the following expressions.
Here, the expressions \(S_0\) and \(S_1\) arise naturally from the application of the Poisson equation solutions. In particular, the terms \(P_{\eta_k} \hat{H}_{\eta_k}(X_{k}) - \hat{H}_{\eta_k}(X_{k+1})\) appearing in \(S_1\) are martingale differences so \(S_1\) vanishes under conditional and total expectations.
The expressions \(S_2,\) \(S_3\) and \(S_4\) are correction terms coming from updates to the parameters and step sizes, and they can be bounded by some suitable assumptions on the regularity of \(P_\eta \hat{H}_\eta(x),\) \(\nabla V (\eta)\) and \(\gamma_{n}\) respectively.
To summarize, we have the following convergence result for biased stochastic approximation [KMMW19], starting with some regularity conditions. Note that \(h(\eta)\) varies like \(\frac{\partial V}{\partial \eta}(\eta)\) so we are guaranteed convergence only to a critical point of the Lyapunov function \(V(\eta).\)
A1 (Direction of mean field). There exists \(c_0 \geq 0, c_1 \geq 0\) such that for all \(\eta \in \mathcal{H},\)
A2 (Length of mean field). There exists \(d_0 \geq 0, d_1 \geq 0\) such that for all \(\eta \in \mathcal{H},\)
A3 (\(\ell\)-smoothness of Lyapunov function). There exists \(\ell < \infty\) such that for all \(\eta, \eta' \in \mathcal{H},\)
A4 (Solution of Poisson equation). There exists a Borel measurable function \(\hat{H} : \mathcal{H} \times \mathcal{X} \rightarrow \mathcal{H}\) such that for all \(\eta \in \mathcal{H}, x \in \mathcal{X}\)
A5 (Regularity of solution). There exists \(\ell_0, \ell_1 < \infty\) such that for all \(\eta, \eta' \in \mathcal{H}, x \in \mathcal{X},\)
A6 (Correction bound). There exists \(\sigma < \infty\) such that for all \(\eta \in \mathcal{H}, x \in \mathcal{X},\)
Theorem (Convergence of Biased Stochastic Approximation). Suppose that we have parameter updates
for \(0 \leq k \leq n,\) using step sizes \(\gamma_k = \gamma_0 k^{-1/2}\) for sufficiently small \(\gamma_0 \geq 0,\) and using a random stop time \(0 \leq N \leq n\) with \(\mathbb{P}(N = l) := (\sum_{k=0}^n \gamma_{k+1})^{-1} \gamma_{l+1}.\) Then assuming A1-A6, we have
References#
- B+98
Léon Bottou and others. Online learning and stochastic approximations. On-line learning in neural networks, 17(9):142, 1998.
- CChafaiG12(1,2)
Patrick Cattiaux, Djalil Chafaı, and Arnaud Guillin. Central limit theorems for additive functionals of ergodic markov diffusions processes. ALEA, 9(2):337–382, 2012.
- Hai14
Martin Hairer. A theory of regularity structures. Inventiones mathematicae, 198(2):269–504, 2014.
- KMMW19(1,2,3)
Belhal Karimi, Blazej Miasojedow, Eric Moulines, and Hoi-To Wai. Non-asymptotic analysis of biased stochastic approximation scheme. In Conference on Learning Theory, 1944–1974. PMLR, 2019.
- LR15
Yingshen Li and Mark Rowland. Stochastic approximation theory (slides). http://yingzhenli.net/home/pdf/SA.pdf, 2015.
- Oks13
Bernt Oksendal. Stochastic differential equations: an introduction with applications. Springer Science & Business Media, 2013.