flex_model.distributed.gather_pipeline_parallel_tensor_dicts
- flex_model.distributed.gather_pipeline_parallel_tensor_dicts(fmps: _ParallelStateAPI, tensor_dict: Dict[str, Tensor]) Dict[str, Tensor]
Gather groups of tensors from ranks of the pipeline group to pipeline rank0.
Note: Assumes input tensors are on CPU and placed output tensors on CPU. - This behaviour is subject to change depending on various optimizations.
- Parameters:
fmps (_ParallelStateAPI) – FlexModel parallel state handle.
tensor_dict – Some python object that can be pickled. May contain tensors.
- Returns:
A collection of the objects sent from all pipeline paralel group ranks.
- Return type:
Dict[str, Tensor]