Transformer的推理中的kv缓存介绍

Wings Lv3

背景前提

随着整个Transformer架构的确定,模型也能够在一定的数据条件下被训练出来。接下来模型就需要推理。

Transformer框架的推理和训练并不是和之前深度学习那般,训练的数据格式是什么,推理的格式也是什么(大多数的逻辑是,输入不定的数据可以经过变换转换成和训练数据一样)。但是Transformer框架的训练和推理却并不一样,前者大概率在固定的序列长度训练,而后者则会不断地让整个序列变长。

这就导致,计算开销和内存开销在训练的时候可能并不是那么明显,因为此时数据长度是固定的(可以做切割和填充),虽然这仍然可能受到Attention架构的平方复杂度的影响。但是在推理的时候,推理的数据是变长的(本篇并不关心动态批次推理的问题,主要关心单样本推理的问题)。而且在当前大家所使用的场景来看,很容易超过训练所设定的序列长度,此时就更容易受到序列平方复杂度的影响。

而在这里虽然不能改变序列变成的情况,但是能够将计算开销和内存开销从的复杂度,这就是KV Cache。

推理下Attention的平方复杂度

在这里重新回顾一下在推理背景下,Attention的关于序列的平方复杂度。首先在推理背景下,假设给定前缀Token(很多时候也叫Prompt,提示词),通过给定的提示词Token生成下一个(有时候可能会一次性生成好几个,在这里首先关注一个)。

那么此时就有,给定一个Token序列其对应的嵌入向量有,所以整个序列所对应的嵌入向量矩阵为。当前经过Attention模块有,所以有转换。然后就有核心计算:

注意到。也就是说,随着不断地增大,整个矩阵计算和序列长度是一个平方的关系。

此时得到最终的输出以后,通过的状态得到下一个token的结果,又加入到原来的序列,所以整个序列会在不断地推理下变得原来越长。那么Attention关于序列长度时间复杂度为就成为了推理无法避免的一个问题——随着序列长度越大所需要的计算和内存也越大。

复用kvCache降低复杂度

在这里介绍一个最简单的改善思路,那就是KVCache(KV缓存)。正如上述所说,每一步推理生成新的token会拼接回原来的token,此时就有的序列长度,而之前的序列实际上已经在模型之中生成过对应的对应的矩阵了。所以一个比较直观的想法是,能不能复用原来的矩阵从而减少计算量?

在这里可以做一个简单地对比,假设新的序列长度为进入Attention中,此时得到的矩阵就是:,计算过程有

其中。这个过程,实际上我们只想要的是最新的状态,也就是。那为了得到这个,实际上只需要最新的,但需要所有的。这是因为它需要之前所有的位置来确定当前的序列状态。

而我们看,假设只有最后一个token输入,此时经过转换就有。如果我们把之前的拼接上去就可以得到上述的有:

所以每次推理,只需要把每一层的KV向量存储起来,然后再给定新的token下,拼接好KV,就可以很方便地进行Attention计算了。此时我们的矩阵计算就变成了:

其中,所以当前的复杂度就变成了从平方复杂度变成线性复杂度。虽然很多时候线性复杂度也是仍然不够的,因为整个系统仍然需要为了不断增长的内容而烦恼。

  • 标题: Transformer的推理中的kv缓存介绍
  • 作者: Wings
  • 创建于 : 2026-06-04 10:00:00
  • 更新于 : 2026-06-05 10:32:42
  • 链接: https://www.wingslab.top/深度学习/Transformer的推理中的kv缓存介绍/
  • 版权声明: 本文章采用 CC BY-NC-SA 4.0 进行许可。