Speculative Decoding:投机解码
投机解码(Speculative Decoding)也叫预测解码/投机采样,它会利用小模型来预测大型模型的行为,从而提升模型在解码(decoding)阶段的解码效率问题,加速大型模型的执行。其核心思路如下图所示,首先以低成本的方式(以小模型为主,也有多头,检索,Early Exit 等方式)快速生成多个候选 Token(串行序列、树、多头树等),然后通过一次并行验证阶段快速验证多个 Token 的正确性,只要平均每个 Step 验证的 Token 数 > 1,就可以一次性生成多个 token,进而减少总的 Decoding 步数,实现加速的目的。
下图左侧是自回归解码模型,右侧是投机解码机制。
从本质上来说,投机解码希望在推理阶段在不大幅度改变模型的情况下,通过更好利用冗余算力来并行"投机"地猜测出模型接下来要输出的token。作为对比,也有一种方案是通过路由的方式组合多个不同规模和性能的模型。路由方式在调用之前已经确定好需要调用哪个模型,直到调用结束。而投机解码在一个 Query 内会反复调用大小模型。
1.1 问题
我们都知道,生成式 LLM 大部分是 Decoder-only 结构,其一方面模型比较大,推理时占用的存储空间、所需的计算量都比较大,另一方面,大模型解码时是一个 Token 一个 Token 串行生成,在 batch size 为 1 时,Transformer block 中的矩阵乘都退化为矩阵乘向量操作,对于 GPU 推理来说,这是非常明显的 IO bound,导致无法充分发挥 GPU 算力。
1.2 自回归解码
当前的主流 LLM 基本都是 Decoder Only 的 Transformer 模型,其推理阶段采用自回归采样,特点如下:
- 模型使用前缀作为输入,将输出结果处理+归一化成概率分布后,采样生成下一个token。
- 从生成第一个 Token之后,开始采用自回归方式一次生成一个 Token,即当前轮输出token 与历史输入 tokens 拼接,作为下一轮的输入 tokens,然后解码。
- 重复执行2。在后续执行过程中,前后两轮的输入只相差一个 token。
- 直到生成一个特殊的 Stop Token(或者满足用户的某个条件,比如超过特定长度) 才会结束。
自回归解码对应的算法如下图所示。
自回归采样的缺点如下:
- 因为在生成文本时,自回归采样是逐个 token 生成的,生成下一个 token 需要依赖前面已经生成的 token,这种串行的模式导致生成速度慢,效率很低。具体参见下图。假设输出总共有 N 个 Token,则 Decoding 阶段需要执行 N-1 次 Forward,这 N-1 次 Forward 只能串行执行。
- 在生成过程中,需要关注的 Token 越来越多(每个 Token 的生成都需要和之前的 Token 进行注意力计算),计算量也会随之增大。
- 大型模型的推理过程往往受制于访存速度。因为推理下一个 token 的时候,需要依赖前面的结果。所以在实际使用 GPU 进行计算时,需要将所有模型参数以及 kv-cache 移至片上内存进行运算,而一般来说片上内存带宽比计算性能要低两个数量级,这就使得大模型推理是 memory-bandwidth-bound 的,内存访问带宽成为严重的瓶颈。
另外,大模型的能力遵循 scaling law,也就是模型的参数越多其拥有的能力越强,而越大的模型自然就需要越多的计算资源。scaling law 告诉我们,我们没有办法通过直接减小模型的参数量来减小访存的访问量。
为了解决推理速度慢的问题,研究人员已经进行了许多针对推理的工程优化,例如:
- 改进的计算核心实现、多卡并行计算、批处理策略等等。其中,最朴素的做法就是增大推理时的 Batch size,比如使用 dynamic batching,将多个请求合并处理,将矩阵乘向量重新变为矩阵乘操作,在 Batch size 不大的情况下,几乎可以获得 QPS 的线性提升。然而,这些方法并没有从根本上解决LLM解码过程是受制于访存带宽的问题。
- 对模型以及 KV Cache 进行量化,使每一个 token 生成过程中读取模型参数时的总比特数减小,缓解 io 压力。
- increasing the arithmetic intensity,即提高“浮点数计算量/数据传输量”这个比值,让数据传输不要成为瓶颈。
- reducing the number of decoding steps,即缩短解码步骤。投机解码就属于这个范畴。
0x02 定义 & 历史
2.1 投机解码
投机解码(Speculative Decoding)允许我们将在同一个用户请求内的多个 Token 一起运算。其目的和 dynamic batching 类似,也是为了将矩阵乘向量重新变为矩阵乘操作,这很适合无法获得更大 Batch size 或者只想降低端到端延时的场景。投机解码一般使用两个模型:Draft Model 快速生成多个候选结果,然后 Target Model 并行验证和修改,最终得到满意答案。具体而言:
- draft model 用来猜测。draft model 推理较快,承担了串行的工作,它以自回归的方式生成 K 个 tokens,从而让目标模型能够并行的计算。
- target model 用来评估采样结果\审核修正。target model 通过并行计算多个 token 来从自回归模型中采样,用推理结果来决定是否使用 draft model 生成的这些 tokens。
投机解码的算法如下图所示。
投机解码无需对输出进行任何更改,就可以保证和使用原始模型的采样分布完全相同,因此和直接用大模型解码是等价的。下图右侧,草稿模型先生成5个预测token后,将5个token一起输入给目标模型。以该前缀作为输入时,目标模型会生成若干token,然后进行验证。绿色表示草稿模型生成的token和目标模型生成的token一致,预测token通过了“验证”——这个token本来就是LLM自己会生成的结果。红色token是没有通过验证的“推测”token。第一个没有通过验证的“推测”token和其后续的“推测”token都将被丢弃。因为这个红色token不是LLM自己会生成的结果,那么前缀正确性假设就被打破,这些后续token的验证都无法保证前缀输入是“正确”的了。
2.2 发展历史
下面给出了投机解码的发展历史。
其中有两篇文章需要特殊提一下,两篇文章都算是投机解码的开山之作,其中公案我们也难以说清。
Speculative Decoding
论文“Speculative Decoding: Exploiting Speculative Execution for Accelerating Seq2seq Generation”是第一篇提出 Speculative Decoding 这个词的文章,也确立了使用 draft-then-verify 这一方法加速 Auto-Regressive 生成的范式。
Speculative Decoding 希望解决的是现有的 Autoregressive 模型推理过慢的问题。下图(a)是Blockwise Decoding,其在目标自回归模型上引入了k − 1个FFN头,这些头使用共享注意力(shared attention)来预测下面k个tokken。(b)是Spec-Drafter模型,该模型是预测草稿token的独立模型,它使用不同的query来预测每个草稿token。下图上黄色部分是自回归AR模型,红色部分是新加入的模块。
Speculative Sampling
论文“Fast Inference from Transformers via Speculative Decoding”最早提出了 Speculative Sampling。此文章和上一篇文章是同时期的研究,被认为是SD的开山之作,后续许多研究都是基于此来展开。本文用 target model(目标模型)指代待加速的大模型,用 approximation model(近似模型)指代用来帮助加速大模型的小模型。
后续我们统一使用speculative decoding这个术语。
接下来,我们先对本领域的先驱之作"Blockwise Parallel Decoding"做简要分析,然后再结合两篇开山之作进行学习。
0x03 Blockwise Parallel Decoding
论文“Blockwise Parallel Decoding for Deep Autoregressive Models”提出的 Blockwise Parallel Decoding 是本领域的先行之作,或者说并行解码的第一个工作,所以我们仔细学习下,有助于我们理解后续脉络。Blockwise Parallel Decoding 使用多头的方式生成候选序列(一个串行序列),然后进行并行验证。
0x04 原理
看完了BPD这个基础之作,我们再来看看投机解码。
4.1 动机
投机解码的动机来自几点观察和一个借鉴。
4.1.1 观察
我们首先看看几点关键观察结果:
- 困难任务包含容易子任务。在困难的语言建模任务中,通常包含了一些相对容易的子任务,比如,预测有些 token 时,softmax 输出的概率分布会集中在某些 token 上,这说明模型有较大的置信度确定下一个输出的 token。这意味着不是所有的解码步骤都同样困难,如果我们用小模型去回答这些简单的问题,在遇到难题的情况下再调用大模型,就可以提高整体的生成效率。即,大多数容易生成的 tokens 其实用更少参数的模型也可以生成。
- 内存带宽和通信是大模型推理的瓶颈。对于 LLM 推理来说,通常瓶颈不是数学计算,而是内存带宽及通信量、通讯速度。LLM 每个解码步所用的推理时间大部分并不是用于模型的前向计算,而是消耗在了将 LLM 巨量的参数从 GPU 显存(High-Bandwidth Memory,HBM)迁移到高速缓存(cache)上(以进行运算操作)。这意味着在某些情况下,适当增加计算量并不会影响推理速度,可以用于提高并发性。
- 大模型在做推理任务(decoding 阶段)时,往往 batch size 为1,一次只能生成一个 token,无法并行计算,导致大量算力冗余。事实上,在数量增加有限的情况下,输入多个 tokens 和输入一个 token 单轮的计算时延基本一致。如果我们能让大模型一次处理一批 tokens,就能利用上算力,让大模型达到计算和访存平衡。
4.1.2 借鉴
“Speculative execution”(猜测性执行)是一种在处理器(CPU)中常见的优化技术。
它的基本思想是在不确定某个任务是否真正需要执行时,提前执行该任务,然后再来验证被执行任务是否真的被需要,这样做的好处可以增加并发性和性能,一个典型的例子是分支预测 branch prediction。在处理器中,“speculative execution” 通常用于处理分支(branch)指令。当处理器遇到一个分支指令时,它不知道分支条件的具体结果,因此会选择一条路径来执行。如果分支条件最终符合预期,那么一切正常,程序将继续执行。但如果条件不符合,处理器会回滚到分支前的状态,丢弃之前的操作,然后选择正确的路径进行执行。
4.2 思路
上文提到,投机解码最早在两篇论文中被提出。基于上述的观察结果和Speculative execution的机制,在解码自回归模型方面,两篇论文的作者将"speculative execution"这一优化技术进行了推广,将其应用于自回归模型的解码过程中。
投机解码使用两个模型:一个是原始 target model(目标模型),另一个是比原始模型小得多的 draft model(近似模型/草稿模型)。draft model 和 target mode 联合推理,draft 模型生成 $\gamma$ 个 token,而 target 模型则去验证 $\gamma$ 个 token 是否为最后需要的 token。就是使用一个小模型来生成多个草稿 token,然后使用大模型对这多个草稿 token 做并行验证、纠正和优化。这样就可以在接近大参数模型的生成一个 token 的时间里面生成多个 tokens。
- 论文里的并行就是指大模型一次计算多个 token,节省下来传输损耗。即用大模型并行验证这一些 token 是否符合大模型的输出,其思路如下。
- 在一次前向传播中,同时验证多个 draft token。在第一个 draft token 与原始模型输出不相符的位置截断,并丢弃在此之后的所有 draft token。这就是"Speculative execution"中的丢弃。
- 利用 prefill 阶段比 decoding 阶段计算效率高的特点。大模型可以一次 prefill 输入几个小模型 decode 步结果来仲裁、提高推理速度。用大模型的 prefill 模式代替 decode 模式可以节约大模型的访存,以及充分利用 tensor core 来加速矩阵乘法。这不是一个纯算法或者纯硬件系统角度考虑问题的加速方案,而是一个同时从考虑算法以及硬件系统的解决方案。
- 然后,利用一种新颖的采样方法(speculative sampling)来最大化这些推测性任务被接受的概率。“speculative decoding"这种验证和重采样过程在理论上是等价于直接从目标 LLM 采样,因此,可以保证最终生成的文本分布与目标 LLM 一致。
总结下,“speculative decoding” 可以通过充分利用模型之间的复杂度差异,以及采用并行计算的方法,使得从大型自回归模型中进行推理变得更快速和高效。同时保持了与目标模型相同的输出分布(在实现对 target LLM 推理加速的同时,不损失 LLM 的解码质量),而无需更改模型架构、训练过程或输出。下图给出了执行流程。
4.3 对比
投机执行和投机解码对比如下。
| 类别 | 投机执行 | 投机解码 |
|---|---|---|
| 提前执行 | 遇到一个分支指令时,CPU不知道分支条件的具体结果,因此会选择一条路径来执行 | draft model串行推理,生成草稿token。相当于用draft model做逐个token的decoding |
| 验证 | 验证执行结果 | target model针对draft model的串行产生结果并行推理,做验证和优化。相当于用大模型一次prefill输入小模型的几个decode步结果来仲裁 |
| 验证成功 | 如果分支条件最终符合预期,那么一切正常,程序将继续执行 | 接受小模型产生的token |
| 验证失败 | 如果条件不符合,处理器会回滚到分支前的状态,丢弃之前的操作 | 在第一个 draft token 与target model输出不相符的位置截断,并丢弃在此之后的所有 draft token |
| 失败后修复 | 选择正确的路径进行执行 | 调整概率分布 |
投机解码和之前方法对比如下。
| 类别 | 之前方案 | 投机解码 |
|---|---|---|
| 是否改变模型架构 | 许多先前的方法需要修改模型的结构,以使推理过程更高效 | 不需要 |
| 是否改变训练程序 | 一些方法可能需要修改训练过程,以便模型在推理阶段能够更有效地运行 | 不需要修改训练过程,可在现有模型上直接应用 |
| 是否重新训练 | 先前的方法可能需要对模型进行重新训练,以适应新的架构或训练程序 | 不需要 |
| 是否改变输出分布 | 先前的方法在加速推理过程时可能会导致模型的输出分布发生变化 | 通过"speculative sampling"方法,保证了从模型中生成的结果具有与原始模型相同的分布 |
另外,块并行解码(blockwise parallel decodin)和推测解码之间的主要区别在于它们的模型使用。投机解码需要额外的小模型来自回归地生成speculative tokens。这些小型模型受到约束,比目标模型更有效,因此加速可以覆盖它们的成本。
总的来说,作者提出的方法在加速推理过程时避免了许多先前方法所涉及的模型结构和训练方面的变化,同时保持了相同的输出特性。
4.4 分类&设计
投机解码实现加速的关键主要在于如下两点:
Speculative的高效性和准确性:如何又快又准地“推测”LLM 未来多个解码步的生成结果。Verify策略的选择:如何在确保质量的同时,让尽可能多的“推测”token 通过验证,提高解码并行性。
因此,研究人员通常基于这两点来对投机解码的实现和研究进行分类。当然,其分类方式也会略有差别。下图是论文“Unlocking Efficiency in Large Language Model Inference: A Comprehensive Survey of Speculative Decoding”给出的投机解码技术的一个正式分类,包括:
- draft model的策略。具体涵盖如何设计模型,运行终止条件,如何管理多个模型(如果有)。“推测”阶段的设计聚焦在“推测精度(accuracy)”和“推测耗时(latency)“的权衡上。一般来说,用以推测的模型越大,推测精度越高(即通过验证的token越多),但是推测阶段的耗时越大。如何在这两者之间达到权衡,使得推测解码总的加速比较高,是推测阶段主要关注的问题。
- 验证策略。此类别涉及到验证方案和验收标准的设计。验证模型通常是目标模型,其首要目的是保证解码结果的质量。接受标准旨在判断草稿 token 是否应(部分)接受,即接受的 token 长度是否小于 k。在每个解码步骤中,验证模型会并行验证草稿 token,以确保输出与目标 LLM 对齐。此过程还决定了每一步接受的 token 数量,这是影响加速的一个重要因素。采样方法具体来说也分为无损采样和有损采样。
- 无损采样主要是说对于原始 LLM 来说仍然采用原先的采样方法比如贪婪采样或者温度采样等等,然后对应地检查 draft 中是否有符合要求的 token。这种方法核心就是 drafting 对于原始 LLM 来说完全透明,不会损失模型性能。
- 有损采样主要是说通过校验阶段对 draft 质量的评估,然后根据一些先验的阈值来筛选一些高质量的 draft 接受,这种方法的核心就是为了提高 draft 的接受率,在可接受的一些质量损失情况下获得更高的加速。常见验证标准包括 Greedy Decoding,Speculative Sampling,Token Tree Verification 等。因为,并不是所有概率最大的 token 都是最合适的解码结果,所以也有一些工作提出可以适当地放松“验证”要求,使得更多高质量的“推测”token 被接受,进一步提升加速比。
下图则是该论文中对分类内容的进一步细化。draft model的策略对应下图标号1。验证策略对应下图标号2。具体的投机解码方法则对应下图标号3。
4.4.1 推测阶段的策略
推测阶段的策略主要有如下几个部分。
产生草稿
在某种程度上,草稿模型本身通常是一个因果语言模型,可以生成推测性的标记。草稿模型可以是目标模型之外的一个额外的小模型,如 speculative decoding 中生成候选 token,也可以是连接到目标模型的几个轻量级预测头,如 blockwise parallel decoding 中预测即将到来的 token。最近的进展表明,草稿模型也可以是从大型语料库中检索标记的检索者(retriever),以完成前面的上下文。
这些草稿模型具体特点如下。
- Independent Drafting。主要思路是:拿一个跟 target LLM 同系列的 smaller LM 进行“推测”。因为是同系列的模型,所以该小模型本身就存在一定的和 target LLM 之间的“行为相似性“(behavior alignment),适合用来作为高效的“推测“模型。需要强调的是,小模型必须与目标模型具有完全相同的词表。目前对于该思路的优化主要集中在增强小模型和大模型之间的“行为相似性”(behavior alignment),让小模型模仿得“更像”一些。比如知识蒸馏。
- 优点是易于实践和部署。
- 缺点是:
- 并不是所有的 LLM 都能找到现成的小模型
- 在单个系统中集成两个不同的模型会引入额外的计算复杂性,尤其不利于分布式部署场景
- 而且往往需要从头开始训练一个草稿模型,此预训练过程需要大量额外的计算资源。
- 单独的预训练可能会在草稿模型和原始模型之间产生分布变化,从而导致原始模型可能不喜欢的序列结果。
- Self-Drafting。因为上述劣势,相关研究工作提出利 target LLM 自己进行“高效推测”,即使用验证模型本身的作为 drafting model,比如,重用在原始 LLM 中的一些中间结果或者参数,用隐藏层状态来更好地预测未来序列。这种方式天然就没有模型表现一致方面的问题,减少了额外的计算开销,对分布式推理也很友好。在时延方面,Self-Drafting 使用一些策略来使得验证模型平均参数量减少,以此来达到高效的目的。比如 Blockwise Decoding 和 Medusa 在 target LLM 最后一层 decoder layer 之上引入了多个额外的 FFN Heads,使得模型可以在每个解码步并行生成多个 token,作为“推测”结果。然而,这些 FFN Heads 依然需要进行额外的训练。除了这两个工作,还有一些研究提出利用 Early-Exiting 或者 Layer-Skipping 来进行“高效推测“,甚至仅仅是在模型输入的最后插入多个
[PAD]token,从而实现并行的“推测”。- Early-Exiting 则是基于 saturation 的观察:在生成某个 token 时,如果在经过第
i层的前后输出 token 完全一致,我们就认为已经达到饱和点,后续层不需要再继续处理,直接返回第 i 层生成的 token 即可。因为除去了第 i 层后面的层,所以模型参数量会减少。 - Layer-Skipping 是判别哪些 layer 如果被跳过,但是对大多数 token 生成影响不大,就在生成 token 时跳过这些层,以此减少 drafting model 的参数量。
- Early-Exiting 则是基于 saturation 的观察:在生成某个 token 时,如果在经过第
- 基于检索的方法。其思想是大部分常见的句子里面的单词组是可以统计出来的,因此在生成某个token之后,可以通过这个token去检索统计的数据库得到这个token之后大概率是哪些tokens,然后把这些tokens取出来去做验证。
此外,草稿模型不仅限于一个小模型。有人认为,在集成学习的推动下,不同尺度的分阶段或级联小模型可以进一步提高性能。比如论文“Cascade Speculative Drafting for Even Faster LLM Inference”提出了Vertical Cascade 和 Horizontal Cascade。Vertical Cascade 用 Speculative Decoding 来加速 Speculative Decoding。Horizontal Cascade 指的是在接受率较高的前几个 token 用较大的 Draft Model,在接受率较小的靠后的 token 用较小的模型来“糊弄”。
终止条件
speculative tokens 的序列太短或太长都是次优的,但是也难以找到非常合适的判别标准。因此,研究人员也对终止条件进行了深入研究,具体大致分为几种。
- Static Setting:最简单的解决方案是将长度 k 设置为一个静态值,该值可以迭代和手动重新设置。
- Adaptive Thresholding:虽然静态设置可以满足大多数用例,但需要不停的手动调节也可能很麻烦。为了解决这个问题,已经提出了自适应阈值方法,旨在尽早停止基于每个 token 一致性(per-token confidence)的草稿生成动作。如果一致性低于阈值,草稿模型的生成动作将停止。阈值可以根据某些优化目标(例如,草稿 token 的质量)进行自适应调整。
- Heuristic Rules:一些启发式规则也可以用于终止条件的判断。比如,如果验证中完全接受之前的猜测,则推测token的长度将增加,否则将减少。另一种方法可能是从系统服务的角度根据批量大小来改变长度。
尽管已经开发了各种方法来自动检测终止条件的理想值,但仍然很难判断它们是否足够好。在这种需求下,我们应该建立更稳健的方法来搜索和设置这些参数,从而获得更稳定、更吸引人的性能。
4.4.2 验证阶段的策略
在 verification 阶段,也就是使用大模型校验阶段中,分为
- 验证方案(如何组织多个序列的输入,比如 token 树验证(token tree verification))
- 验收标准的设计(采样方法,比如贪婪采样,nucleus 采样,typical 采样)
验证方案
组织多个序列的输入最简单的方法就是直接将所有可能输入形成多个batch。
如果只需要验证一个 token,那么基于链的验证器(将 token 作为序列或链接收的通用验证器)应该就足够了。但是,如果使用多个 token,逐一连续验证这些 token 会有冗余计算的问题,将过于耗时。比如有两个序列 ”maching learning is a“ 和 ”machine learning is the“,其实区别只在于最后一个 token 是”a“还是”the“,前缀相同。
因此,有研究人员提出了一种基于树的验证方法,该策略使目标 LLM 能够并行验证多个草稿序列。该方法首先通过共享前缀从多个候选 token 序列建立一个 trie,并从 trie 树中修剪不太频繁的节点。然后,它在一次运行中用树注意力对其进行并行验证(即,子 token 只能通过注意力掩码看到其父 token),这促进了对潜在多 token 的并行验证。作为对比,如果是单个 token,只需要一个注意力链。而基于树的验证方法所依赖的是因果关系和下三角关系(causal and lower-triangular)掩码,如下图所示。
验收标准
一旦草稿 token 被输入目标模型,我们就可以获得相应的输出概率。通过对齐推测 token 和概率,我们可以推断每个 token 在草稿中是否有效。
精确匹配
最简单的接受标准是精确匹配,它检查 speculative token 是否相应地具有最大概率。该策略是基于贪心算法的。贪心采样的验证主要是保证 Drafting model 和 Verification model 都使用贪心策略的时候结果一致。也就是说,需要验证验证模型的每一个生成是否和 drafting model 的生成完全一样。
注意:两篇开山之作的Mp,Mq是相反的,请大家在阅读时候务必注意。
虽然精确匹配简单清晰直接,可以用较小的成本来保证经过验证的输出与目标模型本身的输出一致,但是存在一些问题:
- 虽然精确匹配可以用较小的成本来保证经过验证的输出与目标模型本身的输出一致,但只有在使用贪婪解码时,这种等式才成立。
- 对于目标模型使用采样解码(sampling decoding)的情况,精确匹配很难从草稿模型中接受token,这可能会导致解码速度减慢而不是加快。
- 过于严格的匹配要求通常会导致拒绝高质量的token,仅仅是因为它们与目标LLM的前1个预测不同,从而限制了范式的加速。
拒绝采样(Rejection Sampling)
基于上述问题,多项研究提出了各种近似验证标准。与无损标准相比,这些方法略微放宽了匹配要求,以更加信任草稿,从而提高了草稿token的接受度。比如,研究人员提出了一种从拒绝采样(Rejection Sampling)中修改的验收标准来缓解这一问题(就是那两篇开山之作)。理论上,这种接受标准可以应用于贪婪解码和采样解码。
Typical Acceptance
上述两个验收标准为质量提供了严格的保证。然而,过于严格的验收标准可能会抵消并行验证的努力,并降低推测执行的负担,尤其是在施加温度参数的情况下。因此,在某些情况下,需要适度放宽接受标准,以实现更明显的加速。Typical Acceptance就可以做到这一点:如果token的投机概率超过硬阈值,则接受草稿中的token。另外,阈值也是可以通过top-k约束动态调整的。对于提供多个token的情况,Typical Acceptance将考虑形成最长序列的token,并放弃其他token。
0x05 算法
5.1 总体流程
下图给出了投机解码的算法总体流程。该算法通过首先使用更高效的近似模型 Mq 生成多个猜测 token,然后使用目标模型 Mp 并行评估这些猜测 token 的概率,并根据评估结果来决定哪些猜测 token 可以被接受(并行地接受那些能够导致相同分布的猜测 token)。如果需要,算法还会调整目标模型的分布以保持一致性。最终,算法会返回从 Mp 和 Mq 中得到的生成结果。这个过程有效地利用了两个模型的优势,加速了生成过程。
这里假设pi(x),qi(x)分别是target,draft模型的分布。
我们用一个例子展示随机采样的工作方式。下图中,每一行代表一次迭代。绿色的token是由近似模型提出、且目标模型接受的建议。红色token:近似模型提出但目标模型拒绝的建议;蓝色token:目标模型对于红色token的订正,即拒绝红色的token并重新采样得到蓝色的token。
在第一行中,近似模型生成了5个 token,目标模型使用这5个 token 和前缀拼接后的句子” [START] japan’s bechmark bond”作为输入,通过一次推理执行来验证小模型的生成效果。因为最后一个 token ”bond“ 被目标模型拒绝,重新采样生成”n“。这样中间的四个 tokens,”japan” “’s” “benchmark”都是小模型生成的。以此类推,由于用大模型对输入序列并行地执行,大模型只 forward 了9次,就生成了37个 tokens。尽管大模型的总计算量不变,但是大模型推理一个 token 的延迟和小模型生成5个 token 延迟类似(并行总是比一个一个生成要快),从而显著提高了生成速度。
5.2 关键步骤
我们接下来分析 SpeculativeDecodingStep 算法的关键步骤和操作。
5.2.1 前置条件
算法的输入有三个参数:目标模型(target model)Mp,草稿模型(draft model)Mq和已知前缀prefix。
target model
- 目标模型是指原始的大型自回归模型,例如大型的Transformer模型。它是进行推理的主要模型,负责生成精确的输出。目标模型通常拥有更多的参数和计算资源,但也因此导致单步推理速度较慢。
- 假设 Mp 为目标模型,模型推理就是给定前缀输入 x<t,从模型获得对应的分布 p(xt|x<t)。投机解码要做的就是加速这个推理过程。
draft model
- 草稿模型是一个更为高效的近似模型,其设计旨在在给定前缀的情况下,能够更快地生成下一个token。相对于目标模型,它可能具有较少的参数和更高的计算效率,以便提高整体推理速度。草稿模型可以采用与原始模型相同的结构,但参数更少,或者干脆使用n-gram模型。
- 假设 Mq 为针对相同任务的更高效的近似模型,给定前缀输入 x<t,从模型可以获得对应的分布 q(xt|x<t)。
论文“Speculative Decoding: Exploiting Speculative Execution for Accelerating Seq2seq Generation”确立了 草稿模型的两个原则:Capability Principle(尽可能准)和 Latency Principle(尽可能快)。另外需要注意的是,小模型的参数量要远小于原模型参数量一个级别才效果明显;草稿模型和原模型需要使用同样的tokenizer,不然会增加额外的解码、编码时间。
5.2.2 第一步 - 采样
对于输入 prefix,在用 LLM(目标模型)做推理的同时,并行地让草稿模型基于输入 prefix 以自回归的方式串行运行 γ 次,生成得到 γ 个 token(称作 guesses 或 draft tokens)。即,对于每个 i∈γ,计算 qi(x),其中 qi 是 Mq 在 prefix+[x1,...,xi−1] 上的分布,然后从这个分布 qi 中采样一个 token xi。
5.2.3 第二步 - 并行运行目标模型
把生成的 γ 个tokens和前缀拼接一起送进目标模型Mp执行一次前向传播(验证模型把候选tokens输入采样以prefill的方式并行计算)。即,对于每个猜测的token xi,将它附加到前缀 prefix+[x1,…,xi−1] 上,然后并行运行目标模型Mp,分别得到了Mp的输出 p1(x),…,pγ+1(x)。和自回归相比,尽管计算量一样,但是 γ 个tokens可以同时参与计算,计算访存比显著提升。
5.2.4 第三步 - 计算接受的猜测token数量
既然得到了目标模型和草稿模型的输出,我们就可以计算接受的猜测token数量。但是,如何评价一个token生成的好坏与否?这是通过比较 Mq 生成的猜测和 Mp 对这些猜测的评估概率(p,q表示在大小模型采样概率,也就是logits归一化后的概率分布)来决定的。即,依次使用每个 Token 对应的分布进行对比,找到满足要求的 Token 个数 n,或者说,根据p/q针对每个token进行拒绝采样,以此来确定接受的token数。
-
如果q(x)⩽p(x), 说明大模型在小模型采样出的这个token x上置信度更高,不需要再消减在这个token上的概率分布,就接受这个结果。
-
反之,如果q(x) > p(x),说明大模型在token x上的置信度低于小模型,那么需要把大模型在这个token上的一些采样概率分给其它大模型上信心更高的token,则以p(x)q(x)的概率接受当前token,以1−p(x)q(x)的概率丢弃x。
5.2.5 第四步 - 调整概率分布
这一步会依据需要来调整Mp 的分布。
- 如果接受的猜测token数量n小于γ,那么可能需要对 Mp 生成的分布进行调整,以确保接受的猜测与目标模型的输出分布保持一致。
- 具体是通过创建一个新的分布 p′(x)=norm(max(0,pn+1(x)−qn+1(x)))来调整。
需要注意的是,如果一个样本被拒绝,这意味着我们需要重新进行计算,这会导致一些计算资源的浪费。因此,在算法运行过程中,我们希望尽可能地接受 Mq 生成的样本,以减少计算成本。
我们后续会详细阐释调整概率分布的意义。
5.2.6 第五步 - 返回生成的结果
调整概率分布之后,会返回生成的结果。
-
如果小模型生成结果都满意,则用大模型采样下一个token。即,用 Mp 采样下一个token,加上 Mq 生成的 n 个 tokens,一并返回。
-
如果某一个token x 不满意,x 被拒绝,则从token x 之后的tokens都被丢弃。因为第四步已经调整了Mp的分布,会从这个新的概率分布p′(x)=norm(max(0,p(x)−q(x)))中重新采样一个token作为纠正。
因为加上了后面这个大模型拒绝采样,并补充大模型概率分布差采样的过程,所以上面这个采样过程和直接从p(x) 采样是等价的。
一共最多可以生成多少个token?如果把验证过程看成接受概率为α的连续γ次判定过程,从上述算法流程知道输出token的长度范围是[1,γ+1],有以下3种情况
- 情况1:当第1个token就被大模型拒绝了,那么就直接用大模型的采样输出,生成长度为γ=1
- 情况2:当第t个token被大模型接受,但是第t+1个token被大模型拒绝的时候,生成长度为L=t+1。注意此时t≤γ−1
- 情况3:当所有k个token都被大模型接受,此时理应达到最大生成长度L=γ。但如果draft生成的γ个token都通过验证,那还可以从已经计算的第γ+1个token的logits中额外采样出一个,而且这个token是target模型生成的,也就不需要验证了。因此最终生成长度L=γ+1
5.3 重点分析
我们接下来看看投机解码中的一些重点。
5.3.1 并行验证
我们用示例来看看如何进行并行验证。
下图中,输入为:Our technique illustrated in the case of 。小模型串行生成三个token,小模型每次都是接受(1, vocal_size)的输入。具体参见下图标号1。
- 第1次推理,小模型生成 unconditional。
- 第2次推理,小模型生成 language。
- 第3次推理,小模型生成 modeling。
有两种方案来验证这些token。
方案1是论文中提出的方案,具体参见上图标号2。论文里的并行就是指一次计算多个token,节省传输损耗。然而,论文里对Mp进行并行计算,是一种不顾及计算资源的加速。它在每一步都尝试并行计算大模型的观点,从而达到速度上的最优化,但同时对并行计算能力要求极高。比如r为3时,就需要4个大模型同时计算。在极致并行的情况下,速度可以达到理论最优,但代价是算力的浪费,这在工程上是不可接受的。
方案2是实际工作中的方案,利用prefill阶段(并行处理多个token)比decoding阶段(串行生成多个token)计算效率高的特点来完成加速。target模型的任务不是生成,而是验证。由于现代计算机的并行能力,我们可以近似的认为大模型处理一个token和并行处理多个token的用时是几乎一样的。这就保证额验证这一过程可以并行实现,即调用一次target模型执行prefill操作,就可以完成对多个草稿模型(多个decoding步骤)的一次性验证,从而减少了推理步骤。同时,根据 Mq 对 Mp 的逼近程度,还可能生成多个新token,最多可以生成 γ + 1 个。上图标号3展示了这个过程。大模型一次性接受的是(3, vocal_size)的输入,即,直接检查unconditional ”、“language ”、“modeling” 这3个新token,所以叫并行。其思路和 LLM 训练阶段的交叉熵验证一样,通过错位方和矩阵计算的并行性,一步计算就可以验证小模型生成的3个结果对不对,即可完成验证。需要并行执行四次验证(以 argmax 为例):
- Prefix “Our technique illustrated in the case of”,生成 “unconditional ”,与近似模型生成的第一个 Token “unconditional ” 相同,接受。
- Prefix “Our technique illustrated in the case of unconditional”,生成 “language”,与近似模型生成的第二个 Token “language” 相同,接受。
- Prefix “Our technique illustrated in the case of unconditional language”,生成 “method”,与近似模型生成的第三个 Token “modeling” 不相同,不接受。
- Prefix “Our technique illustrated in the case of unconditional language modeling”,生成 “of”,作为候选,如果前面都接受,则接受该 Token。
5.3.2 加速效果
实现加速的原理是什么?简而言之,投机解码相比自回归采样之所以有加速效果,是因为它减少了对原始模型串行调用的次数。因此,投机解码需要将以下两个步骤结合在一起,才能实现推理的加速。
- 草稿生成。Mq 生成 γ 个补全。因为 draft 模型参数量少,相比于 target 模型生成 token 更快,是更高效的模型,所以减少了生成补全的时间。
- 草稿校验。使用目标模型 Mp 并行评估来自 Mq 的所有猜测及其相应的概率,接受那些可以导致相同分布的猜测,并从调整后的分布中抽取额外的一个token,以修复第一个被拒绝的token,或者如果所有token都被接受,则添加一个额外的token。即,通过错位和矩阵计算的并行性,一步计算就可以验证小模型生成的 γ 个结果对不对。
下图示例中包含不同的 𝛶(验证的 Token 数目),其中紫色为执行目标模型 Mp 的 decoder,蓝色为执行近似模型 Mq 的 decoder,黄色和橙色为调用 encoder。这里规定一次迭代可以接收小模型的tokens数为generated tokens。加速效果和 γ ,p,q都相关。直觉上讲,γ 越大,p、q分布越接近,则 generated tokens越大。
用通俗的话来解释。
- 最下方是大模型直接预测新的token,耗费时间太长。
- 中间和上方是先使用小模型预测 𝛶 个 token,然后大模型借助矩阵计算的并行特性,一次性就可以验证这 𝛶 个中,前面哪几个是对的。如果有对的,那就节约很多时间(因为小模型远小于大模型,所以小模型消耗的时间基本可以忽略不记)。
影响加速比的因素是:
- 小模型的尺寸及一次推理的token数目。
- 小模型生成候选tokens的时延。
- 大模型对小模型推理token的接受率,或者说小模型和大模型的Align程度。
因此,如果小模型的输出草稿接受率足够高,且生成候选tokens的时延不长,那么投机解码就能够获得更高的加速比。假设我们一次猜n个tokens,平均有m个token会被最终接收,那么在这个过程中:我们调用了n次小模型D,1次大模型T,生成了m个token。只要nD显著地小于(m-1)T,就能实现很好的加速效果。
理解了原理,我们就可以知道这个方法加速的限制:小模型生成的分布是否与大模型一致。验证的接受率会很大程度上影响最终的加速比,接受率越高,减少的 Decoding Step 数量就越多,因未接收而浪费的计算就越少。
5.3.3 调整分布
我们提出一个问题:在算法的第四步,当 n < γ 时,为什么需要调整从目标模型(Mp)得到的分布?这个调整的目的是什么?
这就涉及到投机解码的另外一个核心:如何确保通过投机解码得到的token的概率和从大模型直接采样相同。事实上,投机解码和投机解码两篇论文都给出了证明:这种验证和重采样过程在理论等价于直接从目标 LLM 采样,因此,可以保证最终生成的文本分布与目标 LLM 一致。即,对于任意分布p(x)和q(x),通过从p(x)和q(x)进行投机解码所得到的token的分布与仅从p(x)进行采样所得到的token的分布是相同的。
我们首先概述下如何证明。本质上我们想考察的是p(x=~x)的概率,在使用了投机解码策略之后,是否还依然等于我们的原始概率q(x=~x),即q(~x)。概率拆解思路为:有两种可能采样出~x,可以证明通过重采样之后,总体概率和原始概率一致。
- 路径1:小模型p(⋅|⋅)采样出了~x,并且成功的接受了。注意,如果此时对~x发生了拒绝,是不可能通过重采样得到~x。原因是,发生拒绝就说明q(~x)小于p(~x),因此在重采样中max(q(~x)−p(~x),0)为0,不可能重采样出~x。
- 路径2:小模型p(⋅|⋅)采样得到了其他值x≠~x,并且发生了拒绝,此时重采样得到~x。
其次,详细推导流程参见下图,我们基于论文 “Accelerating Large Language Model Decoding with Speculative Sampling” 的公式进行整理和注释。
产生偏差
当 n < γ 时,意味着从更高效的近似模型 Mq 中采样的token数量少于 γ,也就是说,其中一些猜测被目标模型 Mp 拒绝了。这可能是因为 Mq 生成的猜测与目标模型 Mp 的真实分布存在一定的偏差。
当使用这个近似模型 Mq 生成的token的概率小于或等于目标模型 Mp 生成这个 token 的概率时,我们会保留这个token。当近似模型 Mq 生成的token的概率大于目标模型 Mp 生成这个 token 的概率时,我们不能简单地接受这个token,因为这可能会导致生成的结果与目标模型的分布不一致。因此,在这种情况下,我们会以一定的概率拒绝这个token,并重新从调整后的概率分布中重新采样。
注:快速理解,如果 Mp 生成某个token的概率是0.5,Mq 生成该token的概率是0.6,说明 Mq已经比大模型还飘,不可信了。
弥补偏差
调整目标模型 Mp 分布的目的是为了弥补从近似模型 Mq 中得到的猜测与目标模型 Mp 分布之间的差异,以保证最终生成的结果符合目标模型的真实分布。这样可以确保在猜测性解码过程中得到的结果保持了一定的准确性和一致性。
调整分布操作弥补了小模型 Mq 和大模型 Mp 之间的概率分布的gap。思路是:对于小模型 Mq 的每一次猜测,根据大模型 Mp 和小模型 Mq 的概率分布去判断这一次猜测有多大概率是正确的。相当于是从小模型 Mq 的采样到大模型 Mp 的采样之间做了一个映射。可以把小模型 Mq 和大模型 Mp 的概率分别看成若干个随机事件,然后将小模型 Mq 的随机事件和大模型 Mp 的随机事件做映射,如果两边的随机事件的结果一致,我们就认为这个猜测是正确的。特别地,如果两个概率分布一样,则猜测正确的概率为1。如果在某一步中,我们认为小模型 Mq 的猜测是错误的,那么后面的结果都是无效的。此时用大模型 Mp 最后一步得到的概率分布做一个采样后退出。这一步既是保证输出是同分布的,又可以保证每次至少输出一个token。
具体来说,作者需要定义一个新的分布 p′(x),它是根据目标模型 Mp 的原始输出分布 pn+1(x) 调整而来的。如果 n < γ(即目标模型拒绝了一些猜测),作者使用了一个调整函数来修改 pn+1(x)。这个调整函数是 max(0,pn+1(x)−qn+1(x)),它的作用是确保 pn+1(x) 不小于 qn+1(x)。这样做的目的是为了尽量保持目标模型生成的分布与近似模型的分布一致。
这里给一个直观的解释。这个调整后的概率分布p′(x)是通过将目标模型的概率分布(p(x))与来自近似模型的概率分布(q(x))进行相减,并取结果的最大值,然后将其归一化得到的。这个调整后的分布确保了我们从目标模型中采样的结果具有相同的分布特性,同时也能够处理那些被拒绝的token,保证最终的生成结果保持一致性。
p(x’) > q(x’)说明大模型在token x’上概率大于小模型,则大模型对生成token x’更有把握,说明小模型生成的问题不大,可以保留x’。如果p(x’) ≤ q(x’)则小模型更有把握,大模型就以1-p(x)/q(x)为概率概率拒绝,并重新采样。因为接收的概率更偏向q(x)大的位置,重新采样的概率应该更偏向p(x)大的位置,所以是norm(max(0, p(x)-q(x))。
弥补结果
从调整后的分布中生成一个额外的 Token(根据第一个出错 Token 之前的 Token 生成),来修复第一个出错的 Token,如果所有 Token 都被接受,则额外新增一个新生成的 Token(这个token是target模型生成的,也就不需要验证了),以此来保证每次至少生成一个新的 Token。这样,即使在最坏情况下,目标模型相当于完全串行运行,运行次数也不会超过常规模式直接串行运行目标模型的次数(每个目标模型的并行运行至少会生成一个新的标记);当然,也很可能能够生成更多的 Token,最多可以达到 𝛶+1,这取决于近似模型 Mq 对目标模型 Mp 的逼近程度。
5.3.4 优化
在推测解码方法中,草稿 token 的接受率受到草稿模型的输出分布与原始大模型的输出分布的一致程度的显著影响。因此,大量的研究工作都是在改进草稿模型。
DistillSpec直接从目标大模型中提取较小的草稿模型。SSD包括从目标大模型中自动识别子模型(模型层的子集)作为草稿模型,从而消除了对草稿模型进行单独训练的需要。OSD动态调整草稿模型的输出分布,以匹配在线大模型服务中的用户查询分布。它通过监视来自大模型的被拒绝的草稿token,并使用该数据通过蒸馏来改进草稿模型来实现这一点。PaSS提出利用目标大模型本身作为草稿模型,将可训练的token(lookahead token)作为输入序列,以同时生成后续token。REST引入了一种基于检索的推测解码方法,采用非参数检索数据存储作为草稿模型。SpecInfer引入了一种集体提升调优技术来对齐一组草稿模型的输出分布通过目标大模型。Lookahead decoding 包含大模型并行生成n-grams来生成草稿token。Medusa对大模型的几个头进行微调,专门用于生成后续的草稿token。Eagle采用一种称为自回归头的轻量级Transformer层,以自回归的方式生成草稿token,将目标大模型的丰富上下文特征集成到草稿模型的输入中。
另一项研究侧重于设计更有效的草稿构建策略。传统的方法通常产生单一的草稿token序列,这对通过验证提出了挑战。对此,Spectr主张生成多个草稿token序列,并采用k-sequential草稿选择技术并发验证k个序列。该方法利用推测抽样,确保输出分布的一致性。类似地,SpecInfer采用了类似的方法。然而,与Spectr不同的是,SpecInfer将草稿token序列合并到一个“token tree”中,并引入了一个用于验证的树形注意力机制。这种策略被称为“token tree verifier”。由于其有效性,token tree verifier在众多推测解码算法中被广泛采用。除了这些努力之外,Stage Speculative Decoding和Cascade Speculative Drafting(CS Drafting)建议通过将投机解码直接集成到token生成过程中来加速草稿构建。
0x06 实现
我们使用 https://github.com/huggingface/transformers/src/transformers/generation/utils.py 来进行学习
6.1 全局循环
在_assisted_decoding()函数中的while循环里面进行投机解码。
|
|
6.2 外层逻辑
此处包括获取草稿模型的输出,调用论文的算法,依据算法结果对token进行调整。
|
|
6.3 实施算法
注释中写到,实现了论文“Fast Inference from Transformers via Speculative Decoding”的算法1,即如下算法。
代码如下。
|
|
0x07 Token Tree Verification
因为Token Tree Verification的重要性,我们单独用一节来进行阐释。
前文提到过,Token Tree Verification 使目标 LLM 能够并行验证多个草稿序列。其思路就是:让草稿模型在每个时间步都输出 k 个候选 token,然后通过共享前缀从多个候选 token 序列建立一个 trie,并从 trie 树中修剪不太频繁的节点。最后在一次运行中用树注意力对其进行并行验证(子 token 被注意力掩蔽,只能看到其父 token)。
7.1 问题
7.1.1 采样多个序列
论文“SpecInfer: Accelerating Generative Large Language Model Serving with Tree-based Speculative Inference and Verification”发现,大模型验证失败的时候,真实生成的token大多数时候其实也是小参数模型的top-k的tokens。下图展示了使用greedy和stochastic decoding两种方法topK里面k从1到5在各个数据集上的验证成功率。可以看出,尽管预测next next token的top-1准确率徘徊在60%左右,但是在小参数模型每一个step都保留top-5的时候,最后的验证成功率都大大提高。如果使用necleus sampling,top-3的成功率就已经超过了90%。
基于此,我们不应该采样一个单独的序列型的的tokens,而是采样一个树状的token树。不止在第一步猜k个token,我们可以在每一步都猜多个tokens,这样每一步的几率都会变大。只要由此带来的额外的计算开销小于更高的带来的加速,那么猜更多的token就是可以接受的。
7.1.2 验证多个序列
但是,如何对这个token树进行验证?即,如何组织多个序列的输入?组织多个序列的输入最简单的方法就是直接把每一个叶子节点到根节点的所有token组成一个序列,然后进行验证,这种方案存在几个问题:
- 逐一连续验证这些token会有冗余计算的问题,将过于耗时。
- 一些工作发现,一次预测一条链的话,概率衰减的非常快,所以不能预测很长的链,导致不能充分利用上大模型验证的并行度。
另一个方法是把每一个叶子节点到根节点的所有token组成一个序列,n多个叶子节点就会组成n个序列,然后把这n个序列当成batch size=n的输入进行prefill。然而这种方式的问题是根节点的计算不能被复用。
我么接下来看看研究人员是如何解决上述问题。
7.2 思路
7.2.1 开山之作SpecInfer
为了解决上述问题,SpecInfer设计了 Tree Based Parallel Decoding 机制。其核心思路为:通过一系列小模型 SSM(Small Speculative Model)联合预测 LLM 输出,并将这些小模型的预测输出组织为 Token 树,树中每个分支表示一个唯一的候选 Token 序列。最后,LLM 使用基于树的并行解码(Tree-Based Parallel Decoding)机制来并行的验证 Token 树中所有 Token 的正确性,这里树的解码算法还可以重用这些序列之间共享的中间结果。SpecInfer 使用 LLM 作为 Token 树验证器而非增量解码器,这显著降低了生成式 LLM 的端到端延迟,同时可以保持模型的质量。
SpecInfer的具体流程如下。
-
先为每个 SSM 生成了一棵输出树,即在每个 token 取若干种可能性构成一棵树,之后将这些树合并成一棵更大的树。当生成更大的树之后,把该树拓展成若干个token序列。
-
将生成的树进行验证。树结构会带来token之间复杂的依赖关系,如果对树上的每一个从root到leaf的路径都用大模型做一次验证,大量的叶子节点也会导致算法退化到最原始的一次预测一个token的场景。针对这个情况,SpecInfer提出了tree attention来加速decoding的速度。方法是将树上的祖先关系变成attention-mask的可见关系,使得模型可以一次验证多个 sequence。如下图所示,对于这样一棵树,如果采用常规的 mask 方式,t6 是可以看到 t5 的,但在图上的 mask 矩阵下,每个 token 只可以看到自己的 prefix,从而使得 LLM 可以一次完成对于多个 sequence 的不互相干扰的验证。
7.2.2 如何组织树
有多种组织树的方法,具体参见下图。
以下图右下角的Sequoia为例,接受向量是p=(p1,p2,…,pk,…),其中验证算法在子位置k接受token的概率为pk。树的具体构建方法基于 positional acceptance assumption:假设token t 是已接受token的第k个子token,则验证算法接受token t 的概率仅取决于 k 的值 。每个子节点的得分为从根节点到此节点的所有 pk (验证算法在子位置 k 接受token的概率)相乘。最后的目标是,在给定节点数量的情况下使整棵树所有节点得分相加最大。这个问题的解可以用更小的子问题的解来表示,因此可以通过动态规划求解。求得的树结构满足预测概率较大的子节点会有更多的子孙。
再比如,下图给出了EAGLE-2的Token Tree Verification。树的边上的数字表示草稿模型的置信度得分,块内括号中的数字表示节点的值。在扩展阶段,我们从当前层(橙色块)中选择值最高的前2个节点作为草稿模型的输入,并将生成的token(绿色块)连接到草稿树。在重新排序阶段,我们从所有节点(蓝色块)中选择值最高的前8个节点,将其展平为一维序列以形成最终草稿。然后,我们根据树结构来构建注意力掩码,确保每个token只能看到其祖先节点。
7.3 Attention Mask
Medusa 中的 Attention Mask 矩阵如下图所示。左侧给出了候选序列。而其对应的 Attention Mask 矩阵如右侧所示。在图上,Head 1 在下一个位置生成 2 个可能的 Token(It 和 I),Head 2 在下下一个位置生成 3 个可能的 Token(is,’ 和 the)。因为第一个头部的任何预测都可以与第二个头部的任何预测配对,这样下一个位置和下下一个位置就有了 2 x 3 = 6 种可能的候选序列,最终形成一个多层树结构。这棵树的每一层都对应于一个Medusa Head的预测。在这棵树内,Attention Mask只限制一个token对其前面token的注意力。
参考资料
-
No backlinks found.