Fast Best-of-N Decoding via Speculative Rejection



Hanshi Sun1*, Momin Haider2*, Ruiqi Zhang3*, Huitao Yang5, Jiahao Qiu4,
Ming Yin4, Mengdi Wang4, Peter Bartlett3, Andrea Zanette1*
1Carnegie Mellon University 2University of Virginia 3UC Berkeley
4Princeton University 5Fudan University

*Core Contributors

  Introduction

The safe and effective deployment of LLMs involves a critical step called alignment, which ensures that the model's responses are in accordance with human preferences. Techniques like DPO, PPO and their variants, align LLMs by changing the pre-trained model weights during a phase called post-training. While predominant, these post-training methods add substantial complexity before LLMs can be deployed. Inference-time alignment methods avoid the complex post-training step and instead bias the generation towards responses that are aligned with human preferences. The best-known inference-time alignment method, called Best-of-N, is as effective as the state-of-the-art post-training procedures. Unfortunately, Best-of-N requires vastly more resources at inference time than standard decoding strategies, which makes it computationally not viable. We introduce Speculative Rejection, a computationally-viable inference-time alignment algorithm. It generates high-scoring responses according to a given reward model, like Best-of-N does, while being between 16 to 32 times more computationally efficient.

Retrieval-based Drafting

We evaluate the effectiveness of Speculative Rejection on the AlpacaFarm-Eval dataset using various generative models and reward models. The numbers indicate N for Best-of-N and rejection rate α for Speculative Rejection. Our method consistently achieves higher reward scores with fewer computational resources compared to Best-of-N.

 Speculative Rejection

Speculative Rejection is based on the observation that the reward function used for scoring the utterances can distinguish high-quality responses from low-quality ones at an early stage of the generation. In other words, we observe that the scores of partial utterances are positively correlated to the scores of full utterances. As illustrated in the figure, this insight enables us to identify, during generation, utterances that are unlikely to achieve high scores upon completion, allowing us to halt their generation early.

Speculative Rejection System

Speculative Rejection begins with a very large batch size, effectively simulating the initial phases of Best-of-N with a large N (e.g., 5000) on a single accelerator. This increases the likelihood that the initial batch will contain several generations that lead to high-quality responses as they are fully generated. However, such a large batch size would eventually exhaust the GPU memory during the later stages of auto-regressive generation. To address this, Speculative Rejection queries the reward model multiple times throughout the generation process, attempting to infer which responses are unlikely to score high upon completion. Our method dynamically reducing the batch size and preventing memory exhaustion while ensuring that only the most promising responses are fully generated.

  Win-rate Evaluation by GPT-4-Turbo

To further validate the generation quality, we evaluate both the win-rate and the length-controlled (LC) win-rate using GPT-4-Turbo with alpaca eval. For each measurement, the win-rate baseline is Bo120. As shown in the table, Speculative Rejection maintains generation quality while achieving a notable speedup across various settings for the Mistral-7B, Llama-3-8B, and Llama-3-8B-Instruct models, scored by the reward model ArmoRM-Llama-3-8B and evaluated using GPT-4-Turbo. "WR" refers to win-rate, and "LC-WR" refers to length-controlled win-rate.


Methods Mistral-7B Llama-3-8B Llama-3-8B-Instruct Average
WR LC-WR WR LC-WR WR LC-WR WR LC-WR
Bo120 50.0050.00 50.0050.00 50.0050.00 50.0050.00
Bo240 60.6960.07 50.4550.27 49.9252.89 53.6954.41
Bo480 61.2861.84 58.9059.93 50.4953.11 56.8958.29
Bo960 67.5068.07 59.2060.26 50.3951.64 59.0359.99
Bo1920 75.2076.27 60.5761.05 51.8653.13 62.5463.48
Bo3840 76.1377.21 59.1957.91 53.3654.01 62.8963.04
Ours (α=0.5) 69.4273.31 73.6077.91 55.5058.80 66.1770.01

  Conclusion and Future Work

Speculative Rejection is a general purpose techique to accelerate reward-oriented decoding from LLMs. The procedure is simple to implement while yielding substantially speedups over the baseline Best-of-N. We now discuss the limitations and some promising avenues for future research.

Prompt-dependent Stopping. Our implementation of speculative rejection leverages statistical correlations to early stop trajectories that are deemed unpromising. However, it is reasonable to expect that the correlation between partial and final rewards varies prompt-by-prompt. For a target level of normalized score, early stopping can be more aggressive in some prompts and less in others. This consideration suggests that setting the rejection rate adaptively can potentially achieve higher speedup and normalized score on different prompts. We leave this opportunity for future research.

Reward Models as Value Functions. Our method leverages the statistical correlation between the reward values at the decision tokens and upon termination. Concurrently, recent literature also suggest training reward models as value functions. Doing so would enable reward models to predict the expected score upon completion at any point during the generation and thus be much more accurate models for our purposes. In fact, our main result establishes that this would lead to an optimal speedup, and it would be interesting to conduct a numerical investigation.

BibTeX

@article{sun2024fast,
    title={Fast Best-of-N Decoding via Speculative Rejection},
    author={Sun, Hanshi and Haider, Momin and Zhang, Ruiqi and Yang, Huitao and Qiu, Jiahao and Yin, Ming and Wang, Mengdi and Bartlett, Peter and Zanette, Andrea},
    journal={arXiv preprint arXiv:2410.20290},
    year={2024}
    }