新型嵌入模型 M3-Embedding,其创新性体现在:
- 多语言支持(Multi-Linguality):超过 100 种语言
- 多功能实现(Multi-Functionality):同时支持稠密检索、多向量检索和稀疏检索
- 多粒度处理(Multi-Granularity):输入支持从短句到长达 8,192 个 token
在训练方法上,我们取得了一系列技术突破:
- 提出创新的自知识蒸馏(self-knowledge distillation)框架,通过整合不同检索功能的相关性评分作为信号来提升训练质量
- 优化批处理策略,实现大批量训练与高吞吐量,从而增强嵌入向量的判别力(discriminativeness)
实验结果表明,M3-Embedding 在多语言检索、跨语言检索及长文档检索等基准测试中均刷新了最优性能记录。
Introduction
现有的文本嵌入方法在通用性方面仍然有限:
- 大多数嵌入模型仅针对英语定制。
- 现有的嵌入模型通常只针对单一检索功能进行训练。典型的信息检索系统需要多种检索方法的复合工作流程。
- 由于训练成本过高,训练一个有竞争力的长文档检索器具有挑战性,大多数嵌入模型只能支持短输入。
为了优化嵌入质量,我们做出了以下技术贡献:
- 提出新颖的自知识蒸馏框架,其中多种检索功能可以联合学习并相互强化。[CLS] 嵌入用于密集检索,而来自其他 token 的嵌入用于稀疏检索和多向量检索。基于集成学习(ensemble learning)原理,这种异构预测器可以组合成一个更强的预测器。因此,我们整合来自不同检索功能的相关性分数作为教师信号,通过知识蒸馏来增强学习过程。
- 优化批处理策略以实现大批量大小和高训练吞吐量,这对嵌入的判别力做出了显著贡献。
- 进行海量且高质量的数据整理。我们的数据集包括三个来源:1)从海量多语言语料库中提取无监督数据,2)整合密切相关的有监督数据,3)合成稀缺的训练数据。这三个数据源相互补充并应用于不同的训练阶段,为通用文本嵌入奠定了坚实的基础。
Related Work
基于嵌入的检索方法中最常见的形式是密集检索,其通过聚合文本编码器的输出(例如通过 [CLS] 或均值池化)来计算嵌入相似度。另一种常见方法是多向量(multi vector)检索,该方法对文本编码器的输出进行细粒度交互以计算嵌入相似度。此外,文本嵌入还可转换为词项权重,从而支持稀疏(sparse)或词汇(lexical)检索。通常,上述检索方法通过不同的嵌入模型实现。据我们所知,目前尚无现有方法能统一所有这些功能。
M3-Embedding
Data Curation
M3-Embedding 需要大规模、多样化的多语言数据集。我们通过三个来源进行综合数据采集:来自未标注语料库的无监督数据、来自标注语料库的微调数据,以及通过合成生成的微调数据。这三种数据源互为补充,分别应用于训练过程的不同阶段。
- 无监督数据:通过从各类多语言语料中提取富含语义的结构(如标题-正文、标题-摘要、指令-输出等)来构建。原始数据经过过滤以去除潜在不良内容和低相关性样本,包含 194 种语言的 12 亿文本对及 2655 个跨语言对应关系。
- 微调数据:从标注语料库中收集了规模相对较小但多样性强、质量高的微调数据。
- 合成数据:为缓解长文档检索任务的数据短缺问题,我们通过合成方式生成额外多语言微调数据。从维基百科、悟道和 mC4 数据集中抽取长篇文章,随机选取其中段落作为基础,利用 GPT-3.5 生成对应问题。最终将生成的问题与原文组合构成新的文本对加入微调数据集。
Hybrid Retrieval
M3-Embedding 统一了嵌入模型的常见检索功能,即密集检索、词汇(稀疏)检索和多向量检索。
Dense retrieval
输入查询 $q$ 通过文本编码器转换为隐藏状态 $\mathbf{H_q}$。我们使用 “[CLS]” 的归一化隐藏状态来表示查询:$e_q = norm(\mathbf{H_q}[0])$。类似地,我们可以得到段落 $p$ 的嵌入表示:$e_p = norm(\mathbf{H_p}[0])$。因此,查询和段落之间的相关性分数通过两个嵌入 $e_q$ 和 $e_p$ 的内积来衡量:$s_{dense} \leftarrow ⟨e_p, e_q⟩$。
Lexical Retrieval
输出的嵌入也用于测算每个词项的重要性以辅助词汇检索。对于查询中的每个词项 $t$,词项权重计算为 $w_{q_t} \leftarrow Relu(\mathbf W^\top_{lex}\mathbf{H_q}[i])$,其中 $\mathbf W_{lex} \in R^{d \times 1}$ 是将隐藏状态映射到浮点数的矩阵。
如果词项 $t$ 在查询中出现多次,我们只保留其最大权重。基于测算的词项权重,查询和段落之间的相关性分数通过查询和段落中共存词项(表示为 $q \cap p$)的联合重要性来计算:$s_{lex} \leftarrow \sum_{t \in q \cap p}(w_{q_t} ∗ w_{p_t})$。
Multi-Vector Retrieval
多向量作为密集检索的扩展,使用整个输出的嵌入来表示查询和段落:$E_q = norm(\mathbf W^\top_{mul}\mathbf{H_q})$,$E_p = norm(\mathbf W^\top_{mul}\mathbf{H_p})$,其中 $\mathbf W_{mul} \in R^{d \times d}$ 是可学习的投影矩阵。
遵循 ColBERT 的方法,我们使用延迟交互来计算细粒度相关性分数:$s_{mul} \leftarrow \frac{1}{N}\sum_{i=1}^{N} \max_{j=1}^{M} E_q[i] \cdot E_p^\top[j]$;其中 $N$ 和 $M$ 分别是查询和段落的长度。
检索过程可以通过混合方式进行。候选结果可以通过每种方法单独检索,最终检索结果基于整合的相关性分数重新排序:
\[s_{rank} \leftarrow w_1 \cdot s_{dense} + w_2 \cdot s_{lex} + w_3 \cdot s_{mul}\tag{1}\]其中 $w_1$、$w_2$ 和 $w_3$ 的值取决于下游场景。
Self-Knowledge Distillation
嵌入模型被训练来区分正样本和负样本。对于每种检索方法,分配给查询的正样本的分数应该比负样本更高。因此,训练过程旨在最小化 InfoNCE 损失,其一般形式由以下损失函数表示:
\[L_{s(\cdot)} = -\log \frac{\exp(s(q, p^*)/\tau)}{\sum_{p \in \{p^*, P'\}} \exp(s(q, p)/\tau)}\]这里,$p^*$ 和 $P’$ 分别表示查询 $q$ 的正样本和负样本;$s(\cdot)$ 是 ${s_{dense}(\cdot), s_{lex}(\cdot), s_{mul}(\cdot)}$ 中的任意一个函数。
不同检索方法的训练目标可能相互冲突。为了促进多种检索功能的优化,我们提出在自知识蒸馏(self-knowledge distillation)的基础上统一训练过程。基于集成学习原理,来自不同检索方法的预测可以整合为更准确的相关性分数(得益于异构性质)。在最简单的形式中,可以通过不同预测分数的加权和进行整合:
\[s_{inter} \leftarrow w_1 \cdot s_{dense} + w_2 \cdot s_{lex} + w_3 \cdot s_{mul}\tag{1}\]然后我们计算 $L_{dense}$、$L_{lex}$、$L_{mul}$ 和 $L_{inter}$ 的加权和作为无自知识蒸馏的损失:
\[L \leftarrow \frac{\lambda_1 \cdot L_{dense} + \lambda_2 \cdot L_{lex} + \lambda_3 \cdot L_{mul} + L_{inter}}{4} \tag{4}\]在先前的研究中,嵌入模型的训练质量可以从知识蒸馏中受益,它利用了来自另一个排序模型的细粒度软标签。在这里,我们简单地使用整合分数 $s_{inter}$ 作为教师,其中每个检索方法的损失函数被修改为:
\[L'_{*} \leftarrow -p(s_{inter}) \cdot \log p(s_{*}) \tag{5}\]这里,$p(\cdot)$ 是 softmax 激活函数;$s_{*}$ 是 $s_{dense}$、$s_{lex}$ 和 $s_{mul}$ 中的任意一个成员。我们进一步整合并归一化修改后的损失函数:
\[L' \leftarrow (\lambda_1 \cdot L'_{dense} + \lambda_2 \cdot L'_{lex} + \lambda_3 \cdot L'_{mul})/3 \tag{6}\]最后,我们通过 $L$ 和 $L’$ 的线性组合推导出自知识蒸馏的最终损失函数:
\[L_{final} \leftarrow (L + L')/2 \tag{7}\]训练过程构成一个多阶段工作流程(如下图)。首先,文本编码器使用海量无监督数据进行预训练,其中仅以对比学习的基本形式训练密集检索。自知识蒸馏应用于第二阶段,其中嵌入模型被微调以建立三种检索功能。$\mathbf W_{lex}$ 的随机初始化会导致训练初期 $s_{lex}$ 准确率较低 $L_{lex}$ 较高。为了减少这种影响,我们在训练过程中设置 $w_1 = 1$、$w_2 = 0.3$、$w_3 = 1$、$\lambda_1 = 1$、$\lambda_2 = 0.1$ 和 $\lambda_3 = 1$。在此阶段使用标注数据和合成数据,其中按照 ANCE 方法为每个查询引入困难负样本。
Efficient Batch
嵌入模型需要从多样化和海量的多语言数据中学习,以充分捕获不同语言的通用语义。同时,它还需要保持尽可能大的批处理大小(引入大量批内负样本)以确保文本嵌入的判别力。考虑到GPU内存和计算能力的限制,人们通常将输入数据截断为短序列以实现高训练吞吐量和大批处理大小。但这种做法对 M3-Embedding 来说并不可行,因为它需要从短序列和长序列数据中学习,以有效处理不同粒度的输入。
在我们的工作中,我们通过优化批处理策略来提高训练效率,从而实现高训练吞吐量和大批处理大小:
-
按序列长度分组:训练数据按序列长度进行预处理分组。生成小批次时,从同一组中采样训练实例。由于序列长度相似,这显著减少了序列填充,GPU 得到更有效利用。
-
固定随机种子:为不同 GPU 采样训练数据时,随机种子始终保持固定,这确保了负载平衡并最小化每个训练步骤的等待时间。
-
子批次处理:处理长序列训练数据时,小批次进一步分为子批次,这减少了内存占用。我们使用梯度检查点迭代编码每个子批次并收集所有生成的嵌入。这种方法可以显著增加批处理大小。例如,当处理长度为 8192 的文本时,批处理大小可以增加 20 倍以上。
-
分布式嵌入广播:广播来自不同 GPU 的嵌入,允许每个设备在分布式环境中获得所有嵌入,这显著扩展了批内负样本的规模。
对于计算或数据资源严重受限的用户,我们提出了一个更简单的方法 MCLS(Multi-CLS),它简单地在推理期间向长文档插入多个 CLS token,并将所有 CLS 嵌入的平均值作为文档的最终嵌入。尽管简单,但在实践中却出奇地有效。具体而言,我们每隔固定数量的 token 插入一个CLS token(在我们的实验中,每 256 个token插入一个”[CLS]”),每个 CLS token 可以捕获其邻近token的语义信息。最终,通过平均所有 CLS token 的最后隐藏状态来获得最终的文本嵌入。