Core: Wrapper and HookFunction

Modules

FlexModel

Wraps a Pytorch nn.Module to provide an interface for various model-surgery techniques.

HookFunction

Function which retrieves/edits activations in a Pytorch nn.Module.

class flex_model.core.FlexModel(module: Module, output_ptr: Dict[str, List[Tensor]], tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, data_parallel_size: int = 1, process_group: ProcessGroup | None = None)

Wraps a Pytorch nn.Module to provide an interface for various model-surgery techniques. Most importantly, allows registration of user-instantiated HookFunction classes which perform model-surgery.

Note:

Supported features include:

  • Registry, enabling and disabling of HookFunction instances.

  • Creation of HookFunction groups, which may be selectively

    activated during model forward passes.

  • Exposing global states to all Hookfunction runtimes.

  • Distributed orchestration of 1-D to 3-D parallelisms.

  • Providing convenience functions for various attributes.

Note:

output_dict is populated in-place. So running a subsequent forward pass with the same hooks in will delete the previous activations.

Variables:
  • module (nn.Module) – The wrapped Pytorch nn.Module to hook into.

  • hook_functions (Dict[str, HookFunction]) – Collection of HookFunction instances keyed by the module name to hook into.

  • output_ptr (Dict[str, Tensor]) – Pointer to output dictionary provided by the user. Activations will be streamed here on the rank0 process only. The returned tensors will all be on CPU.

  • save_ctx (Namespace) – Context for caching activations or other metadata to be accessed later within the same or a later forward pass.

  • trainable_modules (nn.ModuleDict) – Collection of named Pytorch modules/layers globally accessible to all HookFunction runtimes. Can be trained using calls to .backward().

  • tp_size (int) – Tensor parallel dimension size.

  • pp_size (int) – Pipeline parallel dimension size.

  • dp_size (int) – Data parallel dimension size.

Note:

Calls to .backward() should consider calling wrapped_module_requires_grad(False), else the gradient will be generated for the entire wrapped model and trainable_modules.

Example:

## Code block being run by 4 GPUs ##

# Load model.
model = MyModel.from_pretrained(...)

# Distribute model over many workers using fully-sharded data parallel.
model = FSDP(model)

# Create output dictionary where activations will stream to.
output_dict = {}

# Wrap the model.
flex_model = FlexModel(
    model,
    output_dict,
    tensor_parallel_size=1,
    pipeline_parallel_size=1,
    data_parallel_size=4,
)

# Create hook function for post-mlp.
my_hook_function = HookFunction(
    "my_model.layers.15.mlp",
    expected_shape=(16, 1024, 4096),
    editing_function=None,
)

# Register the hook function (same as PyTorch API).
flex_model.register_forward_hook(my_hook_function)

# Run forward pass. Output dictionary will become populated.
outputs = flex_model(inputs)
forward(*args, groups: str | List[str] = 'all', complement: bool = False, **kwargs) Any

Run a forward pass of the model with all hooks active by default.

Parameters:
  • groups (Union[str, List[str]]) – HookFunction groups to activate during the forward pass.

  • complement (bool) – If True, then the HookFunctions groups passed in the group argument are not active.

get_hook_function_groups(hook_function: HookFunction) Set[str]

Get a collection of groups that the hook_function belongs to.

Parameters:

hook_function (HookFunction) – HookFunction to fetch related groups.

get_group_hook_functions(group_name: str) List[HookFunction]

Get a collection of `HookFunction`s that belong in the given group.

Parameters:

group_name (str) – Group to fetch related `HookFunction`s from.

update_hook_groups(group_constructor: List[HookFunction] | HookFunction | str, group_name: str) None

Adds a group reference to a set of `HookFunction`s.

group_constructor can be one of three things:
  1. List of `HookFunction`s to add the group to.

  2. A single HookFunction to add the group to.

  3. A string pattern to match against HookFunction`s `module_name

    attributes. The matching `HookFunction`s will have the group reference added.

Parameters:
  • group_constructor (Union[List[HookFunction], HookFunction, str]) – See above note.

  • group_name (str) – Name of the group to add.

remove_hook_groups(group_constructor: List[HookFunction] | HookFunction | str, group_name: str) None

Removes a group reference from a set of `HookFunction`s.

group_constructor can be one of three things:
  1. List of `HookFunction`s to remove the group from.

  2. A single HookFunction to remove the group from.

  3. A string pattern to match against HookFunction`s `module_name

    attributes. The matching `HookFunction`s will have the group reference removed.

Parameters:
  • group_constructor (Union[List[HookFunction], HookFunction, str]) – See above note.

  • group_name (str) – Name of the group to remove.

create_hook_group(group_name: str, group_constructor: str, expected_shape: Tuple[int | None, ...] | None = None, editing_function: Callable | None = None, unpack_idx: int | None = 0, hook_type: str = 'forward') None

Create a group of HookFunctions.

Instantiates a collection of `HookFunction`s according to the provided arguments (broadcast). Adds the instantiated `HookFunction`s to a group.

Parameters:
  • group_name (str) – Group name to assign.

  • group_constructor (str) – String pattern to match module/parameter names as the module_name parameter for creating the `HookFunction`s.

  • expected_shape – Expected shape of the activations.

  • editing_function (Callable) – Editing function to apply on each HookFunction.

  • unpack_idx (int) – Index of tensor in module outputs.

  • hook_type (str) – Type of pytorch hook to use.

register_forward_hook(hook_function: HookFunction) None

Register a forward hook function.

Parameters:

hook_function (HookFunction) – HookFunction instance to register.

register_full_backward_hook(hook_function: HookFunction) None

Register a backward hook function.

Parameters:

hook_function (HookFunction) – HookFunction instance to register.

register_hook(hook_function: HookFunction) None

Register a backward hook function on a tensor.

Parameters:

hook_function (HookFunction) – HookFunction instance to register.

register_forward_pre_hook(hook_function: HookFunction) None

Register a pre-forward hook function.

Parameters:

hook_function (HookFunction) – HookFunction instance to register.

register_full_backward_pre_hook(hook_function: HookFunction) None

Register a pre-backward hook function.

Parameters:

hook_function (HookFunction) – HookFunction instance to register.

register_trainable_module(name: str, module: Module) None

Register trainable module accessible to all HookFunction instances.

Given an nn.Module, add it to the nn.ModuleDict which is exposed to all HookFunction runtimes.

Parameters:
  • name (str) – Name of the module/layer.

  • module (nn.Module) – nn.Module to register.

get_module_parameter(parameter_name: str, expected_shape: Tuple[int, ...]) Tensor

Retrieves unsharded parameter from wrapped module.

Given the name of the wrapped module submodule parameter gather it across the relevant process group if necessary and return it to the user on CPU.

Parameters:
  • parameter_name (str) – Name of the wrapped module submodule parameter to retrieve.

  • expected_shape (Tuple[int, ...]) – Shape of the full parameter tensor. Only the dimensions which are sharded need to be provided. Other dimensions can be annotated as None and will be auto-completed.

Returns:

The requested unsharded parameter tensor detached from the computation graph and on CPU.

Return type:

Tensor

property wrapped_module_names: List[str]

Names of wrapped module submodules.

Returns:

List of module names.

Return type:

List[str]

property trainable_modules_names: List[str]

Names of trainable modules.

Returns:

List of module names.

Return type:

List[str]

property all_modules_names: List[str]

Names of all submodules.

Returns:

List of module names.

Return type:

List[str]

named_parameters(*args, **kwargs) Iterator[Tuple[str, Parameter]]

Get the parameter and name for all parameters in the module.

named_buffers(*args, **kwargs) Iterator[Tuple[str, Tensor]]

Get the buffer and name for all buffers in the module.

restore() Module

Cleans up dangling states and modifications to wrapped module.

class flex_model.core.HookFunction(module_name: str, expected_shape: Tuple[int | None, ...] | None = None, editing_function: Callable | None = None, unpack_idx: int = 0)

Function which retrieves/edits activations in a Pytorch nn.Module.

The user provides the module_name of the target submodule. The user can optionally pass in an editing_function containing arbitrarily complex python code, which will be used to edit the full submodule activation tensor. If certain dimensions of the activation tensor are expected to be sharded over distributed workers, the user must also provide an expected_shape hint so the activation tensor can be assembled.

Variables:
  • module_name (str) – Name of the nn.Module submodule to hook into.

  • expected_shape – Shape of the full activation tensor. Only the dimensions which are sharded need to be provided. Other dimensions can be annotated as None and will be auto-completed.

  • editing_function – Function which is run on the full activation tensor and returns some edited function. Global contexts like the save context and trainable modules are available for use in the editing function runtime.

  • save_ctx – Global save context that is exposed to the editing_function.

  • modules – Global trainable modules that are exposed to the editing_function.

Note:

save_ctx and modules are populated when the HookFunction is registered with a FlexModel instance.

Example:

# Define editing function to be run on an activation tensor.
def my_editing_function(current_module,
                        inputs,
                        save_ctx,
                        modules) -> Tensor:

    # Cache data for later.
    _, s, _ = torch.svd(inputs)
    save_ctx.activation_singular_values = s

    # Edit the activation tensor.
    inputs = torch.where(inputs > 1.0, inputs, 0.0)

    # Apply a torch layer to the activation tensor.
    outputs = modules["linear_projection"](inputs)

    # Pass edited activation tensor to next layer.
    return outputs

# Instantiate registration-ready hook function.
my_hook_function = HookFunction(
    "my_model.layers.16.self_attention",
    expected_shape=(4, 512, 5120),
    editing_function=my_editing_function,
)

Miscellaneous

class flex_model.core.DummyModule

Identity module used to expose activations.

Can be placed in any nn.Module to artificially create an activation to be hooked onto. For instance, explicitly calling a module’s .forward() method will not run forward hooks and therefore will not generate an activation. However, applying this module to the output of that will generate an activation which can be hooked onto.