flex_model.core.FlexModel

class flex_model.core.FlexModel(module: Module, output_ptr: Dict[str, List[Tensor]], tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, data_parallel_size: int = 1, process_group: ProcessGroup | None = None)

Wraps a Pytorch nn.Module to provide an interface for various model-surgery techniques. Most importantly, allows registration of user-instantiated HookFunction classes which perform model-surgery.

Note:

Supported features include:

  • Registry, enabling and disabling of HookFunction instances.

  • Creation of HookFunction groups, which may be selectively

    activated during model forward passes.

  • Exposing global states to all Hookfunction runtimes.

  • Distributed orchestration of 1-D to 3-D parallelisms.

  • Providing convenience functions for various attributes.

Note:

output_dict is populated in-place. So running a subsequent forward pass with the same hooks in will delete the previous activations.

Variables:
  • module (nn.Module) – The wrapped Pytorch nn.Module to hook into.

  • hook_functions (Dict[str, HookFunction]) – Collection of HookFunction instances keyed by the module name to hook into.

  • output_ptr (Dict[str, Tensor]) – Pointer to output dictionary provided by the user. Activations will be streamed here on the rank0 process only. The returned tensors will all be on CPU.

  • save_ctx (Namespace) – Context for caching activations or other metadata to be accessed later within the same or a later forward pass.

  • trainable_modules (nn.ModuleDict) – Collection of named Pytorch modules/layers globally accessible to all HookFunction runtimes. Can be trained using calls to .backward().

  • tp_size (int) – Tensor parallel dimension size.

  • pp_size (int) – Pipeline parallel dimension size.

  • dp_size (int) – Data parallel dimension size.

Note:

Calls to .backward() should consider calling wrapped_module_requires_grad(False), else the gradient will be generated for the entire wrapped model and trainable_modules.

Example:

## Code block being run by 4 GPUs ##

# Load model.
model = MyModel.from_pretrained(...)

# Distribute model over many workers using fully-sharded data parallel.
model = FSDP(model)

# Create output dictionary where activations will stream to.
output_dict = {}

# Wrap the model.
flex_model = FlexModel(
    model,
    output_dict,
    tensor_parallel_size=1,
    pipeline_parallel_size=1,
    data_parallel_size=4,
)

# Create hook function for post-mlp.
my_hook_function = HookFunction(
    "my_model.layers.15.mlp",
    expected_shape=(16, 1024, 4096),
    editing_function=None,
)

# Register the hook function (same as PyTorch API).
flex_model.register_forward_hook(my_hook_function)

# Run forward pass. Output dictionary will become populated.
outputs = flex_model(inputs)
__init__(module: Module, output_ptr: Dict[str, List[Tensor]], tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, data_parallel_size: int = 1, process_group: ProcessGroup | None = None)

Initialize the instance by wrapping the Pytorch module.

Parameters:
  • module (nn.Module) – nn.Module to wrap and apply hooks to.

  • output_ptr (Dict[str, List[Tensor]]) – Output dictionary to dump activations to.

  • tensor_parallel_size (int) – Number of workers in each tensor parallel group.

  • pipeline_parallel_size (int) – Number of workers in each pipeline parallel group.

  • data_parallel_size (int) – Number of processes in each data parallel group.

Methods

__init__(module, output_ptr[, ...])

Initialize the instance by wrapping the Pytorch module.

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

create_hook_group(group_name, group_constructor)

Create a group of HookFunctions.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(*args[, groups, complement])

Run a forward pass of the model with all hooks active by default.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_group_hook_functions(group_name)

Get a collection of `HookFunction`s that belong in the given group.

get_hook_function_groups(hook_function)

Get a collection of groups that the hook_function belongs to.

get_module_parameter(parameter_name, ...)

Retrieves unsharded parameter from wrapped module.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers(*args, **kwargs)

Get the buffer and name for all buffers in the module.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters(*args, **kwargs)

Get the parameter and name for all parameters in the module.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook_function)

Register a forward hook function.

register_forward_pre_hook(hook_function)

Register a pre-forward hook function.

register_full_backward_hook(hook_function)

Register a backward hook function.

register_full_backward_pre_hook(hook_function)

Register a pre-backward hook function.

register_hook(hook_function)

Register a backward hook function on a tensor.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

register_trainable_module(name, module)

Register trainable module accessible to all HookFunction instances.

remove_hook_groups(group_constructor, group_name)

Removes a group reference from a set of `HookFunction`s.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

restore()

Cleans up dangling states and modifications to wrapped module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

update_hook_groups(group_constructor, group_name)

Adds a group reference to a set of `HookFunction`s.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

Attributes

T_destination

all_modules_names

Names of all submodules.

call_super_init

dump_patches

trainable_modules_names

Names of trainable modules.

wrapped_module_names

Names of wrapped module submodules.

training