Single-GPU Example

"""Runs Llama-2-13B on a single GPU using Huggingface Transformers. This script
demonstrates basic usage of the `FlexModel` wrapper with a generic
`HookFunction`.

Running:

python single_gpu_example.py
"""
import argparse
from typing import Dict, List

from torch import Tensor
from transformers import AutoModelForCausalLM, AutoTokenizer

from flex_model.core import FlexModel, HookFunction


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--checkpoint_dir", type=str, default="/model-weights/Llama-2-7b-hf"
    )
    parser.add_argument(
        "--tokenizer_dir", type=str, default="/model-weights/Llama-2-7b-hf"
    )
    args = parser.parse_args()
    return args


def main(args):
    """Single forward pass of llama-2-13b-hf model retrieving a single activation.

    This script runs on a single GPU only.
    """
    # Load llama-2-13b-hf model
    model = AutoModelForCausalLM.from_pretrained(
        args.checkpoint_dir,
        local_files_only=True,
        torch_dtype="auto",
        device_map="auto",
    )

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_dir,
        local_files_only=True,
    )

    ## NEW ##
    # Define output to dump activations to
    activation_dict: Dict[str, List[Tensor]] = {}

    # Wrap model in FlexModel
    model = FlexModel(model, activation_dict)

    # Create a hook function
    hook_function = HookFunction(
        module_name="model.layers.24",
        expected_shape=(None, None, None),  # Not sharded, can pass None per dim
        editing_function=None,  # Just doing retrieval
    )

    # Register hook function with the model
    model.register_forward_hook(hook_function)
    ## NEW ##

    # Tokenize a prompt
    inputs = tokenizer(
        "Where is the best spot for lunch?", return_tensors="pt"
    )["input_ids"]

    # Run through model to generate logits and activations
    _outputs = model(inputs)

    print(activation_dict)


if __name__ == "__main__":
    args = parse_args()
    main(args)