Understanding Flow Matching
Recently “Voicebox: Text-Guided Multilingual Universal Speech Generation at Scale” by M. Le et al 2023 paper with the associated model came out, which can do a bunch of nifty things. We here have already looked into at least currently the more popular “diffusion” based models in “Understanding Generative AI (Stable Diffusion) as Galton Board” (no prerequisites required). However, the “Voicebox” model is in turn based on “Flow Matching”, and this gives us as good an opportunity as any to talk about a specific approach to “Continous Normalizing Flows”. In particular, we will basically be looking into “Flow matching for generative modeling” Y. Lipman et al 2022 paper, but will take a bit more physics flavored language on the topic to make intuition a bit happier. Note, that this essay will assume knowledge of calculus, basic probability, and some familiarity with Generative AI. So without further ado, let go!
Continuous Normalizing Flows
In order to get a generative model, we will want to create a map from a known simple distribution, to an unknown complex distribution for which we are given “data” samples. Creating a map from a simple distribution to a complex one should be tricky. One way one could go about it is to gradually “morph” a distribution into the sampled one, which is basically the idea for “Normalizing Flows” - we gradually “morph” (aka “flow”). The “normalizing” indicates that morphs must be such as to keep the total probability of our distribution “1”. So not any morph will do. Now, “Continuous Normalizing Flows” basically amounts to doing “infinitesimal morphs”. In practice, it means we will be dealing with differential equations like our forefathers did.
Fluid mechanics point of view
We already mentioned the “normalizing” constraint and the “morph gradually” idea. Is there some physical example of such process? Indeed, there is - fluids, and conservation of mass for fluids. Let’s visualize a “cloudy mist” (fluid doesn’t have to be liquid, but anything that “flows”, physicists right?). Where the mist is denser we can think of probability density as being higher, and where the mist is thinner - lower. Conservation of mass means that new, extra mist is not appearing. Now, we should be more precise - conservation of mass as such is not sufficient. Mass is conserved even if it disappears and is teleported somewhere else. The conservation of mass is just equivalent to the “probability is normalized” constraint. What we insist for the mist is that e.g. if mass changes in a volume, it must have flown out its boundaries. This is known as “continuity equation”, and is our “gradually” constraint. The derivation is easy, but we will just state it here for a one-dimensional case as the interested reader can find derivation on page 1 of any fluid dynamics book.
$$ \partial_t p + \partial_x(p v) = 0 $$
Where \(p\) is density (aka probability), \(v\) is the velocity of this mist. So, whatever we do, we require that this equation is satisfied, and we have a “continuous normalized flow”.
Our general angle of attack will be to pick a density (and velocity satisfying the continuity equation) in such a manner that it has easy-to-work-with shape, and we will look for a final solution to be a combination of simple solutions. Then to generate a sample from our generative model, we sample from the initial distribution of \(p_{start}\) (which we will have picked easy to do) and then follow the velocity field as if we dropped a leaf in water at random and let it drift to a destination giving us the generated sample.
Adding up simple solutions
In fluid dynamics “continuity equation” with boundary conditions is not enough to uniquely determine velocity and density. For that, we still need “F=ma” for fluids, which is known in the streets as “Navier–Stokes equations”. The thing is, we don’t have one here. This is good - our equations are underspecified, so we have freedom here to pick a solution for the density we want to work with.
Here is the simple distribution we will pick (we could keep things more general but we will keep things a bit more specific for ease) - a normal distribution with a deviation of 1 and centered on 0, where the center linearly moves to \(x_{end}\) and deviation tends linearly to zero (or more specifically some small number; the reader might recognize a naive Dirac delta here). We will use \(p(x_t|x_{end})\) notation to describe this solution, with the center ending at \(x_{end}\) and its value at position \(x\) at time \(t\), where time will go from 0 to 1. So \(p(x_t|x_{end}) = \mathcal{N}(t\cdot x_{end}, 1-(1-\sigma_{min})\cdot t)\). Furthermore our true distribution of data will be \(q(x_{end})\). Now, let us propose, that the total solution is the sum of individual solutions, which is a pretty common instinct.
$$ p(x_t) = \int p(x_t|x_{end})q(x_{end})dx_{end} $$
It’s worth noting that given that \(p(x_{end}|x_{end})\) approximates Dirac delta, \(p(x_{end}) \approx q(x_{end})\). So far so good.
What about velocity for individual solutions? Well, no problem, individual velocity solutions are limited by our continuity equation, and we can find them by a bit of math which we will not do here, but is quite easy. Now, we summed up densities to get “total” density. What about velocity, can we also just naively add them up as we proposed for density? Unfortunately no, we are not dealing with incompressible “mist”, so continuity equations are not linear. Ie. we cannot just naively sum up the individual \(p(x_t|x_{end})\) and corresponding \(v(x_t|x_{end})\) and still satisfy continuity equations. However, we shouldn’t despair. Let’s plug in the \(p(x_t)\) in the continuity equation, and use the fact that each simpler solution satisfies individually the continuity equation.
$$ \partial_t p(x_t) = \partial_t \int p(x_t|x_{end})q(x_{end})dx_{end} = $$ $$ \int \partial_t p(x_t|x_{end})q(x_{end})dx_{end} = $$ $$ \int \partial_x (v(x_t|x_{end})p(x_t|x_{end}))q(x_{end})dx_{end} = $$ $$ \partial_x \int v(x_t|x_{end})p(x_t|x_{end})q(x_{end})dx_{end} = $$ $$ \partial_x p(x_t) \int v(x_t|x_{end})\frac{p(x_t|x_{end})}{p(x_t)}q(x_{end})dx_{end} = \partial_x (p(x_t)v(x_t)) $$
where we have defined
$$ v(x_t) \equiv \int v(x_t|x_{end})\frac{p(x_t|x_{end})}{p(x_t)}q(x_{end})dx_{end} $$
Ie. similar to p, it is still a sum, but a weighted sum.
Sampling
Okay, so we have picked simple solutions and created a total solution as a sum of simpler ones. The thing is \(p(x_t)\) or \(v(x_t)\) are not integrals we can particularly calculate. Instead we will approximate \(v(x_t)\) with a machine learned \(v^\theta(x_t)\). We will simply minimize:
$$ \mathcal{L} = \mathbb{E}_{t, p(x_t)} \left|\left| v^\theta(x_t)-v(x_t) \right|\right|^2 $$
Now, let’s actually plug in the total \(p\) and total \(v\) integrals and see where expectation goes for a specific time
$$ \mathcal{L} = \mathbb{E}_{p(x_t)} \left|| v^\theta(x_t)-v(x_t) \right||^2 = \int (v^\theta(x_t)-v(x_t))^2p(x_t)dx_t = $$ $$ \int (v^\theta(x_t)^2 - 2v^\theta(x_t)v(x_t) + v(x_t)^2)p(x_t)dx_t = $$ $$ \int v^\theta(x_t)^2p(x_t)dx_t - 2\int v^\theta(x_t)v(x_t)p(x_t)dx_t + \int v(x_t)^2p(x_t)dx_t $$
Let’s do the three integrals separately and use the fact that our total \(p\) and total \(v\) are sums. First one $$ \int v^\theta(x_t)^2p(x_t)dx_t = \int v^\theta(x_t)^2 p(x_t|x_{end})q(x_{end})dx_{end}dx_t $$ Second one $$ \int v^\theta(x_t)v(x_t)p(x_t)dx_t = $$ $$ \int v^\theta(x_t)(\int v(x_t|x_{end})\frac{p(x_t|x_{end})}{p(x_t)}q(x_{end})dx_{end})p(x_t)dx_t = $$ $$ \int v^\theta(x_t)v(x_t|x_{end})p(x_t|x_{end})q(x_{end})dx_{end}dx_t $$ Now the third one. Note the following - when running machine learning, we will be calculating the derivative of \(\mathcal{L}\) with respect to \(\theta\) ie. \(\nabla_{\theta}\mathcal{L}\), and we note that the third integral term does not have any \(\theta\) dependence, so it’s a constant, and we can replace it with whatever we would like to as long as its derivative will still be zero. We will pick something nice: $$ \int v(x_t)^2p(x_t)dx_t \rightarrow \int v(x_t|x_{end})^2p(x_t|x_{end})q(x_{end})dx_{end}dx_t $$ So let’s put together back the three integrals and get this new, equivalent \(\mathcal{L}_{EQ} \)
$$ \mathcal{L} \equiv \int v^\theta(x_t)^2 p(x_t|x_{end})q(x_{end})dx_{end}dx_t + $$ $$ \int v^\theta(x_t)v(x_t|x_{end})p(x_t|x_{end})q(x_{end})dx_{end}dx_t - $$ $$ 2\int v(x_t|x_{end})^2p(x_t|x_{end})q(x_{end})dx_{end}dx_t = $$ $$ \int (v^\theta(x_t) - v(x_t|x_{end}))^2 p(x_t|x_{end}) q(x_{end}) dx_{end}dx_t $$
We should have a look and take stock of the last equation. It is the expectation value of \((v^{\theta}(x_t) - v(x_t|x_{end}))^2\) when sampling \(q(x_{end})\) and then \(p(x_t|x_{end})\). This is nice and life is good! Instead of needing to calculate something we don’t want to calculate, we can calculate the derivative of the loss function by simple sampling and calculating simple functions! We can use this procedure to create paths from the “end” distribution to the “start” distribution, and then as the map is one to one we can then for generate sample by going from “start” to “end”.
Taking stock
We are done, but let’s recap the setup
We have picked a simple solution for density \(p(x_t|x_{end})\), for which we can derive a simple associated velocity \(v(x_t|x_{end})\). To train we sample time, then data value, so we get t and \(x_{end}\), and we calculate and sample \(p(x_t|x_{end})\) which is easy. Then calculate \(v(x_t|x_{end})\) and calculate gradient of loss function to adjust \(\theta\) to optimize \(v^\theta(x_t)\).
To sample, pick a sample \(x_{start} \sim p(x_{start})\) which comes from simple distribution. Start calculating \(v^\theta(x_t)\) from model and advance \(x\) by \(\Delta x = v \Delta t\). After some looping, we have our generated sample.