当前位置: 首页 > news >正文

hook 工具随笔

hook 是 pytorch 中的一个工具,主要作用是 在模型前向/反向传播过程之前/之后执行一些自定义操作

比如说: 打印查看一些模型参数;修改梯度等等

基本形态

基本的形式是:

# 自定义操作
def forward_hook(module, input, output):print(f"Module: {module}")print(f"Input: {input}")print(f"Output: {output}")# 注册 hook
hook = model.fc1.register_forward_hook(forward_hook)# 传递
output = model(input_data)# 移除 hook
hook.remove()

其中 forward_hook 是自定义的函数,其 输入参数 随 register_xx_hook 的不同而有所不同,常见的 register 类型如下表所示:

Hook 类型 注册函数 函数签名 调用时机
forward hook register_forward_hook(hook) hook(module, input, output) 模块执行完前向传播后调用
forward pre-hook register_forward_pre_hook(hook) hook(module, input) 模块执行前向传播前调用
full backward hook register_full_backward_hook(hook) hook(module, grad_input, grad_output) 模块反向传播时调用,兼容复杂的 autograd 结构
Tensor 级别 hook tensor.register_hook(hook) hook(grad) 注册在 Tensor 上,在该 Tensor 的梯度被计算时调用

实现机制

hook 实际上是基于 Autograd 引擎的回调机制 实现的。大致流程如下:

  1. 每个 nn.Module 在执行 forward() 时,都会在 Autograd 图中注册节点(Function)。
  2. 当你注册 hook 时,PyTorch 会把你的函数加入到这些节点的 回调列表。
  3. 前向传播:如果是 forward_pre_hook → 在执行 module.forward() 前调用;如果是 forward_hook → 在 forward() 结束后调用。
  4. 反向传播:当 Autograd 计算梯度经过该节点时,会触发该节点的 backward hook。

内部实现类似这样:

# 伪代码示例
for pre_hook in module._forward_pre_hooks:x = pre_hook(module, x)out = module.forward(x)for hook in module._forward_hooks:out = hook(module, x, out)

实例介绍

这段代码来自论文中的一个实验,其核心思想可以理解为:针对两个不同的 query,它们在语义上应当得到相同的 answer。研究者通过将第一个 query 在前向传播过程中生成的中间隐藏状态,替换到第二个 query 的对应层位置中,然后再让模型继续生成输出。若此时模型仍然能够产生目标 answer,就说明模型的推理结果更多地依赖于其内部的表征与推导能力 😎 ,而非仅仅依靠上下文记忆或表面模式匹配

这个过程正是通过 forward_hook 实现的——在前向传播结束后,动态地修改指定层的输出,从而实现对模型中间表示的干预与验证。

def cross_query_semantic_patching(model, tokenizer, device, queries, position, layer):# initialize counterssuccess_counts = 0total_counts = 0for source_prompt, target_prompt, expected_e3 in tqdm(queries):# get the source hidden statesdecoder_temp = tokenizer([source_prompt], return_tensors="pt", padding=True)decoder_input_ids, decoder_attention_mask = decoder_temp["input_ids"], decoder_temp["attention_mask"]decoder_input_ids, decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)with torch.no_grad():outputs1 = model(input_ids=decoder_input_ids,attention_mask=decoder_attention_mask,output_hidden_states=True)hidden_states_batch = outputs1.hidden_states  # [1+num_layers, batch_size, seq_len, hidden_size]# replace the hidden states of the target position with the source hidden statesdef hook_fn(module, input, output):# output ([batch_size, seq_len, hidden_size], ...)main_output = output[0].clone()main_output[0, position, :] = hidden_states_batch[layer][0, position, :]return (main_output,) + output[1:]# 注册前向钩子handle = model.transformer.h[layer - 1].register_forward_hook(hook_fn) # [num_layers, batch_size, seq_len, hidden_size]# target promptdecoder_temp = tokenizer([target_prompt], return_tensors="pt", padding=True)decoder_input_ids, decoder_attention_mask = decoder_temp["input_ids"], decoder_temp["attention_mask"]target_decoder_input_ids, target_decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)with torch.no_grad():outputs2 = model(input_ids=target_decoder_input_ids,attention_mask=target_decoder_attention_mask,# output_hidden_states=True)# 移除钩子,避免影响后续推理handle.remove()# decode the predicted tokenlogits = outputs2.logits  # [batch_size, seq_len, vocab_size]predicted_token_ids = torch.argmax(logits, dim=-1)  # [batch_size, seq_len]decoded_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)decoded_token = decoded_text[0].split()[-1]         # check if the decoded token is the expected tokentotal_counts += 1if decoded_token == expected_e3:success_counts += 1return success_counts/total_counts

模型结构介绍

主要聚焦于这一部分:

hidden_states_batch = outputs1.hidden_states  # [1+num_layers, batch_size, seq_len, hidden_size]# replace the hidden states of the target position with the source hidden states
def hook_fn(module, input, output):# output ([batch_size, seq_len, hidden_size], ...)main_output = output[0].clone()main_output[0, position, :] = hidden_states_batch[layer][0, position, :]return (main_output,) + output[1:]# 注册前向钩子
handle = model.transformer.h[layer - 1].register_forward_hook(hook_fn) # [num_layers, batch_size, seq_len, hidden_size]

这里主要是回顾一下模型结构:

hidden_states_batch 是模型在前向传播过程中保存的各层隐藏状态,其维度为 [1 + num_layers, batch_size, seq_len, hidden_size],其中第 0 个元素对应 embedding 层的输出,而后续的每个元素 hidden_states_batch[layer] 分别对应第 layer 个 Transformer 层的隐藏状态;model.transformer.h 是模型中所有 Transformer 层的列表,因此 model.transformer.h[layer - 1] 就表示第 layer 层,索引从 0 开始。

在 hook_fn 中,input 表示传入该层的张量,output 表示该层前向传播后的输出结果。由于这里的目标是修改该层在前向传播后的输出隐藏状态,而不是输入,因此选择对 output 进行处理;output 通常是一个元组 (hidden_states, other_outputs),其中 output[0] 是该层的主要输出张量,即 [batch_size, seq_len, hidden_size] 的隐藏状态,而其余部分可能包含注意力缓存或其他附加信息,之后的替换就很好理解。

http://www.gsyq.cn/news/35146.html

相关文章:

  • 堆和栈的生命周期对于代码的影响
  • pgsql索引冗余分析
  • 详细介绍:Leetcode 3700. Number of ZigZag Arrays II
  • 老旧环境torch版本(0.4.1)环境配置总结
  • 代码大全阅读笔记3
  • 通过中国信通院SQL质量管理最高等级评测,天翼云TeleDB引领数据库管理新标准!
  • 代码大阅读笔记
  • 第二次软件基础作业
  • 实用指南:从0死磕全栈之Next.js Server Actions 入门实战:在服务端安全执行逻辑,告别 API 路由!
  • 重塑生产力:天翼云全球首发RaaS,开启“机器人即服务”商业时代!
  • Sequence2Sequence - -一叶知秋
  • 第177天:信息收集篇自动项目本机导出外部打点域内通讯PillagerBloodHound
  • 如何在Linux中,为Flatpak版本的Edge浏览器导入证书
  • Java 集合 “Map(1)”面试清单(含超通俗生活案例与深度理解) - 教程
  • 2025 年铸铁井盖生产厂家最新推荐榜,技术实力与市场口碑深度解析防沉降球墨/防沉降/电力/双层铸铁井盖公司推荐
  • Bilidown Setup 1.2.7下载
  • 0291-Nand-实现基础逻辑门(一)
  • NASM下载和安装教程(附安装包)
  • 0292-Nand-实现基础逻辑门(二)
  • 单点登录SSO是怎么实现的?
  • 2025年上海房产继承律师权威推荐榜单:继承律师/离婚律师/婚姻律师事务所精选
  • autotiny下载_v3.0.0.2
  • Python嵌套_多条件判断 _ 对象今天会生气吗 II
  • 解析视频融合平台EasyCVR的分析平台技术如何成为“全域视频管理中台”
  • 2025年10月logo/VI设计专业公司权威推荐排行榜:探索年最佳设计服务
  • 深入解析:GitPuk入门教程:安装及使用指南,一文轻松上手
  • 完整教程:Linux启动流程与字符设备驱动详解 - 从bootloader到驱动开发
  • 学术会议会议合集 | 电子信息工程、计算机技术、文学、人文发展、数字经济等EI会议合集
  • 2025 年弯管机生产厂家最新推荐榜,技术实力与市场口碑深度解析且高性能与可靠性兼具四轴/双轴/双层膜弯管机公司推荐
  • 2025年智慧厕所厂家权威推荐榜单:智慧厕所智能水表/智慧公厕系统/智慧厕所源头厂家精选