LLaMA-Factory 微调大模型教程,AMD 环境也能轻松搞定
环境搭建与核心配置
在 AMD GPU 上进行大模型微调,最让人头疼的往往不是算法本身,而是环境配置的“坑”。对于科研人员和学生党来说,时间宝贵,我们直接切入正题,如何在 ROCm 环境下丝滑安装并配置 LLaMA-Factory。
首先,确保你的系统已经安装了适配当前显卡型号的 ROCm 驱动(建议 5.6 及以上版本)。不要试图混用 CUDA 版本的 PyTorch,这是很多报错的根源。我们需要创建一个干净的 Conda 环境:
conda create-nllama-rocmpython=3.10conda activate llama-rocm接下来是重头戏:安装支持 ROCm 的 PyTorch。请务必去 PyTorch 官网查找对应 ROCm 版本的安装命令,通常长这样:
pipinstalltorch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0安装完成后,验证一下torch.cuda.is_available()是否返回 True(在 ROCm 中它依然识别为 cuda 接口,但底层走的是 HIP)。确认无误后,我们可以安装 LLaMA-Factory。为了获得最新的 ROCm 适配修复,建议直接从源码安装:
gitclone https://github.com/hiyouga/LLaMA-Factory.gitcdLLaMA-Factory pipinstall-e".[torch,metrics]"关键点来了:很多教程到这里就结束了,但在 AMD 卡上跑不通,多半是因为训练引擎后端没选对。LLaMA-Factory 默认可能尝试调用 DeepSpeed 的 CUDA 版本,这在纯 ROCm 环境下会直接崩溃。你需要在启动脚本或配置文件中显式指定使用 PyTorch Native 的分布式策略,或者确保安装的 DeepSpeed 是编译过 ROCm 支持的版本。对于大多数单卡或多卡微调场景,直接使用torchrun配合 PyTorch FSDP 是最稳妥的选择。
在配置文件train_config.yaml中,务必检查以下字段:
compute_type:"fp16"# 或 "bf16",取决于你的显卡架构支持backend:"pytorch"# 强制指定后端,避免自动探测失败LoRA 微调实战:从数据到训练
环境就绪后,我们来实际跑通一个 Llama 3 模型的 LoRA 微调流程。LoRA(Low-Rank Adaptation)非常适合资源有限的场景,它能用极少的显存占用实现不错的效果。
数据集预处理与 Tokenizer 兼容
假设我们有一份自定义的指令微调数据data.json,格式为标准的多轮对话。在 ROCm 环境下,Tokenizer 的加载偶尔会出现编码问题,特别是涉及特殊字符时。建议在预处理阶段增加一步兼容性检查:
fromtransformersimportAutoTokenizerimportjson model_name="meta-llama/Meta-Llama-3-8B-Instruct"tokenizer=AutoTokenizer.from_pretrained(model_name,trust_remote_code=True)# 简单的兼容性测试test_text="测试中文编码与特殊符号 @#$%"encoded=tokenizer.encode(test_text)decoded=tokenizer.decode(encoded)print(f"原始:{test_text}")print(f"还原:{decoded}")asserttest_text==decoded,"Tokenizer 编解码不一致,请检查版本!"# 预处理函数defpreprocess_data(data_path):withopen(data_path,'r',encoding='utf-8')asf:data=json.load(f)processed=[]foritemindata:# 构建 Llama 3 特有的对话模板messages=item["messages"]text=tokenizer.apply_chat_template(messages,tokenize=False)processed.append({"text":text})returnprocessed processed_data=preprocess_data("data.json")# 保存为 LLaMA-Factory 识别的格式withopen("dataset/processed_data.json","w",encoding="utf-8")asf:json.dump(processed_data,f,ensure_ascii=False)这段代码不仅完成了格式转换,还顺带验证了 Tokenizer 在当前环境下的稳定性。如果断言失败,通常意味着需要更新transformers库或手动调整 Chat Template。
启动训练与参数配置
准备好数据后,我们编写具体的训练配置。针对 AMD 显卡,我整理了一份经过验证的配置模板,你可以直接复用:
### modelmodel_name_or_path:meta-llama/Meta-Llama-3-8B-Instructtrust_remote_code:true### methodstage:sftdo_train:truefinetuning_type:loralora_target:alllora_rank:16lora_alpha:32lora_dropout:0.1### datasetdataset:processed_datatemplate:llama3cutoff_len:2048max_samples:1000preprocessing_num_workers:4### trainoutput_dir:saves/llama3-lora-rocmper_device_train_batch_size:2gradient_accumulation_steps:4learning_rate:1.0e-4num_train_epochs:3.0lr_scheduler_type:cosinewarmup_ratio:0.1fp16:true# 如果显卡支持 bf16,建议改为 bf16: trueddp_timeout:180000000### evalval_size:0.1per_device_eval_batch_size:2eval_strategy:stepseval_steps:100启动训练的命令非常标准:
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml注意,如果在运行过程中遇到RuntimeError: CUDA error之类的提示,别慌,在 ROCm 环境下这通常只是报错信息的惯性显示,只要没有伴随HIP initialization failed或显存溢出,一般可以忽略。但如果程序直接崩溃,请检查是否开启了flash_attention,部分旧版 ROCm 对 Flash Attention 2 的支持尚不完善,建议在配置中暂时关闭它:
disable_flash_attn:true监控优化与异常排查
训练开始后,如何知道模型是否在正常学习?LLaMA-Factory 自带了基于 Web 的可视化界面,这对于监控损失曲线至关重要。
启动训练时,加上--plot_loss参数,或者直接访问本地服务端口(默认通常是 7860,具体看日志输出)。在浏览器中输入http://localhost:7860,你就能看到实时的 Loss 下降曲线。在 AMD 环境下,我发现有时图表刷新会有轻微延迟,这是正常的,只要数据点在持续更新即可。重点关注loss是否平稳下降,如果出现剧烈震荡或 NaN,说明学习率过大或数据有问题。
关于混合精度训练的“梯度爆炸”问题:
在使用 FP16 进行微调时,ROCm 某些版本下可能会出现梯度溢出导致训练中断,表现为 Loss 突然变成 NaN。这时候不要急着改数据,最有效的方案是切换到纯 FP32 模式,或者调整损失缩放因子。
修改配置文件:
fp16:falsebf16:false# 确保两者都关闭,强制使用 FP32# 或者如果必须用半精度,尝试调整以下参数(如果框架支持)# loss_scale: 1024.0虽然 FP32 会增加显存占用并略微降低训练速度,但它能极大提升数值稳定性。对于显存紧张的卡片,可以适当减小per_device_train_batch_size来换取空间。我在一次实验中,将 Batch Size 从 4 降到 2,开启 FP32 后,原本频繁崩溃的训练任务顺利跑完了 3 个 Epoch。
最后,训练完成的模型保存在output_dir指定的目录下。你可以使用同样的环境加载适配器进行推理验证:
fromllmtunerimportChatModel chat_model=ChatModel()# 会自动读取最新保存的检查点response,history=chat_model.chat("你好,介绍一下你自己。")print(response)在 AMD 显卡上跑通大模型微调并非难事,关键在于选对工具链和配置细节。希望这份实战记录能帮你少走弯路,轻松上手 ROCm 生态。
200小时GPU算力已就位,快来领取:https://marketing.csdn.net/questions/Q2604140858304426315?utm_source=AIpaper
