flex_model.distributed.all_gather_tensor_parallel
- flex_model.distributed.all_gather_tensor_parallel(tensor: Tensor, dim: int, fmps: _ParallelStateAPI) Tensor
All-to-all gather tensors from ranks in the tensor parallel group.
- Parameters:
tensor (Tensor) – Activation tensor.
fmps (_ParallelStateAPI) – FlexModel parallel state handle.
- Returns:
Input tensor unmodified.
- Return type:
Tensor