from __future__ import annotations
import argparse
import os
from argparse import Namespace
from functools import partial
from typing import Any, Callable
import einops
import matplotlib.pyplot as plt
import torch
import torch.distributed as dist
import torch.nn.functional as F
from flex_model.core import FlexModel, HookFunction
from torch import nn
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
)
from torch.distributed.fsdp import (
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
)
from transformers import (
LlamaConfig,
LlamaForCausalLM,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
def setup() -> None:
"""Instantiate process group."""
dist.init_process_group("nccl")
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
return local_rank
def cleanup() -> None:
"""Destroy process group."""
dist.destroy_process_group()
def args() -> Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--seq_length", default=50, type=int, required=False)
return parser.parse_args()
def setup_model(model_path: str, local_rank: int) -> \
tuple[nn.Module, LlamaConfig]:
"""Instantiate model, tokenizer, and config.
Args:
----
model_path: A path to the model being instantiated
local_rank: The local rank of the worker
Returns:
-------
A tuple of length two containing the model and the config.
"""
config = LlamaConfig.from_pretrained(model_path)
if local_rank == 0:
model = LlamaForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
)
else:
with torch.device("meta"):
model = LlamaForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
)
return model, config
def fsdp_config(local_rank: int) -> dict[str:Any]:
"""Return the config to be used by FSDP.
Args:
----
local_rank: The local rank of the worker
Returns:
-------
A dictionary containing keyword -> respective configuration.
"""
def _module_init_fn(module: nn.Module) -> Callable:
"""Return the function used for initializing modules on FSDP workers."""
return module.to_empty(
device=torch.cuda.current_device(),
recurse=False,
)
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
LlamaDecoderLayer,
},
)
sharding_strategy = ShardingStrategy.FULL_SHARD
device_id = torch.cuda.current_device()
sync_module_states = True
param_init_fn = _module_init_fn if local_rank != 0 else None
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
)
config = {
"auto_wrap_policy": auto_wrap_policy,
"sharding_strategy": sharding_strategy,
"device_id": device_id,
"sync_module_states": sync_module_states,
"param_init_fn": param_init_fn,
"mixed_precision": mp_policy,
}
return config
def calculate_induction_score(
num_hidden_layers: int,
num_attention_heads: int,
activation_dict: dict[str, torch.Tensor],
module_names: list[str],
sequence_length: int,
) -> None:
"""Calculate and save a heatmap of the induction scores for each attention
head.
Args:
----
num_hidden_layers: The number of transformer blocks in the model
num_attention_heads: The number of attention heads in the model
activation_dict: Dictionary containing the activations retrieved using
FlexModel
module_names: A list of the module names to which we have attached
hooks
sequence_length: The sequence length of the prompt passed into the
model
"""
# Create the matrix to store the induction scores for each head across
# all layers
induction_score_store = torch.zeros(
(
num_hidden_layers,
num_attention_heads,
),
device=torch.cuda.current_device(),
)
for i, module_name in enumerate(module_names):
# Retrieve the gathered activation maps for a given module
attn_maps = (
activation_dict[module_name][0]
.detach()
.to(
torch.cuda.current_device(),
)
)
# Attention maps are of shape [batch, head, seq, seq]
# We take the diagonal over the last two dims i.e. the query/key dims
# We offset by 1-sequence_length because we want to see how much
# attention is paid from the *current* token to the token that occurred
# right after the *previous occurrence* of the *current* token (which
# is 1-sequence_length tokens back). A better visualization can be
# found on Anthropic's In-context Learning and Induction Heads paper
induction_stripe = attn_maps.diagonal(
dim1=-2,
dim2=-1,
offset=1 - sequence_length,
)
# We average across the diagonal and the batch dims to get the final
# induction scores
induction_score = einops.reduce(
induction_stripe,
"batch head_index position -> head_index",
"mean",
)
induction_score_store[i, :] = induction_score
plt.imshow(induction_score_store.detach().cpu().numpy(), origin="lower")
plt.xlabel("Head")
plt.ylabel("Layer")
plt.title("Induction Score by Head")
plt.colorbar()
plt.savefig("induction_score_by_head.png", bbox_inches="tight")
def get_module_names(num_hidden_layers: int) -> list[str]:
"""Return the list of module names to apply hooks onto.
Args:
----
num_hidden_layers: The number of transformer blocks in the model
Returns:
-------
A list of model names that we're applying HookFunctions to
"""
prefix = "_fsdp_wrapped_module.model.layers."
postfix = "._fsdp_wrapped_module.self_attn.dummy"
module_names = [f"{prefix}{i}{postfix}" for i in range(num_hidden_layers)]
return module_names
def calculate_per_token_loss(
logits: torch.Tensor,
prompt: torch.Tensor,
) -> None:
"""Calculate and plot the cross-entropy loss per token.
Args:
----
logits: The model's output logits
prompt: The input prompt sequence
"""
# Calculate per token loss
# First take log softmax across the vocab dim to get log probabilities
log_probs = F.log_softmax(logits, dim=-1)
# log_probs[..., :-1, :] takes the log probs up to the final token while
# keeping the shape the same.
# .gather(...) collects the correct log probs across the vocab dim given
# the prompt
# The reason we need prompt[..., 1:, None] is to ensure that the index
# argument has the same rank as log_probs
# Finally, we need [..., 0] at the end so that we get rid of the extra
# trailing rank we created (we also could've done a .squeeze())
predicted_log_probs = -log_probs[..., :-1, :].gather(
dim=-1,
index=prompt[..., 1:, None],
)[..., 0]
# Average loss across the batch dimension
loss_by_position = einops.reduce(
predicted_log_probs,
"batch position -> position",
"mean",
)
plt.plot(
list(range(len(loss_by_position))),
loss_by_position.detach().cpu().numpy(),
)
plt.xlabel("Token Index")
plt.ylabel("Loss")
plt.title("Loss by position on random repeated tokens")
plt.savefig("induction_loss.png", bbox_inches="tight")
def main(args: Namespace) -> None:
"""Execute main demo.
Args:
----
args: Command-line arguments
"""
local_rank = setup()
seq_len = args.seq_length
batch_size = 4
min_vocab_idx, max_vocab_idx = 500, 15000
prompt = torch.randint(
min_vocab_idx, max_vocab_idx, (batch_size, seq_len),
).to(
torch.cuda.current_device(),
)
repeated_tokens = einops.repeat(
prompt,
"batch seq_len -> batch (2 seq_len)",
)
model, config = setup_model(args.model_path, local_rank)
fsdp_cfg = fsdp_config(local_rank)
model = FSDP(
model,
**fsdp_cfg,
)
# Wrap the model
output_dict = {}
model = FlexModel(
model,
output_dict,
data_parallel_size=dist.get_world_size(),
)
# Register hooks for activations
module_names = get_module_names(config.num_hidden_layers)
for module_name in module_names:
model.register_forward_hook(
HookFunction(
module_name,
(None, None, None, None),
),
)
out = model(repeated_tokens).logits
# Do plotting on main rank
if dist.get_rank() == 0:
calculate_induction_score(
config.num_hidden_layers,
config.num_attention_heads,
output_dict,
module_names,
seq_len,
)
plt.clf()
# Note: we are only calculating this over the main rank's output
# for the purpose of demonstration
calculate_per_token_loss(out, repeated_tokens)
cleanup()
if __name__ == "__main__":
parsed_args = args()
main(parsed_args)