flex_model.core.HookFunction

class flex_model.core.HookFunction(module_name: str, expected_shape: Tuple[int | None, ...] | None = None, editing_function: Callable | None = None, unpack_idx: int = 0)

Function which retrieves/edits activations in a Pytorch nn.Module.

The user provides the module_name of the target submodule. The user can optionally pass in an editing_function containing arbitrarily complex python code, which will be used to edit the full submodule activation tensor. If certain dimensions of the activation tensor are expected to be sharded over distributed workers, the user must also provide an expected_shape hint so the activation tensor can be assembled.

Variables:
  • module_name (str) – Name of the nn.Module submodule to hook into.

  • expected_shape – Shape of the full activation tensor. Only the dimensions which are sharded need to be provided. Other dimensions can be annotated as None and will be auto-completed.

  • editing_function – Function which is run on the full activation tensor and returns some edited function. Global contexts like the save context and trainable modules are available for use in the editing function runtime.

  • save_ctx – Global save context that is exposed to the editing_function.

  • modules – Global trainable modules that are exposed to the editing_function.

Note:

save_ctx and modules are populated when the HookFunction is registered with a FlexModel instance.

Example:

# Define editing function to be run on an activation tensor.
def my_editing_function(current_module,
                        inputs,
                        save_ctx,
                        modules) -> Tensor:

    # Cache data for later.
    _, s, _ = torch.svd(inputs)
    save_ctx.activation_singular_values = s

    # Edit the activation tensor.
    inputs = torch.where(inputs > 1.0, inputs, 0.0)

    # Apply a torch layer to the activation tensor.
    outputs = modules["linear_projection"](inputs)

    # Pass edited activation tensor to next layer.
    return outputs

# Instantiate registration-ready hook function.
my_hook_function = HookFunction(
    "my_model.layers.16.self_attention",
    expected_shape=(4, 512, 5120),
    editing_function=my_editing_function,
)
__init__(module_name: str, expected_shape: Tuple[int | None, ...] | None = None, editing_function: Callable | None = None, unpack_idx: int = 0) None

Initializes the instance by wrapping the editing_function.

Parameters:
  • module_name (str) – Name of the nn.Module submodule to hook into.

  • expected_shape (Optional[Tuple[Optional[int], ...]]) – Shape of the full activation tensor.

  • editing_function (Optional[Callable]) – Function which edits the activation tensor.

  • hook_type (str) – Type of hook to register, eg. forward, backward, etc.

  • unpack_idx (int) – Index of the tensor in the unpacked layer output list. When layer outputs are pre-processed before editing function execution, valid torch.Tensor objects are extracted into a list by recursive unpacking. Hence the unpack_idx parameter allows for specification of which tensor to consider the activation tensor for downstream processing in the HookFunction.

Methods

__init__(module_name[, expected_shape, ...])

Initializes the instance by wrapping the editing_function.