flex_model.distributed.scatter_tensor_parallel

flex_model.distributed.scatter_tensor_parallel(tensor: Tensor, dim: int, fmps: _ParallelStateAPI) Tensor

Scatter tensors to ranks in the tensor parallel group.

Parameters:
  • tensor (Tensor) – Activation tensor.

  • fmps (_ParallelStateAPI) – FlexModel parallel state handle.

Returns:

Input tensor unmodified.

Return type:

Tensor