探秘Transformer系列之(28)--- DeepSeek MLA

探秘Transformer系列之(28)--- DeepSeek MLA

0x00 概述

MLA(Multi-Head Local Attention / 多头潜在注意力)的基本思想是将注意力输入\(h_t\) 压缩成一个低维的潜在向量 \(c^{KV}_t\) ,维度为 \(d_c\),且 \(d_c\) 远小于原始的维度(\(h_nd_h\))。在需要计算注意力时,可将这个潜在向量 \(c^{KV}_t\) 映射回高维空间。因此,只需要存储潜在向量 \(c^{KV}_t\) ,就可以显著减少内存的占用。

这个过程可以通过以下公式更正式地进行描述。其中 \(c^{KV}_t\) 表示潜在向量;\(W^{DKV}\) 是压缩矩阵(上标 D 代表"下投影",即降维操作),负责将 \(h_t\) 的维度从(\(h_n·d_h\))压缩到\(d_c\)\(W^{UK}\)\(W^{UV}\) 是上投影矩阵,负责将共享的潜在向量 \(c^{KV}_t\) 映射回高维空间。只需要存储这个潜在向量 \(c^{KV}_t\) ,就能获得对应不同文本特征的Key和Value,而不需要对每个文本特征都存储对应的Key和Value。

类似地,我们也可以将查询向量映射到一个潜在的低维向量,然后再将其映射回原始的高维空间。而且,MLA又结合了权重吸收技术,减少了计算开销。


注:

  • 全部文章列表在这里,估计最终在35篇左右,后续每发一篇文章,会修改此文章列表。cnblogs 探秘Transformer系列之文章列表
  • 本系列是对论文、博客和代码的学习和解读,借鉴了很多网上朋友的文章,在此表示感谢,并且会在参考中列出。因为本系列参考文章实在太多,可能有漏给出处的现象。如果原作者或者其他朋友发现,还请指出,我在参考文献中进行增补。

0x01 原理

1.1 问题

标准Transformer的一大障碍就是KV Cache的空间占用问题:多头注意力机制需要为每个注意力头单独存储历史生成的Key和Value向量(即KV缓存)。随着序列长度增加,KV缓存的存储需求呈指数级增长,导致内存占用急剧上升。而GPU的显存空间往往非常有限,较大的KV Cache会导致同时处理的request数量变少,也即batch size较小;为了减少KV缓存需求,研究人员提出了像Multi-Query Attention(MQA)和Group-Query Attention(GQA)这些方法。这些方法虽然降低了缓存要求,可模型的性能也受到影响。MQA或GQA算子在计算注意力的过程中,所有KV Cache中的数据读取后都仅参与一次或几次计算,导致该算子的MFU极低,并且由于每个request有自己的KV Cache,这一问题无法通过提高batch size的方式解决。

因此,如何减少推理过程的KV Cache,从而实现在更少的设备上推理更长的Context,或者在相同的Context长度下增大batch size,实现更快的推理速度或者更大的吞吐总量,最终降低推理成本。是一个关键问题。

1.2 当前状况

我们首先总结下当前各种方案的情况来看看有没有可以改进的空间。下图给出了MHA、GQA、MQA 以及 MLA 做法。

图上从左到右依次是 MHA、GQA、MQA 以及 MLA 。图中有阴影的长条表示会缓存到显存的结果。MHA、GQA、MQA 都需要将 KVCache 缓存到显存。几种方案特点如下。

  • MHA:MHA KVCache 在注意力头这个维度和 Q 矩阵一样,属于“一对一”。MHA把一个注意力计算拆成多个注意力头,每个注意力头使用独立的Q、K、V进行计算,需要把K、V都存储下来,KV Cache中每个token需要缓存的参数量为\(2𝑛_ℎ𝑑_ℎ𝑙\)。而GQA、MQA 在注意力头的维度比 Q 矩阵小。
  • MQA:所有查询头共享相同的单一键和值头,因此只需要存储共享的K和V,KV Cache中每个token需要缓存的参数量为\(2d_hl\)。在计算注意力时,会把共享的单一K头和V头广播给每个查询头,然后分别一一计算。
  • GQA:将所有的Q头分成g组,同一组的Q头共享一个K头和一个V头,因此KV Cache中每个token需要缓存的参数量为\(2𝑛_g𝑑_ℎ𝑙\)。在计算注意力时,会把KV头复制给所在组的所有Q头进行计算。

\(n_h\)是注意力头数量,\(n_g\)是GQA分组数,\(d_h\)是隐藏层维度,\(l\)为block的块数,\(ℎ_𝑡\in𝑅^𝑑\) 表示第 𝑡 个token在一个attention层的输入。

1.3 改进思路

MLA是对MHA、GQA、MQA方案的改进,其思路是加强信息压缩能力(对应下图标号1)和丰富信息表达能力(对应下图上标号2),其实,两个标号也对应了从输入到Q、K、V的数据流上两个关键点,也是硬币的两面:增强了矩阵的表现能力的同时,也会使得压缩能力更大。

于是就是研究人员经常遇到的困境了:既要压缩更低(降低推理过程中的KV Cache资源开销),又要表现力更强(缓解MQA、MGA对性能的损耗),或者说新方案的表现力要尽可能接近MHA。

1.3.1 增强信息压缩能力

思路

从某个角度考虑,MQA和GQA 也属于低秩压缩的思想,MQA将 \(2n_ℎ\) 压缩到2,GQA 则压缩到 \(2n_ℎ/g\)。但是压缩能力和性能难以兼顾,所以GQA效果要好于MQA。

因此我们要思考,是不是可以在“增强信息压缩能力且兼顾效果”之上再进一步?因为MQA在KV头上已经几乎做到了极致,因此我们没法从KV头数量上做减少。那就势必得从KV本身思考。目前,不管是GQA还是MQA,都需要缓存K、V两个值,两个值不一样。那么,是否可以把两个值合并为一个值?有没有可能每个缓存的KV都比之前小?从LoRA那里得到启发,一个\(M \times N\)的矩阵可以近似成两个\(M\times k\)\(k \times N\)矩阵的乘积,如果我把一个K或者V矩阵拆成两个小矩阵的乘积,就可以减少KV Cache的显存占用。

方案

MLA的核心是对注意力键和值进行低秩联合压缩,以减少推理期间的键值(KV)缓存大小,从而提高推理效率。与 GQA、MQA 直接压缩 KVCache 头维度不同,MLA通过使用下投影矩阵 \(W^{DKV}\)将多个注意力头的Key和Value投影到一个低维的共享潜在向量空间中,取代了传统的逐头存储方式。

具体而言,MLA将KV矩阵转换为低秩形式:将原矩阵表示为两个较小矩阵(相当于潜向量)的乘积。具体而言,

  • 对输入矩阵的 HiddenState 会先做低秩转换,将一个 Shape 为 [S,H] 的 HiddenState 压缩到 Shape 为 [S,CH] 的潜在向量\(c_t^{KV}\),其中 CH≪H 。H是token维度。
  • 将压缩后的KV向量\(c_t^{KV}\)作为 KVCache 存储到显存中,这样就达到了降低 KV 大小的目的。在V2的论文中, \(K_t\) 的表达从 \(W^Kh_t\) 变为 \(W^{UK}W^{DKV}h_t\) , 原来缓存的是 \(W^Kh_t\),而现在缓存的是 \(K_t\) 的一部分 \(W^{DKV}h_t\)
问题

但这有一个问题,如果单纯的把一个大的K/V矩阵拆成2个小矩阵进行缓存,在推理的时候,还是需要计算出完整的K矩阵,这样就失去了缓存的意义,毕竟缓存的意义就是减少计算。

问题就变成:有没有一种方法,即能减少缓存大小,又不增加推理时候的计算?

1.3.2 丰富信息表达

思路

我们可以注意到,在MQA和GQA计算注意力时,只用到了简单的广播或者复制机制把KV头复制给对应的Q头进行计算。我们以GQA为例,GQA 目的是减少KV Cache占用,存储的是KV,即\(C^{KV}\)。下面公式是如何得到k(这里省略了v的操作)。

  • 首先它将向量对半分为两份分别作为K、V。
  • 然后每一份又均分为g份。
  • 每一份复制h/g次,以此来“凑”够h个Attention Head所需要的K、V。

这里的\(W^{UK}\)是一组简单线性变化(比如简单复制)的组合,其表现能力是有限的,所以其压缩维度不大。

\[k = W^{UK}C^{KV} = W^{UK}[k^1,...,k^g,v^1,...,v^g] =\\ [k^1,...,k^1, k^2,...,k^2, ..., k^g,...,k^g] \]

既然MQA和GQA的信息表达能力不强,那么我们是不是可以引入一个矩阵变换来替代这些简单的线性变换操作(切分、复制)?比如通过针对每个 𝑞 都去自适应学习,这样就可以让这一层的信息表达更加丰富。

方案

我们已经得到了潜在向量\(c_t^{KV}\),那么就可以在推理期间使用每个头的上投影矩阵\(W^{UK}\)(用于“键”)和\(W^{UV}\)(用于“值”)从这个潜在向量中\(c_t^{KV}\)重建K和V。

具体而言,MLA 在 Decode 阶段将:

  • 加载压缩的KVCache潜在向量 \(c_t^{KV}\)
  • 然后通过上投影矩阵\(W^{UK}\)\(W^{UV}\)做两个升秩转换,分别转换为 Shape 均为[S,H] 的 K、V 矩阵,即从潜在向量中恢复出每个头的Key和Value(将这个潜在向量映射回高维空间)。上投影矩阵\(W^{UK}\)\(W^{UV}\)做两个升秩转换起到的作用比GQA 的简单线性变化(比如简单复制)的组合要大得多。
  • 进行 MHA 计算。这样,MLA在推断过程中仅缓存潜向量,而不缓存完整的键KV。

MLA的本质是对KV信息的有损压缩,但MLA可以通过训练学习到如何在提高存储信息密度的同时尽可能保留关键细节。这规避了分组查询注意力和多查询注意力的查询的信息损失,从而在降低KV缓存的前提下获得更好的性能。从MLA算子的计算特征来看,同时解决了这两方面的问题:

  • 一方面,通过低秩压缩大幅降低了推理过程中的KV Cache资源开销。减少推理过程的KV Cache,从而实现在更少的设备上推理更长的Context,或者在相同的Context长度下增大batch size,实现更快的推理速度或者更大的吞吐总量,最终降低推理成本。
  • 另一方面,MLA解压缩后的多头注意力机制能够提供较高的计算强度(正比于 Head 数),有助于充分利用GPU的算力资源,缓解MQA、MGA对性能的损耗。MLA 通过低秩转换方式压缩 KVCache,从公式来看引入了额外的升秩转换计算,并且需要存储升秩转换计算的激活值结果。但可以根据矩阵乘的交换率特性,将升秩转换的矩阵乘权重和其他权重融合,然后在 attention kernel 直接完成 attention 计算,无需引入额外的计算开销以及存储开销。

1.3.2 解决位置编码冲突

然而,压缩和RoPE位置编码是冲突的,即矩阵吸收后的\(c_t^{KV}\)没有了位置相关信息(原因:RoPE对key和query都是位置敏感的)。在这种情况下,只依靠\(c_t^{KV}\)来压缩KV-Cache的路是行不通的,所以需要额外的信息来表达qk之间位置关系。为了走出这个困境,DeepSeek提出了一种折中的方法:使用\(W^{QR}\)\(W^{KR}\)两个矩阵来表征跟ROPE相关的特征提取,为q和k都增加一个额外的维度\(d^R_h\)来添加ROPE编码,之前的\(d_h\)维度不使用ROPE编码,总长度变为\(d_h+d_r\)。即,MLA采用了MQA的思想,构造了所有head共享的cache变量\(c_t^{KV}\)\(k_i^R\),这样才大幅降低了KV Cache。其中 \(c_t^{KV}\)是参数低秩分解中Down处理后Up处理前的低维向量,而\(k_i^R\) 可视作是MQA版本的RoPE。

具体参见下图。

1.4 架构图 & 流程

作为对比,下图给出了MHA的数学公式,对于每个token需要缓存\(2n_hd_hl\)个元素。如果是千问72B,则需要$2 \times 80 \times 64 $。在这里 \(𝑞_{𝑡,𝑖}\),\(𝑘_{𝑡,𝑖}\),\(𝑣_{𝑡,𝑖}\) 都是用列向量表示。t是第t个token,j是迭代第1到t个token的序号,i是迭代head的序号。

下图给出了MLA的架构图,以及公式。

图中,黄色区域公式主要是为了计算Q(即Attention中的Q矩阵)。绿色区域主要是为了计算K的位置不敏感部分。紫色区域是计算K的位置敏感部分;灰色是把K聚合起来;红色是计算V。具体流程如下:

  • 查询(Q)的降维压缩:输入序列中的 t 个Token(\(h_t\))通过一个下投影矩阵\(W^{DQ}\)压缩为压缩潜在向量\(c_t^{Q}\)(其维度远远小于输入Token的维度)。此处对应图上标号37。
  • 键(K)和值(V)的联合压缩:输入序列中的第t个Token(\(h_t\))通过一个下投影矩阵\(W^{DKV}\)压缩为压缩潜在向量\(c_t^{KV}\)(其维度\(d_c\)远远小于输入Token的维度d)。在推理阶段,MLA仅需要缓存\(c_t^{KV}\),即KV缓存仅\(d_c \times l\)个元素,其中l为模型层数。此处对应图上标号41。
  • 解耦RoPE策略:为提高模型对序列中上下文信息的敏感性,MLA中应用了解耦旋转位置编码(RoPE)技术。因RoPE与低秩KV压缩矩阵不兼容,故MLA引入额外的查询向量\(q_t^R\)和共享键向量\(k_t^R\)来携带RoPE信息,避免了RoPE与低秩压缩矩阵之间的耦合问题,解决了位置信息与推理效率之间的矛盾。此处大致对应图上标号39和标号43。
  • 恢复信息:进行注意力计算时,进行注意力计算时,\(c_t^{KV}\)分别通过上投影矩阵\(W^{UK}\)\(W^{UV}\)还原出键和值,此处对应图上标号42和45。每个注意力头上的键再与携带了RoPE信息的共享键向量\(k_t^R\)拼接形成MHA的键值输入,此处对应图上标号44。\(c_t^{Q}\)通过上投影矩阵\(W^{UQ}\)\(W^{QR}\)升维还原并生成查询向量\(q_t^C\)(对应图上标号38)和携带RoPE信息的查询向量\(q_t^R\)(对应图上标号39),二者拼接形成MHA的查询向量输入,此处对应图上标号40。
  • 注意力计算。此处对应图上标号46。
  • 最终多个头的输入拼接在一起,并经过线性映射\(W^O\)得到最终的输出。此处对应图上标号47。

从图上可以看出MLA的特色:

从定性角度看,可以节约内存,因为:

  • 在进入标准MHA算法之前,用压缩的向量来替代之前更长的KV向量。之前是缓存K和V两个向量,现在只存储压缩后的一个向量。
  • 不仅仅压缩了KV,而且还能重建成K和V(不是标准MHA下面的K和V)。

如果定量来可看,每个Transformer层,只缓存了上述公式蓝框的向量: \(𝑐_𝑡^{𝐾𝑉}\)\(𝑘_𝑡^𝑅\) ,其它的都可以利用“矩阵吸收”,重新恢复过来。 \(𝑐_𝑡^{𝐾𝑉}\)\(𝑘_𝑡^𝑅\) 这两个向量的大小分别为:

\(𝑐_𝑡^{𝐾𝑉}\) : 维度为 \(𝑑_𝑐=4×𝑑_ℎ\)\(d_h\)是单个头的向量维度。 \(𝑐_𝑡^{𝐾𝑉}\) 是多头共享的。

\(𝑘_𝑡^𝑅\) :维度为 \(𝑑_ℎ^𝑅=𝑑_ℎ/2\)\(𝑘_𝑡^𝑅\) 是多头共享的。

对比MQA(每层有一个\(𝑑_ℎ\) 维度的 𝑘 和 一个 \(𝑑_ℎ\) 维度的 𝑣 ,共 2\(𝑑_ℎ\) 个元素),MLA相当于增加了2.25倍的存储。对比MHA的\(2n_hd_h\),则\(n_h\)会大于2.25,所以肯定减少缓存。

1.5 代码

下图给出了DeepSeek V3源码中MLA的定义部分。

class MLA(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        #隐藏层维度
        self.dim = args.dim
        # 注意力头的总数量
        self.n_heads = args.n_heads
        # 计算每个并行进程的本地注意力头数量
        self.n_local_heads = args.n_heads // world_size
        # 对应 query 压缩后的隐向量的维度 d'_c
        self.q_lora_rank = args.q_lora_rank # q的低秩压缩的维度
        # 对应 key-value 压缩后的隐向量维度 d_c
        self.kv_lora_rank = args.kv_lora_rank # kv的低秩压缩的维度
        # 表示query和key的向量中,不应用旋转位置编码部分的头维度, $d_h$
        self.qk_nope_head_dim = args.qk_nope_head_dim
        # 对应$d_h^R$,表示应用了旋转位置编码的query和key的一个头的维度。
        self.qk_rope_head_dim = args.qk_rope_head_dim
        # $d_h + d_h^R$, 注意力头大小为非rope部分大小加上rope部分大小
        self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
        # value 的一个注意力头的隐藏层维度
        self.v_head_dim = args.v_head_dim

        if self.q_lora_rank == 0:
            # 不适用低秩分解,回归到传统MHA
            self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
        else:
            # 其实就是$W^{DQ}$,用来生成$c_t^Q$
            # 下采样矩阵,得到压缩后的q向量
            self.wq_a = Linear(self.dim, self.q_lora_rank)
            self.q_norm = RMSNorm(self.q_lora_rank)
            # $W^{UQ}$
            # 上采样矩阵,用来恢复q向量
            self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
        # $ [W^{DKV}; W^{KR}] $    
        # 下采样矩阵,得到压缩后的kv向量    
        self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        # 上采样矩阵,用来恢复kv向量
        # $ [W^{UK}; W^{UV}] $
        self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
        # output权重矩阵
        self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
        # 计算1/sqrt(d_k)
        self.softmax_scale = self.qk_head_dim ** -0.5
        if args.max_seq_len > args.original_seq_len:
            mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
            self.softmax_scale = self.softmax_scale * mscale * mscale
         
        if attn_impl == "naive": # native模式下,kvcache存储的是没有压缩的数据,大小为d_h + d_h^R, 不但没有节省,反而增加了显存消耗   
            self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
            self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
        else:
            # 在非native模式下,存储的是压缩的c,大小为d_c
            self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
            self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)

很明显,MLA算子是针对现代GPU硬件特点“量体裁衣”定制的一个注意力机制,通过对存储和计算的再平衡,能够充分发挥现代GPU的各项优势。我们接下来就对MLA的几个核心实现要点进行仔细分析。

0x02 核心要点

MLA的核心要点如下:

  • 通过低秩KV联合压缩(Low-Rank Key-Value Joint Compression)降低了KV Cache的资源占用。在计算注意力时,对压缩后的向量进行升维变换,进而增强模型的表达能力。
  • 通过权重吸收减少了向上投影的计算量。
  • 通过解耦RoPE策略(Decoupled Rotary Position Embedding)来解决RoPE和权重吸收的冲突。

2.1 低秩KV联合压缩

2.1.1 低秩分解

低秩矩阵分解(Low-Rank Matrix Factorization)是一种特别有效的矩阵分解方法,用于发现数据中的低维结构。低秩矩阵分解的核心思想是将一个大矩阵分解为两个或多个更小、更简单的矩阵的乘积,这些小矩阵通常具有更低的秩。

在神经网络层中使用低秩分解一般都是用内存成本换取计算成本,这种方法的变体在LoRA微调等场景中很受欢迎,因为这些场景受限于总内存成本,而不是计算开销或推理速度。其好处是压缩后的矩阵使用的参数更少,并且在某种程度上更具表现力(层数增多)。它们最终可以大致近似或等同于一个更大的矩阵,因此在理论上,我们可以将这些矩阵的权重相乘,以恢复原始矩阵的近似值。

其缺点是,我们现在每次使用这些矩阵时都必须执行两次操作(即,对于每个压缩和解压缩的层,我们将矩阵乘法的总数翻倍,以换取使它们变得更小)。并且因为将它们限制为秩r或更低的矩阵,显然会损失原始矩阵的一部分表示能力。

2.1.2 思路

传统的注意力机制直接将输入X映射到QKV的注意力头维度,MQA和GQA通过共享机制来变相压缩KV Cache的头维度。MLA的核心思想是采用类似LoRA的方式表示KV。具体而言,是在prefill期间构建一个压缩空间,对输入矩阵的HiddenSize 维度进行压缩。即先将输入X映射到隐向量c存储起来。简单理解就是,假设有个矩阵的维度是𝑛∗𝑛,那么可以将其分解为两个𝑛∗𝑑的矩阵相乘,而𝑑≪𝑛。这样就降低了存储量。在decode阶段计算注意力之前,会通过上投影矩阵将c恢复到QKV的原始维度,这样可以减少注意力键(keys)和值(values)在推理过程中的缓存,从而提高推理效率。

其实这里还有一个问题:按照这种低秩方案,传统意义上的\(W^Q, W^K, W^V\)全部变成了低秩矩阵。既然存在了低秩矩阵对满秩矩阵的替换,就可能存在性能问题。既然DeepSeek做了替换且效果不错,就说明\(W^Q, W^K, W^V\)这几个原来的满秩矩阵可能就是冗余的,具备较大的低秩特性。

在实现过程中,Q、K 和 V 的权重矩阵通常会进行融合以提高 GPU 的计算和内存效率。与分别进行独立投影不同,采用组合权重矩阵可以优化运算过程。

2.1.3 向下投影

上图给出了向下投影的具体流程,其中 \(ℎ_t\) 作为输入向量, \(W^{DKV}\)\(W^{DQ}\)为压缩矩阵,用于降维, \(c_t^{KV}\)\(c_t^Q\) 分别是压缩后的KV潜向量和Q潜向量(潜向量的维度远小于输入向量的维度与自注意力头数之积)。这个 \(c_t^{KV}\) 是和具体的哪个head(索引为i)无关的,需要被缓存,相当于说,我们不再直接缓存key/value这两个维度和\(ℎ_𝑡\) 一样的向量,而是缓存 \(c_t^{KV}\) ,并通过计算来动态的恢复 \(k_𝑡\)\(v_t\)

  • 对于KV,构建一个共享的降维映射矩阵\(𝑊^{𝐷𝐾𝑉}\)用来对模型输入进行降维。
    • \(𝑊^{𝐷𝐾𝑉}\)会将输入\(ℎ_𝑡\)(hidden state)投射到隐向量\(𝑐_𝑡^{𝐾𝑉}\),这是key和value的联合隐向量。即将一个 Shape 为 [S,H] 的 HiddenState 压缩到 Shape 为 \([S,d_c]\),其中 \(𝑐_𝑡^{𝐾𝑉}\)的维度\(𝑑_𝑐\)远小于多头key和value的原始维度\(d_h\)。MLA 不保留完整的隐藏维度,而是缩小了它们的尺寸。
    • 将压缩后的KV向量作为 KVCache 存储到显存中。推理的过程中只需要缓存每一层的隐向量\(𝑐_𝑡^{𝐾𝑉}\)(因为每一层的注意力头共享该参数)。由于\(𝑐_𝑡^{𝐾𝑉}\)的维度远小于K、V。因此在MLA中,每一步token推理产生的KV Cache参数由之前的\(2𝑛_ℎ𝑑_ℎ𝑙\)变成\(𝑑_𝑐𝑙\),从而大大减少 KV 缓存的内存占用。
  • 对于Q,使用降维映射矩阵\(𝑊^{𝐷Q}\)用来对模型输入进行降维。这与减少KV Cache无关,主要是为了减少训练期间参数量和相应激活所占的显存。这是因为在模型训练的过程中,每个输入的token会通过多头注意力机制生成对应的query、key和value。这些中间数据的维度往往非常高,因此占用的内存量也相应很大。

2.1.4 向上投影

当 Decode 阶段需要进行 MHA 时,会将加载KVCache,然后利用\(𝑊^{𝑈𝐾}\)\(W^{UV}\)\(𝑐_𝑡^{𝐾𝑉}\)向上投影以恢复更大的尺寸。这个更大的尺寸既可以与原始输入 \(h_t\)的维度匹配,也可以根据注意力头的配置来调整。DeepSeek是将KV的维度扩展回\(𝑑=𝑑_ℎ𝑛_ℎ\),从图上也可知,新的\(k_t^C,v_t^C\)分别被均分为\(n_h\)个向量,即每个注意力头有一个单独的 𝑘,𝑣 (跟MHA的KV数量一致)。

具体参见下图。 \(W^{UK}\)\(W^{UV}\) \(W^{UQ}\)均为投影矩阵,用于升维。注:此处忽略了RoPE,后续会结合RoPE再进行扩充和更新。

结合向下投影和向上投影,我们可以看到, \(W^Q,W^K,W^V\) 的矩阵实际上分别被拆分成了两个,做成了lora的形式进行信息压缩,这个形式下MLA就是MQA加上lora形式的扩展,并且计算量从dxd的复杂度减少为 2 x d x c。这种信息压缩后再恢复原维度的方式相比于之前只有一个矩阵的形式,能很好的帮助网络进一步学习到更有效的信息。实现了同样的低秩分解下更好的效果,这就是MLA比GQA更进一步压缩KV Cache的根本原因。

下图给出了如何拆分,上方和MLA,下方是作为比对的MQA。

实际上,论文“TransMLA: Multi-Head Latent Attention Is All You Need"就对MLA的表达能力做了相关分析。论文指出传统的GQA模型在计算注意力的时候,同一组里头的头都会共享相同的键值对,这就导致它在表达能力上有点受限。而MLA就不一样啦,它通过低秩分解,再加上独特的投影矩阵设计,突破了这个限制。

具体参加下图,在MLA里,就\(W_K^b\)拿来说,如果这里面的向量是正交的,那么每个通道在乘以\(XW_k^a\)之后,输出在各个通道间都不一样。可GQA呢,同一组里所有头的输出都是相同的。就因为这种结构上的差别,在KV缓存大小一样的情况下,MLA的表达能力更强。说白了,MLA通过调整网络结构,优化参数更新策略,让注意力计算过程更高效,这样就能更好地捕捉复杂的语义关系,提升模型的能力。

2.1.5 完整过程

完整的对比过程如下图。图中上方是总体思路。下方是MLA和GQA的对比,其中又分为两部分,上部分是通过公式看看MLA如何增强表现力;下半部分是完整的流程。

2.2 权重吸收

2.2.1 当前状态

我们目前已经通过向下投影将压缩的隐向量进行保存,这减少了KV Cache的内存占用。也通过向上投影矩阵增强了表达能力。然而,MLA强调的是激活值也随之减少。当前我们还看不出来怎么减少激活值的数量的。因为虽然压缩之后的KV占据内存比较少,但是在每次推理的时候,都必须通过 \(𝑊^{𝑈𝐾},𝑊^{𝑈𝑉}\) 来根据缓存的\(c_t^{KV}\)重新计算出 \(𝑘_{𝑡,𝑖},𝑣_{𝑡,𝑖}\),单从KV的数量和维度上看跟MHA是一个量级,比GQA和MQA都要多,上采样后的 kv cache巨大,可能导致OOM。不但内存不少(\(𝑘_{𝑡,𝑖},𝑣_{𝑡,𝑖}\)依然存在),还引入了新的计算量,会处于计算瓶颈。没有达到用时间和算力来换取空间的目的。

2.2.2 权重吸收

既然每次计算量太大,DeepSeek就想是否可以在保存压缩的隐向量的基础上来减少这个计算量(其实也减少了新KV的内存占用),于是他们给出了权重吸收这个法宝。即其作者利用矩阵结合律的特性对这些公式进行了优化,避免了针对每个query重新计算key与value,下面是文章中的原文:

备注:矩阵吸收计算是指利用矩阵乘法的结合律或低秩分解等线性代数技巧,改变矩阵的乘法顺序,重新组合某些矩阵因子,使原本需要独立计算的矩阵乘积合并在一起,避免生成大矩阵,从而降低计算复杂度和内存开销的过程。

比如,给定三个矩阵 \(A \in R^{m,k}\)\(B \in R^{k,p}\)\(C \in R^{p,n}\),通过矩阵乘法的可知\((A \times B) \times C = A \times (B \times C)\),但是二者的计算复杂度是不一样的。 \((A \times B) \times C\)的计算复杂度是 $2\times m\times k\times p+2\times m\times p\times n=2\times m\times p\times (k+n) $, \(A \times (B \times C)\) 的计算复杂度是\(2\times m\times k\times n+2\times k\times p\times n=2\times n\times k\times (m+p)\)。当 n 相比 m 和 p 都显著更小的时候,第二种计算顺序的性能会远好于第一种。假设 ,m=k=p=4096,n=1 ,那么第一种计算顺序的计算复杂度是 \(2\times 4096\times 4096\times 4097\),第二种方式的计算复杂度是 \(2\times1\times4096\times8192\),显著低于第一种。

但是,具体要如何用矩阵吸收,如何使用矩阵乘法结合律,需要权衡计算量,memory读写量和瓶颈,可以套用典型的Roofline Model进行分析。这里的核心就是 AC x CB 矩阵的最终效果和 AB 矩阵的效果有多少差异。

2.2.3 推导

KQ合并

我们来结合Dot-Attention的具体形式,看看如何通过一个简单但不失巧妙的恒等变换来规避这个问题。首先,在训练阶段还是照常进行,此时优化空间不大;然后,在推理阶段,我们利用如下公式(不带位置编码)可以看到,在推理阶段,我们把\({W^{UQ}}^{\top} W^{UK}\) 合并为一个(和位置无关的)矩阵W作为Q的投影矩阵,就可以用\(c_t\)代替原本的\(k_t\)。这样就避免了重复计算中间结果q和k。

其中转置 ⊤ 表示交换张量形状中最后两个维度。各个张量的形状如下,这里注意 num_heads 要拎出来成为一个维度,因为最后 attention weight 的结果是头间独立的。

  • \(C^Q : [batch\_size,1,q\_len, q\_lora\_rank]\)

  • \(W^{UQ}:[num\_heads, q\_lora\_rank, qk\_nope\_head\_num]\)

  • \(W^{UK}:[num\_heads, kv\_lora\_rank, qk\_nope\_head\_num]\)

  • \(C^K : [batch\_size,1,kv\_len, kv\_lora\_rank]\)

我们每次缓存的 \(c_t^{KV}\)都可以直接参与计算,而不需要显式的计算出K。而且,W 矩阵是可以事先就通过\({W^{UQ}}^{\top} W^{UK}\)计算出来的(其实就是被神经网络自动计算出来)。

代码表述如下:

"""来源:https://mathmach.com/8b428574/"""
# 消融W_UK
W_UQ = tf.reshape(W_UQ, [q_lora_dim, num_head, head_dim])
W_UQ = tf.transpose(W_UQ, perm=[1, 0, 2]) # [num_head, q_lora_dim, head_dim]
W_UK = tf.reshape(W_UK, [kv_lora_dim, num_head, head_dim])
W_UK = tf.transpose(W_UK, perm=[1, 2, 0]) # [num_head, head_dim, kv_lora_dim]
W_UQUK = W_UQ * W_UK # [num_head, q_lora_dim, kv_lora_dim]

# 计算qk内积
c_Q = tf.reshape(c_Q, [batch_size, q_seq_len, q_lora_dim])
c_KV = tf.reshape(c_KV, [batch_size, kv_seq_len, kv_lora_dim])
c_KV = tf.transpose(c_KV, perm=[0, 2, 1]) # [batch_size, kv_lora_dim, kv_seq_len]
c_Q_product_W_UQUK = tf.einsum('bij,hjk->bhik', c_Q, W_UQUK) # [batch_size, num_head, q_seq_len, kv_lora_dim]
q_product_k = tf.einsum('bhik,bkj->bhij', c_Q_product_W_UQUK, c_KV) # [batch_size, num_head, q_seq_len, kv_seq_len]
VO合并

另外,传统方法需要先计算 Value 向量 \(v_t^C\) ,然后再进行注意力计算并投影到最终的输出层。我们可以直接将 \(W^{UV}\)吸收到 \(W^{O}\)里,简化最终的输出计算。吸收公式如下(此处提取剧透了rope和nope分离的模式):

\[(p \cdot (c_{kv} \cdot W^{UV})) \cdot W^O = (p \cdot c_{kv}) \cdot (W^{UV} \cdot W^O) = (softmax(q_{nope} \cdot c_{kv} + q_{pe} \cdot k_{pe}) \cdot c_{kv}) \cdot W^{UV} \cdot W^O \]

可以用代码描述为:

q_pe = W_QR(c_q)
q_nope = W_UQ_UK(c_q)
output = W_UV_O(MQA(q_pe, q_nope, c_kv, k_pe))

注意我们需要小心的通过转置等手段保证数学上的恒等。参见下面图,每个注意力头都可以消融成一个矩阵,因此,实际代码中可以使用高维矩阵将所有head消融在一个矩阵里,代码表述见下面。

代码表述:

"""来源:https://mathmach.com/8b428574/"""
# 消融W_UV
W_O = tf.reshape(W_O, [num_head, head_dim, hidden_dim])
W_UV = tf.reshape(W_UV, [kv_lora_dim, num_head, head_dim])
W_UV = tf.transpose(W_UV, perm=[1, 0, 2]) # [num_head, kv_lora_dim, head_dim]
W_OUV = W_UV * W_O # [num_head, kv_lora_dim, hidden_dim]

# 计算u
q_R = RoPE(c_Q * W_QR) # [batch_size, q_seq_len, num_head, rope_dim]
k_R = RoPE(h * W_KR) # [batch_size, kv_seq_len, rope_dim]
q_product_k_rope = tf.einsum('bijk,bhk->bijh', q_R, k_R) # [batch_size, q_seq_len, num_head, kv_seq_len]
q_product_k_rope = tf.transpose(q_product_k_rope, perm=[0, 2, 1, 3]) # [batch_size, num_head, q_seq_len, kv_seq_len]
attention_weight = tf.softmax((q_product_k + rope_score) / tf.sqrt(head_dim + rope_dim)) # [batch_size, num_head, q_seq_len, kv_seq_len]
c_KV = tf.transpose(c_KV, perm=[0, 2, 1]) # [batch_size, kv_lora_dim, kv_seq_len]
attention_weight_product_c_KV = tf.einsum('bijk,bhk->bijh', attention_weight, c_KV) # [batch_size, num_head, q_seq_len, kv_lora_dim]
u = tf.einsum('bijh,ihd->bjd', attention_weight_product_c_KV, W_OUV) # [batch_size, q_seq_len, hidden_dim]
结合

把目前的合并结合起来,我们得到如下:

\[O = AW^O = \phi(QK^T)VW^O= \phi[HW^Q(C^{KV}W^{UK})^T]C^{KV}W^{UV}W^O = \phi[H(W^QW^{{UK}^T}C^{{KV}^T}]C^{KV}(W^{UV}W^O) \]

这样,在推理时期\(W^{UK}\)可以和\(W^{UQ}.W^{DQ}\)结合,\(W^{UV}\)\(W^{O}\)结合,最终只有\(W^Q\)\(W^O\)。矩阵合并以后,对KV的整个计算过程都在低维空间进行,不会出现再把\(C^{KV}\)解压缩回高维空间的情况。 而且,上述矩阵全都是模型的权重,再推理过程重是不会变的,可以看作常量。如果是部署推理服务的话,再加载模型的时候就可以把这两个矩阵乘好,为以后的每次推理节省两次矩阵乘法。实际上并无额外的算力开销。MLA就达到了克服以往方法中KV Cache过大的问题并且保留的KV Cache该有的减少重复计算的功能。

2.2.4 讨论

训练

论文中一直提到在推理阶段使用权重吸收,这点很好理解,因为此时权重矩阵固定了。

那么什么不在训练阶段直接结合\(W^{UK}\)\(W^{UV}\),其原因大致如下:

  • 从梯度更新的角度来看,不做权重吸收会使得优化更加简单,即遵从下面的方式进行训练更好\(\nabla (\phi \psi) =\psi \nabla (\phi ) + \phi\nabla ( \psi)\)
  • 从投影的角度来看,KV共享\(W^{DKV}\)某种意义上对于空间构成了一种约束,Weight Tying 使得模型能够更好的收敛,并且提高其泛化能力,还可以提高模型的稳定性。

所以,MLA在训练阶段和MHA类似。除了多一步低秩投影以及只在部分维度加RoPE之外,MLA与Q、K的头维度由\(d_k\) 换成\(d_k+d_R\)的MHA一样。

MHA

其次,既然权重吸收这么好,为什么MHA没有做权重吸收?

我们先看看推理阶段的特点。

首先,MHA中的计算公式如下(为了演示方便,这里先讨论单头),在标准的MHA实现中,quey、key、value的embedding是分别计算的,然后通过query embedding和key embedding来计算self-attention的权重矩阵,之后将这个权重矩阵和value embedding进行相乘得到最终的结果。但是如果我们展开公式如下。

\[Z = softmax(\frac{q_t^Tk_i}{\sqrt d_k})v_tW^O = softmax(\frac{h_t^T(W^Q)^TW^Kh_i}{\sqrt d_k})h_iW^VW^O \]

此处看起来,\((W^Q)^TW^K\)\(W^VW^O\)都有吸收的可能。

其次,Decode 计算时,输入的 Q 往往只有一个 token,这就天然给我们一个简化计算的机会。即这个顺序是可以交换的,即从query的embedding出发,一直向下进行计算,得到最终的结果。因为首先将比较小的query embedding参与计算,因此看起来整体计算复杂度会明显降低。而且看起来和MLA的思路非常类似,即将
K 的 projection 放到 Q projection 之后,将 V projection 放到 attention 之后,output projection 之前。

目前看起来MHA做矩阵吸收的好处颇多。然而,事实并非如下简单。我们通过\(q_t^Tk_i\) 为例来进行分析为何MHA不适合吸收,以及为何MLA可以提高效率。

\[q_t^Tk_i = (W^Qh_t)^T(W^Kh_i) = h_t^T(W^Q)^TW^Kh_i \]

对于单个头,\(n_h\)=1,对应矩阵乘是\([1,d] \times [d, d_h] \times [d_h, d] \times [d, 1]\)。我们来看看这个矩阵乘哪些可以计算,哪些可以存储。有以下几种可能:

  • 标准KV Cache。
    • 存储角度:我们把\(W^Kh_i\)存储起来,就是存储k(v和k一致),则KV Cache大小为:\(2n_hd_hl\).
    • 计算角度:每个头实例化参数是\(W^Q\)\(W^K\)\(W^V\)\(W^O\),大小为\(4dd_h\)
  • \((W^Q)^TW^K\)结合到一起,并把结合后的权重施加到x上
    • 存储角度:存储\((W^Q)^TW^Kh_i\)作为新的cache,其大小为\(2dn_hl\),与KV Cache相比扩大了\(n_h\)倍。
    • 计算角度:每个头实例化参数是\((W^Q)^TW^K\)\(W^VW^O\)。大小为\(2d^2\)
  • \((W^Q)^TW^K\)结合到一起,但是只cache x,不cache k和v的权重。
    • 存储角度,需要存储的cache大小是\(dl\),相比标准kv cache减少了一半;
    • 计算角度,每个头实例化的参数为\((W^Q)^TW^K\)\(W^VW^O\)。大小为 \(2d^2\)

结合上面的分析,标准的kv cache已经相对而言在空间开销上和计算上是最优的了,尽管我们可以通过只 cache x减少一半的kv cache,但是结合后的矩阵放到运行时计算也增加了计算量,权衡下并不是好的方案。

我们再来看看MLA。\(W^K\)做了低秩变化后,从\([d_h,d]\)变成了\([d_h,r] \times [r,d]\), $ h_tT(WQ)TWKh_i\(变成了\) h_tT(WQ)TWW^{DKV}h_i$。

对应矩阵乘是\([1,d] \times [d,d_h] \times [d_h, d_c] \times [d_c, d] \times [d, 1]\)。我们来看看这个矩阵乘哪些可以计算,哪些可以存储。 ,那么有以下几种可能:

  • 从存储的角度:此时存储的kv cache就是 \(W^{DKV}h_i\), cache大小是 \(d_cl\) ,加上旋转位置编码的部分,总的kv cache是$ (d_c+d_h^R)l$ ,和MHA进行比较,则是$ (d_c+d_h^R)/2d$ =(512+64)/(2∗5120) =5.58%
  • 从计算的角度: \(W^{UK}\) 可以被合并(merge)到 \(W^Q\) 中,类似地,\(W^{UV}\) 可以被合并(merge)到 \(W^O\)中。这样实例化的权重就变成了原来的 \(d/r\) 分之一
  • 无论是存储还是计算的角度,MLA的拆分方法都优于MHA。

所以到这里我们就明白了,MLA的好处来源于两个方面,一个是kv cache的显著降低,另一个是权重的合并和吸收。

不合并

具体实施过程中需要依据实际情况进行抉择,比如 李伟华大神 在https://developnotes.readthedocs.io/zh-cn/latest/deepseek.html#id1 有精彩论述。

考虑如下运算:\(Y=XAB,C=AB\)。其中\(X \in R^{m\times d}\)是输入的hidden states,\(A \in R^{d\times d_c}\)\(B \in R^{d_c \times n}\)是权重矩阵,\(C\in R^{d \times n}\)是吸收后的矩阵。

直接计算\(Y=XAB\)的flops是 \(2mdd_c + 2mnd_c = 2md_c(d+n)\),合并后计算\(C=AB\)的flops是\(2mdn\)。如果\(d_c\)较小,则\(dn \gt d_c(d+n)\),计算量太大,所以不一定需要进行权重吸收。

或者我们使用MLA的实际代码来看。已知配置如下:

"hidden_size": 5120, # 隐藏层的大小
"kv_lora_rank": 512, # KV压缩维度
"q_lora_rank": 1536, # Query压缩维度
"qk_rope_head_dim": 64, # 解耦Query和Key的每个头部维度
"qk_nope_head_dim":128 # 

两种情况的计算量如下:

  • \({c_t^Q}^{\top}{W^{UQ}}^{\top} W^{UK}\)的计算量是:\(2 \times (q\_lora\_rank \times hidden\_size \times qk\_nope\_head\_dim + kv\_lora\_rank \times hidden\_size \times qk\_nope\_head\_dim) = \\ 2 \times hidden\_size \times qk\_nope\_head\_dim(q\_lora\_rank + kv\_lora\_rank ) = \\ 2 \times 5120 \times 128 (1536 + 512 )\)
  • \({c_t^Q}^{\top}W^{UQK}\) 的计算量是:$2 \times hidden_size \times q_lora_rank \times kv_lora_rank= 2 \times 5120 \times 1536 \times 512 $。

可以看到,把\(W^{UQ}W^{UK}\)合并后计算量反而增大很多。prefill 的时候其实是不要做“吸收”的,可以按 $ ({c_tQ}{W{UQ}} )(W^{UK} c_t^{KV})$ 或者$ ({c_tQ}{W{UQ}} W^{UK} )c_t^{KV}$来计算。

因此,他认为,Absorb 的真实含义其实是矩阵乘法结合律,优先结合某些矩阵,并缓存 compressed latent vector \(c_t^{KV}\), 并不是合并权重矩阵,用 Absorb 命名有一定误导性。如果吸收,也是\(W^{UK}\)被吸收到\(Q^C\),而非\(W^{UQ}\)

2.3 解耦RoPE

为提高模型对序列中上下文信息的敏感性,MLA中应用了解耦旋转位置编码(RoPE)技术。而迄今为止,我们在分析中丢失了一个非常重要的步骤,即位置编码。这是因为RoPE与低秩KV压缩矩阵不兼容(与权重吸收会冲突),此时还无法无缝切换。为了解决这个问题,MLA引入额外的查询向量\(q_t^R\)和共享键向量\(k_t^R\)来携带RoPE信息。从架构图中可以发现,DeepSeek的q和k各自都有2个部分,分别是\([q_t^R,q_t^C]\)\([k_t^R,k_t^C]\)

  • 1个部分是压缩部分:\([q_t^C]\)\([k_t^C]\)
  • 1个部分则加上了RoPE位置编码。即有独立一路做RoPE:\([q_t^R]\)\([k_t^R]\)

最终两个部分拼接成Q,K矩阵。这样就把RoPE与低秩压缩矩阵之间做了解耦,解决了位置信息与推理效率之间的矛盾。

我们接下来仔细进行剖析。

2.3.1 RoPE背景

下面代码是Llama 3计算注意力的摘要。RoPE 旋转位置编码中Query和Key都是位置相关的。在进行注意力计算前,代码是先应用\(W^K\)等矩阵得到Q和K,然后在Q和K上施加RoPE(乘以一个旋转矩阵),以此在Q和K中融入相对位置信息。

class Attention(nn.Module):       
    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor,       mask: Optional[torch.Tensor],):
        bsz, seqlen, _ = x.shape
        # 获取Q、K和V
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        # 施加RoPE
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        # 处理KV Cache
        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)        
        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # 计算注意力,分开计算了RoPE部分的q和k的注意力计算再求和
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)        

2.3.2 问题

无法直接应用到低秩压缩

我们先看看是否可以把RoPE 施加到低秩压缩向量上,即RoPE直接被低秩压缩向量K和V所吸收。

因为K和V的低秩表示已经是压缩了的状态,压缩操作可能已经丢失了某些信息,而RoPE矩阵对key和value是位置敏感的,直接在\(𝑐_𝑡^𝑄\)\(𝑐_𝑡^{𝐾𝑉}\) 上应用 \(𝑅_𝑚\)\(𝑅_𝑛\) 不再等价于在完整的Q和K上应用位置编码,不能直接和有效地反映原始Q和K的相对位置关系。换言之,RoPE与低秩KV压缩不兼容(RoPE is incompatible with low-rank KV compression),只能作用到原始K和V上。即只能从低秩KV压缩先还原成原始的KV,然后在原始KV上施加RoPE。之前已经学习过,这样做对性能有损失,所以采用了权重吸收。

与权重吸收不兼容

我们仔细看看RoPE作用到原始K和V上时,是否可以被权重吸收。

在RoPE的实现中,如果我们要让Q、K带上位置信息,会分别乘以相应的位置编码矩阵。

\[\hat Q= 𝑅_𝑚𝑄\\ \hat K =𝑅_𝑛𝐾 \]

如果计算\(𝑄^T𝐾\)时,就变成了

\[S = Q^TR_m^TR_nK \]

DeepSeek-V2对Q和K都进行了压缩,则整个过程变成:

\[𝑆=(𝑊^{𝑈𝑄}𝑐_𝑡^𝑄)^𝑇𝑅_𝑚^𝑇𝑅_𝑛𝑊^{𝑈𝐾}𝑐_𝑡^{𝐾𝑉} = \\ {c_t^Q}^{\top}{W^{UQ}}^{\top}𝑅_𝑚^𝑇𝑅_n W^{UK} c_t^{KV} = \\ {c_t^Q}^{\top}{W^{UQ}}^{\top}𝑅_{m-n} W^{UK} c_t^{KV} \]

这里,\(𝑊^{𝑈𝑄}\)\(𝑊^{𝑈𝐾}\) 分别是用于从低秩表示恢复到原始维度的解压缩矩阵。目前公式中间多了一个与token位置差t-i相关的矩阵\(R_{m-n}\),该矩阵随着相对位置变化而变化,并不是个固定矩阵,无法提前计算好。并且矩阵乘法不遵循交换律,没办法把\(R_{m-n}\)挪到公式的其它地方,因此在推理时,\(𝑊^{𝑈𝑄}\)\(𝑊^{𝑈𝐾}\) 无法直接进行交互,\(𝑊^{𝑈𝐾}\) 就无法整合到 \(𝑊^𝑄\) 中。即\(𝑊^{𝑈𝐾}\)\(𝑊^𝑄\) 无法合并为一个固定的投影矩阵。如果要强行降低KV Cache,则必须将参数簇\(R_{m-n}^s, s=1,2,...,head_{num}\)全部全部缓存下来。这个参数簇包含了\(O(sequence\_length^2)\)个参数张量,实在太大。因此,这就导致DeepSeek-V2原定的权重吸收无法实现,在推理过程中需要对所有前置tokens对应的Key进行旋转位置编码的计算,这会降低推理速度。

下图给出了更加精确的阐述,上方是NoPE,下方是RoPE。

2.3.3 解决方案

为了解决MLA中的RoPE与低秩KV联合压缩不兼容的问题,DeepSeek团队提出了解耦RoPE的策略:对于一个head,用一个高维度的向量表示其文本信息,以及一个低维度的向量来表示其旋转位置编码信息。前面的高维度向量称为nope,后面的低维度向量称为rope。具体而言是,把Query和Key进行拆分为\([q_t^R,q_t^C ]\)\([k_t^R,k_t^C]\),其中一部分小向量进行了旋转位置编码( \(q_t^R,k_t^R\) ),一部分大向量进行压缩( \(q_t^C,k_t^C\))。

  • 信息存储部分( \(q_t^C,k_t^C\))。这部分存储了大部分的业务信息,是被压缩的。下图的红圈和紫圈表明,我们有\(n_h\)个注意力头,因此,我们需要把\(q_t^C,k_t^C\)𝑡分别均分为\(n_h\)份。下标 i 表示的是第 i 个头。
  • 位置信息部分( \(q_t^R,k_t^R\) )。具体又分为两部分。
    • 使用共享的键(shared keys)\(𝑘_𝑡^𝑅∈𝑅^{𝑑_ℎ^𝑅}\) 来携带RoPE信息,\(𝑑_ℎ^𝑅\) 表示解耦的queries和key的一个head的维度。共享的\(𝑘_𝑡^𝑅\)指的是每个头的K都用这同一个\(𝑘_𝑡^𝑅\)。注意,此处是基于 \(h_t\)(输入嵌入)而不是基于向下投影的 \(C_t^{KV}\) 来生成\(k_t^R\)
    • 使用额外的多头查询(multi-head queries) \(𝑞_{𝑡,𝑖}^𝑅∈𝑅^{d_ℎ^𝑅}\) 来携带RoPE位置信息。注意,此处是基于\(c_t^Q\)生成\(q_t^R\),而且每个头会有自己的\(𝑞_{𝑡,𝑖}^𝑅\)

最后将这四个变量分别拼接起来进行注意力计算。从而在推理时不需要对Key进行位置编码的计算,避免了RoPE与低秩压缩矩阵之间的耦合问题,解决了位置信息与推理效率之间的矛盾,提高了推理效率。具体参见下图。

最终乘积计算如图中标号4.1,其中前一项(标号4.2)按照无RoPE的情况计算,推理时只需要缓存\(c_t^{KV}\),后者(标号4.3)则对于所有注意力头只缓存一个共享\(k_t^R\)。即,在推理阶段,单个Token产生的KV Cache包含了两个部分。

  • 需要缓存键值的压缩潜在向量\(c_t^{KV}\)(维度为\(d_c\))。
  • 携带RoPE信息的共享键向量\(k_t^R\)(维度为\(d_h^R\))。

一共是\((𝑑_𝑐+𝑑_ℎ^𝑅)𝑙\) 个元素,l是层数。这种折中的方法保证了KV Cache的显存空间依然很小(虽然在 𝑑𝑐 的基础上增加了64维的 𝑑𝑟 ),FLOPS上有增加但是代价不大。

经过Concat过程会增加 Q 和 K 向量的维度。为了处理增加的维度,模型可以选择:

  • 增加注意力头的数量:这将保持原有的每头维度,但需要更多的计算资源。
  • 调整每个头的处理维度:保持头的数量不变,但提高每个头的维度,以适应Concat向量。

下图给出了清晰的对比。进行注意力计算时,\(c_t^{KV}\)分别通过上投影矩阵\(W^{UK}\)\(W^{UV}\)还原出键和值,每个注意力头上的键再与携带了RoPE信息的共享键向量\(k^R_t\)拼接形成MHA的键值输入。\(c_t^Q\)通过上投影矩阵\(W^{UQ}\)\(W^{UR}\)还原并生成查询向量\(q_t^C\)和携带RoPE信息的查询向量 \(q_t^R\),二者拼接形成MHA的查询向量输入。最终多个头的输入拼接在一起,并经过线性映射\(W^O\)得到最终的输出。

2.3.5 和权重吸收结合

我们再看看结合权重吸收之后如何处理,这里就需要将nope和rope也加进来,公式演变如下。

2.4 资源占用

2.4.1 参数量

MLA的思路来自LoRA,LoRA强调的是参数量的减少,而MLA也确实做到了减少参数量。按DeepSeek-V3的参数配置,两个低秩矩阵参数量: \(2 \times d_c \times d =2\times512\times7168\) ,而正常MHA的参数矩阵参数量: \(d \times d=7168 \times 7168\)

具体参数如下:

"vocab_size": 129280,
"dim": 7168,
"inter_dim": 18432,
"n_heads": 128,
"q_lora_rank": 1536,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,

各个矩阵的参数量如下:

  • \(W^{DKV}\):dim * kv_lora_rank = 7168 * 512

  • \(W^{UK}\):kv_lora_rank * qk_rope_head_dim * n_heads = 512 * 128 * 128

  • \(W^{UV}\):kv_lora_rank * qk_nope_head_dim * n_heads = 512 * 128 * 128

  • \(W^{KR}\): dim * qk_rope_head_dim = 7168 * 64

  • \(W^{DQ}\):dim * q_lora_rank = 7168 * 1536

  • \(W^{UQ}\): q_lora_rank * qk_nope_head_dim * n_heads = 1536 * 128 * 128

  • \(W^{QR}\):q_lora_rank * qk_rope_head_dim * n_heads = 1536 * 64 * 128

  • \(W^O\):n_heads * v_head_dim * hidden_size = 128 * 128 * 7168。

2.4.2 内存占用

但MLA强调的是KV-cache的减少,也就是KV的激活值减少。我们接下来继续分析。与经典的MHA和GQA,MQA比较。MLA实际缓存的向量是:

  • \(c_t^{KV}\),维度是\(d_c\)
  • \(k_t^R\),维度是\(d_h/2\)

如下图所示,我们可以看出,MLA在优化kv cache和保证模型效果上有很强的优越性。图中\(n_h\)是注意力头数量,\(n_g\)是GQA分组数,\(d_h\)是隐藏层维度(低秩压缩后的维度),\(d_c\)是KV压缩维度,\(l\)为block的块数。和MHA相比,Q和K的头维度变成了\(d_c+d_r\),V的头维度变成了\(d_c\),对于DeepSeek-V2,\(d_c\) 被设置为\(4d_h\),而\(d_h^R\)被设置为\(\frac{d_h}{2}\)。KV Cache的数量以元素数量来衡量(不考虑存储精度)。

  • 在MHA中,推理阶段针对每个Token,需要缓存其键向量和值向量,则每个Token的缓存参数个数为\(2 \times n_h \times d_n \times l\)。与MHA相比,MLA占用的token数\(\frac{9}{2}d_hl\) 通常要小于\(2n_hd_hl\),所以MLA能获得比 MHA 更强的性能,显著降低了KV缓存的大小。
  • GQA 通过分组共享 K/V 矩阵(如 LLaMA-70B 设置 g=8)减少显存占用,但压缩率有限(仅减少到 g/h 倍)。与GQA相比,MLA相当于GQA中的组数量 𝑛𝑔 =2.25,小于大多数Model里的 group数量,由此可见,其kv cache的尺寸会大大减小。即,MLA 的 KVCache 存储成本约等于GroupNum=2.25 的 GQA 的 KVCache 存储成本。
  • 与MQA相比,MLA相当于增加了2.25倍的存储,但是MLA的性能和效果显著优于MQA,甚至强于MHA和GQA,真正实现了即降低推理成本,又保证了模型性能。

2.4.3 计算量

和MHA相比,MLA的Q和K的头维度变成了\(d_c+d_h^R\),V的头维度变成了\(d_c\)。而 DeepSeek V3的一些超参数如下:

  • \(d_k\)(hidden dimension/模型维度):7168。
  • \(n_h\)(注意力头数):128。因为MLA的KV Cache大小跟\(n_h\)无关,增大\(n_h\)只会增加计算量和提升模型能力,但不会增加KV Cache。
  • \(d_h\)(每个注意力头的维度):128。
  • \(d_c\)(KV的压缩维度):512,即\(4d_h\)
  • \(d_h^R\)(RoPE头相关维度):64,即\(\frac{d_h}{2}\)

既然MLA每个头的Q/K的head size变大了不小,所以MLA的推理计算量增加了。那为什么还能提高推理效率呢?其实,MLA可以提高效率是因为结合了LLM推理的瓶颈时访存而不是计算这一特性。我们可以将LLM的推理分两部分:第一个Token的生成(Prefill)和后续每个Token的生成(Generation),Prefill阶段涉及到对输入所有Token的并行计算,然后把对应的KV Cache存下来,这部分对于计算、带宽和显存都是瓶颈,MLA虽然增大了计算量,但KV Cache的减少也降低了显存和带宽的压力。Generation阶段由于每步只计算一个Token,实际上它更多的是带宽瓶颈和显存瓶颈,因此MLA的引入理论上能明显提高Generation的速度。另一方面,由于Compressed KV在每个head中都参与了计算,DeepSeek-V2的128个heads能够提供足够的计算强度(正比于 Head 数),这样就把 LLM 解码过程的访存密集型,转换为计算密集型的操作,因此Attention部分的MFU也得到了大幅提高。

我们假设 q 的形状是\((b,n_h,s_q, d_h)\)\(c^{KV}\)的形状是\((b,1,s_{kv},d_c)\)\(W^{UK}\) 的形状是\((d_c,n_h,d_h)\)。prefill阶段,\(s_q = s_{kv} =s\)

  • native的计算量是:\(2bsd_cd_hn_h + 2bn_hssd_h = 2bn_hd_hs(d_c+s)\)
  • 吸收后的计算量是:\(2bsd_cd_hn_h + 2bn_hssd_c = 2bn_hd_cs(d_h+s)\)

两者相比是:\((d_h(d_c+s)) / (d_c(d_h+s))\)

decode阶段,\(s_q=1,s_{kv}=s\)

  • 缓存K的计算量。\(2bd_cd_hn_h+2bn_hsd_h=2bn_hd_h(d_c+s)\)
  • 缓存潜向量时候的计算量。\(2bsd_cd_hn_h+2bn_hsd_h=2bn_hd_h(d_cs+s)\)
  • 吸收后的计算量。\(2bd_cd_hn_h+2bn_hsd_c=2bn_hd_c(d_h+s)\)

2.4.4 信息转移

有研究人员再读MLA,还有多少细节是你不知道的认为,MLA的作用其实是"信息转移“,即把KV头中独有的信息转移到对应的Q头上,而把KV头中间共享的相同信息存储到KV Cache中。具体思路如下:

  • 改进目的:在尽量不压缩head上K、V信息的情况下,节省kv cache。
  • 改进背景:之所以要保存token对应的所有注意头上的K、V值,是因为每个k_head附带有不同的信息,它将用这份独有的信息和对应的q_head进行注意力计算。
  • 改进思路(下面以K头为例,V头类似):
    • 把一个token中所有K头中的共有信息抽取出来,压缩到KV Cache中,因为这些共有信息会更少,只保存它们才能减少KV Cache的大小。这个相同信息是每个tokens的所有k_heads共享一份,同时在不同tokens间共享。
    • 把K中每个头上独有的信息转移到对应的Q头上。因为Q头需要承载更多信息,所以Q和K的头维度变成了\(d_c+d_r\)\(d_c\) 被设置为\(4d_h\)。在V3上,\(d_c\)是512,相当于把缓存7168维的向量降低到了缓存512维。而q压缩之后是1536维,之所以这么大,就是因为Q要承载更多的信息。

虽然从形式上来说,MLA和MQA/GQA很像,似乎都是通过压缩k/v_heads的数量来节省KV cache大小的。但MLA是压缩num_heads,不压缩信息(把信息转移到了q_heads上);而MQA/GQA则在一定程度上对信息做了压缩。具体这些相同信息、相异信息存储在何处?是在\(W_K\)矩阵中?还是存储在原始token \(h_t\)中?笔者目前不能确定。所以只能用下图展示。

另外,GQA 的分组数需严格匹配硬件规模(如 8 卡对应 g=8),限制了模型部署的灵活性。而 MLA 通过潜在空间投影和解耦式权重合并,可动态适配不同硬件配置(如单卡或多机集群)。GQA 为弥补性能损失需增大 FFN 层规模(如 LLaMA3-70B 的 FFN 参数量增加 20%),导致模型复杂度上升。MLA 则通过低秩投影和动态路由,无需额外补偿即可维持性能。

2.5 并行

在大模型推理的decode阶段,MLA无法使用张量并行。故在目前的一些开源实现中,主要还是基于数据并行来对MLA进行处理,即不同请求的KVCache存储到不同的GPU中。DeepSeek-V3论文提到使用张量并行和序列并行。

  • 张量并行:MHA通常对head_num维度进行切分来实现张量并行。而MLA则有自己的特点,如果采用 tp 并行时,部分权重和 kvcache 都无法按 head_num 划分到不同的卡上。
    • 使用张量并行部分:kv_b_proj、o_proj等模块都包括了head维度,因此可以按照head维度切分执行张量并行,将MLA计算均匀的划分到多卡上,实现并行加速。
    • 难以使用张量并行部分。
      • mla存储KV Cache时,对于一个token存储的是(1, 1, kv_lora_rank+qk_rope_head_dim),而不是常规MHA下的(1, kv_head_num, head_dim)。因此KVCache中只保存一份潜空间的压缩向量,并不包含head维度,没有办法按照head进行划分。导致每张卡上都要保存所有请求的的完整kvcache,其形状是(bs, 1, seq_len, kv_lora_rank),这意味着KVCache 各个卡的存储是冗余的。
      • 部分权重由于head_num=1无法切分到不同的卡上,比如q_a_proj 和kv_a_proj_with_mqa不能按 head_num 切分。只有上投影矩阵才能考虑按列切分和最后输出矩阵按行切分。
  • 数据并行。即按照请求切分,不同请求的潜空间的压缩向量存储到不同的GPU中。但是因为不同GPU上的请求长度可能差异很大,这样会导致显存占用不均衡,也会导致不同GPU上计算时间差异较大,进而导致性能最差的GPU拖慢整体进度。
  • 序列并行:MLA会用序列并行(Sequence Parallel)来进行辅助。即,对KVCache按照序列维度进行切分,每一张卡上都使用query来做local的attention计算,然后对结果进行规约。

0x03 计算过程

我们来梳理下MLA在推理阶段的计算流程。

3.1 公式

首先,我们给出Q、K、V的变换过程对应的公式。后续会按照这个公式来进行解析。

3.2 原始流程

我们将上述公式转换为流程图,图中细节如下:

  • 从上到下分为Q、K、V三路。
    • Q和K又都细分为两路,“上路”绿色的权重和激活值对应隐向量/低秩部分;“下路”灰色渐变的权重和激活值对应decoupled RoPE。
    • K的下路和V路的数据流向有所交错。
  • “缓存” 代表在推理阶段会进行缓存的数据,具体分为两部分:
    • KV联合隐向量 \(c_t^{KV}\)
    • 单独施加了RoPE的键$k_t^R \(。K路位置编码模块接受的输入还是原始的\)ℎ_t\(而不是压缩后的\)c_t$。

此处假设头数\(n_h\)为2,矩阵大小并不是完全按照比例缩放。

3.3 吸收

3.3.1 过程

接下来第二步,将论文中所说的权重吸收过程施加进去,得到下图:

  • 推理阶段要缓存的东西不变。
  • \(W^{UK}\) 吸收进 \(W^{UQ}\) 之后。
    • Q的上路计算逻辑没有变,但是权重和激活值的形状都有相应的调整。
    • K的上路则直接少掉了一处线性映射的计算逻辑,变成了重复拷贝$n_h $份,与K下路类似。
  • $W^{UV} $吸收进 \(W^{O}\)之后。
    • V路由线性映射退化为重复拷贝的逻辑。
    • 最后输出映射的计算逻辑不变,但是权重和激活值的形状有相应的调整。
  • 红色字体公式代表了吸收对应的公式。绿色箭头表示有进一步吸收的可能。

3.3.2 吸收结果

我们对上图进行整理,得到吸收的结果如下。

3.3.3 MQA形式

MLA推理阶段的计算逻辑其实很像一个MQA,我们进行比对下(不考虑 RoPE)。

MQA和MHA的最大区别在于 \(K,V\) 是所有 head共享的,因此能够减少KV Cache的显存占用。其中 $$ Q_iTK=HT(W_iQ)TW^KH $$。

对于MLA,单独看 Attention 计算的前一部分,其中$ Q_iTK_i=HT(W{DQ})T(W_i{UQ})TW{UK}_iWH$,令 \(W_i^Q=(W_i^{UK})^TW_i^{UQ}W^{DQ}\),我们有 $$ Q_iTK_i=HT(W_iQ)TW^{DKV}H $$ 。可以看到这一计算公式和 Multi-Query Attention 其实是一样的,都是使用的单独的 \(Q\) 和共享的 \(K\)\(C^{KV}\)),等价于将single-head的KV重复拷贝若干遍再执行正常的MHA。

区别在于,这里 \(W_i^QH,W^{DKV}H\in\mathbb{R}^{d_c\times l}\)。也就是说在进行 attention 计算的时候,向量点积的维度是 \(d_c\) 而不是 \(d\)。在论文中实际设置的是 \(d_c=4d\)。也就是说 Multi-Head Latent Attention 其实是 head dimension 提高到4倍的 Multi-Query Attention。在论文中也提到了在推理的时候 absorb \(W^{UK}\) into \(W^{UQ}\),其实就代表了这里的结合方式。因为每个head的维度提高了,所以能够计算出更加复杂的 attention分布,从而相比起 Multi-Query Attention 取得性能提升。相比起直接提高 head dimension,其优点在于所有head的 \(W^{DQ},W^{UQ},W^{UK}\)的总参数量是 \(d\cdot d_c+d \cdot d_c+ d \cdot d_c=3d\cdot d_c=12d\cdot d_h\),而所有 head 的 \(W^Q\) 的参数量是 \(d \cdot d_c\cdot n_h=4d^2\),节省了参数量。也就是说对 \(W^Q\) 做了一个低秩分解。

但是这个提升并不是免费午餐,因为 head dimension 提高意味着 attention 的计算量也提高,而 attention 的计算量是 \(O(l^2)\) 的。为了处理长文本,现在大家一般都倾向于尽可能降低 attention 计算量的常数,而这个方法是会增加常数的。以上分析没有考虑 RoPE,如果考虑 RoPE 的话,每个 head 的维度会从 \(4d\) 变成 \(4.5d\),其中\(4d\)是没有 positional encoding的,\(0.5d\) 是使用 RoPE encoding的。其实 ChatGLM2-6B 中已经使用过类似的做法,即只在一半的 head dimension 上使用 RoPE ,目的是为了把 attention 计算分成位置相关和位置无关的两部分,与性能提升的关系并不大。

了看得更明显,我们可以把图中的一些权重进一步吸收合并,得到下图。

  • Q的计算过程退化为普通multi-head线性映射
    • 每个head一部分维度保持不动,对应绿色部分
    • 每个head另一部分维度施加RoPE变换,对应红色部分
  • K的计算过程退化为single-head线性映射
    • 同样只对部分维度施加RoPE变换。
    • 施加后进行重复拷贝(逻辑上如此呈现以便于理解,计算上当然可以优化掉)。
  • V则直接使用K中未经施加RoPE变换的部分,同样重复拷贝。

下图与标准MQA的区别是:

  • QK只有部分维度施加RoPE;
  • V与未施加RoPE的K共享激活值。

0x04 代码

我们主要使用V2的代码来分析,因为条理更加清晰。也需要注意的是,DeepSeek的代码在很多地方和论文不一致。V2中的DeepseekV2Attention的实现本质上和V3中的native一样,其实并没有节省KV-Cache,V3版本的非native版本是跟论文一致,节省了显存。

4.1 配置

我们摘录一些相关配置信息如下。在 Naive 实现中,512 维的 Latent KV \(c^{KV}\) 被映射回对应 128 个 head,每个 head 128 维的 \(k^C\)\(v^C\),然后再拼接上位置向量 \(k^R\) ,最终形成标准的 q、k、v 输入到标准的 Multi Head Attention 进行 Attetion 计算。另外,代码中也使用了norm,在论文中也有相应提及。

具体配置信息如下。其中:

  • 键和值的压缩维度 \(d_c\) :设置为 512 ,原始嵌入维度 𝑑=5120,比例为 1/10。由于键和值在推理时需要缓存,因此采用较大的压缩比例以显著减少内存开销。
  • 查询的压缩维度 \(d'_c\) :设置为 1536 ,比例为 0.3 。查询在训练时需要频繁计算,因此采用较小的压缩比例以保留更多信息,确保模型性能。
"num_hidden_layers": 60, # Transformer层的数量
"hidden_size": 5120, # 隐藏层的大小
"num_attention_heads": 128, # 注意力头的数量
"kv_lora_rank": 512, # KV压缩维度
"q_lora_rank": 1536, # Query压缩维度
"qk_rope_head_dim": 64, # 解耦Query和Key的每个头部维度
"n_shared_experts": 2, # MoE层中的共享专家数量
"n_routed_experts": 160, # MoE层中的路由专家数量
"moe_intermediate_size": 1536, # 每个MoE专家的中间隐藏层的维度
"num_experts_per_tok": 6, # 每个token激活的专家数量
"routed_scaling_factor": 16.0, # 路由专家的缩放因子
"rms_norm_eps": 1e-06 # RMS归一化的epsilon值

4.2 定义

给定输入向量\(h_t \in \mathbb{R}^{B \times L \times 5120}\),其中\(B\)为batch size,\(L\)为sequence length。

class DeepseekV2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""
    def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        # 对应 query 压缩后的隐向量的维度 d'_c
        self.q_lora_rank = config.q_lora_rank
        # query和key的隐藏向量中,应用rope部分的维度,对应d_h^R
        self.qk_rope_head_dim = config.qk_rope_head_dim
        # 对应 key-value 压缩后的隐向量维度 d_c
        self.kv_lora_rank = config.kv_lora_rank
        # value 的一个注意力头的隐藏层维度
        self.v_head_dim = config.v_head_dim
        # 向量中不应用rope部分的维度
        self.qk_nope_head_dim = config.qk_nope_head_dim
        # 每一个注意力头的维度应该是nope和rope两部分之和
        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
        self.is_causal = True

        # MLA 中对 Q 投影矩阵也做了一个低秩分解,对应生成 q_a_proj 和 q_b_proj 两个矩阵,即两阶段投影:先将hidden_size投影到q_lora_rank,再投影到最终维度
        # 对query进行压缩,即down-projection。即,第一阶段投影:hidden_size -> q_lora_rank,对应论文公式中的W^DQ
        self.q_a_proj = nn.Linear(
            self.hidden_size, config.q_lora_rank, bias=config.attention_bias
        )
        self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)
        # 对压缩后的query映射成高维,即up-projection。对应上述公式中的W^UQ和W^QR合并后的大矩阵,仅仅只是内存放在一起。
        # q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128 + 64
        self.q_b_proj = nn.Linear(
            config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
        )

        # KV向量的生成也是先投影到一个低维的 compressed_kv 向量(对应c_t^{KV}),再升维展开
        # 对应论文公式中的W^{DKV}和W^{KR}
        self.kv_a_proj_with_mqa = nn.Linear(
            self.hidden_size,
            config.kv_lora_rank + config.qk_rope_head_dim,
            bias=config.attention_bias,
        )
        self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)
        # 对应论文公式中的W^{UK}和W^{UV},由于 W^{UK} 只涉及 non-rope 的部分,所以维度中把 qk_rope_head_dim 去掉了
        self.kv_b_proj = nn.Linear(
            config.kv_lora_rank,
            self.num_heads
            * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
            bias=False,
        )

        # 对应论文公式的第 47 行
        self.o_proj = nn.Linear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=config.attention_bias,
        )
        self._init_rope()

        self.softmax_scale = self.q_head_dim ** (-0.5)
        if self.config.rope_scaling is not None:
            mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
            scaling_factor = self.config.rope_scaling["factor"]
            if mscale_all_dim:
                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
                self.softmax_scale = self.softmax_scale * mscale * mscale

对应的一些信息如下。把整个计算流程拆成 q_nope, k_nope, k_pe, k_nope 这四个部分就是为了把RoPE进行解耦。两个pe结尾的变量就是用于储存旋转位置编码的信息。Deepseek-V2将kv cache压缩到了同一个小矩阵中,后面再解压缩出来。

# q = q.view(bsz, q_len, num_heads, q_head_dim).transpose(1, 2)
# q_nope, q_pe = torch.split(q, [qk_nope_head_dim, qk_rope_head_dim], dim=-1)
q_pe : torch.Size([16, 128, 1, 64])
q_nope : torch.Size([16, 128, 1, 128])
# query_states = k_pe.new_empty(bsz, num_heads, q_len, q_head_dim)
query_states : torch.Size([16, 128, 1, 192])
    
# kv = .view(bsz, kv_seq_len, num_heads, qk_nope_head_dim + v_head_dim).transpose(1, 2)
# k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
value_states : torch.Size([16, 128, 1024, 128])
k_nope : torch.Size([16, 128, 1024, 128])  
# k_pe = k_pe.view(bsz, kv_seq_len, 1, qk_rope_head_dim).transpose(1, 2)
k_pe : torch.Size([16, 1, 1024, 64])
# key_states = k_pe.new_empty(bsz, num_heads, kv_seq_len, q_head_dim)
key_states : torch.Size([16, 128, 1024, 192])

self = {DeepseekAttention}  
 hidden_size = {int} 5120
 kv_a_layernorm = {DeepseekV2RMSNorm} DeepseekV2RMSNorm()
 kv_a_proj_with_mqa = {Linear} Linear(in_features=5120, out_features=576, bias=False)
 kv_b_proj = {Linear} Linear(in_features=512, out_features=32768, bias=False)
 kv_lora_rank = {int} 512
 num_heads = {int} 128
 o_proj = {Linear} Linear(in_features=16384, out_features=5120, bias=False)
 q_a_layernorm = {DeepseekV2RMSNorm} DeepseekV2RMSNorm()
 q_a_proj = {Linear} Linear(in_features=5120, out_features=1536, bias=False)
 q_b_proj = {Linear} Linear(in_features=1536, out_features=24576, bias=False)
 q_head_dim = {int} 192
 q_lora_rank = {int} 1536
 qk_nope_head_dim = {int} 128
 qk_rope_head_dim = {int} 64
 rotary_emb = {DeepseekV2RotaryEmbedding} DeepseekV2RotaryEmbedding()
 softmax_scale = {Tensor} tensor(0.0723, dtype=torch.bfloat16)
 v_head_dim = {int} 128

另外,https://github.com/sgl-project/sglang/discussions/3082 这里阐释了为何使用norm。

4.3 操作Q

我们把Q相关的代码都合并在一起进行分析。总的流程是:模型处理上一层计算出的隐藏状态(hidden_size=5120)时,首先会将模型的q压缩到 q_lora_rank 这一维度(设定为1536),再扩展到 q_b_proj 的输出维度(num_heads * q_head_dim),最后切分成 q_peq_nope 两个部分。

4.3.1 变量定义

MLA 中对 Q 投影矩阵\(W^Q\)做了一个低秩分解,对应生成 q_a_proj 和 q_b_proj 两个矩阵。

  • q_a_proj 大小为 [hidden_size, q_lora_rank] = [5120, 1536],对应公式中的 $$W^{DQ}$$,用来降维。
  • q_b_proj 大小为 [q_lora_rank, num_heads * q_head_dim] = [q_lora_rank, num_attention_heads * (qk_nope_head_dim + qk_rope_head_dim)] = [1536, 128*(128+64)] = [1536, 24576] ,用来升维,对应公式中的 \(W^{UQ}\)\(W^{QR}\)合并后的大矩阵。因为从公式来看这两个矩阵都需要和\(c_t^Q\)计算,所以可以合并矩阵后再进行拆分。对于一个head,用一个128维度的向量表示其文本信息,以及一个64维度的向量来表示其旋转位置编码信息。前面的128维度,称为nope,后面的64维度,称为rope
self.num_heads = config.num_attention_heads # 128
self.q_lora_rank = config.q_lora_rank # 1536
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim # 128 + 64

# 对query进行压缩,即down-projection
self.q_a_proj = nn.Linear(
    self.hidden_size, config.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)
# 对压缩后的query映射成高维,即up-projection
self.q_b_proj = nn.Linear(
    config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)

4.3.2 变量操作

在DeepSeek-V2中,Q向量也采用了低秩压缩的方式。

  • 首先,将输入向量投影到一个1536维的低维空间:$$ c_t^Q = W^{DQ} ,h_t \in \mathbb{R}^{B \times L \times 1536} $$。对应论文第37号公式。
  • 然后,将其投影到\(\mathbb{R}^{H \times 128}\)的多头向量空间上(其中\(H=128\)是heads数),得到了Q向量的第一部分:$$ q_t^C = W^{UQ} c_t^Q \in \mathbb{R}^{B \times L \times H \times 128} $$。对应第38号公式。
  • 再将其投影到\(\mathbb{R}^{H \times 64}\)上并使用RoPE嵌入位置信息,得到Q向量的第二部分:$$ q_t^R = \mathrm{RoPE}(W^{KR} h_t) \in \mathbb{R}^{B \times L \times H \times 64} $$。对应第39号公式。每个head有自己的旋转位置编码,每个head之间不共享。
  • 将两部分拼接的到最终的Q向量:$$ q_t = [q_t^C, q_t^R] \in \mathbb{R}^{B \times L \times H \times 192} $$。对应第40号公式。

在具体的实现过程中其输入为 hidden_states 向量,对应公式中的 \(ℎ_t\)。是一个大小为 [batch_Size, sequence_length, hidden_size] 的矩阵,其中 hidden_size 具体为 5120。后续的nope指代非rope。

# hidden_states对应公式中的h_t,hidden_states的shape是(batch_size, seq_length, hidden_size),其中 hidden_size为 5120,是num_head * q_head_dim
bsz, q_len, _ = hidden_states.size()

# 下面两行代码对应第37、38号公式,先降维再升维。q_b_proj维度是[1536, 24576],q_a_proj维度是[5120, 1536],是W^Q [5120, 24576]矩阵的低秩分解。即[5120, 24576] -> [5120, 1536] * [1536, 24576] 
# 首先,使用全连接层(self.q_a_proj)对输入的隐状态(hidden_states)进行降维投影
# 然后,使用全连接层(self.q_b_proj)对压缩的向量进行上投影  
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))

# 重塑为多头形式,是第40号公式的前置准备操作,或者说是40号公式的反向操作
# q_pe 要扔给 RoPE模块,所以需要重整下形状
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)

# 把最后一维切分成nope和rope两部分
# 将最后一层 192 的hidden_states切分为 128 (qk_nope_head_dim) + 64 (qk_rope_head_dim),即将查询表示(q)分为两部分:没有经过位置编码的部分(q_nope)和经过位置编码的部分(q_pe),q_nope表示不需要应用RoPE的,q_pe表示需要应用RoPE的
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

# 第39号公式,给q和k施加RoPE
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

# 初始化查询状态(query_states)的张量,这个张量将用于存储融合了解耦RoPE的查询表示,其中q_head_dim = qk_nope_head_dim + qk_rope_head_dim = 128 + 64 = 192
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)

# 下面两行对应第40号公式
# 将未经过位置编码的查询表示(q_nope)复制到 query_states 张量的前一部分,即那些不包含位置编码的维度。
# 这样做可以有利于后续将原始的查询表示与含有位置编码信息的查询表示分开来处理
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope # 128
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe # 64       

4.4 操作KV

我们把KV相关的代码都合并在一起进行分析。对于kv矩阵的设计,模型使用了kv压缩矩阵设计(只有576维),在训练时进行先降维再升维。在模型推理的时候,需要缓存的量变成 compressed_kv,经过 kv_b_proj 升高维度得到 k,v 的计算结果。

4.4.1 变量定义

KV向量和Q向量类似,也做了一个低秩分解,对应生成 kv_a_proj_with_mqa和 kv_b_proj 两个矩阵。

  • kv_a_proj_with_mqa 大小为 [hidden_size, kv_lora_rank + qk_rope_head_dim] = [5120, 512 + 64] = [5120, 576],对应上述公式中的 $$W^{DKV}$$ 和 $$W^{KR}$$的合并矩阵,用来把输入先投影到一个低维的空间(对应 $$C_t^{KV}$$),同时做两种降维操作(nope,rope的前置操作)。因为因为从公式来看这两个矩阵都需要和\(h_t\)计算,所以可以合并矩阵计算后再进行拆分。输出的维度则是512+64=576了。前面的512维度是给kv的,后面的64维度是给key的旋转位置编码的。
  • kv_b_proj 大小为 [kv_lora_rank,num_heads * (q_head_dim - qk_rope_head_dim + v_head_dim)] = [512, 128*((128+64)-64+128)] = [512, 32768],对应上述公式中的 $$W^{UK}$$ 和$$W^{UV}$$的合并矩阵。由于 $$W^{UK}$$ 只涉及nope 的部分,所以维度中把 qk_rope_head_dim 去掉了。192-64是把key表示向量中的64维度的旋转位置编码向量从192维度中减去;然后的128维度是留给value的,因为value不需要考虑位置信息。需要考虑位置信息的只有query和key。

或者说,通过kv_a_proj_with_mqa 来对head脱敏,即得到的张量和具体的head无关;通过kv_b_proj来重新恢复成对每个head敏感,得到的是形如[1, 16, 26, 128]这样的,和具体16个head分别相关的张量。

self.kv_lora_rank = kv_lora_rank # 512,key和value各占256维度
self.qk_rope_head_dim = config.qk_rope_head_dim # 64
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim # 128 + 64
self.v_head_dim = config.v_head_dim # 128
self.hidden_size = config.hidden_size # 5120

# 计算压缩后的latent kv以及需要缓存的应用RoPE的k的部分:k_t^R(前置条件),即把隐向量的5120维度 映射到 config.kv_lora_rank + config.qk_rope_head_dim = 512 + 64维度
self.kv_a_proj_with_mqa = nn.Linear(
    self.hidden_size,
    config.kv_lora_rank + config.qk_rope_head_dim,
    bias=config.attention_bias,
)
self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)
# 计算up-projection后的不应用RoPE的k的部分 和 up-projection后的v的结果
self.kv_b_proj = nn.Linear(
    config.kv_lora_rank,
    self.num_heads
    * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
    bias=False,
)

4.4.2 变量操作

计算KV向量时,有几个和公式中不同的地方,即把某些矩阵操作打包在一起执行(同时将K,V的向量一起产出了),后续再拆分开。

  • 首先需要将输入向量投影为512维的联合压缩表示:$$ c_t^{KV} = W^{DKV} h_t \in \mathbb{R}^{B \times L \times 512} $$,对应第41号公式。

  • 与Q向量的计算过程类似,K向量的第一部分是将\(c_t^{KV}\)通过投影解压缩到\(\mathbb{R}^{H \times 128}\)的多头向量空间:$$ k_t^C = W^{UK} c_t^{KV} \in \mathbb{R}^{B \times L \times H \times 128} $$,对应第42号公式。注意:此处增加了一个头维度。

  • K的第二部分是将输入向量投影到64维向量空间并施加RoPE嵌入位置信息:$$ k_t^R = \mathrm{RoPE}(W^{KR} h_t) \in \mathbb{R}^{B \times L \times 64} $$,对应第43号公式。

  • 与Q不同的是,完整的K是将K的第二部分广播到每个head后与第一部分拼接得到:

    \[ k_t = \begin{bmatrix} k_{t,1}^C & k_t^R \\ k_{t,2}^C & k_t^R \\ \vdots & \vdots \\ \end{bmatrix} \in \mathbb{R}^{B \times L \times H \times 192} \]

    也就是说,每个head的RoPE部分是完全相同的。此处对应第44号公式。再强调下:对于query,每个head有自己的旋转位置编码向量;key则是所有heads共享同一个旋转位置编码向量。

  • V向量的计算较为简单,直接将\(c_t^{KV}\)解压缩到\(\mathbb{R}^{H \times 128}\)即可:$$ v_t = W^{UV} c_t^{KV} \in \mathbb{R}^{B \times L \times H \times 128} $$,对应第45号公式。

通过维度分析可以看到 kv_lora_rank 是 qk_nope_head_dim 的 4 倍且 K 和 V 共享 latent state,qk_rope_head_dim 只有 qk_nope_head_dim 的一半,结合起来 4+1/2=9/2,是 正式下图中 MLA KVCache per Token 大小的来源。

具体的代码实现如下,可以发现除了在对q做计算时涉及到gemv之外,也就是q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))),其它地方的矩阵乘运算q_len维度都是和num_heads在一起做计算,而num_heads在Deepseek2的配置里面已经是128了,导致其它的Matmul几乎都落在了计算密集的范畴。

# 使用MQA(Multi-Query Attention)对输入的隐状态进行处理,得到压缩后的键值对表示(compressed_kv),对应41号公式和43号(还没有加 rope)。此时compressed_kv就是公式中的c_t^{KV}+W^{KR}h_t,形状是[B, q_len, kv_lora_rank + qk_rope_head_dim],kv_lora_rank是d_t
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)

# 将压缩后的键值对表示分为两部分:低秩压缩的键值对部分和经过位置编码的键部分(k_pe),分别对于nope和rope。这是第44号公式的前置准备操作,或者说是44号公式的反向操作
# 此时compressed_kv才是公式中的c_t^{KV},k_pe是公式中的W^{KR}h_t
compressed_kv, k_pe = torch.split(
    compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)

# k_pe 要传给 RoPE模块,所以需要重整下形状,增加注意力头这个维度
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)

# 计算得到k^C和v^C
# 1. 对压缩后的键值对升维,包括RMSNorm(self.kv_a_layernorm)和全连接层(self.kv_b_proj,对应W^{UK}和W^{UV}),是42号和45号公式结合体的前半部分,得到W^{UK}c^{KV}_t(k^C_t)和W^{UV}c^{KV}_t(V^C_t),但此时k^C_t和V^C_t是拼接在一起的
# 2. 用view()和transpose()函数将MLA展开成标准MHA的形式。注意:此处增加了一个头维度
kv = (
    self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
    .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
    .transpose(1, 2)
)

# 使用torch.split函数将k^C_t和V^C_t分离开,是42号和45号公式结合体的后半部分。因为 kv_b_proj 包括 W^{UK} 和 W^{UV},因此要把它们的计算结果分离出来,分别在不同的地方吸收,最终拆分成两部分:
# k_nope是没有经过位置编码的键部分,不包含位置信息。维度为[B, num_head, kv_seq_len, qk_nope_head_dim]
# value_states是值部分,用于后续的位置编码和注意力权重计算,维度为[B, num_head, kv_seq_len, v_head_dim]
k_nope, value_states = torch.split(
    kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
)

# 获取key/value的序列长度,即包含当前位置可用上下文的长度
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
    kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

# 调用self.rotary_emb函数,根据value_states和更新后的序列长度kv_seq_len计算RoPE的cos和sin值
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

# 使用apply_rotary_pos_emb函数对W^{KR}h_t施加RoPE,得到k_t^R,即k_pe变量
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

# 初始化键状态(key_states)的张量,存储融合了解耦RoPE的键表示
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope # k^C_t
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe # k^C_t + k_t^R

4.5 注意力操作

4.5.1 变量定义

o_proj对应矩阵\(W^O\),大小为[num_heads * v_head_dim, hidden_states]=[128 * 128, 5120]。

self.v_head_dim = config.v_head_dim # 128
self.num_heads = config.num_attention_heads # 128
self.hidden_size = config.hidden_size # 5120

self.o_proj = nn.Linear( # 对应第47号公式
    self.num_heads * self.v_head_dim,
    self.hidden_size,
    bias=config.attention_bias,
)

4.5.2 变量操作

生成 QKV 向量之后的流程就基本上等同于标准的 MHA 计算了。唯一的区别在于只有 q_pe, k_pe 这两个部分给加上了 rope。具体流程如下:

首先计算attention score:

\[ a = \mathrm{softmax}\left(\frac{q_t^\top k_t + \mathrm{Mask}}{\sqrt{192}}\right) = \mathrm{softmax}\left(\frac{{q_t^C}^\top k_t^C + {q_t^R}^\top k_t^R + \mathrm{Mask}}{\sqrt{128 + 64}} \right) \in \mathbb{R}^{B \times L \times H \times L} \]

然后计算对V的加权和,并将所有head压平,得到Attention输出:

\[o = a \cdot v_t \in \mathbb{R}^{B \times L \times H \times 128} \cong \mathbb{R}^{B \times L \times 16384} \]

最后经过另一个矩阵的投影,就能得到MLA的最终输出:

\[u = W^O o \in \mathbb{R}^{B \times L \times 5120} \]

# 更新和拼接历史 KVCache,将当前位置之前的压缩后的kv以及应用过rope的k的部分拼接进去,可以看到这里存储的是展开后的 MHA KVCache
if past_key_value is not None:           
    cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
    key_states, value_states = past_key_value.update( # 更新kv cache
        key_states, value_states, self.layer_idx, cache_kwargs
    )

# 后续就是标准的 MHA 代码,代码 Q^T*K*V*O
attn_weights = (
    torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
)

if attention_mask is not None:
    attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(
    attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
    attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
    attn_weights = None

return attn_output, attn_weights, past_key_value

4.6 前向传播

我们把完整的前向传播代码摘录如下,大家可以更好的理解。

def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None, # V2代码中,kv cache存储的是全部缓存,不是压缩后的
    output_attentions: bool = False,
    use_cache: bool = False,
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    # hidden_states对应公式中的h_t,的shape是(batch_size, seq_length,hidden_size)
    bsz, q_len, _ = hidden_states.size()

    q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
    q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
    q_nope, q_pe = torch.split(
        q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
    )

    compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
    compressed_kv, k_pe = torch.split(
        compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
    )
    k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
    kv = (
        self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
        .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
        .transpose(1, 2)
    )

    k_nope, value_states = torch.split(
        kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
    )
    kv_seq_len = value_states.shape[-2]
    if past_key_value is not None:
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

    q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

    query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
    query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
    query_states[:, :, :, self.qk_nope_head_dim :] = q_pe

    key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
    key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
    key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
        key_states, value_states = past_key_value.update(
            key_states, value_states, self.layer_idx, cache_kwargs
        )

    attn_weights = (
        torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
    )

    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(
        attn_weights, dim=-1, dtype=torch.float32
    ).to(query_states.dtype)
    attn_weights = nn.functional.dropout(
        attn_weights, p=self.attention_dropout, training=self.training
    )
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

对应如下图例。

4.7 V3 代码

我们也给出V3代码具体如下。V3中的 native 版本其实并没有节省KV-Cache(甚至还多了存储),V3版本的非native版本是跟论文一致,节省了显存。

native 版本的实现直观、适合学习,但是不适合Decode阶段,因为Decode阶段需要用到KV Cache。针对KV Cache,native 版本的实现有两种选择:

  • ① 缓存 Latent KV。缓存规模小,矩阵运算是\((b,n_h,1,d_c) \times (b,1,s,d_c)\),假定是bfloat16精度,内存读取量是\(2bn_hd_c + 2bsd_c = 2bd_c(n_h+s)\)。但 Latent KV 缓存不能直接送 MHA 计算,还得经过 \(W^{UK}\)\(W^{UV}\) 的线性映射,这是两个规模不小的矩阵计算,而且每轮都得重复计算。

  • ② 缓存 KV。缓存规模大,不用重复计算,性能好。标准MHA \((b,n_h,1,d_h) \times (b,n_h,s,d_h)\)的内存读取量是\(2bn_hd_h+2bn_hsd_h = 2bd_hn_h(1+s)\)。但 MLA 的一大好处就是 KV Cache 压缩,这样显存内能缓存更多 token,支持更大的 batch 和 prefix cache。如果缓存 KV,在显存上对比 MHA 就完全没有优势了。

native 版本最终的选择是方案2。所以,Naive 实现可能会用于 Prefill阶段,但在 Decode 计算时需要更好的计算方法,也就是非native版本。在非native版本最核心的 Attention kernel 计算中,“吸收“模式下 K/V tensor Shape 中不携带 num_attn_heads 信息,计算逻辑转换成了类 MQA 计算,“不吸收”模式下 K/V tensor 仍携带 num_attn_heads,就是MHA计算。

# from: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MLA(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.n_heads = args.n_heads
        self.n_local_heads = args.n_heads // world_size
        # 对应 query 压缩后的隐向量的维度 d'_c
        self.q_lora_rank = args.q_lora_rank # q 压缩后的维度
        # 对应 key-value 压缩后的隐向量维度 d_c
        self.kv_lora_rank = args.kv_lora_rank # kv 压缩后的维度
        # 表示query和key的向量中应用rope部分的维度, $d_h$
        self.qk_nope_head_dim = args.qk_nope_head_dim
        # 对应$d_h^R$, 表示应用了rope的 queries 和 key 的一个 head 的维度。
        self.qk_rope_head_dim = args.qk_rope_head_dim
        # $d_h + d_h^R$, 注意力头大小为非rope部分大小加上rope部分大小
        self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
        self.v_head_dim = args.v_head_dim

        if self.q_lora_rank == 0:
            # 不适用低秩分解,回归到传统MHA
            self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
        else:
            # 其实就是$W^{DQ}$,用来生成$c_t^Q$
            # 下采样矩阵,得到压缩后的q向量
            self.wq_a = Linear(self.dim, self.q_lora_rank)
            self.q_norm = RMSNorm(self.q_lora_rank)
            # $W^{UQ}$
            # 上采样矩阵,用来恢复q向量
            self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
        # $[W^{DKV}; W^{KR}]$    
        # 下采样矩阵,得到压缩后的kv向量    
        self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        # 上采样矩阵,用来恢复kv向量
        # $[W^{UK}; W^{UV}]$
        self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
        self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
        self.softmax_scale = self.qk_head_dim ** -0.5
        if args.max_seq_len > args.original_seq_len:
            mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
            self.softmax_scale = self.softmax_scale * mscale * mscale
         
        if attn_impl == "naive": # native模式下,kvcache存储的是没有压缩的数据,大小为d_h + d_h^R, 不但没有节省,反而增加了显存消耗   
            self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
            self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
        else:
            # 在非native模式下,存储的是压缩的c,大小为d_c
            self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
            self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        bsz, seqlen, _ = x.size()
        end_pos = start_pos + seqlen
        # 计算q
        if self.q_lora_rank == 0:
            q = self.wq(x)
        else:
            q = self.wq_b(self.q_norm(self.wq_a(x)))
        q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
        # 分离nope,rope
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        # 执行RoPE计算
        q_pe = apply_rotary_emb(q_pe, freqs_cis)
        kv = self.wkv_a(x)
        # KV-Cache大小为wkv_a outputdim(self.kv_lora_rank + self.qk_rope_head_dim)
        # 分离KV和K位置编码
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        # 执行RoPE计算
        k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
        if attn_impl == "naive":
            q = torch.cat([q_nope, q_pe], dim=-1)
            kv = self.wkv_b(self.kv_norm(kv))
            kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
            self.k_cache[:bsz, start_pos:end_pos] = k # 存储的是完全没有压缩的k
            self.v_cache[:bsz, start_pos:end_pos] = v # 存储的是完全没有压缩的v
            # score = q^T \times k_cache
            scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
        else:
            # 处理KV u-pprojection矩阵
            wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) 
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
            # q_{nope} = q_{nope} \times W^{UK}
            # q中不需要位置编码的先和K的不需要位置编码的权重相乘
            q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
            self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) # 保存KV Cache
            self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) # 保存K的位置编码Cache(pe cache)
            # scores = q_{nope}^T \times kv\_cache + q_{pe}^T \times pe\_cache
            scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                      torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
        if mask is not None:
            scores += mask.unsqueeze(1)
        scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
        if attn_impl == "naive":
            # score \times v_cache
            x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
        else:
            # u = W^{UV} \times scores \times kv\_cache
            x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
            # out = W^O \times u
            x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
        x = self.wo(x.flatten(2))
        return x

具体比对如下图。

0x05 优化代码

DeepSeek代码并没有给出某些功能的具体方案,比如压缩优化和权重吸收。因此,我们主要以章明星老师给出的方案 https://github.com/madsys-dev/deepseekv2-profile/tree/main DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子为例进行学习。

5.1 压缩优化

目前V2代码中,Attention中的KV Cache缓存的仍然是全量的key和value(从隐向量又解压缩出来),而并非论文中所说的压缩后的compressed_kv以及k_pe,导致其实没有减少KV Cache的缓存。

主要原因可能是:一方面复用transformers原有的Cache逻辑,方便实验和理解;另一方面这部分应该是训练代码,而推理代码会针对这部分进行优化和改进。

我们可以做如下修改,也将RoPE后的k_pe一并缓存入KV Cache中。

# 将当前位置之前的压缩后的kv(c_t^{kv})以及应用过rope的k的部分拼接到KV Cache前面
if past_key_value is not None:
    # 得到的应该是
    # compressed_kv: [B, kv_seq_len, d_c]
    # k_pe: [B, 1, kv_seq_len, qk_rope_head_dim]
    compressed_kv, k_pe = past_key_value.update(compressed_kv, k_pe)

章明星老师给出了更详尽的方案。

# CacheCompressed
def forward(self, hidden_states_q: torch.Tensor, q_position_ids: torch.LongTensor, compressed_kv: torch.Tensor):
    ...
    kv_seq_len = compressed_kv.size(1)
    # 对应完整公式的 44 行反过来
    compressed_kv, k_pe = torch.split(
        compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
    )
    k_pe = k_pe.view(bsz, kv_seq_len, 1, self.qk_rope_head_dim).transpose(1, 2)
    kv = self.kv_b_proj(compressed_kv) \
        .view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) \
        .transpose(1, 2)
    
    k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
    ... 
    
def compress_kv(self, hidden_states_kv: torch.Tensor, kv_position_ids: torch.LongTensor) -> torch.Tensor:
    # return the RoPE'ed & compressed kv
    bsz, kv_seq_len, _ = hidden_states_kv.size()
    compressed_kv = self.kv_a_proj_with_mqa(hidden_states_kv) 
    compressed_kv, k_pe = torch.split(
        compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
    )
    compressed_kv = self.kv_a_layernorm(compressed_kv)
    k_pe = k_pe.view(bsz, kv_seq_len, 1, self.qk_rope_head_dim).transpose(1, 2)
    cos, sin = self.rotary_emb(k_pe) 
    k_pe = apply_rotary_pos_emb(k_pe, cos, sin, kv_position_ids).view(bsz, kv_seq_len, self.qk_rope_head_dim)
    return torch.cat([compressed_kv, k_pe],dim=-1) 

5.2 权重吸收

在计算MLA的时候,仍然需要存储解压后的完整的KV Cache,这很可能引起OOM崩溃。DeepSeek-V2的论文中提出,可以将KV的解压缩矩阵吸收到Q-projection和Out-projection中,从而可以在不解压缩KV Cache的情况下直接计算最终的Attention结果。

实际上,把权重吸收理解成矩阵乘法交换律更合适。因为实际上是提前将两个参数矩阵乘起来,即把 \((W^{UQ})^TW^{UK}\) 的计算结果做为新的参数矩阵,然后再跟中间张量乘,在性能上不一定比分开计算更好。

下图分别给出了MHA、MLA和权重吸收的MLA的计算示例。最右侧的两个虚线箭头,显示了在优化的计算过程中,哪些参数矩阵被交换了位置。它们能交换的原因,就是从数学上这样修改是等价的(矩阵乘法交换律)。此时,输入注意力机制的 q、k、v 形状发生了明显的变化。q 的形状由 $$[n_h \times (d_h+d_h^R)]$$ 变化成了 $$[n_h \times (d_c+d_h^R)]$$,k 的形状由 \([n_h \times (d_h + d_h^R)]\) 变化成了 \([n_h \times (d_c + d_h^R)]\),v 的形状由 \(d_h\) 变化成了 \(d_c\)。这样一来,新的计算过程中只剩下 Latent KV 了。原来的 KV 就不存在了,变成可以用Latent KV表示。而且实际上 V 也不存在了,因为 V 就是 K 的前 512 维。这其实就是MQA,这实际上就是 FlashMLA 代码库解决的问题。

我们接下来依据章老师的代码和文字来继续学习。

5.2.1 absorbed_cache_compressed.py

与论文不同,此处将代码中 kv_b_proj 中属于 K 的部分权重(论文中对应\(W^{UK}\))吸收进 q_nope(论文中对应 \(q^C\),而且是在运行时做,非提前吸收);将代码中 kv_b_proj 中属于 V 的部分权重(论文中对应\(W^{UV}\))吸收进 attn_out。抽象一点的理解就是,将 Q 也映射到 KV 的低秩空间,然后在低秩空间做完整的 Attention,之后再映射回 Q 的原始空间。

\(W^{UK}\)

对于K的吸收,在注意力分数的计算公式中,非RoPE部分可以做如下展开:

\[{q_t^C}^\top k_t^C = (W^{UQ} c_t^Q)^{\top} W^{UK} c_t^{KV} = {c_t^Q}^{\top}{W^{UQ}}^{\top} W^{UK} c_t^{KV} = ({c_t^Q}^{\top}{W^{UQ}}^{\top} W^{UK}) c_t^{KV} \]

也就是说,我们事实上不需要每次都将低维的\(c_t^{KV}\)展开为\(k_t\)再计算,而是通过矩阵乘法结合律,直接将 \(W^{UK}\) 通过结合律先和左边做乘法,改为计算,避免了解压缩出完整的K矩阵。即将前三者进行计算:

\[attention\_weights = ({c_t^Q}^{\top}{W^{UQ}}^{\top} W^{UK}) c_t^{KV} \]

此外,在原始版本的解压缩的过程中,由于每个token的key都需要与\(W^{UK}\)相乘才能得到,因此计算量较大;矩阵吸收后,\(W^{UK}\)只需要对\(q_t^C\)这一个向量相乘,也大大减少了浮点计算量。

# Absorbed_CacheCompressed
def forward(hidden_states_q: torch.Tensor, q_position_ids: torch.LongTensor, compressed_kv: torch.Tensor):
    ...
    # 从 kv_b_proj 中分离的 W^{UK} 和 W^{UV} 两部分,他们要分别在不同的地方吸收
    kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
    q_absorb = kv_b_proj[:, :self.qk_nope_head_dim,:]
    out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]
    
    cos, sin = self.rotary_emb(q_pe)
    q_pe = apply_rotary_pos_emb(q_pe, cos, sin, q_position_ids)
    
    qk_head_dim = self.kv_lora_rank + self.qk_rope_head_dim
    query_states = k_pe.new_empty(bsz, self.num_heads, q_len, qk_head_dim)
    # 此处改变了q_nope的计算顺序,把 W^{UK} 吸收到 W^{UQ}
    query_states[:, :, :, : self.kv_lora_rank] = torch.einsum('hdc,bhid->bhic', q_absorb, q_nope)
    query_states[:, :, :, self.kv_lora_rank :] = q_pe
    
    ...

    attn_weights = nn.functional.softmax(
        attn_weights, dim=-1, dtype=torch.float32
    ).to(q_nope.dtype)
    # 此处改变了attn_output的计算顺序
    attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
    attn_output = torch.einsum('bhqc,hdc->bhqd', attn_output, out_absorb)
    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
    attn_output = self.o_proj(attn_output)

除了压缩KV Cache之外,我们还可以观察到上面涉及到的2个矩阵乘法实际上都来到了计算密集的领域,例如对于 torch.einsum('bhqc,blc->bhql', q_nope, compressed_kv) 。由于不同 head 的 q_nope 部分共享了共同的 compressed_kv 部分,实际计算的是 batch_size 个 [head_num * q_len, kv_lora_rank] 和 [past_len, kv_lora_rank] 的矩阵乘法。计算等价于一个 MQA 操作,计算强度正比于 head_num 的也就是 128。因此相比 MHA,吸收后的 MLA 计算强度要大得多,可以更加充分的利用 GPU 算力。

\(W^{UV}\)

对于V的吸收,情况稍微复杂。为表述的清楚性,我们采用Einstein求和约定描述该过程

v_t = einsum('hdc,blc->blhd', W_UV, c_t_KV) # (1)
o   = einsum('bqhl,blhd->bqhd', a, v_t)     # (2)
u   = einsum('hdD,bhqd->bhD', W_o, o)       # (3)

# 将上述三式合并,得到总的计算过程
u   = einsum('hdc,blc,bqhl,hdD->bhD', W_UV, c_t_KV, a, W_o)

# 利用结合律改变计算顺序
o_  = einsum('bhql,blc->bhqc', a, c_t_KV) # (4)
o   = einsum('bhqc,hdc->bhqd', o_, W_UV)  # (5)
u   = einsum('hdD,bhqd->bhD', W_o, o)     # (6)

5.2.2 Move Elision

不过,这样还不能完全发挥出MLA的威力。在原始代码中,query_states和key_states会通过拼接RoPE和非RoPE部分得到:

def forward(...):
    ...
    # 更新和拼接历史 KVCache,可以看到这里存储的是展开后的 MHA KVCache
    # 其中 q_head_dim 等于 qk_nope_head_dim + qk_rope_head_dim    
    query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
    query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
    query_states[:, :, :, self.qk_nope_head_dim :] = q_pe

    key_states = k_pe.new_empty(bsz, self.num_heads, kv_seq_len, self.q_head_dim)
    key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
    key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
    ...

当我们采取了上述优化后,此处的拼接过程会产生大量无用的数据拷贝和广播,同时也会占用大量显存空间导致OOM,而且如果是concat放在框架做,但可能会增加IO,尤其是decode本就是IO瓶颈。而且,先对Latent解压缩再计算,则Attn的计算是一个实打实的Multi Head Attention,会增大计算量。

为此,我们采用MoveElision优化策略,即省略此处的拼接RoPE部分和非RoPE部分的过程,而是直接分别计算量部分的Attention Score并相加(考虑\(q_t^\top k_t = {q_t^C}^\top k_t^C + {q_t^R}^\top k_t^R\))。即,将 RoPE 部分与 NoPE 部分分别做乘法,然后进行拼接的操作,改为 NoPE 部分 Attention 和 RoPE 部分 Attention 两者结果相加,这样做的好处在于节省了内存搬运操作,这种做法等效于ALiBi。我们具体推导如下。

\[[q^{\top}_{t,i}k_{j,i}] = [{c_t^Q}{W^{UQ}}^{\top},q_t^{R^\top}]\begin{bmatrix}W^{UK}c_t^{KV}\\ k_t^R \end{bmatrix} = c_t^Q{W^{UQ}}^{\top}W^{UK}c_t^{KV} + q_t^{R^{\top}}k_t^R \]

具体对应下面代码中的torch.matmul(q_pe, k_pe.transpose(2, 3))这行。即,分开计算了RoPE部分的q和k的注意力计算再求和。标准实现是将加上了 rope 的 q_pe/k_pe 和没加 rope 的 q_nope/k_nope 拼接起来一起。

# Absorbed_CacheCompressed_MoveElision
def forward(...):
    ...
    # qk_head_dim = self.kv_lora_rank + self.qk_rope_head_dim
    # query_states = k_pe.new_empty(bsz, self.num_heads, q_len, qk_head_dim)
    # query_states[:, :, :, : self.kv_lora_rank] = torch.einsum('hdc,bhid->bhic', q_absorb, q_nope)
    # query_states[:, :, :, self.kv_lora_rank :] = q_pe

    # key_states = k_pe.new_empty(bsz, self.num_heads, kv_seq_len, qk_head_dim)
    # key_states[:, :, :, : self.kv_lora_rank] = compressed_kv.unsqueeze(1)
    # key_states[:, :, :, self.kv_lora_rank :] = k_pe

    # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale

    # 吸收后 attn_weights 直接基于 compressed_kv 计算不用展开
    attn_weights = torch.matmul(q_pe, k_pe.transpose(2, 3)) + torch.einsum('bhqc,blc->bhql', q_nope, compressed_kv)
    attn_weights *= self.softmax_scale
    ...

代码比对如下:

5.2.3 Materializing Projection Matrices

DeepSeek-V2的论文中说:

不过,似乎并没有必要再改变顺序,对模型参数进行预处理,将\(W^{UK}\)\(W^{UQ}\)相乘,以及将\(W^{UV}\)\(W^O\)相乘。这是因为,\(W^{UK}\)\(W^{UQ}\)相乘后的结果可以视为\(H\)个大小为\(1536 \times 512\)的低秩(不超过128)矩阵,而\(W^{UV}\)\(W^O\)相乘的结果可以视为\(H\)个大小为\(5120 \times 512\)的低秩矩阵。相比用这些特别大的低秩矩阵做投影,明显不如按照低秩分解形式依次相乘来得划算。因此,章老师认为这一步的优化并不是很有必要。

因为假设有矩阵 A[m,k],B[k,n],C[n,l],B 和 C 为低秩矩阵,依次相乘 A⋅B⋅C 需要的算力: 2mkn+2mnl=2mn⋅(k+l),而提前合并 D=(B⋅C),A⋅D 需要的算力:2mkl,当 n⋅(k+l)

具体代码如下:

def forward(self, hidden_states_q: torch.Tensor, q_position_ids: torch.LongTensor, compressed_kv: torch.Tensor):
    '''
    Attention masks and past cache are removed.
    Input: 
    - hidden_states_q: [bsz, q_len, hidden_size]
    - compressed_kv: [bsz, kv_len, kv_lora_rank]
    - position_ids: [bsz, q_len]
    '''
    bsz, q_len, _ = hidden_states_q.size()
    q_b_proj_rope, q_absorbed, out_absorbed = self.get_absorbed_proj()
    q = self.q_a_layernorm(self.q_a_proj(hidden_states_q))
    q_nope = torch.einsum('bqc,hdc->bhqd', q, q_absorbed)
    q_pe = torch.einsum('bqc,hdc->bhqd', q, q_b_proj_rope)
    
    cos, sin = self.rotary_emb(q_pe)
    q_pe = apply_rotary_pos_emb(q_pe, cos, sin, q_position_ids)
    kv_seq_len = compressed_kv.size(1)
    compressed_kv, k_pe = torch.split(
        compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
    )
    k_pe = k_pe.view(bsz, 1, kv_seq_len, self.qk_rope_head_dim)
    
    attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * elf.softmax_scale

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(
        attn_weights, dim=-1, dtype=torch.float32
    ).to(q_nope.dtype)
    attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
    attn_output = torch.einsum('bhqc,dhc->bqd', attn_output, out_absorbed)
    return attn_output

5.3 融合算子

另外,如果针对prefill和decode阶段进行不同处理,则在推理的时候Prefill 和Decode 走的逻辑不同。

  • 推理的时候 Prefill 是不做矩阵吸收的(原因是Prefill做矩阵吸收会增加计算量),MLA计算与普通的MHA计算大致相同,唯一的区别在于需要支持q/k和v/o使用不同的head_dim。

  • Decode 是要做矩阵吸收的,矩阵吸收ops 远小于矩阵不吸收。这是因为此时Q的长度是1,原本重复在KV 上做up projection的操作转移到了Q 上,让Q 投影到kv 的latent space 上,Q的长度远小于KV的长度,不需要对KV做重复做up projection。或者说,MLA的主要思路就是通过交换矩阵计算顺序,利用decode阶段query seq_len比较小的特点,优化矩阵计算开销,进而达到只存储Multi-head attention中hidden states cache,而不是key和value两个cache,进而降低一半KVCache存储的目的。

因此Decode阶段需要单独设计高效的融合算子,以便高效地与低秩kv-cache进行attention计算。

权重吸收之后,公式如下:

\[(p \cdot (c_{kv} \cdot W^{UV})) \cdot W^O = (p \cdot c_{kv}) \cdot (W^{UV} \cdot W^O) = (softmax(q_{nope} \cdot c_{kv} + q_{pe} \cdot k_{pe}) \cdot c_{kv}) \cdot W^{UV} \cdot W^O \]

可以用代码描述如下,即可以设计一个MQA算子来实现。

q_pe = W_QR(c_q)
q_nope = W_UQ_UK(c_q)
output = W_UV_O(MQA(q_pe, q_nope, c_kv, k_pe))

FlashAttention最初设计的初衷是减少对softmax矩阵储存的开销,其大小正比于 \(l_q \cdot l_{kv}\),占整体I/O的比值为:

\[ratio(softmax) = \frac{1}{1+\frac{H_{kv}}{H_{qo}} \frac{D}{L_{qo}}+\frac{D}{L_{kv}}} \]

对于推理阶段而言,\(l_q\) 其实是非常小的,不融合qk和pv两阶段的计算也能取得不错的效果。但是对于MLA而言,融合是必要的,这是因为:

  • MLA有较大的group ratio: \(𝐻_{𝑞𝑜}/𝐻_{𝐾𝑉}=128\) ,会增大softmax的占比。
  • MLA复用了key和value矩阵,因此如果我们不融合两阶段的话,前后两个算子将各自访问一遍KV-Cache,如果硬件的cache不够大的话,带宽利用率将无法超过50%。

0x06 转换

6.1 GQA

Group Query Attention(GQA)是MHA的一种变体,旨在减少KV缓存的开销。它将查询头分成多个组,每个组共享一个键和值对。这种方法通过减少键和值头的数量来降低KV缓存的大小,但可能会牺牲模型的表达能力。可以将GQA看作是MLA的一种特例。由于GQA是通过复制产生的,而MLA不受这种限制,表达能力更强。

尽管MLA在Deepseek V2/V3/R1中已经证明了其效率和有效性,但许多主要的模型提供商仍然依赖GQA。为了促进MLA的更广泛应用,论文“TransMLA: Multi-Head Latent Attention Is All You Need"提出了TransMLA,这是一种后训练方法,可以将广泛使用的基于GQA的预训练模型(例如LLaMA、Qwen、Mixtral)转换为基于MLA的模型。转换后,模型可以进行额外的训练以增强表达能力,而不会增加KV缓存的大小。

6.1.1 思路

论文首先证明了对于相同的KV缓存开销,MLA的表达能力总是大于GQA。具体来说,任何GQA配置都可以等价地转换为MLA表示,但反之不然。这一结论为将基于GQA的模型转换为基于MLA的模型提供了理论基础。

在等价转换过程中,TransMLA方法首先将GQA中的键矩阵进行复制,以匹配查询头的数量。然后,它将这个复制后的键矩阵分解为两个较小矩阵的乘积,从而得到MLA中的低秩表示。通过这种方法,TransMLA可以在不增加KV缓存大小的情况下,将基于GQA的模型转换为基于MLA的模型。

6.1.2 方案

第一步是复制key矩阵,以匹配查询头的数量。在GQA中,为使标准多头注意力计算时,𝑄和𝐾(以及𝑉)具有相同数量的头,需要对𝐾进行扩展,从\(n_k\)个头扩展到\(n_q\)个头。这其实也有两种方法。

  • 定义复制因子\(𝑠=\frac{𝑛_𝑞}{𝑛_𝑘}\)\(𝑛_𝑞\)为𝑄的头数,\(𝑛_𝑘\)为𝐾的头数),将𝐾按列划分为\(𝑛_𝑘\)个块\(𝐾^{(𝑖)}\),通过将每个\(𝐾^{(𝑖)}\)复制𝑠次并连接,得到扩展矩阵𝐾′。具体见下图(a)。
  • 另一种方法是将复制操作移到参数侧(其实也是使用MHA替代GQA的方法),即在计算K之前,先复制投影矩阵\(W_K\)。先将\(𝑊_𝐾\)按列拆分为\(𝑛_𝑘\)个部分\(𝑊_𝐾^{(𝑖)}\),然后复制每个\(𝑊_𝐾^{(𝑖)}\) 𝑠次并连接,形成新的投影矩阵\(𝑊'_𝐾\),再应用\(𝑊'_𝐾\)到𝑋直接得到\(𝐾′=𝑋𝑊′_𝐾\),此方法与先计算𝐾再复制其头在数学上是等效的。具体见下图(b)。

由于\(𝑊'_𝐾\)由复制\(𝑊_𝐾\)形成,其自由度最多为\(𝑛_𝑘𝑑_ℎ\),因此它的秩最多为\(𝑛_𝑘𝑑_ℎ\)。为了更正式地理解这一点,通过奇异值分解(SVD)对\(𝑊'_𝐾\)进行分解:\(𝑊'_𝐾=𝑈_𝐾𝑆_𝐾𝑉_𝐾^⊤\) ,其中\(𝑈_𝐾\)\(𝑉_𝐾\)是𝐷×𝐷正交矩阵,\(𝑆_𝐾\)是包含奇异值的𝐷×𝐷对角矩阵。只有前\(n_kd_h\)(或更少)的奇异值可能是非零的。因此,可以截断SVD,只保留前 r 个奇异值,其中$ r \le n_kd_h\(。则\)𝑊'_𝐾=𝑊_𝐾𝑎𝑊_𝐾𝑏\(且\)𝐾′=𝑋𝑊_𝐾𝑎𝑊_𝐾𝑏$ 。这样就将GQA的“重复KV”方案解释为类似MLA的低秩分解形式,在实际缓存时,仅需存储低秩表示\(𝑋𝑊_𝐾^𝑎\),在注意力计算时通过乘以\(𝑊_𝐾^𝑏\)恢复完整维度,增强了模型的表现力。

6.2 MHA

如何使原本为 MHA 训练的 LLMs(如 Llama)快速适应 MLA 进行推理,而无需从头开始预训练,既具有意义又充满挑战。论文“Towards Economical Inference: Enabling DeepSeek’s Multi-Head Latent Attention in Any Transformer-based LLMs” 第一种数据高效的微调方法MHA2MLA,用于*从MHA转换到MLA。该方法包含两个关键组件:

  • 对于partial-RoPE,论文从对注意力分数贡献较小的查询和键的维度中去除 RoPE。

  • 对于低秩近似,论文基于键和值的预训练参数引入联合SVD近似。

这些精心设计的策略使 MHA2MLA 仅使用极少部分(3‰至 6‰)的数据就能恢复性能,显著降低推理成本,同时能与 KV 缓存量化等压缩技术无缝集成。

6.2.1 partial-RoPE

为实现从标准 MHA 到 MLA 的迁移,论文提出 partial-RoPE 微调策略,从目标比例的维度中去除 RoPE 并转换为 NoPE。

MHA

MHA 的 Full-RoPE 通过特定频率的旋转将位置信息编码到查询和键中,具体如下图所示。

拆解

MLA中,\(k_i\)\([k_{i,nope};k_{i,rope}]\)组成,所以我们首先需要把MHA的\(k_{i,rope}\)也分解成这样的无RoPE编码和有RoPE两部分。

DeepSeek的MLA里面其实是在原始的每个head的不使用RoPE编码\(d_h\)维度上,再增加一个使用RoPE编码的\(d_h^R\)维度。但是我们现在只能把全长为\(d_h\)维度的\(k_{i,rope}\)进行拆解,把里面\(d_r,dr \ll d_h\)部分做RoPE编码。也就是\(r=\frac{d_r}{2}\)长度的2D子空间做旋转编码。

在注意力计算中,并非所有维度上的旋转位置编码(RoPE)都对结果有同等的贡献。Partial-RoPE 技术通过去除对结果贡献较小的维度上的 RoPE,减少了冗余计算。这就像是在一场考试中,抓住重点知识进行复习,避免在一些无关紧要的知识点上浪费时间。通过这种方式,Partial-RoPE 技术在不影响模型性能的前提下,有效提升了计算效率。

在从 Full-RoPE 转换到 Partial-RoPE 时,我们选择哪一部分子空间来做旋转编码呢?论文提出四种策略(主要是依据旋转的频率)来旋转 RoPE 编码的子空间。

  • 高频保留:保留 r 个旋转最快(高频)的子空间,即位置最靠前的个2D子空间。
  • 低频保留:保留 r 个旋转最慢(低频)的子空间。
  • 均匀采样:选择间隔相等的 r 个子空间,即不管是高频还是低频,按照等距离采样,这样高低频都分别有一部分。
  • 根据每个头2-norm贡献选择(Head-wise 2-norm Contribution):根据每个头中各子空间的 2-norm分数对所有子空间进行排序,选择前 r 个。第 r 个频率子空间对最终的attention logits的贡献有上界。

选择好了\(d_h\)维度中的\(d_r\)维度做RoPE位置编码,剩下的\(d_h - d_r\)部分我们就要当成当成MLA中的无位置编码部分,也就是\(q_{nope}\)。但是要注意DeepSeek的MLA中这部分维度是\(d_h\),我们这里是\(d_h - d_r\)

6.2.2 低秩近似

MHA中的\(k_i = W_kx_i,v_i=W_vx_i\)。我们已经使用上面的四种方法之一找到了需要做RoPE的部分,也就可以把\(W_k\)对应的部分取出来得到\(W^{KR}\)

我们也把\(W_k\)中对应非RoPE的部分参数提取出来:

\[k_{i,nope} = W_{k,nope}x_i \\ v_{i,nope} = W_{v,nope}x_i \]

我们的目标是从\(W_{k,nope},W_{v,nope}\)中构造出MLA中的\(W^{DKV}\)

从 Full RoPE 转换到 Partial RoPE 后,为得到 MLA 中 KV 缓存的第二个组件\(c_{i,kv}\),论文提出两种基于SVD的策略:解耦 SVD和联合 SVD,具体参见下图。

  • 解耦 SVD(\(SVD_{split}\)):分别对\(W_{k,nope}\)\(W_n\)进行截断 SVD 分解,分配\(d_{kv}/2\)个维度给每个矩阵。
  • 联合 SVD(\(SVD_{joint}\)):为保留\(K_{nope}\)和V之间的交互,对连接矩阵\([W_{k,nope},W_v]\)进行联合分解。这种分解方式更加贴合MLA的标准格式。

到这里,我们就处理完了key和value部分。query部分并不像DeepSeek里面的MLA一样再做低秩分解,而是把得到的query对应key中的nope和rope部分也分解成两部分。

0xFF 参考

DP MLA For DeepSeek In Sglang 是小肖啊

DeepSeek V3, R1, Janus-Pro系列模型方法解读 榴莲酥

【LLM算法】MLA 技术在 DeepSeek-R1 大显神通,清华 TransMLA 将 GQA 一键转换成 MLA SmartMindAI

首个参数高效微调框架:在任何LLMs中使用DeepSeek的MLA AcademicDaily00 [AcademicDaily]([removed]void(0)😉

【LLM算法】MLA 技术在 DeepSeek-R1 大显神通,清华 TransMLA 将 GQA 一键转换成 MLA SmartMindAI

DeepSeekV2之MLA(Multi-head Latent Attention)详解 一滴水的使命

DeepSeek模型解读:Scaling Law,MLA,MoE JMXGODLZ

还在用MHA?MLA来了DeepSeek-v2的MLA的总结和思考 rainbow

一文通透DeepSeek-V2(改造Transformer的中文模型):详解MoE、GRPO、MLA v_JULY_v

DeepSeekV2之MLA(Multi-head Latent Attention)详解 一滴水的使命

大模型KV Cache节省神器MLA学习笔记(包含推理时的矩阵吸收分析) BBuf

浅读 DeepSeek-V2 技术报告 AGI 梦工厂

用PyTorch从零开始编写DeepSeek-V2 Deephub

图解Mixtral 8 * 7b推理优化原理与源码实现 猛猿

从MHA到MLA看Attention优化:谈谈DeepSeek拼多多级的推理价格 扎波特的橡皮擦 [zartbot]([removed]void(0)😉

继续谈谈MLA以及DeepSeek-MoE和SnowFlake Dense-MoE 扎波特的橡皮擦 [zartbot]([removed]void(0)😉

关于 MHLA(Multi-Head Latent Attention)的一些分析 Zhengxiao Du

[LLM底座] 关于DeepSeek-V2中的MLA(含代码) 莫冉

deepseek-v2 MLA深度解析 单字卓

Deepseek-V2技术详解 队长

如何看待 DeepSeek 发布的 MoE 大模型 DeepSeek-V2? 郑华滨

缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA 苏剑林

DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子 ZHANG Mingxing

速读 deepseek v2(一) —— 理解MLA Bruce 仗剑走天涯

还在用MHA?MLA来了DeepSeek-v2的MLA的总结和思考 rainbow

如何看待 DeepSeek 发布的 MoE 大模型 DeepSeek-V2? - 知乎 (zhihu.com)

Deepseek-V2技术报告解读!全网最细! (qq.com) [包包算法笔记]([removed]void(0)😉 2

DeepSeek-V2高性能推理优化笔记:MLA优化 madsys-dev

GQA 论文阅读以及相关的思考 clvsit

LLM 加速技巧:Muti Query Attention deephub

大模型基础|注意力机制|MHA|稀疏|MQA|GQA 养生的控制人

Attention优化:Flash Attn和Paged Attn,MQA以及GQA miangangzhen

从头开始编写 LoRA 代码

大模型轻量级微调(LoRA):训练速度、显存占用分析 绝密伏击

MLKV:跨层 KV Cache 共享,降低内存占用 AI闲谈

继续谈谈MLA以及DeepSeek-MoE和SnowFlake Dense-MoE 扎波特的橡皮擦 [zartbot]([removed]void(0)😉

【深度学习】DeepSeek核心架构-MLA:剖析低秩联合压缩优化KV缓存、提升推理效率的技术细节 赵南夏 [南夏的算法驿站]([removed]void(0)😉

DeepSeek-R1模型架构深度解读(二)MLA [AI算法之道]([removed]void(0)😉

SGLang DP MLA 特性解读 BBuf [GiantPandaCV]([removed]void(0)😉

【LLM论文详解】MLA 技术在 DeepSeek-R1 大显神通,清华 TransMLA 将 GQA 一键转换成 MLA AI-PaperDaily [AI-PaperDaily]([removed]void(0)😉

TransMLA: Multi-Head Latent Attention Is All You Need

SGLang DP MLA 特性解读 BBuf [GiantPandaCV]([removed]void(0)😉

从代码角度学习和彻底理解 DeepSeek MLA 算法 chaofa用代码打点酱油

全网最细!DeepSeekMLA 多头隐变量注意力:从算法原理到代码实现 懂点AI事儿

deepseek技术解读(1)-彻底理解MLA(Multi-Head Latent Attention) 姜富春

[代码学习]deepseek-v2的inference code学习-MLA-part 1 迷途小书僮

[代码学习]deepseek-v2的inference code学习-MLA -part 3 迷途小书僮

[代码学习]deepseek-v2的inference code学习-MLA -part 4 迷途小书僮

[代码学习]deepseek-v2的inference code学习-MLA -part 2 迷途小书僮

缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA 苏剑林

DeepSeek开源FlashMLA之际从原理到代码详解MLA 杜凌霄 [探知轩]([removed]void(0)😉

首个参数高效微调框架:在任何LLMs中使用DeepSeek的MLA [AcademicDaily]([removed]void(0)😉

如何把预训练好的模型中的MHA变为MLA? 杜凌霄 [探知轩]([removed]void(0)😉

终于把 deepseek 中的多头潜在注意力机制搞懂了!! 程序员小寒 [程序员学长]

DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model

DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子

FlashInfer中DeepSeek MLA的内核设计

细说DeepSeek MLA矩阵消融 formath 2025-02-24

sglang mla 代码解析 hcy

DP MLA For DeepSeek In Sglang 是小肖啊

SGLang DP MLA 特性解读 BBuf

DeepSeek V2/V3中的MLA和Matrix Absorption ariesjzj

FlashInfer中DeepSeek MLA的内核设计 yzh119

终于把 deepseek 中的多头潜在注意力机制搞懂了!! 程序员小寒 [程序员学长]([removed]void(0)😉

DeepSeek 开源周第一天开源的项目 FlashMLA,有哪些亮点值得关注? SIY.Z

大模型KV Cache节省神器MLA学习笔记(包含推理时的矩阵吸收分析)

Qwen架构改造成Deepseek,再复现R1计划 孟繁续

DeepSeek V2 “多头潜在注意力”论文解读 (上) 大模型咖啡时间

Deepseek MLA 一定要做吸收吗? 代码搬运工

DeepSeek V3推理: MLA与MOE解析 Arthur

DeepSeek MLA引发的一些记忆碎片 YyWangCS

谈谈深度学习性能优化中的矩阵计算顺序 YyWangCS

[Deepseek v3技术报告学习] 1.MLA Duludulu

attention中的concat能不能换成相加? Zhai Feiyue

sglang mla 代码解析 hcy

MLA 实现理解 大卫

SGLang MLA 实现解析 BBuf

DeepSeek V3推理: MLA与MOE解析 Arthur

理解 FlashMLA 在 DeepSeek MLA 计算过程中的位置和作用 solrex [边际效应]

MLA 吸收之谜 拉航母的小朱

MLA原理介绍(极简版) opter

DeepSeek-V3/R1推理效率分析(v0.17) zartbot

DeepSeek V3/R1 推理效率分析(2): DeepSeek 满血版逆向工程分析 Han Shen

DeepSeek V3/R1 推理效率分析(3):Decode 配置泛化讨论 Han Shen

DeepSeek V3/R1 推理效率分析(1):关于DeepSeek V3/R1 Decoding吞吐极限的一些不负责任估计 Han Shen

MoE Inference On AnyScale MoE-On-AnyScale

基于 chunked prefill 理解 prefill 和 decode 的计算特性 Chayenne Zhao

LLM PD 分离背后的架构问题 极客博哥

deepseek MLA推理优化 屈屈臣氏

DeepSeek-V3 MTP 工程实现思考 极客博哥

一点浅见:deepep 为什么快? 云开

prefill 和 decode 该分离到不同的卡上么? Chayenne Zhao

[1. deepseek模型学习笔记](https://developnotes.readthedocs.io/zh-cn/latest/deepseek.html#id1) 李伟华

DeepSeek-V3 (671B) 模型参数量分解计算 ZihaoZhao

vLLM 深度解析:Deekseek and vLLM -1 stephenxi

DeepSeek MLA在SGLang中的推理过程及代码实现 榴莲酥

MHA->MQA->GQA->MLA的演进之路 假如给我一只AI

The Annotated Transformer https://nlp.seas.harvard.edu/2018/04/03/ention.html
Attention Is All You Need https://arxiv.org/pdf/1706.03762.pdf

Fast Transformer Decoding: One Write-Head is All You Need https://arxiv.org/pdf/1.02150.pdf

https://www.researchgate.net/figure/led-dot-product-self-attention-mechanism_fig1_363923096

GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints https://arxiv.org/pdf/5.13245.pdf

How Attention works in Deep Learning: understanding the attention mechanism in sequence models https://theaisummer.com/ention/

A simple overview of RNN, LSTM and Attention Mechanism https://medium.com/swlh/imple-overview-of-rnn-lstm-and-attention-mechanism-9e844763d07b

https://pytorch-forecasting.readthedocs.io/en/latest/_modules/pytorch_forecasting/models/temporal_fusion_transformer/_modules.html#ScaledDotProductAttention

浅谈Transformer的初始化、参数化与标准化 https://spaces.ac.cn/archives/0

https://theaisummer.com/self-attention/ ps://theaisummer.com/self-attention/

https://zhuanlan.zhihu.com/p/626820422 https://zhuanlan.zhihu.com/p/626820422
Are Sixteen Heads Really Better than One? https://arxiv.org/pdf/5.10650.pdf

This post is all you need(上卷)——层层剥开Transformer https://zhuanlan.zhihu.com/p/420820453

The Illustrated Transformer https://jalammar.github.io/ustrated-transformer/

Multi-Query Attention is All You Need https://blog.fireworks.ai/multi

DeepSeek MLA的序列并行和张量并行 YyWangCS

DP MLA For DeepSeek In Sglang 是小肖啊

SGLang MLA 实现解析 BBuf

Multi-Head Latent Attention (MLA) 详细介绍(来自Deepseek V3的回答) 银翼的魔朮师

MLA机制原理及代码研究 zrq96

DeepSeek面试通关(1)|MLA如何让推理效率飙升200%? 丁师兄大模型

DeepSeek-V2 MLA KV Cache 真的省了吗? 沉积岩

【Deepseek技术原理】第一篇:深度剖析和图解模型结构MLA 罗辑

From:https://www.cnblogs.com/rossiXYZ/p/18827618
DeepSeek MLA - 罗西的思考
100+评论
captcha