Biased stochastic approximation with mutable processes#
The goal of this post is to derive a general online learning recipe for training a mutable process \(\{Z_t,X_t\}\) to learn the true distribution \(Q_*(X)\) of a partially-observed Markov process \(\{X_t\}\). The recipe returns a generative distribution \(P(Z,X)\) whose marginal \(P(X)\) approximates \(Q_*(X).\)
The variables \(Z\) of the mutable process are auxiliary variables that assist in inference and computation. During training, the distribution of \(Z\) given \(X\) is controlled by a discriminative model \(\{Q(Z\vert X)\}.\) Our method works in both discrete time and continuous time. We assume in the mutable process that for each time \(t,\) the variables \(Z_t\) and \(X_t\) are conditionally independent of each other given their past.
Our strategy is relative inference, where we use a relative information objective that measures the divergence between the discriminative distribution \(Q(Z,X)\) and the generative distribution \(P(Z,X).\) We minimize this objective by coordinate-wise updates to the discriminative and generative distributions using stochastic gradients.
We will be using biased stochastic approximation [KMMW19] where the stochastic updates are dependent on the past but the conditional expectation of the stochastic updates given the past is not equal to the mean field. These biased stochastic approximation schemes for mutable processes generalize the classical expectation maximization algorithm for mutable models.
This post is a continuation from our series on spiking networks, path integrals and motivic information.
What do we assume about the true distribution, the model and the learning objective?#
As before, we assume that the universe is a Markov process \(\{X_t\},\) and let its true distribution be the path measure \(Q_*.\)
Suppose that we have a parametric discriminative model \(\{Q_\lambda : \lambda \in \Lambda\}\) and a parametric generative model \(\{P_\theta : \theta \in \Theta\}\) where the distributions \(Q_\lambda\) and \(P_\theta\) are path measures on some joint process \(\{(Z_t, X_t)\}.\) The random variables \(Z_t\) represent computational states in this discriminative-generative model. We can also interpret the \(Z_t\) as sample beliefs from belief distributions \(Q_\lambda(Z_t\vert Z_{t-1},X_{t-1}).\)
We assume that in both models, the distributions are Markov and each \(Z_t\) and \(X_t\) are conditionally independent given their past. We also assume that marginals \(Q(X_{0\ldots T})\) of the discriminative model distributions \(Q_\lambda(Z_{0 \ldots T}, X_{0\ldots T})\) are all equal to the true distribution \(Q_*(X_{0\ldots T}).\)
Some parts of universe \(\{X_t\}\) are observed and other parts are hidden. We will impose these conditions by putting constraints on the structure of the models \(\{Q_\lambda\}\) and \(\{P_\theta\}\), as described in this post.
Our goal is to train the models by minimizing the asymptotic relative information rate (continuous time)
or asymptotic conditional relative information (discrete time)
over \(\{Q_\lambda\}\) and \(\{P_\theta\}\). We first explore the problem in discrete time, before discussing the analogous results in continuous time.
We assume that \(Q_\lambda\) has a stationary distribution \(\bar{\pi}_\lambda,\) and let \(\bar{Q}_\lambda\) be the distribution of a Markov chain that has the same transition probabilities as \(Q_\lambda\) but has the initial distribution \(\bar{\pi}_\lambda.\) Then,
What is the general intuition behind online learning for mutable processes?#
To minimize the relative information objective, we adopt an approach similar to the expectation-maximization (EM) or exponential-mixture (em) algorithm. Specifically, we perform coordinate-wise minimization for the discriminative distribution \(Q_\lambda\) and for the generative distribution \(P_\theta\), updating one distribution while holding the other constant.
First, we pick some initial generative model distribution \(P_{\theta_0}\) and discriminative model distribution \(Q_{\lambda_0}.\) Then, for \(n = 0, 1, \ldots,\) we repeat the next two steps. Here, we will perform both steps in parallel rather than in an alternating fashion.
Step 1 (generative model update). Fixing the discriminative model distribution \(Q_{\lambda_{n}}(Z_1 \vert Z_0, X_0),\) minimize \(I_{\bar{Q}_{\lambda_{n}}\Vert P_{\theta}}(Z_1, X_1 \vert Z_0, X_0)\) over generative model distributions \(P_{\theta}\).
By definition,
where we note that the first term is independent of \(\theta\).
We update the parameter \(\theta\) using the gradient
where we can also write
Step 2 (discriminative model update). Fixing the generative model distribution \(P_{\theta_{n}},\) minimize \(I_{\bar{Q}_\lambda \Vert P_{\theta_{n}}}(Z_1, X_1 \vert Z_0, X_0)\) over discriminative model distributions \(Q_\lambda.\)
We update the parameter \(\lambda\) using the gradient
where, as shown in the appendix, we have
Is there a stochastic approximation of the above procedure?#
In the above two-step procedure, the term
cannot be evaluated because it depends on the true distribution \(Q_*.\) Fortunately, this term only scales the discriminative model update; it does not change the direction of the update. We will then replace the unknown \(\log Q_*(X_{T+1}\vert X_{T})\) with an estimate.
Suppose we study the asymptotic time-average
of the negative log-transition of the true distribution. Under mild regularity conditions, we have the ergodic relationship
where \(\bar{\pi}_*\) is the stationary distribution of \(Q_*.\) Let \(\bar{Q}_*\) be the distribution of the true stationary process with initial distribution \(\bar{\pi}_*\) and transition probabilies \(Q_*.\) The asymptotic time-average \(H\) is therefore the true conditional entropy of \(X_1\) given \(X_0\) under the true stationary process.
More precisely, given random variables \(X_0, X_1, X_1',\) we construct two distributions, namely
where \(\mathbb{I}(X_1 = X_1')\) is the indicator function that ensures that \(X_1\) and \(X_1'\) are copies of each other. Then, the true conditional entropy is
Let \(-\xi\) be an estimate of this true conditional entropy. We can substitute the unknown \(\log Q_*(X_{T+1}\vert X_{T})\) with this constant without affecting the convergence of the algorithm, as we shall see in another post. More generally, we can replace the unknown with any estimate \(\xi(X_{T+1} \vert X_T)\) that does not depend on parameters \(\theta, \lambda\) or beliefs \(Z_{T+1}, Z_T.\)
Now, the above two-step procedure has the following stochastic approximation.
In continuous time, the above updates will become differential equations. The samples \(Z_t\) would be driven by a Poisson process, and the transition probabilities appearing in the updates for \(\theta_t\), \(\alpha_t\), \(\gamma_t\) would be replaced by transition rates.
Before we make some preliminary observations about this stochastic approximation, let us introduce some terminology. Given \((Z_{n}, X_{n}),\) suppose we sample \((Z_{n+1}, X_{n+1})\) from \(Q_\lambda(Z_{n+1},X_{n+1} \vert Z_{n}, X_{n}).\) The conditional expectation of a function \(r(Z_{n+1}, X_{n+1}, Z_{n}, X_{n})\) is the expectation of \(r\) conditioned on some given values of \((Z_{n}, X_{n}).\) The mean field or total expectation of \(r\) is the expectation of its conditional expectation over the stationary distribution \(\bar{\pi}_\lambda\) on \((Z_{n}, X_{n}).\)
If the conditional expectations of the updates are independent of \((Z_{n}, X_{n})\), then they will be equal to their mean fields. In this case, we say that the stochastic approximation is unbiased. On the other hand, if the conditional expectations depend on \((Z_{n}, X_{n})\), we say that the stochastic approximation is biased.
In continuous time, the mean fields will be derivatives of relative information rates. The conditional expectations which depend on the current states \((Z_t,X_t)\) will be biased estimates of the mean fields.
How can we interpret the discriminative model update?#
For a fixed generative model \(P_\theta,\) the discriminative model update looks for a distribution \(Q_\lambda(Z_{n+1}\vert Z_{n},X_{n})\) that minimizes the learning objective \(I_{\bar{Q}_\lambda \Vert P_\theta}(Z_{n+1}, X_{n+1} \vert Z_{n}, X_{n}).\) Intuitively, we can think of the update as looking for good belief \(Z_{n+1}\) given the previous belief \(Z_{n}\) and observation \(X_{n}.\)
Because \(Z_{n+1}\) and \(X_{n+1}\) are conditionally independent given the past, the learning objective decomposes as a sum of two terms.
The first term vanishes when
This term shows that the discriminative model update tends to exploit the generative model \(P_\theta(Z_{n+1}\vert Z_{n}, X_{n})\) in generating a belief \(Z_{n+1}.\)
The second term vanishes when \(Q_\lambda(X_{n+1}\vert Z_{n},X_{n}) = Q_*(X_{n+1}\vert X_{n})\) equals \(P_\theta(X_{n+1}\vert Z_{n},X_{n}),\) but this is clearly impossible because the true distribution is fixed.
Instead, note that
so the parameter \(\lambda\) has an effect only on the stationary transition \(\bar{\pi}_\lambda(dZ_n\vert dX_n).\) Thus, in the long run, the discriminative model update tends to pair beliefs \(Z_n\) with the current \(X_n\) such that the generative model \(P_\theta(X_{n+1} \vert Z_n, X_n)\) is able to effectively guess the next state \(X_{n+1}\) under \(Q_*(X_{n+1}\vert X_n).\) In simpler words, the discriminative model update tends to explore good beliefs \(Z_n\) for predicting the next observation \(X_{n+1}\).
Note that the above two tendencies to exploit and explore could be in conflict with each other. For example, at the start of the training regime, the generative model \(P_\theta\) is often a poor fit for the observations. In exploiting the bad generative model, the discriminative model update may end up with a belief \(Z_n\) that poorly predicts the next observation \(X_{n+1}\), where the prediction \(P_\theta(X_{n+1}\vert Z_n, X_n)\) was made under this same generative model. However, by exploring beliefs \(Z_n\) that well predict the next observation under the generative model, the discriminative model update is giving feedback which the generative model update can use for strengthening the useful parts of the generative model. More precisely, the generative model update will make these useful beliefs more likely under \(P_\theta(Z_{n+1}\vert Z_{n}, X_{n})\) so that they can be exploited at the next discriminative model update.
In the long run, when the generative model is a good fit for the observations, the tendencies to exploit and to explore will be more in tune with each other. This is because beliefs generated by the model \(P_\theta(Z_n\vert Z_{n-1}, X_{n-1})\) will also be useful for predicting the next state \(P_\theta(X_{n+1}\vert Z_n,X_n).\)
Explicitly, the exploitative part of the discriminative model update is estimated by
while the explorative part is estimated by
The explorative update is large when \(Q_*(X_{T+1}\vert X_{T})\) and \(P_\theta(X_{T+1}\vert Z_{T},X_{T})\) are far apart.
In the stochastic approximation, the explorative part is controlled by
where \(\xi(X_{n+1} \vert X_n)\) is an estimate of the true log-likelihood \(\log Q_*(X_{n+1}\vert X_{n})\). When \(X_{n+1}\) is too likely or too unlikely given \((Z_{n}, X_{n})\), there will be a big difference between the log-likelihood \(\log P_{\theta_{n}}(X_{n+1}\vert Z_{n},X_{n})\) and the threshold \(\xi(X_{n+1} \vert X_n).\) This will generate a strong signal response in the learning system to correct the discrepancy.
In [RG14], this strong signal was called novelty or surprise. The authors hypothesized that biological neural networks could implement this signal using neuromodulation.
In [PBR20], a reinforcement learning scheme for training multilayer neural networks was derived. To implement the weight updates, besides computing the usual feedforward signals, the scheme also computes feedback signals using feedback connections, a global modulating signal representing the reward prediction error, and a local gating signal representing top-down attention. The resulting weight updates are Hebbian.
While there are many interesting similarities between their scheme and our algorithm, one major difference is that we do not require the feedback weights to be the same as the feedforward weights. In our algorithm, the feedback weights are represented by the parameter \(\lambda\) and the feedforward weights by \(\theta\). At the end of training, the feedback weights will tend towards the feedforward weights because of the tendency to exploit. However, tying the weights together at the start of training could be detrimental to learning due to the need of the neural network to explore.
Appendix: Discriminative model update#
In this appendix, we derive the gradient
used in the discriminative model update. The methods used are similar to those employed in the policy gradient theorem [BB01].
We start with the following formula from [BB01] and [KMMW19] for the integral of a function \(r(W)\) with respect to the derivative of the stationary distribution \(\bar{\pi}_\lambda(W)\).
We now derive the discriminative model update. Let \(\{W_n\}\) denote the Markov chain \(\{(Z_{n+1},X_{n+1},Z_{n},X_{n})\}.\) Abusing notation, we write the distribution of \(W_n\) as
and its stationary distribution as
By the product rule,
The first term equals
Taking derivatives of the stationary distribution, the second term becomes
Lastly, because
the gradient simplifies (after a change of indices) to
where the last equality follows because the limit does not depend on the initial distribution of \((Z_0, X_0).\)
References#
- BB01(1,2)
Jonathan Baxter and Peter L Bartlett. Infinite-horizon policy-gradient estimation. Journal of Artificial Intelligence Research, 15:319–350, 2001.
- KMMW19(1,2)
Belhal Karimi, Blazej Miasojedow, Eric Moulines, and Hoi-To Wai. Non-asymptotic analysis of biased stochastic approximation scheme. In Conference on Learning Theory, 1944–1974. PMLR, 2019.
- PBR20
Isabella Pozzi, Sander Bohte, and Pieter Roelfsema. Attention-gated brain propagation: how the brain can implement reward-based error backpropagation. Advances in Neural Information Processing Systems, 33:2516–2526, 2020.
- RG14
Danilo Jimenez Rezende and Wulfram Gerstner. Stochastic variational learning in recurrent spiking networks. Frontiers in computational neuroscience, 8:38, 2014.