flex_model.distributed.ParameterTensorParallelRoutingStrategy

class flex_model.distributed.ParameterTensorParallelRoutingStrategy(prologue_fn, epilogue_fn)

Defines a routing strategy for parameter tensors supporting TP sharding.

__init__(prologue_fn, epilogue_fn)

Methods

__init__(prologue_fn, epilogue_fn)

execute_epilogue(tensor)

execute_prologue(tensor)

initialize(fmps, tensor, expected_shape)