当前位置: 首页 > news >正文

大模型微调示例四之Llama-Factory-DPO - 教程

大模型微调示例四之Llama-Factory-DPO

  • 一、强化学习数据处理
  • 二、配置训练文档
  • 三、模型预测

一、强化学习数据处理

原始数据地址:https://nijianmo.github.io/amazon/index.html

第一步:读取 video game 信息

import codecs, json, re
from random import shuffle
# 第一步:读取 video game 信息
# key 是 productID,value是 title
games = {
}
cc = 0
with codecs.open('./data/src_data/meta_Video_Games.json', mode='r') as fin:
for line in fin:
tmp_info = json.loads(line.strip())
# asin - ID of the product
# title - name of the product
games[tmp_info["asin"]] = tmp_info["title"]
if len(games) % 10000 == 0:
print(f'Length of games: {
len(games)
}')

第二步:读取用户评分信息

# key 是 userid,value 是评价的游戏和评分
user_reviews = {
}
cc = 0
with codecs.open('./data/src_data/Video_Games_5.json', mode='r') as fin:
for line in fin:
tmp_info = json.loads(line.strip())
# reviewerID - ID of the reviewer
reviewer_id = tmp_info["reviewerID"]
time_info = re.split(', | ', tmp_info["reviewTime"])
review_time = time_info[2] + '-' + time_info[0] + '-' + time_info[1]
# asin - ID of the product
product_id = tmp_info["asin"]
# overall - rating of the product
rating = tmp_info["overall"]
# if cc > 1000:
# break
# print(tmp_info)
# print(user_reviews)
if product_id in games.keys():
product_title = games[product_id]
if reviewer_id in user_reviews.keys():
user_reviews[reviewer_id].append((product_title, rating, review_time))
else:
user_reviews[reviewer_id] = [(product_title, rating, review_time)]
if len(user_reviews) % 10000 == 0:
print(f'Length of user_reviews: {
len(user_reviews)
}')
cc += 1
user_reviews_sorted = {
}
for k, v in user_reviews.items():
# 首先去重
v = list(set(v))
# 然后根据评价时间从小到大排序,表示用户的评价历史
v_sorted = sorted(v, key=lambda x: x[2])
# 选择具有7个及以上的评论样本
if len(v) >= 7:
# print(f'v: {v}, v_sorted: {v_sorted}')
user_reviews_sorted[k] = v_sorted
print(f'Length of user_reviews_sorted: {
len(user_reviews_sorted)
}')

第三步 训练数据生成

# 总样本
samples = []
# 指令
instruction = "You are an assistant working on Video Games recommendations. Given the user's history of Video Games they have shopped, which includes the \"Title\" of the Video Games and the \"Rating\" the user rate (the Rating value is like or dislike), please decide whether the user likes to shop the target Video Games by outputting the order of their titles."
samples = []
cc = 0
for k, v in user_reviews_sorted.items():
# print('-'*10)
# print(v)
sample_input = "User shopped Video Games histories (Title and Rating): \n"
# 前面的当作对话历史
for vv in v[0: -2]:
# 当 rating 大于 3.0 的时候设置为 like
if vv[1] >
3.0:
rating = 'like'
# 当 rating 小于等于 3.0 的时候设置为 dislike
else:
rating = 'dislike'
sample_input += "<Title: {}, Rating: {}>\n".format(vv[0], rating)sample_input += "Based on the Video Games histories, please sort the following two Video Games titles. The one in the front is what the user like and should be recommended to user: \n"# 最后两个设置为需要预测的目标sample_input += "<Title: " + v[-2][0] + '>\n'sample_input += "<Title: " + v[-1][0] + '>\n'# print(f'v[-1][1]: {v[-1][1]}, v[-2][1]: {v[-2][1]}')# 保证有一个是 like,有一个是 dislikeif (v[-1][1] >3.0 and v[-2][1] <= 3.0) or (v[-1][1] <= 3.0 and v[-2][1] >3.0):# print(f'v[-1][1] != v[-2][1]: {v[-1][1]}, {v[-2][1]}')if v[-1][1] > v[-2][1]:# likeoption1 = v[-1][0]# dislikeoption2 = v[-2][0]else:# likeoption1 = v[-2][0]# dislikeoption2 = v[-1][0]# chosen 是 like 在前面chosen = "<Title: " + option1 + '>\n' + "<Title: " + option2 + '>'# rejected 是 dislike 在前面rejected = "<Title: " + option2 + '>\n' + "<Title: " + option1 + '>'sample = {"instruction": instruction,"input": sample_input,"chosen": chosen,"rejected": rejected}# print(f'--------')# print(v)# print(sample)samples.append(sample)if len(samples) % 10000 == 0:print(f'Length of samples: {len(samples)}')# cc += 1# if cc > 10:# breakprint(f'Length of samples: {len(samples)}')

第四步 划分 train 和 test 保存样本

# 首先打乱
shuffle(samples)
train = samples[:int(len(samples)*0.8)]
test = samples[int(len(samples)*0.8):]
print(f'总样本数: {
len(samples)
},训练集样本数: {
len(train)
},测试集样本数: {
len(test)
}')
with open("./data/processed/rlhf_train.json", "w", encoding='utf-8') as save_file:
json.dump(train, save_file, indent=4)
with open("./data/processed/rlhf_test.json", "w", encoding='utf-8') as save_file:
json.dump(test, save_file, indent=4) # , sort_keys=True

二、配置训练文档

rlhf_train.yaml

### model
model_name_or_path: /ZhipuAI/glm-4-9b-chat
### method
stage: dpo
do_train: true
finetuning_type: lora
lora_target: all
lora_rank: 16
lora_alpha: 32
pref_beta: 0.1
pref_loss: orpo
### dataset
dataset: amazon_video_games
template: glm4
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: ./saves/amazon_video_games_orpo
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 5.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

rlhf_inference.yaml

model_name_or_path: /ZhipuAI/glm-4-9b-chat
adapter_name_or_path: ./saves/amazon_video_games_orpo
template: glm4
finetuning_type: lora

三、模型预测

import json
from openai import OpenAI
from tqdm import tqdm
# 加载模型
client = OpenAI(
api_key="EMPTY",
# 需要修改为大模型地址
base_url="http://10.114.16.65:8000/v1/"
)
# 加载测试数据
test_file_path = "./data/processed/rlhf_test.json"
with open(test_file_path, "r", encoding='utf-8') as test_file:
test_data = json.load(test_file)
print(len(test_data))
# 开始预测
labels = []
predictions = []
cc = 0
for each_test in tqdm(test_data):
chat_completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": each_test["instruction"]
},
{
"role": "user",
"content": each_test["input"],
}
],
model="glm4",
)
predictions.append(chat_completion.choices[0].message.content)
labels.append(each_test["chosen"])
if len(labels) % 100 == 0:
correct = 0
wrong = 0
for l, p in zip(labels, predictions):
l = l.strip()
p = p.strip()
# print(f'l: {l}, p: {p}')
if l == p:
correct += 1
else:
wrong += 1
# print(f'\nl: {l}, \np: {p}')
print(f'总样本数:{
len(labels)
},准确数:{correct
}, 错误数:{wrong
}, 准确率:{correct / len(labels)
}')
cc += 1
# if cc > 100:
# break
assert len(predictions) == len(labels)
correct = 0
wrong = 0
for l, p in zip(labels, predictions):
l = l.strip()
p = p.strip()
if l == p:
correct += 1
else:
wrong += 1
print(f'总样本数:{
len(labels)
},准确数:{correct
}, 错误数:{wrong
}, 准确率:{correct/len(labels)
}')
http://www.gsyq.cn/news/10425.html

相关文章:

  • n8n+MySQL实现数据库查询!
  • firewalld 端口流量转发
  • Day20封装的初步认识
  • 【Qt开发】显示类控件(三)-> QProgressBar - 详解
  • 完整教程:数据结构与算法-树和二叉树-二叉树的存储结构(Binary Tree)
  • 工业相机与镜头靶面尺寸的关系:从原理到选型的避坑指南 - 教程
  • 提供优雅报错能力
  • Security Onion Solution
  • 详细介绍:MySQL进阶学习
  • 时序数据库 TimechoDB V2.0.6 发布 | 新增查询写回、黑白名单等功能
  • 第二篇
  • EasyDSS “进度条预览”黑科技,如何重塑视频点播的交互体验?
  • AI重塑招聘:从筛简历到做决策,HR如何借技术提效35%?
  • 直播点播之外,EasyDSS如何开辟“实时协作”第三极?它的会议功能,远比你想象的强大
  • 抖音视频关键词批量下载工具分享|分享痛点|
  • 第二部分:VTK核心类详解(第38章 vtkPointData点数据类) - 教程
  • 使用ai来搭建测试用例1
  • 总线的概念以及分类
  • 详细介绍:基于伪随机数的WPS PIN码逆向原理分析(精灵尘埃/仙尘攻击)
  • WPF Prism PrismApplication OnInitialized()
  • 使用shell脚本一键部署docker及docker-compose环境
  • 数据全生命周期安全建设方案推荐:双轮驱动架构的实践与创新
  • 噬菌体展示技术原理深度解析:从基因型-表型偶联到亲和筛选的核心逻辑
  • 日记2
  • AP2 (Agent Payments Protocol) 使用教程
  • RTK精度和时间 - MKT
  • LeetCode-100.相同的树
  • ubuntu安装minio并切换数据存储目录
  • 数据全生命周期安全解决方案推荐(2025):以全链路泛监测补强控制面,走通“观测先行—证据回灌—渐进加固”的落地路径
  • Java 语法糖大揭秘:让代码更甜更高效的幕后功臣 - 教程