当前位置: 首页 > news >正文

大模型微调--MoELora

文章目录

      • MOELoRA 的核心组件
      • MOE 在多任务学习中的作用
      • LoRA 在参数高效微调中的贡献
      • MOELoRA 的协同工作机制

https://arxiv.org/pdf/2310.18339
When MOE Meets LLMs: Parameter Efficient Fine-tuning for Multi-task Medical Applications


MOELoRA 的核心组件

MOELoRA 的核心思想建立在两个关键技术上:混合专家系统(MOE)和低秩自适应(LoRA)。MOE 负责处理多任务学习中的任务分配和专家协作,LoRA 则专注于参数高效的模型微调。

MOE 在多任务学习中的作用

MOE 结构通过动态路由机制将输入数据分配给不同的专家模块,每个专家专注于特定任务或数据子集。这种设计允许模型在不显著增加参数量的情况下,灵活处理多任务场景。MOE 的优势在于其能够根据任务复杂度自动调整专家资源的分配,提升模型在有限数据和计算资源下的表现。

LoRA 在参数高效微调中的贡献

LoRA 通过低秩矩阵分解技术,在预训练模型的基础上引入少量可训练参数,大幅降低微调阶段的资源消耗。具体实现中,LoRA 将权重更新 ΔW 分解为两个低秩矩阵的乘积(例如 ΔW = BA,其中 B 和 A 的秩远小于原权重矩阵)。这种方法既保留了预训练模型的知识,又实现了高效的任务适配。

MOELoRA 的协同工作机制

MOELoRA 将 MOE 的任务分配能力与 LoRA 的参数效率结合,形成分层优化结构。MOE 层负责识别任务类型并激活对应的专家模块,每个专家内部采用 LoRA 进行微调。这种设计既避免了多任务间的干扰,又通过共享基础模型参数减少了冗余。


https://github.com/liuqidong07/MOELoRA-peft/blob/master/src/MLoRA/peft/tuners/mmoelora.py

classMMOELoraLayer(LoraLayer):def__init__(self,in_features:int,out_features:int,expert_num:int):super().__init__(in_features,out_features)self.expert_num=expert_numdefupdate_layer(self,adapter_name,r,lora_alpha,lora_dropout,init_lora_weights):self.r[adapter_name]=r self.lora_alpha[adapter_name]=lora_alphaiflora_dropout>0.0:lora_dropout_layer=nn.Dropout(p=lora_dropout)else:lora_dropout_layer=nn.Identity()self.lora_dropout.update(nn.ModuleDict({adapter_name:lora_dropout_layer}))# Actual trainable parametersifr>0:self.lora_A.update(nn.ModuleDict({adapter_name:MMOELinearA(self.in_features,r,self.expert_num)}))self.lora_B.update(nn.ModuleDict({adapter_name:MMOELinearB(r,self.out_features,self.expert_num)}))self.scaling[adapter_name]=lora_alpha/rifinit_lora_weights:self.reset_lora_parameters(adapter_name)self.to(self.weight.device)defreset_lora_parameters(self,adapter_name):ifadapter_nameinself.lora_A.keys():# initialize A the same way as the default for nn.Linear and B to zeroforiinrange(self.expert_num):nn.init.normal_(self.lora_A[adapter_name].loraA[i].mlp.weight,mean=0.0,std=0.01)nn.init.zeros_(self.lora_B[adapter_name].loraB[i].mlp.weight)classMMOELoraLinear(nn.Linear,MMOELoraLayer):# Lora implemented in a dense layer# nn.Linear is the pretrained weights in LLM, MMOELoraLayer is the designed trainable Loradef__init__(self,adapter_name:str,in_features:int,out_features:int,r:int=0,lora_alpha:int=1,lora_dropout:float=0.0,fan_in_fan_out:bool=False,# Set this to True if the layer to replace stores weight like (fan_in, fan_out)**kwargs,):init_lora_weights=kwargs.pop("init_lora_weights",True)self.expert_num=kwargs.pop("expert_num",True)self.task_num=kwargs.pop("task_num",True)self.te_dim=kwargs.pop("task_embedding_dim",True)nn.Linear.__init__(self,in_features,out_features,**kwargs)MMOELoraLayer.__init__(self,in_features=in_features,out_features=out_features,expert_num=self.expert_num)# init the Gate networkself.lora_task_embedding=nn.ModuleDict({})self.lora_gate=nn.ModuleDict({})self.lora_task_embedding.update(nn.ModuleDict({adapter_name:nn.Embedding(self.task_num+1,self.te_dim)}))self.lora_gate.update(nn.ModuleDict({adapter_name:Gate(self.te_dim,self.expert_num)}))# Freezing the pre-trained weight matrixself.weight.requires_grad=Falseself.fan_in_fan_out=fan_in_fan_outiffan_in_fan_out:self.weight.data=self.weight.data.T nn.Linear.reset_parameters(self)self.update_layer(adapter_name,r,lora_alpha,lora_dropout,init_lora_weights)self.active_adapter=adapter_namedefmerge(self,task_id):ifself.active_adapternotinself.lora_A.keys():returnifself.merged:warnings.warn("Already merged. Nothing to do.")returnifself.r[self.active_adapter]>0:expert_weight=self.lora_gate[self.active_adapter](self.lora_task_embedding[self.active_adapter](task_id))foriinrange(self.expert_num):lora_A_weights=self.lora_A[self.active_adapter].loraA[i].mlp.weight lora_B_weights=self.lora_B[self.active_adapter].loraB[i].mlp.weight self.weight.data+=(transpose(lora_B_weights @ lora_A_weights,self.fan_in_fan_out,)*self.scaling[self.active_adapter]*expert_weight[...,i])self.merged=Truedefunmerge(self,task_id):ifself.active_adapternotinself.lora_A.keys():returnifnotself.merged:warnings.warn("Already unmerged. Nothing to do.")returnifself.r[self.active_adapter]>0:expert_weight=self.lora_gate[self.active_adapter](self.lora_task_embedding[self.active_adapter](task_id))foriinrange(self.expert_num):lora_A_weights=self.lora_A[self.active_adapter].loraA[i].mlp.weight lora_B_weights=self.lora_B[self.active_adapter].loraB[i].mlp.weight self.weight.data-=(transpose(lora_B_weights @ lora_A_weights,self.fan_in_fan_out,)*self.scaling[self.active_adapter]*expert_weight[...,i])self.merged=Falsedefforward(self,x:torch.Tensor,**kwargs):task_id=kwargs["task_id"]previous_dtype=x.dtypeifself.active_adapternotinself.lora_A.keys():# No adapter, directly use linearreturnF.linear(x,transpose(self.weight,self.fan_in_fan_out),bias=self.bias)ifself.disable_adapters:# No adapterifself.r[self.active_adapter]>0andself.merged:# merge the adapter to linearself.unmerge(task_id)result=F.linear(x,transpose(self.weight,self.fan_in_fan_out),bias=self.bias)elifself.r[self.active_adapter]>0andnotself.merged:# general lora processresult=F.linear(x,transpose(self.weight,self.fan_in_fan_out),bias=self.bias)x=x.to(self.lora_A[self.active_adapter].loraA[0].weight.dtype)expert_weight=self.lora_gate[self.active_adapter](self.lora_task_embedding[self.active_adapter](task_id))foriinrange(self.expert_num):result+=(# lora processself.lora_B[self.active_adapter].loraB[i](self.lora_A[self.active_adapter].loraA[i](self.lora_dropout[self.active_adapter](x)),)*self.scaling[self.active_adapter]*expert_weight[...,i].unsqueeze(-1).unsqueeze(0))else:result=F.linear(x,transpose(self.weight,self.fan_in_fan_out),bias=self.bias)result=result.to(previous_dtype)returnresultclassMMOELinearA(nn.Module):'''MMOE based LoRA block'''def__init__(self,in_features,out_features,expert_num)->None:super().__init__()self.expert_num=expert_num self.in_features,self.out_features=in_features,out_features self.loraA=nn.ModuleList([])assertself.out_features%self.expert_num==0# lora rank should be divided by expert numberself.r=self.out_features//self.expert_numfor_inrange(self.expert_num):self.loraA.append(Expert(self.in_features,self.r))defforward(self,x):'''input x is a vector, return output is a list'''outputs=[]foriinrange(self.expert_num):outputs.append(self.loraA[i](x))returnoutputsclassMMOELinearB(nn.Module):'''MMOE based LoRA block'''def__init__(self,in_features,out_features,expert_num)->None:super().__init__()self.expert_num=expert_num self.in_features,self.out_features=in_features,out_features self.loraB=nn.ModuleList([])assertself.in_features%self.expert_num==0self.r=self.in_features//self.expert_numfor_inrange(self.expert_num):self.loraB.append(Expert(self.r,self.out_features))defforward(self,x):'''input x is a list, return output is also a list'''outputs=[]foriinrange(self.expert_num):outputs.append(self.loraB[i](x[i]))returnoutputsclassExpert(nn.Module):def__init__(self,in_features,out_features):super().__init__()self.in_features,self.out_features=in_features,out_features self.mlp=nn.Linear(self.in_features,self.out_features,bias=False)self.weight=self.mlp.weightdefforward(self,x):# LoRA A or B blocky=self.mlp(x)returnyclassGate(nn.Module):def__init__(self,input_size,expert_num):super().__init__()# 使用embedding来代替线性层self.GateL=nn.Linear(input_size,expert_num,bias=False)self.act=nn.Softmax(dim=1)# 第0维为batch sizedefforward(self,x):y=self.GateL(x)y=self.act(y)returny
http://www.gsyq.cn/news/134408.html

相关文章:

  • Open-AutoGLM输入法无法响应?5分钟快速诊断与恢复流程曝光
  • 2025年年终深圳家电搬运公司推荐:专业排行解析与多维度服务对比指南 - 十大品牌推荐
  • LangFlow能否支持增量更新?部分节点重新执行机制
  • 别再被重复文本困扰!Open-AutoGLM输入清洗的7个关键步骤(独家实战经验)
  • LangFlow是否提供权限管理系统?多用户访问控制现状
  • Open-AutoGLM字符编码崩溃怎么办?资深架构师教你快速定位并修复
  • 仅限内部流传的Open-AutoGLM调试秘技:触控无响应的7个隐藏原因(首次公开)
  • LangFlow工作流导出为API接口的操作步骤详解
  • PHP网络/磁盘 I/O 远慢于 CPU的庖丁解牛
  • $urls = array_chunk($urls, ceil(count($urls)/$workers));的庖丁解牛
  • 2025年年终深圳家电搬运公司推荐:实力榜单TOP5与全方位服务对比评测 - 十大品牌推荐
  • Laravel 中 Http::get() 默认同步,切勿在循环中直接使用!
  • LangFlow与Google Docs联动编辑AI生成内容实测
  • 2025年年终济南家电搬运公司推荐:深度评测报告与关键指标对比分析 - 十大品牌推荐
  • Open-AutoGLM输入法频繁崩溃?3步精准定位并修复切换异常
  • sam9x60 tcp协议栈 小记
  • 【Open-AutoGLM输入法异常处理指南】:99%开发者忽略的5大切换故障根源揭秘
  • LangFlow中的条件分支节点如何配置?逻辑控制进阶教学
  • LangFlow在高校教学中的应用前景:AI课程实验平台搭建
  • 2025年高性价比短视频代运营公司排行榜,专业服务商推荐 - 工业推荐榜
  • 毕业设计项目 python小游戏设计 吃豆人小游戏
  • 2025年机油供应商靠谱推荐,口碑好的汽轮机油机油源头厂家有哪些? - myqiye
  • LangFlow支持哪些LangChain模块?兼容性与扩展性测试报告
  • LangFlow未来发展方向预测:是否会成为标准开发工具?
  • 2025年年终成都管道疏通推荐:专业评测、用户评价与排名指南 - 十大品牌推荐
  • Open-AutoGLM长按功能卡顿问题全解析(一线工程师实战经验曝光)
  • LangFlow中的变量传递机制详解:上下文共享原理
  • 2025年五大常州泽尔达机械同行对比排行榜,常州泽尔达机械的节能效果如何? - mypinpai
  • 揭秘Open-AutoGLM滑动无响应之谜:5个关键修复方案立即生效
  • LangFlow与向量数据库(如Pinecone)集成实战教程