flex_model.distributed.ActivationTensorAllToAllRoutingStrategy

class flex_model.distributed.ActivationTensorAllToAllRoutingStrategy(prologue_fn, epilogue_fn)

Defines a routing strategy which materializes the activation tensor on all TP and DP ranks via all-gather collectives.

__init__(prologue_fn, epilogue_fn)

Methods

__init__(prologue_fn, epilogue_fn)

execute_epilogue(tensor)

execute_prologue(tensor)

initialize(fmps, tensor, expected_shape)