当前位置: 首页 > news >正文

强化学习Q-learning求最优策略

理论基础:

on policy:behavior policy=target policy

off policy:behavior policy!=target policy

注意:

behavior policy的初始化最好具有较强的随机性,就能尽可能遍历到所有的(s, a)pair。

强化学习的数据基础这种书中有不同的behavior policy导致的不同的探索路径的图:

代码可运行:

import numpy as np from env import GridWorldEnv from utils import drow_policy class Q_Learning(object): def __init__(self, env: GridWorldEnv, gamma=0.9, alpha=0.001, epsilon=0.1, samples=1, start_state=(0, 0),mode="on policy"): ''' :param env: 定义了网格的基础配置 :param gamma: discount rate :param alpha: learning rate :param samples: 从起点到终点采样的路径数 :param start_state: 起点 :param mode: 模式 ''' self.env = env self.action_space_size = self.env.num_actions # 上下左右原地 self.state_space_size = self.env.num_states self.reward_list = self.env.reward_list self.gamma = gamma self.samples = samples self.alpha = alpha self.epsilon = epsilon self.mode=mode self.start_state = self.env.state_id(start_state[0], start_state[1]) self.behavior_policy = np.ones( (self.state_space_size, self.action_space_size)) / self.action_space_size # 探索性很强 self.target_policy = np.zeros((self.state_space_size, self.action_space_size)) self.qvalues = np.zeros((self.state_space_size, self.action_space_size)) def update_qvalues(self,s_t,a_t,s_next,r_next): max_q_next = np.max(self.qvalues[s_next]) td_target = r_next + self.gamma * max_q_next td_error = td_target - self.qvalues[s_t][a_t] # 负号提出去 self.qvalues[s_t][a_t] += self.alpha * td_error def solve(self): if self.mode=="off policy": for _ in range(self.samples): s = self.start_state a = np.random.choice(self.action_space_size, p=self.behavior_policy[s]) episode = self.env.generate_episodes(self.behavior_policy, s, a) for i in range(len(episode)): s_t, a_t, r_next_t, s_next_t= episode[i] self.update_qvalues(s_t,a_t,s_next_t,r_next_t) # greedy best_a = np.argmax(self.qvalues[s_t]) self.target_policy[s_t] = np.eye(self.action_space_size)[best_a] elif self.mode=="on policy": # target_policy=behavior_policy for _ in range(self.samples): s = self.start_state while s not in self.env.terminal: a = np.random.choice(self.action_space_size, p=self.behavior_policy[s]) # generate at following πt(st) next_s, next_r, _ = self.env.step(s, a) # generate rt+1, st+1 by interacting with the environment # updata q-value for (s_t,a_t) # qt+1(st, at) = qt(st, at) − αt(st, at) [ qt(st, at) − (rt+1 + γ max(qt(st+1, a)))] self.update_qvalues(s,a,next_s,next_r) # update policy for s_t: epsilon greedy 因为要用policy生成数据,因此需要策略具有一定的探索性,因此使用epsilon greedy best_a = np.argmax(self.qvalues[s]) self.behavior_policy[s] = self.epsilon / self.action_space_size self.behavior_policy[s, best_a] += 1 - self.epsilon self.target_policy=self.behavior_policy s = next_s else: raise Exception("Invalid mode") if __name__ == '__main__': env = GridWorldEnv( size=5, forbidden=[(1, 2), (3, 3)], terminal=[(4, 4)], r_boundary=-1, r_other=-0.04, r_terminal=1, r_forbidden=-1, r_stay=-0.1 ) # 注意samples要大一点,否则每个state被访问到的概率很小 vi = Q_Learning(env=env, gamma=0.8, alpha=0.01, samples=1000, start_state=(0, 0),mode="off policy") vi.solve() print("\n state value: ") print(vi.qvalues) drow_policy(vi.target_policy, env)
http://www.gsyq.cn/news/99771.html

相关文章:

  • ComfyUI文生图工作流详解
  • c#教程实战应用案例分享
  • 3分钟搞定网盘限速:无需会员的高速下载加速方案
  • 【开题答辩全过程】以 基于Spring Boot的香飘万里外卖平台为例,包含答辩的问题和答案
  • 实战LLaMA2-7B指令微调
  • 优化Sigmoid函数计算:提升AI模型训练速度
  • 计算机毕业设计springboot餐厅预定系统 基于SpringBoot的智慧餐饮订座平台 SpringBoot驱动的线上餐厅席位预约管理系统
  • Java 八大排序算法详解(从入门到面试)
  • AI如何革新漏洞扫描工具的开发流程
  • 深入解析 DNS:互联网的隐形神经系统
  • 数字色彩的骨架:计算机如何理解颜色
  • AI大模型赋能消费升级:新机遇与新路径
  • Ascend C算子精度调试全攻略 - 从Print函数到结构化数据比对
  • Web3.js钱包与账户管理
  • 【开题答辩全过程】以 基于微信小程序的失物认领系统为例,包含答辩的问题和答案
  • 《线性代数应该这样学》学习笔记 | 第一章 向量空间
  • 光电设计大赛-基于树莓派4B的YOLOv5-Lite目标检测的移植与部署
  • AI弱智文章 - sunny
  • 亚马逊基本功:低成本测品攻略
  • MATLAB程序设计基础
  • 密码系统
  • 电商系统中ES检索技术设计和运用 - 实践
  • C#+VisionMaster联合开发(十)_全局触发
  • 学生党必备!这款桌面课表工具太省心了
  • 江西过碳酸钠生产厂、浙江过碳酸钠生产厂实力榜,值得关注 - 品牌2026
  • 重磅科研发现:香蕉是宇宙的终极遥控器 - sunny
  • 基于springboot的课程作业管理系统(11490)
  • 成膜助剂出口厂商有哪些?有出口资质的成膜助剂供应商推荐 - 品牌2026
  • 过碳酸钠供应商、生产厂家盘点,靠谱供应商及制造商合集 - 品牌2026
  • C#+VisionMaster联合开发(十一)_全局脚本