Core: Wrapper and HookFunction
Modules
Wraps a Pytorch |
|
Function which retrieves/edits activations in a Pytorch nn.Module. |
- class flex_model.core.FlexModel(module: Module, output_ptr: Dict[str, List[Tensor]], tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, data_parallel_size: int = 1, process_group: ProcessGroup | None = None)
Wraps a Pytorch
nn.Moduleto provide an interface for various model-surgery techniques. Most importantly, allows registration of user-instantiatedHookFunctionclasses which perform model-surgery.- Note:
Supported features include:
Registry, enabling and disabling of
HookFunctioninstances.- Creation of
HookFunctiongroups, which may be selectively activated during model forward passes.
- Creation of
Exposing global states to all
Hookfunctionruntimes.Distributed orchestration of 1-D to 3-D parallelisms.
Providing convenience functions for various attributes.
- Note:
output_dictis populated in-place. So running a subsequent forward pass with the same hooks in will delete the previous activations.- Variables:
module (nn.Module) – The wrapped Pytorch
nn.Moduleto hook into.hook_functions (Dict[str, HookFunction]) – Collection of
HookFunctioninstances keyed by the module name to hook into.output_ptr (Dict[str, Tensor]) – Pointer to output dictionary provided by the user. Activations will be streamed here on the rank0 process only. The returned tensors will all be on CPU.
save_ctx (Namespace) – Context for caching activations or other metadata to be accessed later within the same or a later forward pass.
trainable_modules (nn.ModuleDict) – Collection of named Pytorch modules/layers globally accessible to all
HookFunctionruntimes. Can be trained using calls to.backward().tp_size (int) – Tensor parallel dimension size.
pp_size (int) – Pipeline parallel dimension size.
dp_size (int) – Data parallel dimension size.
- Note:
Calls to .backward() should consider calling
wrapped_module_requires_grad(False), else the gradient will be generated for the entire wrapped model andtrainable_modules.
Example:
## Code block being run by 4 GPUs ## # Load model. model = MyModel.from_pretrained(...) # Distribute model over many workers using fully-sharded data parallel. model = FSDP(model) # Create output dictionary where activations will stream to. output_dict = {} # Wrap the model. flex_model = FlexModel( model, output_dict, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=4, ) # Create hook function for post-mlp. my_hook_function = HookFunction( "my_model.layers.15.mlp", expected_shape=(16, 1024, 4096), editing_function=None, ) # Register the hook function (same as PyTorch API). flex_model.register_forward_hook(my_hook_function) # Run forward pass. Output dictionary will become populated. outputs = flex_model(inputs)
- forward(*args, groups: str | List[str] = 'all', complement: bool = False, **kwargs) Any
Run a forward pass of the model with all hooks active by default.
- Parameters:
groups (Union[str, List[str]]) – HookFunction groups to activate during the forward pass.
complement (bool) – If True, then the HookFunctions groups passed in the group argument are not active.
- get_hook_function_groups(hook_function: HookFunction) Set[str]
Get a collection of groups that the hook_function belongs to.
- Parameters:
hook_function (HookFunction) – HookFunction to fetch related groups.
- get_group_hook_functions(group_name: str) List[HookFunction]
Get a collection of `HookFunction`s that belong in the given group.
- Parameters:
group_name (str) – Group to fetch related `HookFunction`s from.
- update_hook_groups(group_constructor: List[HookFunction] | HookFunction | str, group_name: str) None
Adds a group reference to a set of `HookFunction`s.
- group_constructor can be one of three things:
- Parameters:
group_constructor (Union[List[HookFunction], HookFunction, str]) – See above note.
group_name (str) – Name of the group to add.
- remove_hook_groups(group_constructor: List[HookFunction] | HookFunction | str, group_name: str) None
Removes a group reference from a set of `HookFunction`s.
- group_constructor can be one of three things:
- Parameters:
group_constructor (Union[List[HookFunction], HookFunction, str]) – See above note.
group_name (str) – Name of the group to remove.
- create_hook_group(group_name: str, group_constructor: str, expected_shape: Tuple[int | None, ...] | None = None, editing_function: Callable | None = None, unpack_idx: int | None = 0, hook_type: str = 'forward') None
Create a group of HookFunctions.
Instantiates a collection of `HookFunction`s according to the provided arguments (broadcast). Adds the instantiated `HookFunction`s to a group.
- Parameters:
group_name (str) – Group name to assign.
group_constructor (str) – String pattern to match module/parameter names as the module_name parameter for creating the `HookFunction`s.
expected_shape – Expected shape of the activations.
editing_function (Callable) – Editing function to apply on each HookFunction.
unpack_idx (int) – Index of tensor in module outputs.
hook_type (str) – Type of pytorch hook to use.
- register_forward_hook(hook_function: HookFunction) None
Register a forward hook function.
- Parameters:
hook_function (HookFunction) – HookFunction instance to register.
- register_full_backward_hook(hook_function: HookFunction) None
Register a backward hook function.
- Parameters:
hook_function (HookFunction) – HookFunction instance to register.
- register_hook(hook_function: HookFunction) None
Register a backward hook function on a tensor.
- Parameters:
hook_function (HookFunction) – HookFunction instance to register.
- register_forward_pre_hook(hook_function: HookFunction) None
Register a pre-forward hook function.
- Parameters:
hook_function (HookFunction) – HookFunction instance to register.
- register_full_backward_pre_hook(hook_function: HookFunction) None
Register a pre-backward hook function.
- Parameters:
hook_function (HookFunction) – HookFunction instance to register.
- register_trainable_module(name: str, module: Module) None
Register trainable module accessible to all
HookFunctioninstances.Given an
nn.Module, add it to thenn.ModuleDictwhich is exposed to allHookFunctionruntimes.- Parameters:
name (str) – Name of the module/layer.
module (nn.Module) –
nn.Moduleto register.
- get_module_parameter(parameter_name: str, expected_shape: Tuple[int, ...]) Tensor
Retrieves unsharded parameter from wrapped module.
Given the name of the wrapped module submodule parameter gather it across the relevant process group if necessary and return it to the user on CPU.
- Parameters:
parameter_name (str) – Name of the wrapped module submodule parameter to retrieve.
expected_shape (Tuple[int, ...]) – Shape of the full parameter tensor. Only the dimensions which are sharded need to be provided. Other dimensions can be annotated as None and will be auto-completed.
- Returns:
The requested unsharded parameter tensor detached from the computation graph and on CPU.
- Return type:
Tensor
- property wrapped_module_names: List[str]
Names of wrapped module submodules.
- Returns:
List of module names.
- Return type:
List[str]
- property trainable_modules_names: List[str]
Names of trainable modules.
- Returns:
List of module names.
- Return type:
List[str]
- property all_modules_names: List[str]
Names of all submodules.
- Returns:
List of module names.
- Return type:
List[str]
- named_parameters(*args, **kwargs) Iterator[Tuple[str, Parameter]]
Get the parameter and name for all parameters in the module.
- named_buffers(*args, **kwargs) Iterator[Tuple[str, Tensor]]
Get the buffer and name for all buffers in the module.
- restore() Module
Cleans up dangling states and modifications to wrapped module.
- class flex_model.core.HookFunction(module_name: str, expected_shape: Tuple[int | None, ...] | None = None, editing_function: Callable | None = None, unpack_idx: int = 0)
Function which retrieves/edits activations in a Pytorch nn.Module.
The user provides the
module_nameof the target submodule. The user can optionally pass in anediting_functioncontaining arbitrarily complex python code, which will be used to edit the full submodule activation tensor. If certain dimensions of the activation tensor are expected to be sharded over distributed workers, the user must also provide anexpected_shapehint so the activation tensor can be assembled.- Variables:
module_name (str) – Name of the
nn.Modulesubmodule to hook into.expected_shape – Shape of the full activation tensor. Only the dimensions which are sharded need to be provided. Other dimensions can be annotated as
Noneand will be auto-completed.editing_function – Function which is run on the full activation tensor and returns some edited function. Global contexts like the save context and trainable modules are available for use in the editing function runtime.
save_ctx – Global save context that is exposed to the
editing_function.modules – Global trainable modules that are exposed to the
editing_function.
- Note:
save_ctxandmodulesare populated when theHookFunctionis registered with aFlexModelinstance.
Example:
# Define editing function to be run on an activation tensor. def my_editing_function(current_module, inputs, save_ctx, modules) -> Tensor: # Cache data for later. _, s, _ = torch.svd(inputs) save_ctx.activation_singular_values = s # Edit the activation tensor. inputs = torch.where(inputs > 1.0, inputs, 0.0) # Apply a torch layer to the activation tensor. outputs = modules["linear_projection"](inputs) # Pass edited activation tensor to next layer. return outputs # Instantiate registration-ready hook function. my_hook_function = HookFunction( "my_model.layers.16.self_attention", expected_shape=(4, 512, 5120), editing_function=my_editing_function, )
Miscellaneous
- class flex_model.core.DummyModule
Identity module used to expose activations.
Can be placed in any
nn.Moduleto artificially create an activation to be hooked onto. For instance, explicitly calling a module’s.forward()method will not run forward hooks and therefore will not generate an activation. However, applying this module to the output of that will generate an activation which can be hooked onto.