当前位置: 首页 > news >正文

SAC算法里的“熵”到底是啥?用Python代码带你直观理解最大熵强化学习

SAC算法中的熵用Python代码揭开强化学习探索之谜在强化学习的世界里我们常常教导智能体要聪明地行动——选择那些能带来最高奖励的动作。但有趣的是最先进的算法如SACSoft Actor-Critic却反其道而行之它鼓励智能体表现得不那么确定这就是熵的魔力。本文将通过Python代码带你直观理解这个看似矛盾却极其强大的概念。1. 熵在强化学习中的直观意义想象你正在玩一个全新的电子游戏。如果只选择已知能得分的操作你可能永远发现不了隐藏的彩蛋或更高效的得分方式。这就是传统强化学习的局限——过于功利的智能体容易陷入局部最优。而SAC通过引入熵让智能体保持适度的好奇心。熵的数学定义很简单对于一个概率分布π(a|s)其熵H(π) -Σπ(a|s)logπ(a|s)。在代码中我们可以这样计算import numpy as np def compute_entropy(prob_dist): return -np.sum(prob_dist * np.log(prob_dist 1e-10)) # 加小量避免log(0) # 示例两个不同的策略在3个动作上的分布 deterministic_policy np.array([0.9, 0.1, 0.0]) # 确定性强的策略 random_policy np.array([0.4, 0.3, 0.3]) # 随机性强的策略 print(f确定性策略熵: {compute_entropy(deterministic_policy):.3f}) print(f随机策略熵: {compute_entropy(random_policy):.3f})运行这段代码你会看到确定性策略的熵值明显更低。SAC的核心思想就是在奖励函数中加入这个熵值作为额外奖励鼓励策略保持一定的随机性。2. 构建极简SAC从网格世界开始为了直观展示熵的作用我们实现一个简化版SAC来解决网格世界问题。这个环境包含5x5网格起点在(0,0)目标在(4,4)某些格子有惩罚悬崖动作空间上、下、左、右import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque import random class GridWorld: def __init__(self): self.size 5 self.goal (4, 4) self.cliffs [(1, 2), (2, 2), (3, 2)] self.reset() def reset(self): self.pos (0, 0) return self.pos def step(self, action): x, y self.pos if action 0: y min(y 1, self.size - 1) # 上 elif action 1: y max(y - 1, 0) # 下 elif action 2: x max(x - 1, 0) # 左 elif action 3: x min(x 1, self.size - 1) # 右 self.pos (x, y) if self.pos in self.cliffs: return self.pos, -10, True if self.pos self.goal: return self.pos, 10, True return self.pos, -0.1, False # 每步小惩罚鼓励尽快到达目标3. SAC核心组件实现我们的极简SAC包含三个关键部分策略网络Actor、两个Q网络Critic和自动调节的温度参数α。class PolicyNetwork(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim64): super().__init__() self.fc1 nn.Linear(state_dim, hidden_dim) self.fc_mean nn.Linear(hidden_dim, action_dim) self.fc_logstd nn.Linear(hidden_dim, action_dim) def forward(self, state): x torch.relu(self.fc1(state)) mean torch.tanh(self.fc_mean(x)) # 输出在[-1,1]之间 log_std self.fc_logstd(x) return mean, log_std def sample_action(self, state): mean, log_std self.forward(state) std log_std.exp() normal torch.distributions.Normal(mean, std) action normal.rsample() # 重参数化采样 log_prob normal.log_prob(action).sum(-1) return action.tanh(), log_prob # 使用tanh确保动作在[-1,1] class QNetwork(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim64): super().__init__() self.fc1 nn.Linear(state_dim action_dim, hidden_dim) self.fc2 nn.Linear(hidden_dim, hidden_dim) self.fc_out nn.Linear(hidden_dim, 1) def forward(self, state, action): x torch.cat([state, action], dim-1) x torch.relu(self.fc1(x)) x torch.relu(self.fc2(x)) return self.fc_out(x)4. 熵如何影响智能体行为温度参数α控制着熵对策略的影响程度。我们可以通过调整α值来观察智能体的行为变化def train_sac(env, alpha0.2, episodes1000): state_dim 2 # (x,y)坐标 action_dim 4 # 上下左右 # 初始化网络 policy PolicyNetwork(state_dim, action_dim) q1 QNetwork(state_dim, action_dim) q2 QNetwork(state_dim, action_dim) # 优化器 policy_optim optim.Adam(policy.parameters(), lr3e-4) q_optim optim.Adam(list(q1.parameters()) list(q2.parameters()), lr3e-4) replay_buffer deque(maxlen10000) batch_size 64 for ep in range(episodes): state env.reset() episode_reward 0 while True: state_tensor torch.FloatTensor(state) action, log_prob policy.sample_action(state_tensor) action_idx torch.argmax(action).item() # 简化处理 next_state, reward, done env.step(action_idx) replay_buffer.append((state, action_idx, reward, next_state, done)) # 训练步骤 if len(replay_buffer) batch_size: batch random.sample(replay_buffer, batch_size) states, actions, rewards, next_states, dones zip(*batch) states torch.FloatTensor(np.array(states)) actions torch.LongTensor(np.array(actions)) rewards torch.FloatTensor(np.array(rewards)) next_states torch.FloatTensor(np.array(next_states)) dones torch.FloatTensor(np.array(dones)) # Q函数更新 with torch.no_grad(): next_actions, next_log_probs policy.sample_action(next_states) q1_next q1(next_states, next_actions) q2_next q2(next_states, next_actions) q_next torch.min(q1_next, q2_next) - alpha * next_log_probs target_q rewards 0.99 * (1 - dones) * q_next.squeeze() current_q1 q1(states, actions) current_q2 q2(states, actions) q1_loss nn.MSELoss()(current_q1.squeeze(), target_q) q2_loss nn.MSELoss()(current_q2.squeeze(), target_q) q_loss q1_loss q2_loss q_optim.zero_grad() q_loss.backward() q_optim.step() # 策略更新 new_actions, new_log_probs policy.sample_action(states) q1_new q1(states, new_actions) q2_new q2(states, new_actions) q_new torch.min(q1_new, q2_new) policy_loss (alpha * new_log_probs - q_new).mean() policy_optim.zero_grad() policy_loss.backward() policy_optim.step() episode_reward reward state next_state if done: break if ep % 50 0: print(fEpisode {ep}, Reward: {episode_reward:.1f})5. 温度参数α的调节艺术α值的选择直接影响智能体的探索行为高α值如1.0智能体像好奇宝宝愿意尝试各种路径即使看起来不是最优的低α值如0.1智能体变得功利快速锁定看似最优的路径自动调节的αSAC通常会自动调整α保持策略熵在一个目标值附近我们可以通过实验观察不同α值的效果# 比较不同α值的效果 for alpha in [0.1, 0.5, 1.0]: print(f\nTraining with alpha{alpha}) env GridWorld() train_sac(env, alphaalpha, episodes300)在实际运行中你会发现α0.1时智能体倾向于选择最短路径但可能掉入悬崖α1.0时智能体会探索更多路径最终可能发现更安全的路线适中的α值如0.5能在探索和利用间取得平衡6. 可视化熵在训练中的变化为了更直观理解熵的作用我们可以记录训练过程中策略熵的变化import matplotlib.pyplot as plt def plot_entropy_during_training(): alphas [0.1, 0.5, 1.0] entropy_records {alpha: [] for alpha in alphas} for alpha in alphas: env GridWorld() policy PolicyNetwork(2, 4) for _ in range(100): state env.reset() state_tensor torch.FloatTensor(state) _, log_prob policy.sample_action(state_tensor) entropy -log_prob.exp() * log_prob # 近似计算熵 entropy_records[alpha].append(entropy.item()) plt.figure(figsize(10, 6)) for alpha, entropies in entropy_records.items(): plt.plot(entropies, labelfα{alpha}) plt.xlabel(Training Steps) plt.ylabel(Policy Entropy) plt.title(Policy Entropy Under Different α Values) plt.legend() plt.show() plot_entropy_during_training()这张图会清晰展示高α值对应更高的策略熵更多探索随着训练进行所有策略的熵都会逐渐降低学会利用但高α值的策略始终保持着更高的随机性7. SAC在实际问题中的优势通过这个简化实现我们可以看到SAC相比传统强化学习算法的优势更鲁棒的策略学习不会轻易陷入局部最优自动平衡探索与利用通过熵正则化自然实现适应复杂环境在多模态奖励场景下表现优异例如在机器人控制中SAC能让机器人尝试多种行走方式而不仅限于一种固定步态在遇到障碍时能灵活切换策略持续学习新技能而不忘记已有能力# 实际应用中的SAC通常包含更多优化 class AdvancedSAC: def __init__(self, state_dim, action_dim): # 双Q网络和目标网络 self.q1 QNetwork(state_dim, action_dim) self.q2 QNetwork(state_dim, action_dim) self.target_q1 QNetwork(state_dim, action_dim) self.target_q2 QNetwork(state_dim, action_dim) # 自动调节的温度参数α self.target_entropy -action_dim # 常见设置 self.log_alpha torch.zeros(1, requires_gradTrue) self.alpha_optim optim.Adam([self.log_alpha], lr3e-4) def update_alpha(self, log_probs): alpha_loss -(self.log_alpha * (log_probs self.target_entropy).detach()).mean() self.alpha_optim.zero_grad() alpha_loss.backward() self.alpha_optim.step() return self.log_alpha.exp().item()这个进阶实现展示了SAC在实际应用中的常见组件包括目标网络和自动温度调节这些都是确保算法稳定性的关键。
http://www.gsyq.cn/news/1374281.html

相关文章:

  • 火箭设计仿真软件终极指南:OpenRocket如何让每个人都能设计专业火箭
  • C51工具覆盖分析机制与8051内存优化实践
  • 征集暑期亲子研学北京的靠谱机构,要求经验多,专业程度高 - 品牌2025
  • 大麦抢票终极指南:如何用自动化工具轻松获取热门演唱会门票
  • 如何在macOS上快速创建PDF文件:终极虚拟打印机解决方案
  • 如何安全烧录系统镜像:Balena Etcher免费开源工具的终极指南
  • Token CSS高级技巧:如何扩展自定义设计令牌和主题的终极指南
  • 如何将普通汽车升级为智能驾驶伙伴:openpilot开源项目深度解析
  • React Native 开发者必读:react-native-bottom-sheet-behavior 源码解析与自定义扩展
  • 避坑指南:VirtualBox装Ubuntu 22.04时,你可能忽略的3个关键设置(内存/磁盘/增强功能)
  • 在Ubuntu 18.04上用RTX 3060复现ICCV 2021 PMF:一个4天11小时的踩坑与加速训练实录
  • 2026年靠谱的杭州工装装修施工榜单优选公司 - 品牌宣传支持者
  • 别再让SSD越用越慢了!手把手教你检查并开启TRIM功能(Linux/Windows保姆级教程)
  • 北京研学机构哪家好?住宿条件好的青少年北京研学机构推荐 - 品牌2025
  • 用100行PyTorch代码实现扩散模型:从理论到实战的完整指南
  • 如何从零开始构建AI社会模拟:AgentSociety终极指南
  • 小电视空降助手:告别B站广告烦恼的终极解决方案
  • CSharpVerbalExpressions核心API详解:StartOfLine、Then、Maybe等方法的终极教程
  • Pushd新手入门:iOS/Android/Windows推送协议一键集成完整指南
  • 10个Promise核心概念解析:Async-JavaScript-Cheatsheet项目深度教程
  • GitHub Gem核心命令详解:10个必学的高效GitHub操作技巧
  • EasyDoc深度解析:如何将PDF、Word文档智能转换为JSON格式的终极指南
  • defx.nvim 高级操作技巧:50+动作命令提升文件管理效率
  • ARM SME指令集:LD1B与LD1D向量加载技术详解
  • C++打印 vector的几种方法小结
  • 如何通过Pushd API实现用户订阅管理?完整指南
  • 保姆级教程:手把手教你将DIOR遥感数据集转为YOLOv5可用的格式(附完整Python脚本)
  • ARM SVE指令集:UQINCH/UQINCW向量饱和递增详解
  • 2026保安岗亭品牌权威度评测报告:可移动垃圾房、台州岗亭、吸烟亭、嘉兴岗亭、杭州岗亭、浙江岗亭、湖州岗亭、移动卫生间选择指南 - 优质品牌商家
  • 解锁网络资源下载:res-downloader跨平台资源嗅探解决方案