FSDP Example

"""Runs Llama-2-13B on 2 GPUs using PyTorch's FSDP wrapper. This script
demonstrates basic usage of the `FlexModel` wrapper with a generic
`HookFunction`.

Running:

torchrun --nodes 1 --nproc_per_node 2 fsdp_example.py
"""
import argparse
import functools
import os
from typing import Dict, List

import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.distributed.fsdp import BackwardPrefetch, CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import LlamaForCausalLM, LlamaTokenizerFast
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

from flex_model.core import FlexModel, HookFunction
from flex_model.utils import setup_logger


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--log_level", type=str, default="debug")
    parser.add_argument(
        "--checkpoint_dir", type=str, default="/model-weights/Llama-2-13b-hf"
    )
    parser.add_argument(
        "--tokenizer_dir", type=str, default="/model-weights/Llama-2-13b-hf"
    )
    args = parser.parse_args()
    return args


def get_llama2_tokenizer(tokenizer_dir):
    tokenizer = LlamaTokenizerFast.from_pretrained(
        tokenizer_dir,
        local_files_only=True,
    )
    tokenizer.model_max_length = 512

    # Llama-2 has no PAD token, substitute the EOS token.
    tokenizer.pad_token = tokenizer.eos_token

    return tokenizer


def make_llama2_fsdp(checkpoint_dir):
    # Load llama-2 model and prepare it for FSDP (CPU RAM-efficient)
    if dist.get_rank() == 0:
        base_model = LlamaForCausalLM.from_pretrained(
            checkpoint_dir,
            local_files_only=True,
            torch_dtype=torch.bfloat16,
        )
        param_init_fn = None
    else:
        with torch.device("meta"):
            base_model = LlamaForCausalLM.from_pretrained(
                checkpoint_dir,
                local_files_only=True,
                torch_dtype=torch.bfloat16,
            )

        def _param_init_fn(module: nn.Module):
            module = module.to_empty(
                device=torch.cuda.current_device(), recurse=False
            )
            return module

        param_init_fn = _param_init_fn

    # Initialize fsdp options.
    backward_prefetch = BackwardPrefetch.BACKWARD_PRE

    # Shard model parameters, optimizer, grads over all GPUs.
    sharding_strategy = ShardingStrategy.FULL_SHARD

    mixed_precision = MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        buffer_dtype=torch.bfloat16,
        cast_root_forward_inputs=True,
    )

    # Don't offload to CPU.
    cpu_offload = CPUOffload(offload_params=False)

    transformer_auto_wrapper_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={LlamaDecoderLayer},
    )

    # Wrap model.
    model = FSDP(
        base_model,
        process_group=None,  # default pg.
        sharding_strategy=sharding_strategy,
        cpu_offload=cpu_offload,
        auto_wrap_policy=transformer_auto_wrapper_policy,
        backward_prefetch=backward_prefetch,
        mixed_precision=mixed_precision,
        ignored_modules=None,
        param_init_fn=param_init_fn,
        device_id=torch.cuda.current_device(),
        sync_module_states=True,
        forward_prefetch=True,
        limit_all_gathers=True,
        use_orig_params=False,
    )

    return model


def init_dist():
    dist.init_process_group("nccl")
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    torch.cuda.set_device(local_rank)


def main(args):
    """Forward pass of FSDP-wrapped llama-2-13b-hf model retrieving activations.

    This script must be run via Huggingface Accelerate FSDP. Retrieves
    activations over all DP-workers by gathering them in the batch dimension.
    """
    setup_logger("debug")

    init_dist()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    prompts = [
        "It's a nice day we're having",
        "The capital of Canada is",
        "What should I eat for dinner tonight?",
        "There's about three people going to",
    ]

    model = make_llama2_fsdp(args.checkpoint_dir)

    # Load tokenizer
    tokenizer = get_llama2_tokenizer(args.tokenizer_dir)

    # Define output to dump activations to
    activation_dict: Dict[str, List[Tensor]] = {}

    # Wrap model in FlexModel
    model = FlexModel(
        model,
        activation_dict,
        data_parallel_size=world_size,
    )

    # Create a hook function
    module_name = (
        "_fsdp_wrapped_module.model.layers.30._fsdp_wrapped_module.mlp"
    )
    hook_function = HookFunction(
        module_name=module_name,
        expected_shape=(None, None, None),
        editing_function=None,
    )

    # Register hook function with the model
    model.register_forward_hook(hook_function)

    # Tokenize a prompt
    inputs = tokenizer(prompts, padding="max_length", return_tensors="pt")[
        "input_ids"
    ]

    # Split the batch across dp workers
    dp_worker_inputs = inputs.chunk(world_size, dim=0)[rank]

    # Run through model to generate logits and activations
    _outputs = model(dp_worker_inputs)

    # Activations are only dumped to main process
    if rank == 0:
        activation = activation_dict[module_name][0]
        print(f"Activation shape: {activation.shape}")
        print(activation)

        assert activation.shape[0] == 4
        assert activation.shape[-1] == 5120


if __name__ == "__main__":
    args = parse_args()
    main(args)