flex_model.distributed.gather_pipeline_parallel_tensor_dicts

flex_model.distributed.gather_pipeline_parallel_tensor_dicts(fmps: _ParallelStateAPI, tensor_dict: Dict[str, Tensor]) Dict[str, Tensor]

Gather groups of tensors from ranks of the pipeline group to pipeline rank0.

Note: Assumes input tensors are on CPU and placed output tensors on CPU. - This behaviour is subject to change depending on various optimizations.

Parameters:
  • fmps (_ParallelStateAPI) – FlexModel parallel state handle.

  • tensor_dict – Some python object that can be pickled. May contain tensors.

Returns:

A collection of the objects sent from all pipeline paralel group ranks.

Return type:

Dict[str, Tensor]