flex_model.core.HookFunction
- class flex_model.core.HookFunction(module_name: str, expected_shape: Tuple[int | None, ...] | None = None, editing_function: Callable | None = None, unpack_idx: int = 0)
Function which retrieves/edits activations in a Pytorch nn.Module.
The user provides the
module_nameof the target submodule. The user can optionally pass in anediting_functioncontaining arbitrarily complex python code, which will be used to edit the full submodule activation tensor. If certain dimensions of the activation tensor are expected to be sharded over distributed workers, the user must also provide anexpected_shapehint so the activation tensor can be assembled.- Variables:
module_name (str) – Name of the
nn.Modulesubmodule to hook into.expected_shape – Shape of the full activation tensor. Only the dimensions which are sharded need to be provided. Other dimensions can be annotated as
Noneand will be auto-completed.editing_function – Function which is run on the full activation tensor and returns some edited function. Global contexts like the save context and trainable modules are available for use in the editing function runtime.
save_ctx – Global save context that is exposed to the
editing_function.modules – Global trainable modules that are exposed to the
editing_function.
- Note:
save_ctxandmodulesare populated when theHookFunctionis registered with aFlexModelinstance.
Example:
# Define editing function to be run on an activation tensor. def my_editing_function(current_module, inputs, save_ctx, modules) -> Tensor: # Cache data for later. _, s, _ = torch.svd(inputs) save_ctx.activation_singular_values = s # Edit the activation tensor. inputs = torch.where(inputs > 1.0, inputs, 0.0) # Apply a torch layer to the activation tensor. outputs = modules["linear_projection"](inputs) # Pass edited activation tensor to next layer. return outputs # Instantiate registration-ready hook function. my_hook_function = HookFunction( "my_model.layers.16.self_attention", expected_shape=(4, 512, 5120), editing_function=my_editing_function, )
- __init__(module_name: str, expected_shape: Tuple[int | None, ...] | None = None, editing_function: Callable | None = None, unpack_idx: int = 0) None
Initializes the instance by wrapping the
editing_function.- Parameters:
module_name (str) – Name of the
nn.Modulesubmodule to hook into.expected_shape (Optional[Tuple[Optional[int], ...]]) – Shape of the full activation tensor.
editing_function (Optional[Callable]) – Function which edits the activation tensor.
hook_type (str) – Type of hook to register, eg. forward, backward, etc.
unpack_idx (int) – Index of the tensor in the unpacked layer output list. When layer outputs are pre-processed before editing function execution, valid torch.Tensor objects are extracted into a list by recursive unpacking. Hence the unpack_idx parameter allows for specification of which tensor to consider the activation tensor for downstream processing in the HookFunction.
Methods
__init__(module_name[, expected_shape, ...])Initializes the instance by wrapping the
editing_function.