当前位置: 首页 > news >正文

GritLM:用一个 LLM 既做 embedding 又做生成

问题背景

embedding 模型和生成模型一直是两条独立路线。BERT 类双向 encoder 适合做表示,decoder-only LLM 适合做生成,把 LLM 直接拿来取 hidden state 作为 embedding 一般效果不好。论文给的对照是 Llama 2 70B 用 weighted-mean pooling 在 MTEB 上只拿到 35.6,而 BGE Large 0.34B 是 64.2,参数差两个数量级却被反超。

反方向同样不通。把 embedding 模型 fine-tune 之后再加回 LM head 做生成,论文实测 Emb.-only 7B 在 MMLU 上是 23.5(随机基线 25.0),意味着 embedding 训练把生成能力彻底洗掉了。

实际部署里这导致 RAG 系统要同时持有两个模型:embedding 模型编码 query 和文档做检索,生成模型再读 query + 召回文档生成回答。query 和文档都要在两个模型上各跑一遍,共 4 次 forward pass。GritLM 想消除的就是这种割裂。

GRIT 方法

一个预训练 LLM,两份训练数据,两套 forward 路径,两个损失加权求和。

embedding 流。输入格式<s><|user|>{instruction}<|embed|>{sample},注意力切成 bidirectional,对最后一层 hidden state 做 mean pooling(仅对 sample 部分平均,instruction 和格式 token 不计入但通过 self-attention 影响表示)。损失是 in-batch negatives 的 InfoNCE:

是 cosine similarity, 是温度。

生成流。输入格式<s><|user|>{instruction}<|assistant|>{response}</s>,注意力保持 causal,过 LM head 做 next-token 预测,损失只在 response 部分计算。

总损失:

<|embed|>这个特殊 token 是关键开关。E5 数据集的 instruction 没有固定前缀,模型必须靠这个 token 知道当前样本走对比损失而不是 LM 损失。

GritLM 7B 用 Mistral 7B 初始化,embedding batch 2048,生成 batch 256,训练 1253 步(embedding 1.36 epoch,生成 1 epoch)。GritLM 8x7B 因算力受限,embedding batch 降到 256。

为什么两个目标能共存

直觉上对比损失逼模型学统一的 sentence-level 表示,LM 损失逼模型学逐 token 的条件概率,两者目标不一致,混训理应互相伤害。论文的实测结果是 GritLM 7B 在 MTEB 上 66.8,单独训 embedding 的 Emb.-only 也是 66.8;GritLM 7B 生成平均 55.5,单独训生成的 Gen.-only 是 55.2。两个目标几乎不打架。

论文给的解释是这两类任务都要求模型深度理解自然语言,差别只在表达方式。论文进一步推测模型内部可能存在"少量参数充当开关",让最终表示要么适合 mean pooling 后做 embedding,要么适合喂给 LM head 做生成,但论文将这一点明确标注为推测,未做定位实验。

值得注意的是 MEDI2 数据集下,加生成目标后 embedding 性能反而比 embedding-only 更高,但这一现象在切换到 E5 数据集后消失,两者持平。

实验结果

MTEB(embedding)。GritLM 7B 在 56 个数据集平均 66.8,超过 E5 Mistral 7B 的 66.6 和 BGE Large 的 64.2,是当时开源模型 SOTA。GritLM 8x7B 是 65.7,比 7B 略低,论文归因于 embedding batch 从 2048 砍到 256(算力受限)。

生成任务。GritLM 7B 在 MMLU/GSM8K/BBH/TyDi QA/HumanEval/AlpacaEval 六项平均 55.5,超过 Tülu 2 7B 的 46.3 和 Mistral 7B Instruct 的 44.1,已经能压过 Llama 2 70B 的 46.4。GritLM 8x7B 平均 65.7,在论文对比的开源生成模型里最高,超过 Mixtral 8x7B Instruct 的 60.3 和 Tülu 2 70B 的 65.1。

reranking。GritLM 既能当 bi-encoder 又能当 cross-encoder。论文用 Sun et al. 的 permutation generation prompt,让生成能力对 top-10 召回结果重排,MTEB 检索平均从 57.4 提到 57.9,16 个检索数据集中 15 个有提升。

RAG 缓存加速。这是论文的卖点之一。传统 RAG 用两个独立模型:embedding 模型编码 query 做检索,生成模型读"query + 召回文档"出回答,query 和 doc 在两个模型上各跑一次,共 4 次 forward。GritLM 检索和生成同一组权重,embedding 阶段算出的 transformer 内部 KV states 可以直接喂给生成阶段,省掉重复 forward。这是不同模型无法做的,因为 KV 是模型内部表示,跨模型不通用。

论文给出三种缓存策略:

Query Caching:embedding 阶段算 query 的 KV 时顺手缓存,生成阶段不再重新 forward query

Doc Caching:建索引时不仅存 doc embedding,还把每篇 doc 的 KV 一起存进索引,命中后直接喂给生成阶段

Query-Doc / Doc-Query Caching:两者都缓存,但因为 query 和 doc 各自缓存时没机会互相 attend,会偏离原始 RAG 的语义

在 Natural Questions 上,sample A(query 1 token,doc 4000 tokens)CPU 上 Doc Caching 5.25s vs 传统 RAG 14.18s,提速 63%;sample B(doc 1 token,query 4000 tokens)CPU 上 Query Caching 6.87s vs RAG 14.88s,提速 54%。GPU 上提速幅度小一些(30% 量级),论文解释是 GPU 本来就并行处理整个序列,缓存收益相对小。

但 Query Caching 改变了 query 的 attention 模式(embedding 时是双向,生成时模型期望 causal),论文实测 match 分数 Query Caching 从 RAG 的 30.50 掉到 25.46。Doc Caching 反而微涨到 33.38,论文的解释是文档不需要被像 query 那样彻底理解,"略微损坏"的 KV 状态对生成质量影响不大。Query-Doc 和 Doc-Query Caching 因有双重 attention 错位,分数掉到 21.63 和 18.39,接近 No RAG 的 21.00。论文的 takeaway:Query-Doc Caching 实用性受限,单边缓存才是性价比高的选择。

关键消融

注意力。把 causal LLM 在 fine-tune 阶段切成 bidirectional 后再做 mean pooling,embedding 涨了 1.8(causal+wmean 60.0 → bidirectional+mean 61.8),论文确认了"causal LLM 拿来做 embedding 应该改成双向"这一结论。改成 PrefixLM(instruction 双向 + response causal)反而掉分。

初始化。Mistral 7B > Llama 2 7B > GPT-J 6B 在 embedding 和生成上都成立。论文一个有意思的发现:预训练后直接测 embedding,GPT-J 比 Mistral 强;但 fine-tune 后 Mistral 反超。结论是 pretrained embedding 能力不能预测 fine-tuned embedding 能力,pretrained 生成能力反而更靠谱。

embedding 数据集。E5 (66.0) > MEDI2 (64.7) > MEDI (64.0)。论文将 E5 的优势归为 GPT-4 生成的 hard negative 质量更高、任务多样性更好。

生成损失粒度。token level 还是 sample level 直接影响生成长度,进而影响 AlpacaEval(已知偏好长回答)。论文最终用 mix:32 个样本内 token level,再 8 个 sub-batch 间 sample level。这个 mix 在 AlpacaEval 上是 74.7,纯 sample-level 只有 67.6,差 7 分,对应生成中位长度 941 → 865。

in-batch negative 来源。让 negative 全部来自同一数据集 vs 任意数据集,平均分一样(66.0),但 Retrieval 子集涨 1.3。论文将原因归为同数据集内的 negative 区分难度更高,逼模型学更细的差异。

embedding batch size。从 256 涨到 4096,embedding 平均 +1.0,主要来自 15 个 retrieval 数据集,生成性能不变。

精度。整体 BF16 mixed precision 即可,但 pooling 和相似度计算必须 cast 到 FP32,否则 embedding 性能略有下降。论文未对此给出更深的理论解释,只是经验性地建议这样做。

few-shot embedding 不 work。在 instruction 后面加一个示例,整体性能下降。即使在 MEDI2 训练里塞了 5% 的 few-shot 样本,模型也没学会用。论文将这一现象简单归为"模型似乎没学会利用 few-shot 示例",未做更深分析。

一些易被忽视的实现细节

asymmetric 任务用 one-sided instruction。E5 数据集对 retrieval 类任务只给 query 加 instruction,文档不加。这样文档只需编码一次就能跨任务复用,缓存友好。symmetric 任务训练时也是单边,但评估时按双边格式喂给模型,论文说这是合理的,因为 cosine similarity 的传递性保证了 A↔B↔C 仍然成立。

KV 缓存的存储代价。对 2,681,468 个文档用 7B 模型,KV states 总量约 30TB。论文指出这部分可以完全 offload 到磁盘,按需加载,每个样本约 12.5MB。原始 index 只是 43GB,KV cache 比 index 大三个数量级。

KTO 对齐 trade-off。KTO(Kahneman-Tversky Optimization)是一种偏好对齐方法,相比 DPO 不需要成对偏好数据,只需要每条样本的"好/坏"二元标签。论文在 GRIT 之后追加了一段 KTO 阶段,UltraFeedback 二元化数据,只训生成不训 embedding。结果 MTEB 从 66.8 微跌到 66.7,AlpacaEval 涨超过 10 分。意味着对齐阶段会缓慢侵蚀 embedding 能力,需要继续维持 embedding 训练才能保住。

的设定。论文坚持 ,理由是模型已经预训练过 LM 损失,对比损失是新东西需要更多学习。实际 embedding 损失下降很快,到训练后期两个损失都稳定在 1.0 附近,初始的权重差异被自动平滑掉。

embedding head 的取舍。可选加一个 4096→1024 的下投影线性层,存储省 4 倍但 embedding 平均掉 1 分。GritLM 最终未采用,留给下游用 PCA 等后处理压维。

学AI大模型的正确顺序,千万不要搞错了

🤔2026年AI风口已来!各行各业的AI渗透肉眼可见,超多公司要么转型做AI相关产品,要么高薪挖AI技术人才,机遇直接摆在眼前!

有往AI方向发展,或者本身有后端编程基础的朋友,直接冲AI大模型应用开发转岗超合适!

就算暂时不打算转岗,了解大模型、RAG、Prompt、Agent这些热门概念,能上手做简单项目,也绝对是求职加分王🔋

📝给大家整理了超全最新的AI大模型应用开发学习清单和资料,手把手帮你快速入门!👇👇

学习路线:

✅大模型基础认知—大模型核心原理、发展历程、主流模型(GPT、文心一言等)特点解析
✅核心技术模块—RAG检索增强生成、Prompt工程实战、Agent智能体开发逻辑
✅开发基础能力—Python进阶、API接口调用、大模型开发框架(LangChain等)实操
✅应用场景开发—智能问答系统、企业知识库、AIGC内容生成工具、行业定制化大模型应用
✅项目落地流程—需求拆解、技术选型、模型调优、测试上线、运维迭代
✅面试求职冲刺—岗位JD解析、简历AI项目包装、高频面试题汇总、模拟面经

以上6大模块,看似清晰好上手,实则每个部分都有扎实的核心内容需要吃透!

我把大模型的学习全流程已经整理📚好了!抓住AI时代风口,轻松解锁职业新可能,希望大家都能把握机遇,实现薪资/职业跃迁~

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

http://www.gsyq.cn/news/1490496.html

相关文章:

  • 2026年6月目前优秀的不锈钢板现货厂家推荐,不锈钢板定制厂家,质量上乘,品质有保障的钢板 - 品牌推荐师
  • 超越QFIL GUI:命令行dump高通设备eMMC全分区的实战与参数详解
  • 告别卡顿!手把手教你将TUM RGBD的tgz包转成30Hz流畅ROS Bag(附Python脚本)
  • 从原理图到数据:手把手教你用STM32同时读取多个DS18B20的温度
  • 智谱清言粘贴到 word 格式混乱难题破解,AI 导出鸭实现版式精准还原与稳定输出
  • 2026年小型熔炼机专业品牌TOP5排行:立式淬火机/立柱移动式伺服数控淬火机床/贵金属熔炼小型熔炼机/贵金属熔炼柜式熔金机/选择指南 - 优质品牌商家
  • 别再只会用AT指令了!用HC-05蓝牙模块和安卓手机,做个无线控制小项目(附完整代码)
  • 别再买错卡了!Arduino+RC522复制门禁卡前,你必须知道的M1卡、UID卡区别与避坑指南
  • 不止于安装:深入理解Horizon连接服务器与CA证书的信任链(附配置清单)
  • 跳出熬夜写稿怪圈:在 paperxie 毕业论文 AI 写作里,找到学术创作的全新解题思路
  • Parasolid核心函数PK_TOPOL_facet深度解析:几何匹配、拓扑匹配、修剪匹配到底怎么选?
  • 人生“地震”来临时,你的反应决定了你的结局
  • 别再一个个改文件权限了!一键配置阿里云OSS存储桶公共读,并理解其安全边界
  • 2026年5月YBP德国意普产品符合欧标吗,poloplast/YBP德国意普/普立曼,YBP德国意普售后保障怎么样 - 品牌推荐师
  • TestDisk与PhotoRec:免费开源的数据恢复终极指南,拯救丢失的分区和文件
  • 第六周. nginx实践
  • 织带原料多维度评测:远动袜专用尼龙纱线、锦纶DTY、锦纶染色丝、锦纶色纺丝、70D140D锦纶高弹丝、仿锦纶、尼龙彩色高弹丝选择指南 - 优质品牌商家
  • 2026洪泽湖大闸蟹选购评测:大闸蟹礼券/大闸蟹礼品卡/大闸蟹礼盒/大闸蟹自助/大闸蟹蟹卡/湖蟹/红膏大闸蟹/苏州蟹黄面/选择指南 - 优质品牌商家
  • 2026年保定公考品牌排行:石家庄申论教学/石家庄考公培训品牌/石家庄考公机构/邢台公考品牌/邢台考公基地/邢台考公机构/选择指南 - 优质品牌商家
  • 【Redis分布式缓存实战】第19章 多级缓存架构设计实战
  • 用手机App Inventor 2做个蓝牙遥控器,5分钟控制你的Arduino LED灯(HC-42模块实战)
  • 斯坦福评测第一!北大 EvoPhys-World世界模型在摩尔线程GPU完成原生训练
  • 别再到处找破解版了!用这个免费在线工具draw.io,5分钟画出高颜值技术架构图
  • 别再只学攻击了!用Kali Linux的arpspoof工具,手把手教你搭建ARP欺骗防御测试环境
  • 2026年口碑好的南通二手房家装改造公司/南通本地家装设计公司业主好评榜 - 品牌宣传支持者
  • aixingpan.cn API开发文档:api_docs_authentication接口指南
  • 别再死记硬背公式了!用Python+NumPy手把手模拟MIMO信道,直观理解空分复用
  • 告别迷茫:用C++从零手搓一个Echo Server(附完整代码与nc测试)
  • EoM:用哈耶克的市场经济理论开发智能体,效果惊人
  • 都2026年了!想入行网络安全却不知道从哪开始?