flex_model.distributed.ParameterTensorParallelRoutingStrategy
- class flex_model.distributed.ParameterTensorParallelRoutingStrategy(prologue_fn, epilogue_fn)
Defines a routing strategy for parameter tensors supporting TP sharding.
- __init__(prologue_fn, epilogue_fn)
Methods
__init__(prologue_fn, epilogue_fn)execute_epilogue(tensor)execute_prologue(tensor)initialize(fmps, tensor, expected_shape)