image-20250218215709695

摘要

长上下文建模对于下一代语言模型至关重要,然而标准注意力机制的高计算成本带来了显著的计算挑战。稀疏注意力为提高效率同时保持模型能力提供了有前景的方向。我们提出了NSA(可原生训练的稀疏注意力机制),通过算法创新与硬件特性对齐的优化相结合,实现了高效的长上下文建模。NSA采用动态分层稀疏策略,结合粗粒度Token压缩与细粒度Token选择,在保持全局上下文感知能力的同时确保局部精度。该方法通过两大关键创新推进稀疏注意力设计:(1)通过算术强度平衡的算法设计实现显著加速,并针对现代硬件进行实现优化;(2)支持端到端训练,在保持模型性能的前提下减少预训练计算量。如图1所示,实验表明采用NSA预训练的模型在通用基准测试、长上下文任务和基于指令的推理中均保持或超越全注意力模型。与此同时,NSA在处理64k长度序列时,在解码、前向传播和后向传播阶段相较全注意力机制均实现了显著加速,有效验证了该方案在整个模型生命周期中的效率优势。

介绍

image-20250218222325217

一种实现高效长文本建模的自然方法是利用 softmax 注意力的固有稀疏性。通过选择性地计算关键的查询-键对,可以在保持性能的同时显著降低计算开销。最近的进展通过各种策略展示了这种潜力,包括:KV-cache 驱逐方法、分块 KV-cache 选择方法,以及基于采样、聚类或哈希的选择方法。尽管这些策略很有前景,但现有的稀疏注意力方法在实际部署中往往表现不足。许多方法未能实现与其理论收益相当的加速;此外,大多数方法主要关注推理阶段,缺乏有效的训练时支持来充分利用注意力的稀疏模式,从而无法充分挖掘注意力的稀疏模式。为了解决这些局限性,有效部署稀疏注意力必须应对两个关键挑战:(1)硬件对齐的推理加速:将理论计算量的减少转化为实际的速度提升,需要在预填充和解码阶段进行硬件友好的算法设计,以缓解内存访问和硬件调度瓶颈;(2)训练感知的算法设计:通过可训练的算子实现端到端计算,以降低训练成本,同时保持模型性能。这些要求对于实际应用至关重要,以实现快速的长文本推理或训练。综合考虑这两个方面,现有方法仍然存在明显的差距。

为了实现更有效和高效的稀疏注意力机制,我们提出了 NSA,一种原生可训练的稀疏注意力架构,它集成了分层 Token 建模。如图2 所示,NSA 通过将键 (key) 和值 (value) 组织成时间块,并通过三个注意力路径处理它们来减少每个查询的计算量:压缩的粗粒度 Token、选择性保留的细粒度 Token 和用于局部上下文信息的滑动窗口。然后,我们实现专门的内核以最大限度地提高其实际效率。NSA 引入了两个与上述关键要求相对应的核心创新:(1)硬件对齐系统:优化块状稀疏注意力,以充分利用 Tensor Core 并优化内存访问,确保平衡的算术强度。(2)训练感知设计:通过高效的算法和反向算子实现稳定的端到端训练。这种优化使 NSA 能够支持高效部署和端到端训练。我们通过对真实世界语言语料库的综合实验来评估 NSA。在使用 260B Token 的 27B 参数 Transformer 主干上进行预训练后,我们评估了 NSA 在通用语言评估、长上下文评估和思维链推理评估中的性能。我们进一步比较了 A100 GPU 上内核速度与优化的 Triton 实现。实验结果表明,NSA 实现了与完整注意力基线相当或更优越的性能,同时优于现有的稀疏注意力方法。此外,与完整注意力相比,NSA 在解码、前向和后向阶段都提供了显著的加速,并且加速比随着序列长度的增加而增加。这些结果验证了我们的分层稀疏注意力设计有效地平衡了模型能力和计算效率。

image-20250218220013643

重新思考稀疏注意力方法

大多数方法主要在推理过程中应用稀疏性,同时保留预训练的完整注意力骨干网络,这可能会引入架构偏差,从而限制它们充分利用稀疏注意力的优势。在介绍我们原生的稀疏架构之前,我们通过两个关键视角系统地分析这些局限性。

高效推理的错觉

尽管在注意力计算中实现了稀疏性,但许多方法未能实现相应的推理延迟降低,这主要是由于两个挑战:阶段性稀疏。诸如 H2O 之类的方法在自回归解码期间应用稀疏性,同时在预填充期间需要计算密集型预处理(例如,注意力图计算、索引构建)。相比之下,像 MInference 这样的方法仅专注于预填充稀疏性。这些方法未能实现跨所有推理阶段的加速,因为至少一个阶段的计算成本与完整注意力机制相当。这种阶段专业化降低了这些方法在以预填充为主的工作负载(如书籍摘要和代码补全)或以解码为主的工作负载(如长链式思维推理)中的加速能力。

与高级注意力架构的不兼容。 一些稀疏注意力方法无法很好地适配现代高效解码架构,例如多查询注意力(MQA)和分组查询注意力(GQA)。这些架构通过在多个查询头之间共享 KV,显著降低了解码过程中的内存访问瓶颈。 例如,像 Quest 这样的方法中,每个注意力头独立地选择其 KV-cache 子集。 虽然这种方法在多头注意力(MHA)模型中表现出一致的计算稀疏性和内存访问稀疏性,但在基于 GQA 等架构的模型中,情况则有所不同。在 GQA 架构中,KV-cache 的内存访问量取决于同一 GQA 组内所有查询头选择的并集。 这种架构特性意味着,虽然这些方法可以减少计算操作,但所需的 KV-cache 内存访问量仍然相对较高。 这种限制迫使我们面临一个关键选择:虽然一些稀疏注意力方法减少了计算量,但它们分散的内存访问模式与高级架构中高效的内存访问设计相冲突。 这些限制的出现是因为许多现有的稀疏注意力方法侧重于 KV-cache 减少或理论计算减少,但难以在高级框架或后端中实现显著的延迟降低。 这促使我们开发算法,将先进的架构和硬件高效的实现相结合,从而充分利用稀疏性来提高模型效率。

可训练稀疏性的迷思

我们对原生可训练稀疏注意力机制的探索,源于对仅用于推理的方法进行分析后得到的两个关键洞察:(1) 性能下降:事后强加稀疏性会迫使模型偏离其预训练的优化轨迹。正如 @magicpig 所展示的,前 20% 的注意力只能覆盖总注意力分数的 70%,这使得预训练模型中诸如检索头之类的结构在推理时容易受到剪枝的影响。(2)训练效率需求:高效处理长序列训练对于现代大语言模型 (LLM) 的开发至关重要。这包括在更长的文档上进行预训练以增强模型容量,以及随后的适应阶段,例如长上下文微调和强化学习等。然而,现有的稀疏注意力方法主要针对推理,很大程度上忽略了训练中的计算挑战。这种局限性阻碍了通过高效训练来开发更强大的长上下文模型。此外,将现有稀疏注意力应用于训练的尝试也暴露了一些挑战:不可训练的组件。 诸如 ClusterKV(包含 k-means 聚类)和 MagicPIG(包含基于 SimHash 的选择)等方法中的离散操作会在计算图中造成不连续性。这些不可训练的组件会阻碍梯度流经 Token 选择过程,从而限制模型学习最优稀疏模式的能力。

低效的反向传播。一些理论上可训练的稀疏注意力方法在实际训练中存在效率问题。诸如 HashAttention 等方法中使用的 Token 粒度选择策略,导致在注意力计算过程中需要从 KV 缓存中加载大量独立的 Token。这种非连续的内存访问阻碍了 FlashAttention 等快速注意力技术的高效应用,后者依赖于连续内存访问和分块计算来实现高吞吐量。因此,实际应用中不得不退回到较低的硬件利用率,从而显著降低了训练效率。

原生稀疏性势在必行

推理效率和训练可行性的这些限制促使我们从根本上重新设计稀疏注意力机制。我们提出了 NSA,一个原生稀疏注意力框架,它解决了计算效率和训练要求。在以下章节中,我们将详细介绍 NSA 的算法设计和算子实现。

方法论

我们的技术方法涵盖算法设计和内核优化。在以下小节中,我们首先介绍我们方法论的背景。然后,我们介绍 NSA 的整体框架,接着是其关键算法组件。最后,我们详细介绍我们为最大限度提高实际效率而设计的硬件优化内核设计。

背景

$$ \mathbf{o}_t = \operatorname{Attn}\left(\mathbf{q}_t, \mathbf{k}_{:t}, \mathbf{v}_{:t}\right) $$$$ \operatorname{Attn}\left(\mathbf{q}_t, \mathbf{k}_{:t}, \mathbf{v}_{:t}\right) = \sum_{i=1}^t\frac{ \alpha_{t,i} \mathbf{v}_i}{\sum_{j=1}^t \alpha_{t,j}}, \quad \alpha_{t,i} = e^{\frac{\mathbf{q}_t^\top \mathbf{k}_i}{\sqrt{d_k}}}\,. $$

这里,$\alpha_{t,i}$ 表示 $\mathbf{q}_t$ 和 $\mathbf{k}_i$ 之间的注意力权重,$d_k$ 是键的特征维度。随着序列长度的增加,注意力计算在整体计算成本中变得越来越重要,这给长上下文处理带来了重大挑战。

算术强度是指计算操作与内存访问的比率。它本质上决定了硬件上的算法优化策略。每个 GPU 都有一个临界算术强度,这个值由其峰值计算能力和内存带宽决定,计算方法是这两个硬件限制的比率。对于计算任务而言,高于此临界值的算术强度会成为计算密集型(受 GPU FLOPS 限制),而低于此临界值的算术强度则会成为内存密集型(受内存带宽限制)。具体到因果自注意力机制,在训练和预填充阶段,批量矩阵乘法和注意力计算表现出较高的算术强度,使得这些阶段在现代加速器上属于计算密集型。相反,自回归解码则会受到内存带宽的限制,因为它每次前向传递只生成一个 Token,同时需要加载整个键值缓存,导致算术强度较低。这就引出了不同的优化目标——在训练和预填充阶段降低计算成本,而在解码阶段减少内存访问。

总体框架

$$ \tilde{K}_t = f_K(\mathbf{q}_t, \mathbf{k}_{:t}, \mathbf{v}_{:t}), \quad \tilde{V}_t = f_V(\mathbf{q}_t, \mathbf{k}_{:t}, \mathbf{v}_{:t}) $$$$ \mathbf{o}^*_t=\operatorname{Attn}\left(\mathbf{q}_t,\tilde{K}_t, \tilde{V}_t \right) $$$$ \mathbf{o}^*_t = \sum_{c \in \mathcal{C}} g_t^c \cdot \text{Attn}(\mathbf{q}_t, \tilde{K}_t^c, \tilde{V}_t^c). $$$$N_t = \sum_{c \in \mathcal{C}}\text{size}[\tilde{K}^c_t].$$

我们通过确保 ${N_t}{} \ll t$ 来维持高稀疏度。

算法设计

在本小节中,我们将介绍重映射策略 $f_K$ 和 $f_V$ 的设计:Token 压缩、Token 选择和滑动窗口。

Token 压缩

$$ \tilde{K}^\text{cmp}_t = f_K^\text{cmp}(\mathbf{k}_{:t}) = \left\{\varphi(\mathbf{k}_{i d+1: i d+l})\middle| 1\leqslant i\leqslant\left\lfloor\frac{t-l}{d}\right\rfloor\right\} $$

其中 $l$ 是块长度,$d$ 是相邻块之间的滑动步长,$\varphi$ 是一个可学习的 MLP,具有块内位置编码,用于将块中的键映射到单个压缩键。$\tilde{K}_t^\text{cmp}\in \mathbb{R}^{ d_k \times \left\lfloor\frac{t-l}{d}\right\rfloor }$ 是由压缩键组成的张量。通常,我们采用 $d < l$ 来减轻信息碎片化。对于压缩的值表示 $\tilde{V}_t^\text{cmp}$,存在类似的公式。压缩的表示捕获了更粗粒度的高级语义信息,并减少了注意力的计算负担。

Token 选择

仅使用压缩的键 (key),值 (value) 可能会丢失重要的细粒度信息,这促使我们有选择地保留单个键 (key)、值 (value)。下面我们描述了我们高效的 Token 选择机制,该机制以低计算开销识别并保留最相关的 Token。

分块式选择。 我们的选择策略在空间连续的块中处理键 (key) 和值 (value) 序列,这受到两个关键因素的推动:硬件效率考虑和注意力得分的固有分布模式。

分块式选择对于在现代 GPU 上实现高效计算至关重要。 这是因为与基于随机索引的读取相比,现代 GPU 架构对于连续块访问表现出明显更高的吞吐量。此外,分块式计算能够最佳地利用 Tensor Core。这种架构特性已将分块内存访问和计算确立为高性能注意力实现的基本原则,如 FlashAttention 的基于块的设计所例证的那样。

分块式选择遵循注意力得分的固有分布模式。 先前的工作表明,注意力得分通常表现出空间上的连续性,这表明相邻的键 (key) 倾向于共享相似的重要性级别。为了实现分块选择,我们首先将键(key)、值(value)序列划分为选择块。为了识别用于注意力计算的最重要块,我们需要为每个块分配重要性分数。下面我们将介绍我们计算这些块级别重要性分数的方法。

$$ \mathbf{p}_t^\text{cmp} = \operatorname{Softmax}\left(\mathbf{q}_t^T \tilde{K}_t^\text{cmp}\right), $$$$ \mathbf{p}_t^\text{slc}[j] = \sum_{m=0}^{\frac{l'}{d}-1}\sum_{n=0}^{\frac{l}{d} -1} \mathbf{p}_t^\text{cmp}\left[\frac{l'}{d}j+m +n \right], $$$$ {\mathbf{p}_t^{\text{slc}}}' = \sum_{h=1}^{H} \mathbf{p}_{t}^{\text{slc}, (h)},$$

其中上标$(h)$表示头索引,$H$是每个组中查询头的数量。这种聚合确保了同一组内各头之间的一致块选择。

$$ \mathcal{I}_t = \{i \mid \text{rank}({\mathbf{p}_t^\text{slc}}'[i]) \leqslant n\} $$$$ \tilde{K}^\text{slc}_t = \operatorname{Cat}\left[\{\mathbf{k}_{il'+1:(i+1)l'}|i \in \mathcal{I}_t\}\right], $$

其中 rank$(\cdot)$ 表示降序排列的位置,rank = 1 对应于最高分,$\mathcal{I}_t$ 是所选块的索引集合,$\operatorname{Cat}$ 表示连接操作。$\tilde{K}_t^\text{slc}\in \mathbb{R}^{ d_k \times nl' }$ 是由压缩键组成的张量。类似的公式适用于细粒度值 $\tilde{V}^\text{slc}_t$。然后,所选的键和值参与与 $\mathbf{q}_t$ 的注意力计算。

滑动窗口

在注意力机制中,局部模式通常适应得更快,并且可能主导学习过程,从而可能阻止模型有效地从压缩和选择 Token 中学习。为了解决这个问题,我们引入了一个专用的滑动窗口分支,该分支显式地处理局部上下文,从而允许其他分支(压缩和选择)专注于学习它们各自的特征,而不会被局部模式所绕过。具体来说,我们在窗口 $w$ 中维护最近的 Token $\tilde{K}_t^\text{win}=\mathbf{k}_{t-w:t}, \tilde{V}_t^\text{win}=\mathbf{v}_{t-w:t}$,并将不同信息源(压缩 Token、选择的 Token、滑动窗口)的注意力计算隔离到单独的分支中。然后,这些分支的输出通过学习到的门控机制进行聚合。为了进一步防止具有边际计算开销的注意力分支之间的捷径学习,我们为三个分支提供独立的键和值。这种架构设计通过防止局部和长程模式识别之间的梯度干扰来实现稳定的学习,同时引入最小的开销。在获得所有三个类别的键和值($\tilde{K}_t^\text{cmp}, \tilde{V}_t^\text{cmp}$;$\tilde{K}_t^\text{slc}, \tilde{V}_t^\text{slc}$;以及$\tilde{K}_t^\text{win}, \tilde{V}_t^\text{win}$)之后,我们按照 gate merge计算最终的注意力输出。结合上述压缩、选择和滑动窗口机制,这构成了 NSA 的完整算法框架。

内核设计

为了在训练和预填充阶段实现 FlashAttention 级别的加速,我们基于 Triton 实现了硬件对齐的稀疏注意力内核。 考虑到多头注意力机制(MHA)在解码过程中是内存密集型且效率较低,我们参考当前最先进的 大语言模型 (LLM),专注于采用共享 KV 缓存的架构,例如 GQA 和 MQA。 虽然压缩和滑动窗口注意力计算可以很容易地与现有的 FlashAttention-2 内核兼容,但我们针对稀疏选择注意力引入了专门设计的内核。 如果我们沿用 FlashAttention 的策略,将时间上连续的查询块加载到 SRAM 中,会导致内存访问效率低下,因为一个块内的查询可能需要不相交的 KV 块。 为了解决这个问题,我们的主要优化策略是采用不同的查询分组方式:对于查询序列上的每个位置,我们将一个 GQA 组内的所有查询头(它们共享相同的稀疏 KV 块)加载到 SRAM 中。图3展示了我们的前向传播实现。 所提出的内核架构具有以下关键特征:

  1. 以组为中心的数据加载。 对于每个内部循环,加载组中位置 $t$ 处的所有头的查询 $Q\in \mathbb{R}^{[ h, d_k]}$ 及其共享的稀疏键/值块索引 $\mathcal{I}_t$。
  2. 共享键值 (KV) 获取。 在内部循环中,顺序加载由 $\mathcal{I}_t$ 索引的连续键/值块到 SRAM 中,表示为 $K \in \mathbb{R}^{[B_k, d_k]}, V \in \mathbb{R}^{[B_k, d_v]}$,以最小化内存加载。其中 $B_k$ 是满足 $B_k | l'$ 的内核块大小。
  3. 网格上的外部循环。 由于内部循环的长度(与所选块计数 $n$ 成正比)对于不同的查询块几乎相同,我们将查询/输出循环放置在 Triton 的网格调度器中,以简化和优化内核。

此设计通过以下方式实现接近最佳的算术强度:(1) 通过组共享消除冗余的 KV 传输;以及 (2) 平衡 GPU 流式多处理器上的计算工作负载。 image-20250218220453912

实验

我们从三个方面评估 NSA 的性能:(1)通用基准测试性能,(2)长文本基准测试性能,以及(3)思维链推理性能。我们将 NSA 与完整注意力机制基线以及目前最先进的稀疏注意力方法进行比较。

预训练设置

遵循最先进的 大语言模型 (LLM) 中的常见做法,我们的实验采用了一个结合了分组查询注意力(GQA)和混合专家(MoE)的骨干网络,总参数为$27\text{B}$,其中激活参数为$3\text{B}$。该模型由30层组成,隐藏维度为2560。对于GQA,我们将组数设置为4,总共有64个注意力头。对于每个头,查询、键和值的隐藏维度分别配置为$d_q = d_k = 192$和$d_v = 128$。对于MoE,我们使用DeepSeekMoE结构,具有72个路由专家和2个共享专家,并将top-k专家设置为6。为了确保训练稳定性,第一层中的MoE被替换为SwiGLU形式的MLP。所提出的架构实现了计算成本和模型性能之间的有效权衡。对于NSA,我们设置压缩块大小$l=32$,滑动步长$d=16$,选择的块大小$l'=64$,选择的块计数$n=16$(包括固定激活的1个初始块和2个局部块),以及滑动窗口大小$w=512$。全注意力模型和稀疏注意力模型都在$270\text{B}$个 Token 的$8\text{k}$长度文本上进行预训练,然后通过YaRN在$32\text{k}$长度文本上进行持续训练和监督 微调,以实现长上下文适应。对这两个模型都进行了充分的收敛训练,以确保公平的比较。如图3所示,我们的NSA和全注意力基线的预训练损失曲线显示出稳定且平滑的下降,其中NSA始终优于全注意力模型。

image-20250218222416676

基线方法

除了与全注意力(Full Attention)进行比较外,我们还评估了几种最先进的推理阶段稀疏注意力方法:H2O、$v1$ infLLM、$v2$ Quest、$v3$ 和 Exact-Top。这些方法首先计算完整的注意力分数,然后选择对应于每个提示词(prompt)的前 $v4$ 个最高分数的键,然后基于这些位置计算注意力。这些方法涵盖了各种稀疏注意力范式,包括 KV 缓存驱逐、查询感知选择和精确的 top-$v5$ 稀疏选择。

对于一般评估,在大多数样本的长度都在稀疏注意力基线的局部上下文窗口内的情况下,这些方法实际上等同于全注意力(Full Attention)。因此,在这种情况下,我们仅展示 NSA 和全注意力(Full Attention)基线之间的比较结果。在长上下文评估中,我们对所有基线方法进行比较,所有稀疏注意力方法的稀疏度都设置为相同,以确保公平的比较。对于需要长文本监督微调的思维链推理评估,我们将比较限制为全注意力(Full Attention),因为稀疏注意力基线不支持训练。

性能比较

image-20250218220531859

image-20250218220606919

image-20250218222438056

image-20250218220704285

效率分析

image-20250218222451316

训练速度

我们将基于 Triton 实现的 NSA 注意力机制和完整注意力机制与基于 Triton 的 FlashAttention-2 进行比较,以确保在相同后端进行公平的速度比较。随着上下文长度的增加,我们的 NSA 实现了逐渐增加的加速,在 64k 上下文长度下,正向传播加速高达 9.0$\times$,反向传播加速高达 6.0$\times$。值得注意的是,速度优势随着序列长度的增加而更加明显。这种加速源于我们的硬件对齐算法设计,旨在最大限度地提高稀疏注意力架构的效率:(1)分块式内存访问模式通过合并加载最大限度地利用 Tensor Core,(2)内核中精细的循环调度消除了冗余的 KV 传输。

解码速度

Attention 的解码速度主要取决于内存访问瓶颈,这与 KV 缓存的加载量密切相关。在每个解码步骤中,我们的 NSA 只需要加载最多 $\left\lfloor\frac{s-l}{d}\right\rfloor$ 个压缩 Token,$nl'$ 个选定的 Token 和 $w$ 个邻居 Token,其中 $s$ 是缓存的序列长度。随着解码长度的增加,我们的方法在延迟方面表现出显著的降低,在 64k 上下文长度下实现了高达 11.6$\times$ 的加速。这种内存访问效率的优势也会随着序列长度的增加而放大。

讨论

在本节中,我们将回顾 NSA 的开发过程,并讨论从我们对不同稀疏注意力策略的探索中获得的关键见解。虽然我们的方法展示了有希望的结果,但理解替代策略遇到的挑战并分析注意力模式,为未来的研究方向提供了有价值的背景。我们首先检查替代 Token 选择策略所面临的挑战,这些挑战促使我们做出了设计选择;然后,通过可视化,可以深入了解注意力分布模式。