Python torch.nn.Module.register_forward_pre_hook用法及代码示例

class AddTrainableMask(ABC):

    _tensor_name: str
    
    def __init__(self):
        pass
    
    def __call__(self, module, inputs):

        setattr(module, self._tensor_name, self.apply_mask(module))

    def apply_mask(self, module):

        mask_train = getattr(module, self._tensor_name + "_mask_train")
        mask_fixed = getattr(module, self._tensor_name + "_mask_fixed")
        orig_weight = getattr(module, self._tensor_name + "_orig_weight")
        pruned_weight = mask_train * mask_fixed * orig_weight

        return pruned_weight

    @classmethod
    def apply(cls, module, name, mask_train, mask_fixed, *args, **kwargs):

        method = cls(*args, **kwargs)  
        method._tensor_name = name
        orig = getattr(module, name)

        module.register_parameter(name + "_mask_train", mask_train.to(dtype=orig.dtype))
        module.register_parameter(name + "_mask_fixed", mask_fixed.to(dtype=orig.dtype))
        module.register_parameter(name + "_orig_weight", orig)#这个权重参数的id指向原来的weight
        del module._parameters[name]

        setattr(module, name, method.apply_mask(module))#每次forwar之前都会调用这个钩子,所以每次weight的权重都被直接改了
        module.register_forward_pre_hook(method)

        return method


用法

register_forward_pre_hook(hook)

返回:
一个句柄,可用于通过调用 handle.remove() 删除添加的钩子

返回类型:

torch.utils.hooks.RemovableHandle

在模块上注册一个前向预挂钩。

每次调用 forward() 之前都会调用该钩子。它应该具有以下签名:

hook(module, input) -> None or modified input

输入仅包含给模块的位置参数。关键字参数不会传递给钩子,只会传递给 forward 。钩子可以修改输入。用户可以在钩子中返回一个元组或单个修改值。如果返回单个值,我们会将值包装到一个元组中(除非该值已经是一个元组)。


版权声明:本文为qq_43263543原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。