Transformer与图神经网络结合的算法推理模型
变换器(Transformers)以其简单而有效的架构彻底改变了机器学习。在互联网上大量的文本数据集上预训练变换器,为自然语言理解(NLU)任务带来了无与伦比的泛化能力。然而,当这些语言模型面临需要精确且稳健计算的算法形式推理任务时,它们仍然显得脆弱。为了解决这一局限性,我们提出了一种新的方法,将变换器的语言理解能力与基于图神经网络(GNN)的神经算法推理器(NARs)的稳健性结合起来。这些NARs已被证明是算法任务的有效通用求解器,当以图的形式指定时。为了使它们的嵌入能够被变换器访问,我们提出了一种混合架构和两阶段训练过程,允许语言模型中的标记跨注意力到NAR的节点嵌入。我们在CLRS-Text上评估了我们的结果TransNAR模型,这是CLRS-30基准的文本版本,并证明了在算法推理方面,无论是在分布内还是分布外,都比仅使用变换器的模型有显著提升。
引言
我们的TransNAR架构,它直接结合了变换器和神经算法推理器的协同作用,在CLRS-Text的广泛算法任务类别中,在分布外推理方面取得了明显改进。CLRS-Text是CLRS-30基准的文本版本。在这里,x轴表示CLRS-30的八个算法家族之一,y轴跨越了分布外示例数据集的平均执行准确率。TransNAR在此处展示的特定分布外领域中启用了新兴能力,在几个算法类别中有超过20%的绝对改进。
最近的工作激发了并展示了图神经网络在稳健解决各种输入大小的算法任务方面的有效性,无论是在分布内还是分布外——这样的系统通常被称为神经算法推理器。只要使用适当的归纳偏差,NARs即使在训练集中看到的输入大6倍的情况下,也能保持完美的泛化,用于具有长序列的高复杂性算法任务。然而,NARs仍然是相对狭窄的人工智能形式,因为它们需要严格结构化的输入格式,因此不能直接应用于以更嘈杂的形式提出的问题——例如在自然语言中——即使底层问题本质上仍然是算法性的。
增强LLMs的算法推理:TransNAR的鸟瞰图。一个大型语言模型(LLM)消耗输入标记并产生输出标记,这是单模态变换器的常见做法。神经算法推理器(NAR)模块是一个图神经网络(GNN),预先训练用于在基于图的输入集合上执行各种算法计算——预先训练的管道由褪色的箭头表示。在其前向传递过程中,变换器可以通过利用交叉注意力(通过可学习的“粘合”权重训练)来访问NAR计算的嵌入。
相反,目前无可争议的最先进方法用于模拟嘈杂文本数据是基于变换器的语言模型。尽管它们在自然语言理解属性方面无与伦比,但面对即使是最简单的算法任务时,它们也出了名的脆弱——特别是如果需要分布外泛化。
看来将变换器与NARs联合可以带来双方的丰硕回报。在本文中,我们首次探索这个接口,构建了TransNAR模型。
贡献
我们的探索取得了成果。我们在这项工作中呈现的关键要点如下:
- 我们提出了一种混合架构,结合了变换器的语言理解能力和预先训练的基于GNN的NAR的稳健推理能力。变换器使用NAR作为一个高维工具,将调节其标记嵌入。
- 我们通过在CLRS-Text上的评估,展示了这样一个增强了NAR的大型语言模型(LLM)在分布外展现出改进和更稳健的推理能力。
相关工作
我们的工作位于几个领域的交汇处:神经算法推理、语言模型的长度泛化、工具使用和多模态性。在这里,我们简要概述每个领域中各种相关的工作。由于观点的多样性,为了保持简洁,我们不提供相关工作的全面回顾,而是旨在提供对我们工作启发最大的特定工作的迹象。
神经算法推理
NAR通常是指构建能够捕捉算法计算的神经网络的艺术。通过算法对齐的选择、逐步训练或对比目标,这些能力可以得到增强。
最近,有研究表明:(1) 可以学习一个能够在其潜在空间中同时执行多个算法的NAR——Triplet-GMPNN就巧妙地为CLRS基准测试中的三十个算法集合做到了这一点;(2) 一旦训练完成,这样的NAR可以在各种下游任务中得到有效部署,包括强化学习、自监督学习、组合优化、计算生物学和神经科学。
我们对NAR的使用主要受到前面列出的两项工作的启发:我们使用一个相对较小的、预训练的、多任务NAR,并将其部署在一个规模更大的环境中——正如所引用的,NAR理论上应该能够扩展到比NAR训练分布大得多的系统。
长度泛化在LLMs中的应用
虽然NAR通常可以强烈泛化到更大的测试输入,但LLMs在这些场景中看到的成功要少得多。我们认为这归因于它们的自回归、因果掩蔽目标,这可能并不总是与算法输出预测的最逻辑顺序相对应。例如,通过逆序预测结果,可以显著提高各种LLMs在乘法上的性能。当然,在更复杂的算法上,确定最佳的输入排列方式可能要困难得多,并且可能不是最易于人类阅读的。
上述问题的认识导致大量努力投入到构建能够在长度上泛化的Transformer中。虽然长度泛化不是OOD推理感兴趣的唯一类型的分布偏移,但它是最容易模拟的偏移之一。因此,各种工作已经尝试通过使用仔细的提示、随机位置编码、课程或草稿纸来在LLMs中诱导长度泛化。我们坚信推理的一个重要特征是与提示质量的鲁棒性——只要提示明确指定了问题——因此在这里我们不探讨提示修改方法;只有随机位置[randomizedpe]被利用。
工具使用和多模态性
获得鲁棒泛化性能的另一种方式是通过教LLM调用其API来利用硬编码算法(也称为[emph{工具}])。可以说,LLMs在推理方面的大多数主要成功主要归功于LLM巧妙地使用工具,而不是LLM本身,因为工具按定义不会有泛化到不同输入的问题。
由于我们的目标是直接评估LLMs的推理能力,我们明确不允许在我们的基线中使用工具。也就是说,我们设想预训练的NAR作为Transformer嵌入的[emph{调制器}],它对OOD噪声更加鲁棒。因此,我们可能将NAR视为[emph{“内部工具”}]:Transformer和NAR可以使用它们的嵌入进行通信,打破相关的算法瓶颈。
如何实现这种通信和嵌入交换?为此,我们转向多模态LLMs寻求灵感,因为我们需要整合来自算法问题的两种不同表示(文本和图形)的信号。具体来说,我们的交换操作直接受到视觉语言模型(VLMs)和Flamingo中使用的交叉注意力操作的启发,它为融合文本和图像模态的信息提供了一种有原则的方式。
TransNAR:用预训练的基于GNN的NAR增强Transformer
本节描述了我们的混合TransNAR架构
本节描述了我们的混合TransNAR架构(参见图fig:architecture)。TransNAR接受双重输入,包括文本算法问题规范(包含T个标记)及其相应的CLRS-30特定图形表示(包含N个节点),并输出问题的文字回应。我们可以假设,一旦编码,文本输入存储在T(一个 ( T \times k ) 的实数矩阵)中,图形输入存储在G(一个 ( N \times l ) 的实数矩阵)中。注意,为了简化下面的方程,我们假设所有与问题图形版本相关的信息都存储在节点中——这在CLRS-30中通常不是真的(也可能有边和图级别的输入),但这不会改变下面介绍的基本数据流。
TransNAR的前向传递如下展开。首先,我们通过设置 ( \mathbf{T}^{(0)} = \mathbf{T} ) 和 ( \mathbf{G}^{(0)} = \mathbf{G} ) 正确初始化输入。接下来,为了计算步骤 ( (t+1) ) 的表示,文本(标记)表示被送入Transformer的当前层: \(\mathbf{\Theta}^{(t+1)} = \text{FFN}\left(\text{softmax}\left(\frac{(\mathbf{T}^{(t)}\mathbf{Q}_t)^\top\mathbf{T}^{(t)}\mathbf{K}_t}{\sqrt{d_k}}\right)\mathbf{T}^{(t)}\mathbf{V}_t\right)\) 其中 ( \mathbf{Q}_t, \mathbf{K}_t \in \mathbb{R}^{k \times d_k}, \mathbf{V}_t \in \mathbb{R}^{k \times k} ) 分别是键、查询和值转换,而 FFN 是一个前馈网络。类似地,图形表示被送入 NAR 层,实现例如标准 max-MPNN: \(\mathbf{g}^{(t+1)}_u = \phi\left(\mathbf{g}^{(t)}_u, \max_{1 \leq v \leq N}\psi\left(\mathbf{g}^{(t)}_u, \mathbf{g}^{(t)}_v\right)\right)\) 其中 ( \psi, \phi : \mathbb{R}^k \times \mathbb{R}^k \rightarrow \mathbb{R}^k ) 分别是可学习的消息和更新函数,而 max 是逐元素最大聚合。注意方程只提供节点之间的成对交互——实际上,我们的 NAR 是一个 Triplet-GMPNN,也包含三元组交互和一个门控机制。进一步注意,NAR 的可学习部分没有时间步索引——每一步,应用的是共享函数。这与算法计算在图形上的迭代、重复性质很好地对齐。
一旦两个流都准备好了它们的表示 ( \mathbf{\Theta}^{(t+1)} ) 和 ( \mathbf{G}^{(t+1)} ),图中的节点嵌入就调节 Transformer 的标记嵌入,产生 Transformer 流中 TransNAR 块的最终结果,灵感来自 Flamingo: \(\mathbf{T}^{(t+1)} = \text{FFN}\left(\text{softmax}\left(\frac{(\mathbf{\Theta}^{(t)}\mathbf{Q}^\times_t)^\top\mathbf{G}^{(t)}\mathbf{K}^\times_t}{\sqrt{d_k}}\right)\mathbf{G}^{(t)}\mathbf{V}^\times_t\right)\) 其中 ( \mathbf{Q}_t^\times, \mathbf{K}_t^\times \in \mathbb{R}^{k \times d_k}, \mathbf{V}_t^\times \in \mathbb{R}^{k \times k} ) 分别是交叉注意力的键、查询和值转换。在结束这一层之前,不执行 ( \mathbf{G}^{(t+1)} ) 的额外转换。 这个过程一直重复,直到最终的第$N_l$层,当最终的文本输出从${\bf T}^{(N_l)}$中读取出来。最终输出通过最终层产生的预测头转换为标记logits,我们通过标准的下一个标记预测目标来监督。
在TransNAR微调开始之前,我们预先训练NAR以稳健地执行CLRS-30涵盖的三十个算法,类似于generalist。众所周知,这样的程序能够在图形空间中实现高达4倍的输入尺寸的分布外泛化。NAR的参数在微调期间通常保持冻结,因为额外的梯度会消除模型原始的鲁棒性属性。这也是为什么图形嵌入不执行交叉注意力的原因。LLM本身可能在大规模数据集上预先训练过,以建立其一般语言先验,尽管即使LM最初是随机初始化的,我们也恢复了相同的实验结果。
实验
在我们的实验中,我们将展示TransNAR所提供的方案在语言模型架构中的分布外推理方面带来了显著的好处。在这一部分,我们将提供我们实验设置的详细信息。
变换器架构和初始化
我们使用Chinchilla家族的一个仅解码器、6层变换器模型,该模型在MassiveText上进行了预训练。我们特别使用了一个有7000万个参数,上下文大小为2048的模型。为了展示我们的方法无论训练起点如何都适用,我们运行了两个消融变体。在第一个变体中,变换器权重用预训练的结果初始化——模拟了一个微调场景——在第二个变体中,我们使用完全随机的初始化。在我们随后的结果图表中,我们将将这两种设置称为“预训练”和“未训练”。
随机化位置编码
先前的工作强调了变换器中随机位置嵌入的重要性,特别是为了实现更稳健的推理。根据先前对语言模型泛化能力的研究,随机位置嵌入确实在基线和TransNAR上都带来了显著的提升,允许在两者中都展现出更有趣的推理行为。因此,本文中的所有实验都将使用随机位置嵌入。我们在附录中提供了更多细节。
预训练NAR
按照generalist的方法,我们在CLRS-30基准测试的输入问题大小高达16的情况下预训练一个多任务MPNN基础的NAR。由于其图形结构公式,这样的NAR能够显著地进行分布外泛化——有时在大小是4倍的图形上仍然具有竞争力。我们将尝试通过TransNAR使用这样的模型,将这种丰富的表示知识传达给文本。
结合节点和边缘的交叉注意力贡献
通过generalist提出的方法预训练的NAR产生节点和边缘潜在表示,我们对它们都进行交叉注意力,因为它们可能包含有用的互补信息。为了对边缘特征进行交叉注意力,${\bf E}^{(t)}\in\mathbb{R}^{N\times N\times k}$,我们再次应用公式,但需要注意的是,我们需要将${\bf E}$的第一和第二轴展平为一个,以确保维度匹配。我们通过连接将节点和边缘嵌入提供的交叉注意力贡献与预训练的NAR结合起来,然后应用一个线性层。我们尝试使用其他缩减方案,例如将向量求和,或者应用2层MLP。我们还尝试了不同的预处理方案,例如使用Gram‐Schmidt过程正交化贡献,以确保它们在组合之前的代数互补性。然而,这些变化都没有带来比我们原始方法更好的改进。
数据集
我们使用CLRS-Text基准测试,即CLRS-30基准测试的文本版本。表1展示了该数据集的几个样本,以及它们的输入大小和令牌数量。请注意,文本表示直接从基于图形的CLRS-30以确定性方式派生,因此两个数据集传达的信息完全相同。但由于令牌化表示,我们在不超出Chinchilla模型的上下文长度的情况下,可以评估的问题大小有严格的限制。
因此,我们在较小的问题大小上训练我们的算法——4、8和12,并在问题大小10(分布外——插值)、12(分布内)和14(分布外——外推)上进行评估。
值得注意的是,CLRS-Text是语言模型进行长期推理的最具挑战性的任务之一,与当前的评估环境相比——与小学数学相比,复杂性明显提高,主要是因为它允许显式控制分布外泛化。然而,每个任务都有一个明确多项式时间算法描述,意味着它们可以用相对较少的参数来解释——当然比当今典型的大型语言模型要少得多!
数据集包括每个算法每个输入大小的10000个样本,总共240万个数据点,按照上述方式分成70%用于训练,30%用于验证。
训练细节
我们在训练数据上训练所有模型,经过七个epoch,批量大小为256,并使用Adam优化器,学习率为10^-4。我们在基础Chinchilla变换器使用的旋转位置编码(RoPE)之上应用随机位置编码,最大长度为8192。如前所述,对于所有TransNAR模型,我们在训练期间保持NAR冻结。
评估指标
我们不使用精确字符串匹配来计算每个模型的准确性,因为这种指标不能提供对特定数据点失败原因的洞察,而且更重要的是,它不能捕捉给定模型输出接近正确性的程度。相反,我们根据三个指标评估每个模型的性能,这些指标衡量生成文本的能力复杂性逐渐增加:
- 形状得分:一个二进制指标,捕捉输出是否具有正确的形状。例如,如果我们考虑一个排序任务,输出应该与输入具有完全相同的元素数量。同样,如果输出是一个矩阵,我们确保它的形状与输入和任务一致。
- 解析得分:一个二进制指标,捕捉输出是否不包含任何非法字符,例如,再次考虑一个对数字列表进行排序的任务,输出不应包含任何字母表中的字母。
- CLRS得分:输出中与真实答案匹配的元素百分比。这个分数是在CLRS-30中传统使用的,因此得名。注意,如果形状得分为0,我们将自动分配CLRS得分为0,因为输出索引之间没有明确的对应关系。
这些多方面的得分明确设计用来捕捉LLMs在学习对文本输入进行推理时的各种失败模式:它们可能过度专业化于训练问题大小(导致在测试时形状不正确),未能应对看不见的数字组合(导致解析不正确),当然,产生不正确或不一致的输出,由CLRS得分捕获。
结果
我们在图4(CLRS得分)中总结了我们的发现。我们的结果表明,我们的TransNAR在整体上和大多数单独算法上都显著优于基线变换器,无论是在分布内还是分布外。特别是,我们看到我们的方法不仅增强了现有的分布外泛化能力,而且在完全没有这些能力的情况下也引起了这些能力的出现——在图中反映为基线的性能为零或接近零。
形状得分的分析(图5)提供了另一种方式来阐明为什么TransNAR表现得和它一样好。回想一下,首先,如果形状不匹配,CLRS得分必然为零。观察到的形状得分表明,将变换器输出基于NAR嵌入显著增加了变换器将产生正确形状输出的输入比例——表明这是TransNAR帮助缓解的一个非常具体的失败模式。
然而,我们注意到,仍有一些算法TransNAR无法超越基线。仔细查看结果表明,这些任务(二分查找、查找最大子数组、最小值和快速选择)都涉及在输入列表中搜索特定索引的元素。这暗示了一个统一的失败模式:由于这些失败在插值和外推时都持续存在,作为实现的模型无法泛化到训练数据中未见过的新颖索引边界。因此,我们怀疑使用索引提示——正如已经由所证明的——是改善这种行为的有前途的途径。或者,可能是最终NAR计算的隐藏状态更难被交叉注意力层以一种泛化的方式解码,因此可能需要要么给交叉注意力增加额外的容量,或者执行更渐进式的解码:不是让所有交叉注意力层都从最终NAR计算的隐藏状态解码,而是让早期交叉注意力层从早期消息传递步骤来的隐藏状态解码,让后期交叉注意力层从后期消息传递步骤来的隐藏状态解码。
最后,我们在附录中提供了解析得分——省略了主要文本中的得分,因为在大多数情况下,解析可以完全准确完成。
限制
虽然我们的方法在所有我们评估的分布外场景中展示了有利的平均性能,我们强调TransNAR需要访问文本和图形表示输入才能被有效地训练和使用。虽然这限制了TransNAR在特定真实执行器或模拟器(或对其的先前信念)可用的情况下,现在我们知道TransNAR这样的想法是有益的,未来的研究可以使得这样的想法部署到纯单模态变换器中。例如,通过蒸馏训练的TransNAR模型获得的知识到一个普通的变换器模型中,可以取消对第二数据流的需求。
结论
我们提出了一种变换器-NAR混合架构:一种语言模型,它结合了变换器的语言理解能力和预训练的基于图神经网络的神经算法推理器的强大算法推理能力,以解决用自然语言指定的算法任务。我们在CLRS-text基准测试中展示了我们的模型相对于仅变换器模型的优越性,在分布内,更重要的是,在两个分布外的输入问题大小方面。我们希望未来的工作能够借鉴我们在这里分享的结果和见解,并进一步调查感兴趣的扩展,特别是具有更加模糊问题规范的数据集(如现实世界中经常遇到的),以及它们相应的等价求解器就绪的符号输入不是事先给出的。