flex_model.distributed.all_gather_tensor_parallel

flex_model.distributed.all_gather_tensor_parallel(tensor: Tensor, dim: int, fmps: _ParallelStateAPI) Tensor

All-to-all gather tensors from ranks in the tensor parallel group.

Parameters:
  • tensor (Tensor) – Activation tensor.

  • fmps (_ParallelStateAPI) – FlexModel parallel state handle.

Returns:

Input tensor unmodified.

Return type:

Tensor