flex_model.distributed.scatter_data_parallel
- flex_model.distributed.scatter_data_parallel(tensor: Tensor, dim: int, fmps: _ParallelStateAPI) Tensor
Scatter tensors to ranks in the data parallel group.
- Parameters:
tensor (Tensor) – Activation tensor.
fmps (_ParallelStateAPI) – FlexModel parallel state handle.
- Returns:
Input tensor unmodified.
- Return type:
Tensor