Distributed: Backend, mappings and strategies

Distributed API

initialize_distributed_state

Mappings

broadcast_tensor_parallel

Broadcast tensor to all ranks in the tensor parallel group.

broadcast_data_parallel

Broadcast tensor to all ranks in the data parallel group.

all_gather_tensor_parallel

All-to-all gather tensors from ranks in the tensor parallel group.

all_gather_data_parallel

All-to-all gather tensors from ranks in the data parallel group.

scatter_tensor_parallel

Scatter tensors to ranks in the tensor parallel group.

scatter_data_parallel

Scatter tensors to ranks in the data parallel group.

gather_pipeline_parallel_tensor_dicts

Gather groups of tensors from ranks of the pipeline group to pipeline rank0.

Strategies

BaseRoutingStrategy

Defines a routing strategy, which every device participates in.

ParameterTensorParallelRoutingStrategy

Defines a routing strategy for parameter tensors supporting TP sharding.

ActivationTensorAllToAllRoutingStrategy

Defines a routing strategy which materializes the activation tensor on all TP and DP ranks via all-gather collectives.

BaseOffloadStrategy

Defines an offload strategy, which each device may or may not participate in.

NullMemoryOffloadStrategy

CPUPinnedMemoryOffloadStrategy

CPUPagedMemoryOffloadStrategy

GPUMemoryOffloadStrategy

BaseFunctionStrategy

Defines an editing function execution strategy.

NonValidatedFunctionStrategy