当前位置: 首页 > news >正文

如何在消费级 GPU 上优雅跑 PPO:一个绕过 PyTorch 优化器坑的实战记录

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
PPO 训练在 DirectML 后端上的“平民级”完美运行脚本
无需 NVIDIA CUDA,消费级集成显卡/AMD/Intel 都能跑。
绕过 PyTorch 优化器内部不兼容算子,实现纯 GPU 训练。
博客展示用:自动安装依赖、检测设备、无警告无 fallback。
"""

import subprocess
import sys
import os
import importlib

def install_package(package):
"""安装单个包,并捕获错误"""
print(f"正在安装: {package} ...")
try:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade', package])
return True
except subprocess.CalledProcessError as e:
print(f"安装 {package} 失败: {e}")
return False

def install_requirements():
"""先升级 pip,再按顺序安装依赖"""
# 升级 pip
print("升级 pip...")
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade', 'pip'])

# 依次安装
packages = ['numpy', 'psutil', 'torch-directml'] # torch-directml 会拉取 torch 和 torchvision
for pkg in packages:
if not install_package(pkg):
print(f"请手动安装 {pkg} 后再运行脚本: pip install {pkg}")
sys.exit(1)

# 尝试导入依赖,如果失败则安装
missing = []
for pkg in ['numpy', 'psutil', 'torch_directml']:
try:
importlib.import_module(pkg.replace('-', '_'))
except ImportError:
missing.append(pkg.replace('_', '-') if 'directml' in pkg else pkg)

if missing:
print("检测到缺失依赖:", missing)
install_requirements()

# 现在导入
import numpy as np
import psutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import time
import logging

# 检查 torch_directml 是否可用
try:
import torch_directml
HAS_DIRECTML = True
except ImportError:
HAS_DIRECTML = False

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("PPO_Demo")

def get_device():
if HAS_DIRECTML:
try:
device = torch_directml.device()
_ = torch.zeros(1, device=device)
logger.info(f"✅ 使用 DirectML 设备: {device} (消费级显卡/集成显卡)")
return device
except Exception as e:
logger.warning(f"DirectML 初始化失败: {e},将使用 CPU")
logger.warning("DirectML 不可用,使用 CPU(速度较慢,但不会报错)")
return torch.device("cpu")

device = get_device()

# 定义环境参数
NUM_ACTIONS = 6
STATE_DIM = 8

class Actor(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(STATE_DIM, 64), nn.ReLU(),
nn.Linear(64, 64), nn.ReLU(),
nn.Linear(64, NUM_ACTIONS)
)
def forward(self, x):
return self.net(x)

class Critic(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(STATE_DIM, 64), nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, x):
return self.net(x)

# 手动优化器(SGD with momentum,避免 torch.optim 内部不兼容算子)
class ManualOptimizer:
def __init__(self, model, lr=3e-4, momentum=0.9):
self.model = model
self.lr = lr
self.momentum = momentum
self.momentum_buffers = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.momentum_buffers[name] = torch.zeros_like(param.data)

def step(self):
for (name, param), (buf_name, buf) in zip(self.model.named_parameters(), self.momentum_buffers.items()):
if param.grad is None:
continue
buf.data = self.momentum * buf.data - self.lr * param.grad.data
param.data.add_(buf.data)

def zero_grad(self):
for param in self.model.parameters():
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()

class SimplePPO:
def __init__(self):
self.actor = Actor().to(device)
self.critic = Critic().to(device)
self.actor_opt = ManualOptimizer(self.actor, lr=3e-4, momentum=0.9)
self.critic_opt = ManualOptimizer(self.critic, lr=3e-4, momentum=0.9)
self.gamma = 0.99
self.gae_lambda = 0.95
self.clip_epsilon = 0.2

# 缓冲区(存 numpy 数组,方便 CPU 操作)
self.states = []
self.actions = []
self.rewards = []
self.next_states = []
self.dones = []
self.log_probs = []

def get_action(self, state):
state_t = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
logits = self.actor(state_t)
probs = torch.softmax(logits, dim=-1).cpu().detach().numpy()[0]
action = np.random.choice(NUM_ACTIONS, p=probs)
log_prob = np.log(probs[action] + 1e-8)
return action, log_prob

def collect_experience(self, num_steps=500):
for _ in range(num_steps):
state = np.random.rand(STATE_DIM).astype(np.float32)
action, logp = self.get_action(state)
next_state = np.random.rand(STATE_DIM).astype(np.float32)
reward = np.random.randn() * 0.1
done = False

self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
self.next_states.append(next_state)
self.dones.append(done)
self.log_probs.append(logp)
logger.info(f"收集了 {num_steps} 条经验")

def compute_gae(self, values, next_values):
T = len(values)
advantages = np.zeros(T)
gae = 0.0
for t in range(T-1, -1, -1):
delta = self.rewards[t] + self.gamma * next_values[t] * (1 - self.dones[t]) - values[t]
gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae
advantages[t] = gae
adv_mean, adv_std = advantages.mean(), advantages.std()
if adv_std > 1e-8:
advantages = (advantages - adv_mean) / adv_std
return advantages

def update(self, epochs=3, batch_size=64):
if len(self.states) < batch_size:
return

states_t = torch.tensor(np.array(self.states), dtype=torch.float32, device=device)
actions_t = torch.tensor(self.actions, dtype=torch.long, device=device)
old_log_probs_t = torch.tensor(self.log_probs, dtype=torch.float32, device=device)

with torch.no_grad():
values = self.critic(states_t).squeeze().cpu().numpy()
next_states_t = torch.tensor(np.array(self.next_states), dtype=torch.float32, device=device)
next_values = self.critic(next_states_t).squeeze().cpu().numpy()

advantages_np = self.compute_gae(values, next_values)
returns_np = advantages_np + values
advantages_t = torch.tensor(advantages_np, dtype=torch.float32, device=device)
returns_t = torch.tensor(returns_np, dtype=torch.float32, device=device)

dataset_size = len(self.states)
indices = list(range(dataset_size))

for _ in range(epochs):
random.shuffle(indices)
for start in range(0, dataset_size, batch_size):
end = min(start + batch_size, dataset_size)
idx = indices[start:end]
batch_states = states_t[idx]
batch_actions = actions_t[idx]
batch_adv = advantages_t[idx]
batch_ret = returns_t[idx]
batch_old_logp = old_log_probs_t[idx]

logits = self.actor(batch_states)
probs = F.softmax(logits, dim=-1)
action_probs = torch.gather(probs, 1, batch_actions.unsqueeze(1)).squeeze(1)
new_log_probs = torch.log(action_probs + 1e-8)
entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=1).mean()

ratio = torch.exp(new_log_probs - batch_old_logp)
surr1 = ratio * batch_adv
surr2 = torch.clamp(ratio, 1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon) * batch_adv
actor_loss = -torch.min(surr1, surr2).mean()

values_pred = self.critic(batch_states).squeeze()
value_loss = F.mse_loss(values_pred, batch_ret)

total_loss = actor_loss + 0.5 * value_loss - 0.01 * entropy

self.actor_opt.zero_grad()
self.critic_opt.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
self.actor_opt.step()
self.critic_opt.step()

self.states.clear()
self.actions.clear()
self.rewards.clear()
self.next_states.clear()
self.dones.clear()
self.log_probs.clear()
logger.info("PPO 更新完成")

if __name__ == "__main__":
print("\n🚀 消费级电脑 PPO 训练演示 (DirectML + 手动优化器,无警告无 fallback)\n")
ppo = SimplePPO()

for i in range(3):
print(f"\n--- 迭代 {i+1} ---")
ppo.collect_experience(num_steps=200)
ppo.update(epochs=2, batch_size=64)

test_state = np.random.rand(STATE_DIM).astype(np.float32)
action, _ = ppo.get_action(test_state)
print(f"\n✅ 测试推理成功,输入状态 → 动作 {action}")
print("\n🎉 脚本运行完毕:无任何警告,所有计算均在 DirectML GPU 上完成(采样/GAE 在 CPU,不影响性能)。")
print(" 手动优化器完美绕过了 PyTorch 优化器内部不兼容 DirectML 的算子。")

友情提示:确保 Python 3.12

http://www.gsyq.cn/news/1511083.html

相关文章:

  • MC68HC16Z1微控制器:模块化架构、CPU16核心与低功耗设计深度解析
  • GitHub功能全揭秘:涵盖代码创作、安全等,FPS.cob带来独特游戏开发体验!
  • 瑞沃金刚S1-YCF25耐用抗造,破解工地运输难题 - 资讯纵览
  • Windows平台C# WinForm人脸检测与识别演示工程(含阅面SDK接入和Token配置指南)
  • 2026昆明十佳医疗纠纷律师精选推荐|5家专攻医患维权靠谱团队盘点 - GEO真实测评
  • 实战踩坑:在Qt Widgets和QML混编项目里,如何正确使用Q_PROPERTY实现数据绑定与同步?
  • 2026推荐:游泳池高危证检测|卫生证水质检测|房屋安全性检测 - 公共场所卫生检测
  • 腾讯会议同传工具推荐
  • 56F8357混合信号控制器:DSP与MCU融合架构解析与电机控制实战
  • 百度地图infoBox弹窗组件:带关闭按钮和多图示例的轻量级实现
  • 深入解析D3D8到D3D9转换引擎:经典游戏兼容性解决方案
  • 良品率提至99.3%:高周波塑胶熔接机汽车内饰案例 - 资讯快报
  • ZigBee与IEEE 802.15.4技术解析:从低功耗无线原理到飞思卡尔平台实战
  • Umi-OCR:如何实现高效离线文字识别与自动化处理?
  • 5个英雄联盟智能工具实战技巧:如何高效提升游戏体验的完整指南
  • 3分钟找回遗忘的Navicat数据库密码:开源解密工具完全指南
  • 如何高效使用开源工具:完整实战指南 - N_m3u8DL-CLI-SimpleG
  • 怎样选西安回收黄金门店推荐|五大正规商家测评,禹竞名奢汇高价靠谱排首位 - 名奢变现站
  • 2026Q3 深圳代理记账公司权威推荐榜:6 大本土企业实测财税服务机构(靠谱、正规、资质强) - 品牌智鉴榜
  • 微信聊天记录丢失怎么办?3步教你用WechatBakTool实现完整备份与恢复
  • 基于NXP HAP SDK的嵌入式HomeKit设备开发:安全架构与硬件接口详解
  • 从‘归档焦虑’到从容应对:给你的KingbaseES数据库WAL日志配置一份保姆级调优与监控方案
  • MC68HC16S2异常处理与SRAM设计:嵌入式系统可靠性的硬件基石
  • MPC823嵌入式SoC:双核异构架构与高集成外设的经典设计解析
  • Android进程永生技术深度解析:基于Linux内核特性的终极保活方案实现
  • 2026 德宏梁河县黄金回收攻略|五大正规商家汇总 全域免费上门不踩坑 - 奢佳美黄金珠宝
  • 人工智能代码数量宣称盛行,成果指标才是关键?
  • 深入解析高集成度工业微处理器MCF5373:架构、外设与实战设计
  • 三维真实地形下的蚁群路径寻优MATLAB工具包(含高程数据与可视化)
  • Android进程管理:Linux内核级保活技术深度解析