FlowRL can be derived from first principles by turning RLHF-style reward maximization into reverse-KL distribution matching against a reward-weighted target, which yields a tractable sequence-level trajectory-balance loss with a learned partition function and practical stabilizers for long chain-of-thought training. The result is a simple squared “flow residual” objective whose minimization makes sampling proportional to reward (modulated by a reference model), connecting directly to GFlowNet trajectory balance guarantees for mode coverage.
Problem setup
- Given a prompt \(x\), the policy generates a full sequence \(y=(y_1,\dots,y_T)\) with autoregressive likelihood \(\pi_{\theta}(y\mid x)=\prod_{t=1}^{T}\pi_{\theta}(y_t\mid y_{<t},x)\) and obtains a sequence-level reward \(r(x,y)\) from a verifier or reward model.
- A frozen reference model \(\pi_{\mathrm{ref}}(y\mid x)\) acts as a prior that constrains drift, and FlowRL learns a per-prompt normalizer \(Z_{\phi}(x)\) (partition function) to model the intractable normalization of a reward-weighted target distribution.
Target distribution view
- FlowRL aligns the policy to a reward-weighted target distribution \(\tilde{\pi}(y\mid x)\) defined as \(\tilde{\pi}(y\mid x)=\frac{1}{Z(x)}\exp\!\big(\beta\,r(x,y)\big)\,\pi_{\mathrm{ref}}(y\mid x)\), where \(\beta>0\) is a temperature and \(Z(x)=\sum_{y}\exp(\beta r)\pi_{\mathrm{ref}}\) is the partition function.
- Because \(Z(x)\) is unknown, the method learns \(Z_{\phi}(x)\) and minimizes the reverse KL divergence \(D_{\mathrm{KL}}(\pi_{\theta}(\cdot\mid x)\,\|\,\tilde{\pi}(\cdot\mid x))\) per prompt, turning RL into distribution matching over complete trajectories.
Reverse KL objective
- Writing out the reverse KL yields the per-prompt objective
\(J(\theta,\phi;x)\;=\;\mathbb{E}_{y\sim \pi_{\theta}(\cdot\mid x)}\!\Big[\log \pi_{\theta}(y\mid x)\;-\;\beta\,r(x,y)\;-\;\log \pi_{\mathrm{ref}}(y\mid x)\;+\;\log Z_{\phi}(x)\Big]\,,\) which is tractable since the expectation is under the current policy and \(\log Z_{\phi}(x)\) is a single scalar output. - Intuitively, this objective pulls \(\pi_{\theta}\) toward sequences that have high reward and are likely under \(\pi_{\mathrm{ref}}\), with \(\beta\) controlling sharpness and \(\log Z_{\phi}(x)\) adjusting the overall scale to normalize the target.
Gradient of the reverse KL
- Using the score-function identity \(\nabla_{\theta}\mathbb{E}_{\pi_{\theta}}[f]=\mathbb{E}_{\pi_{\theta}}[\nabla_{\theta}\log \pi_{\theta}\,f+\nabla_{\theta}f]\), let \(f(y,\theta)=\log \pi_{\theta}(y\mid x)-\beta r-\log \pi_{\mathrm{ref}}+\log Z_{\phi}(x)\) to obtain
\(\nabla_{\theta}J=\mathbb{E}_{\pi_{\theta}}\!\big[\nabla_{\theta}\log \pi_{\theta}(y\mid x)\,\big(f(y,\theta)+1\big)\big]\,,\) since only \(\log \pi_{\theta}\) depends on \(\theta\) inside \(f\). - Using the score property \(\mathbb{E}_{\pi_{\theta}}[\nabla_{\theta}\log \pi_{\theta}]=0\), the “+1” term does not affect stationarity, yielding the core stationary condition \(\mathbb{E}_{\pi_{\theta}}[\nabla_{\theta}\log \pi_{\theta}(y\mid x)\,\delta(x,y)]=0\) with residual \(\delta=\log Z_{\phi}+\log \pi_{\theta}-\beta r-\log \pi_{\mathrm{ref}}\).
From reverse KL to trajectory balance
- FlowRL shows that driving \(\mathbb{E}_{\pi_{\theta}}[\nabla_{\theta}\log \pi_{\theta}\,\delta]=0\) is equivalent in expected gradients to minimizing the squared trajectory-balance residual \(\delta(x,y)^2\), producing a simple regression-style loss.
- This is the autoregressive-tree specialization of GFlowNet trajectory balance, where minimizing the squared residual ensures the induced sampling distribution is proportional to the terminal reward (here \(\exp(\beta r)\pi_{\mathrm{ref}}\)) once the partition function is learned, providing mode-covering behavior.
Sequence-level TB loss
- The unnormalized TB residual is
\(\delta(x,y)\;=\;\log Z_{\phi}(x)\;+\;\log \pi_{\theta}(y\mid x)\;-\;\beta\,r(x,y)\;-\;\log \pi_{\mathrm{ref}}(y\mid x)\,,\) and the basic loss is \(L_{\mathrm{TB}}=\delta(x,y)^2\) per trajectory. - Minimizing \(L_{\mathrm{TB}}\) over \((\theta,\phi)\) matches \(\pi_{\theta}\) to the reward-weighted target distribution and regresses \(\log Z_{\phi}(x)\) to the true \(\log Z(x)\) at optimum, recovering the reverse-KL solution with a squared surrogate.
Length normalization
- For long chain-of-thought, \(\log \pi_{\theta}(y\mid x)=\sum_{t=1}^{T}\log \pi_{\theta}(y_t\mid y_{<t},x)\) grows with \(T\), inflating residuals and gradients, so FlowRL replaces sequence log-likelihoods by their averages \(\frac{1}{T}\log \pi_{\theta}(y\mid x)\) and \(\frac{1}{T}\log \pi_{\mathrm{ref}}(y\mid x)\) inside \(\delta\).
- The length-normalized residual is
\(\delta_{\text{norm}}=\log Z_{\phi}(x)+\frac{1}{T}\log \pi_{\theta}(y\mid x)-\beta\,\hat r(x,y)-\frac{1}{T}\log \pi_{\mathrm{ref}}(y\mid x)\,,\) where \(\hat r\) is a stabilized reward (see below), keeping gradient magnitudes comparable across response lengths.
Reward standardization
- Rewards are group-normalized per prompt by sampling \(G\) rollouts \(\{y_i\}_{i=1}^{G}\) and setting \(\hat r_i=\frac{r_i-\mu}{\sigma+\epsilon}\) with \(\mu=\frac{1}{G}\sum_i r_i\) and \(\sigma^2=\frac{1}{G}\sum_i (r_i-\mu)^2\), which stabilizes scale across prompts and batches.
- This standardization retains the ranking within a prompt group while mitigating scale drift, making a single \(\beta\) effective across diverse prompts during sequence-level optimization.
Off-policy correction
- To reuse trajectories from a behavior policy \(\pi_{\text{old}}\), expectations under \(\pi_{\theta}\) are reweighted by the sequence-level importance ratio \(\rho(y)=\frac{\pi_{\theta}(y\mid x)}{\pi_{\text{old}}(y\mid x)}=\exp\big(\log \pi_{\theta}(y\mid x)-\log \pi_{\text{old}}(y\mid x)\big)\).
- FlowRL uses a PPO-style clipped, detached weight \(w=\mathrm{clip}(\rho,\,1-\epsilon,\,1+\epsilon)^{\text{detach}}\) to control variance and avoid unstable feedback through \(\rho\) inside a squared residual, yielding the final per-trajectory loss \(L=w\,\delta_{\text{norm}}^2\).
Final FlowRL loss
- Putting the pieces together, the train-time objective is
\(L_{\text{FlowRL}}\;=\;\mathrm{clip}\!\left(\frac{\pi_{\theta}(y\mid x)}{\pi_{\text{old}}(y\mid x)},\,1-\epsilon,\,1+\epsilon\right)^{\text{detach}}\cdot\Big(\log Z_{\phi}(x)+\tfrac{1}{T}\log \pi_{\theta}(y\mid x)-\beta\,\hat r(x,y)-\tfrac{1}{T}\log \pi_{\mathrm{ref}}(y\mid x)\Big)^{2}\,,\) optimized over \(\theta\) and \(\phi\) jointly. - This loss is sequence-level, requires no critic, and integrates KL-to-reference directly inside the TB residual, differing fundamentally from token-level PPO/GRPO surrogates that optimize expected advantage.
Gradients in detail
- Let \(\delta=\delta_{\text{norm}}\) and \(w\) be the detached clipped ratio; the gradients are \(\nabla_{\theta}L=2w\,\delta\,\nabla_{\theta}\big(\tfrac{1}{T}\log \pi_{\theta}(y\mid x)\big)\) and \(\nabla_{\phi}L=2w\,\delta\,\nabla_{\phi}\log Z_{\phi}(x)\).
- Expanding the policy term gives \(\nabla_{\theta}\big(\tfrac{1}{T}\log \pi_{\theta}(y\mid x)\big)=\tfrac{1}{T}\sum_{t=1}^{T}\nabla_{\theta}\log \pi_{\theta}(y_t\mid y_{<t},x)\), i.e., a normalized sum of token-level score gradients.
Why squared TB matches reverse KL
- At reverse-KL stationarity, \(\mathbb{E}_{\pi_{\theta}}[\nabla_{\theta}\log \pi_{\theta}\,\delta]=0\), and minimizing \(\mathbb{E}[\delta^2]\) drives \(\delta\) toward zero while sharing the same stationary points under regularity, providing a stable regression proxy for the KL problem.
- This mirrors GFlowNet trajectory balance guarantees: when the TB residual is globally minimized in an autoregressive tree, the induced trajectory distribution matches the desired reward distribution, enabling proportional sampling across multiple high-reward modes.
Relation to GFlowNets
- GFlowNets view sampling as learning flows on a DAG such that terminal-state probability is proportional to reward, with trajectory balance enforcing a log-space equality between source flow, path log-probability, and terminal reward.
- FlowRL specializes TB to sequence generation, adds a reference prior and length normalization, and couples it with off-policy reuse for LLM post-training, maintaining the core mode-covering property of GFlowNets for reasoning trajectories.
Toy example
- Consider one prompt \(x\) with two answers \(y_1,y_2\) and rewards \(r_1,r_2\), and a reference \(\pi_{\mathrm{ref}}(y_i\mid x)\); the target is \(\tilde{\pi}(y_i\mid x)\propto \exp(\beta r_i)\pi_{\mathrm{ref}}(y_i\mid x)\).
- Minimizing TB gives \(\pi_{\theta}(y_i\mid x)=\frac{\exp(\beta r_i)\pi_{\mathrm{ref}}(y_i\mid x)}{\sum_j \exp(\beta r_j)\pi_{\mathrm{ref}}(y_j\mid x)}\) at optimum, directly illustrating distribution matching and the role of the learned partition \(Z_{\phi}(x)\).
Implementation recipe
- For each prompt, sample \(G\) rollouts from \(\pi_{\text{old}}\), compute \(r\), group-normalize to \(\hat r\), and evaluate \(\log \pi_{\theta}\), \(\log \pi_{\mathrm{old}}\), and \(\log \pi_{\mathrm{ref}}\), keeping token sums to form sequence-level log-probabilities and lengths \(T\).
- Form \(w=\mathrm{clip}(\exp(\log \pi_{\theta}-\log \pi_{\text{old}}),1-\epsilon,1+\epsilon)^{\text{detach}}\), compute \(\delta_{\text{norm}}\), and minimize \(w\,\delta_{\text{norm}}^2\) while updating \(\theta,\phi\), periodically refreshing \(\pi_{\text{old}}\) and using multi-sample evaluation (e.g., 16 rollouts) for reasoning metrics.
Practical stabilizers, justified
- Length normalization makes the residual scale roughly invariant to response length, avoiding gradient explosions on 8k–16k token CoT and improving optimization conditioning without changing the essential stationary solution class.
- Clipped, detached importance weights enable off-policy reuse with controlled variance and avoid destabilizing gradients through the ratio inside a squared loss, with ablations showing substantial degradation if this component is removed.
Contrast with PPO/GRPO
- PPO optimizes a clipped advantage surrogate with an explicit KL penalty and a value critic, operating token-wise and tending to be mode-seeking under sharp reward landscapes.
- GRPO removes the critic via group-normalized advantages but still maximizes expected return, whereas FlowRL’s sequence-level TB loss matches the whole reward-induced distribution and empirically preserves multiple reasoning modes.
Minimal PyTorch-style loss
# Given: x, y, T, logp_cur_seq, logp_old_seq, logp_ref_seq, reward
# Hyperparams: beta, eps
# Models: policy (for logp_cur_seq), partition_net Z_phi(x)
with torch.no_grad():
r_hat = (reward - reward.mean(dim=1, keepdim=True)) / (reward.std(dim=1, keepdim=True) + 1e-6)
logp_cur_norm = logp_cur_seq / T
logp_ref_norm = logp_ref_seq / T
with torch.no_grad():
rho = torch.exp(logp_cur_seq - logp_old_seq)
w = torch.clamp(rho, 1-eps, 1+eps)
logZ = partition_net(x) # scalar per prompt
delta = logZ + logp_cur_norm - beta * r_hat - logp_ref_norm
loss = (w * (delta ** 2)).mean()
loss.backward()This implements length-normalized, sequence-level trajectory balance with a learned partition and detached clipped importance weighting as specified in FlowRL.
Takeaways
- The FlowRL objective is a squared sequence-level trajectory-balance residual that is equivalent (in expected gradients) to reverse-KL distribution matching against \(\exp(\beta r)\pi_{\mathrm{ref}}\), giving a simple, critic-free route to reward-proportional sampling.
- The three crucial engineering choices—length normalization, per-prompt reward standardization, and detached clipped importance weights—stabilize long-CoT optimization and enable efficient off-policy training in LLM post-training pipelines.
References: FlowRL arXiv for derivations, implementation details, and ablations; Trajectory Balance (GFlowNet) for the theoretical guarantee behind proportional sampling via TB.