diff --git a/README.md b/README.md index 61bc70bbb67d02129744b852f9d6ef5622f750c0..6d3b368f65b9b86ce755e0a79c06ccd0c6ca31e2 100644 --- a/README.md +++ b/README.md @@ -403,11 +403,13 @@ DeepSparkHub甄选上百个应用算法和模型,覆盖AI和通用计算各领 [Llama2-7B](nlp/llm/llama2-7b/megatron-deepspeed/README.md) | PyTorch (Megatron-DeepSpeed) | Bookcorpus [Llama2-7B Reward Model Finetuning](nlp/llm/llama2-7b_reward_sft/deepspeed/README.md) | PyTorch (DeepSpeed) | Dahoas/rm-static [Llama2-7B RLHF](nlp/llm/llama2-7b_rlhf/megatron-deepspeed/README.md) | PyTorch (Megatron-DeepSpeed) | llama2-7b&tiny-llama -[Llama2-7B SFT](nlp/llm/llama2-7b_sft/megatron-deepspeed/README.md) | PyTorch (Megatron-DeepSpeed) | gpt_small-117M +[Llama2-7B SFT](nlp/llm/llama2-7b_sft/megatron-deepspeed/README.md) | PyTorch (Megatron-DeepSpeed) | GPT Small-117M [Llama2-13B](nlp/llm/llama2-13b/megatron-deepspeed/README.md) | PyTorch (Megatron-DeepSpeed) | Bookcorpus [Llama2-34B](nlp/llm/llama2-34b/megatron-deepspeed/README.md) | PyTorch (Megatron-DeepSpeed) | Bookcorpus [Llama3-8B](nlp/llm/llama3_8b/megatron-deepspeed/README.md) | PyTorch (Megatron-DeepSpeed) | Bookcorpus [Llama3-8B SFT](nlp/llm/llama3_8b/ColossalAI/applications/Colossal-LLaMA/README.md) | PyTorch (ColossalAI) | school_math_0.25M +[Mamba-2](nlp/llm/mamba-2/megatron-lm/README.md) | PyTorch (Megatron-LM) | GPT Small-117M +[Mixtral 8x7B](nlp/llm/mixtral/megatron-lm/README.md) | PyTorch (Megatron-LM) | GPT Small-117M [QWen-7B](nlp/llm/qwen-7b/firefly/README.md) | PyTorch (Firefly) | qwen-7b [QWen1.5-7B](nlp/llm/qwen1.5-7b/firefly/README.md) | PyTorch (Firefly) | school_math [QWen1.5-14B](nlp/llm/qwen1.5-14b/firefly/README.md) | PyTorch (Firefly) | school_math diff --git a/nlp/llm/mamba-2/megatron-lm/README.md b/nlp/llm/mamba-2/megatron-lm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..76b28465085897424a5a0aca3bf666ee1484c105 --- /dev/null +++ b/nlp/llm/mamba-2/megatron-lm/README.md @@ -0,0 +1,39 @@ +# Mamba-2 (Megatron-LM) + +## Model description + +Mamba-2 is a cutting-edge state space model (SSM) architecture designed as a highly efficient alternative to traditional Transformer-based large language models (LLMs). It is the second version of the Mamba model and builds on the strengths of its predecessor by offering faster inference, improved scalability for long sequences, and lower computational complexity. + +## Step 1: Installation + +```sh +# uninstall +pip3 uninstall -y megatron-lm + +# clone and install +git clone https://github.com/NVIDIA/Megatron-LM.git +(cd Megatron-LM/ && git checkout bd677bfb13ac2f19deaa927adc6da6f9201d66aa) +## apply patch +cp -r -T ../../../../toolbox/Megatron-LM/patch ./Megatron-LM/ +## install +cd Megatron-LM/ +python3 setup.py develop +``` + +## Step 2: Preparing datasets + +```sh +cd datasets/ +bash download_and_convert_dataset.sh +``` + +## Step 3: Training + +```bash +cd examples/mamba +bash train.sh +``` + +## Reference + +- [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/mamba) diff --git a/nlp/llm/mixtral/megatron-lm/README.md b/nlp/llm/mixtral/megatron-lm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6963e4e5bc5225c5d545355b378b6d9558e82259 --- /dev/null +++ b/nlp/llm/mixtral/megatron-lm/README.md @@ -0,0 +1,39 @@ +# Mixtral 8x7B (Megatron-LM) + +## Model description + +The Mixtral model is a Mixture of Experts (MoE)-based large language model developed by Mistral AI, an innovative company focusing on open-source AI models. Mixtral is designed to achieve high performance while maintaining computational efficiency, making it an excellent choice for real-world applications. + +## Step 1: Installation + +```sh +# uninstall +pip3 uninstall -y megatron-lm + +# clone and install +git clone https://github.com/NVIDIA/Megatron-LM.git +(cd Megatron-LM/ && git checkout bd677bfb13ac2f19deaa927adc6da6f9201d66aa) +## apply patch +cp -r -T ../../../../toolbox/Megatron-LM/patch ./Megatron-LM/ +## install +cd Megatron-LM/ +python3 setup.py develop +``` + +## Step 2: Preparing datasets + +```sh +cd datasets/ +bash download_and_convert_dataset.sh +``` + +## Step 3: Training + +```bash +cd examples/mixtral +bash train_mixtral_8x7b_distributed.sh +``` + +## Reference + +- [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/mixtral) diff --git a/toolbox/Megatron-LM/patch/datasets/download_and_convert_dataset.sh b/toolbox/Megatron-LM/patch/datasets/download_and_convert_dataset.sh new file mode 100644 index 0000000000000000000000000000000000000000..80a93020311274ca4d7dadf8b37f272d1c1396af --- /dev/null +++ b/toolbox/Megatron-LM/patch/datasets/download_and_convert_dataset.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set -euox pipefail + +CUR_DIR=$(pwd) +if [[ ! -f $CUR_DIR/small-117M.train.jsonl ]]; then + wget http://files.deepspark.org.cn:880/deepspark/small-117M.train.jsonl +fi + +if [[ ! -f $CUR_DIR/tokenizer.model ]]; then + wget -O tokenizer.model http://files.deepspark.org.cn:880/deepspark/megatron-lm_tokenizer.model +fi + +PROJ_HOME=$(dirname "$PWD") +SAVE_PATH=./gpt_small_117M_Mixtral +mkdir -p $SAVE_PATH + +TOKENIZER=Llama2Tokenizer +TOKENIZER_PATH=./tokenizer.model + +python3 $PROJ_HOME/tools/preprocess_data.py \ + --input ./small-117M.train.jsonl \ + --json-keys text \ + --tokenizer-type $TOKENIZER \ + --tokenizer-model $TOKENIZER_PATH \ + --output-prefix $SAVE_PATH/gpt_small_117M \ + --append-eod \ + --workers 32 + +rm -f small-117M.train.jsonl diff --git a/toolbox/Megatron-LM/patch/examples/mamba/train.sh b/toolbox/Megatron-LM/patch/examples/mamba/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..74e9d46df56276262eda79ae5fb33024a033d963 --- /dev/null +++ b/toolbox/Megatron-LM/patch/examples/mamba/train.sh @@ -0,0 +1,107 @@ +#!/bin/bash +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. + +# Use: ./train.sh + +MODEL_SCALE="800M" # or "8B" + +case "${MODEL_SCALE}" in + "800M") + TENSOR_MODEL_PARALLEL_SIZE=1 + NUM_LAYERS=48 + HIDDEN_SIZE=1024 + NUM_ATTENTION_HEADS=16 + GLOBAL_BATCH_SIZE=16 + ;; + "8B") + TENSOR_MODEL_PARALLEL_SIZE=4 + NUM_LAYERS=56 + HIDDEN_SIZE=4096 + NUM_ATTENTION_HEADS=32 + GLOBAL_BATCH_SIZE=8 + ;; + *) + echo "Invalid version specified" + exit 1 + ;; +esac + +TOKENIZER_PATH=../../datasets/tokenizer.model +DATA_PATH=../../datasets/gpt_small_117M_Mixtral/gpt_small_117M_text_document + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_IB_TIMEOUT=19 +export NCCL_IB_QPS_PER_CONNECTION=4 + +CHECKPOINT_DIR="./checkpoints" +DATACACHE_DIR="./data-cache" +TENSORBOARD_DIR="./tensorboard" + +mkdir -p ${CHECKPOINT_DIR} +mkdir -p ${DATACACHE_DIR} +mkdir -p ${TENSORBOARD_DIR} + +export TRITON_CACHE_DIR="./triton-cache/" +export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" + +SEQ_LEN=4096 +TRAIN_SAMPLES=73242188 # 300B tokens / 4096 +LR_WARMUP_SAMPLES=50000 +LR_DECAY_SAMPLES=73192188 # TRAIN_SAMPLES - LR_WARMUP_SAMPLES + +options=" \ + --tensor-model-parallel-size ${TENSOR_MODEL_PARALLEL_SIZE} \ + --sequence-parallel \ + --pipeline-model-parallel-size 1 \ + --use-distributed-optimizer \ + --overlap-param-gather \ + --overlap-grad-reduce \ + --untie-embeddings-and-output-weights \ + --init-method-std 0.02 \ + --position-embedding-type none \ + --num-layers ${NUM_LAYERS} \ + --hidden-size ${HIDDEN_SIZE} \ + --num-attention-heads ${NUM_ATTENTION_HEADS} \ + --group-query-attention \ + --num-query-groups 8 \ + --hybrid-attention-ratio 0.08 \ + --hybrid-mlp-ratio 0.5 \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings ${SEQ_LEN} \ + --train-samples ${TRAIN_SAMPLES} \ + --lr-warmup-samples ${LR_WARMUP_SAMPLES} \ + --lr-decay-samples ${LR_DECAY_SAMPLES} \ + --save ${CHECKPOINT_DIR} \ + --load ${CHECKPOINT_DIR} \ + --data-path ${DATA_PATH} \ + --data-cache-path ${DATACACHE_DIR} \ + --split 99,1,0 \ + --tokenizer-type GPTSentencePieceTokenizer \ + --tokenizer-model ${TOKENIZER_PATH} \ + --distributed-backend nccl \ + --micro-batch-size 1 \ + --global-batch-size ${GLOBAL_BATCH_SIZE} \ + --lr 2.5e-4 \ + --min-lr 2.5e-5 \ + --lr-decay-style cosine \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --disable-bias-linear \ + --normalization RMSNorm \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --log-interval 10 \ + --save-interval 2000 \ + --eval-interval 2000 \ + --eval-iters 32 \ + --bf16 \ + --use-mcore-models \ + --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ + --no-create-attention-mask-in-dataloader \ + --tensorboard-dir ${TENSORBOARD_DIR}" + +torchrun --nproc_per_node 16 ../../pretrain_mamba.py ${options} diff --git a/toolbox/Megatron-LM/patch/examples/mixtral/pretrain_gpt.py b/toolbox/Megatron-LM/patch/examples/mixtral/pretrain_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..d9cb6facf77c0e23f79aa7115b56d47c5accdc60 --- /dev/null +++ b/toolbox/Megatron-LM/patch/examples/mixtral/pretrain_gpt.py @@ -0,0 +1,310 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. + +"""Pretrain GPT.""" + +import os +import torch +from functools import partial +from contextlib import nullcontext +import inspect + +from typing import List, Optional, Tuple, Union +from megatron.training import get_args +from megatron.training import print_rank_0 +from megatron.training import get_timers +from megatron.training import get_tokenizer +from megatron.core import mpu +from megatron.core.enums import ModelType +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.gpt_dataset import GPTDatasetConfig +from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset +from megatron.core.rerun_state_machine import get_rerun_state_machine +import megatron.legacy.model +from megatron.core.models.gpt import GPTModel +from megatron.training import pretrain +from megatron.core.utils import StragglerDetector +from megatron.core.transformer.spec_utils import import_module +from megatron.training.utils import ( + get_batch_on_this_cp_rank, + get_batch_on_this_tp_rank, + get_blend_and_blend_per_split, +) +from megatron.training.arguments import core_transformer_config_from_args +from megatron.training.yaml_arguments import core_transformer_config_from_yaml +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_decoder_block_spec, + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) + + +stimer = StragglerDetector() + +def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]: + """Builds the model. + + If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. + + Args: + pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. + + + Returns: + Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model + """ + args = get_args() + use_te = args.transformer_impl == "transformer_engine" + args.use_legacy_models=True + + if args.record_memory_history: + torch.cuda.memory._record_memory_history(True, + # keep 100,000 alloc/free events from before the snapshot + trace_alloc_max_entries=100000, + + # record stack information for the trace events + trace_alloc_record_context=True) + + print_rank_0('building GPT model ...') + # Experimental loading arguments from yaml + if args.yaml_cfg is not None: + config = core_transformer_config_from_yaml(args, "language_model") + else: + config = core_transformer_config_from_args(args) + + if args.use_legacy_models: + model = megatron.legacy.model.GPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + ) + else: # using core models + if args.spec is not None: + transformer_layer_spec = import_module(args.spec) + else: + if args.num_experts: + # Define the decoder block spec + transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te) + else: + # Define the decoder layer spec + if use_te: + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + args.num_experts, args.moe_grouped_gemm, + args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm) + else: + transformer_layer_spec = get_gpt_layer_local_spec( + args.num_experts, args.moe_grouped_gemm, + args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm) + + build_model_context = nullcontext + build_model_context_args = {} + if args.fp8_param_gather: + try: + from transformer_engine.pytorch import fp8_model_init + + build_model_context = fp8_model_init + build_model_context_args["enabled"] = True + + # Check if fp8_model_init supports preserve_high_precision_init_val + if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters: + build_model_context_args["preserve_high_precision_init_val"] = True + except: + raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.") + + with build_model_context(**build_model_context_args): + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + rope_scaling=args.use_rope_scaling + ) + + return model + + +def get_batch(data_iterator): + """Generate a batch.""" + + # TODO: this is pretty hacky, find a better way + if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): + return None, None, None, None, None + + # get batches based on the TP rank you are on + batch = get_batch_on_this_tp_rank(data_iterator) + + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + + return batch.values() + + +# define spiky loss as a variation of 20% or more +SPIKY_LOSS_PERC = 0.2 + + +def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + """Loss function. + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + + Returns: + the loss scalar for this micro-batch + the number of non-padded tokens in this microbatch + a dict containing reporting metrics on the loss and number of tokens across + the data parallel ranks + """ + args = get_args() + + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + total_tokens = loss_mask.sum() + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) + + if args.context_parallel_size > 1: + torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) + + # Check individual rank losses are not NaN prior to DP all-reduce. + rerun_state_machine = get_rerun_state_machine() + if args.check_for_nan_in_loss_and_grad: + rerun_state_machine.validate_result( + result=loss[0], + rejection_func=torch.isnan, + message="found NaN in local forward loss calculation", + tolerance=0.0, # forward pass calculations are determinisic + fatal=True, + ) + # Check for spiky loss + if args.check_for_spiky_loss: + rerun_state_machine.validate_result( + result=loss[0], + rejection_func=partial(rerun_state_machine.is_spiky_loss, threshold=SPIKY_LOSS_PERC), + message="Spiky loss", + tolerance=0.0, # forward pass calculations are determinisic + fatal=False, + ) + # Reduce loss for logging. + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + local_num_tokens = loss[1].clone().detach().to(torch.int) + return ( + loss[0] * args.context_parallel_size, + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])}, + ) + + +def forward_step(data_iterator, model: GPTModel): + """Forward training step. + + Args: + data_iterator : Input data iterator + model (GPTModel): The GPT Model + """ + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + global stimer + with stimer(bdata=True): + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + with stimer: + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def is_dataset_built_on_rank(): + return ( + mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() + ) and mpu.get_tensor_model_parallel_rank() == 0 + + +def core_gpt_dataset_config_from_args(args): + tokenizer = get_tokenizer() + + # Sometimes --data-path is too long, instead we parse it from a file. + blend: Optional[Tuple[List[str], Optional[List[float]]]] + blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]] + blend, blend_per_split = get_blend_and_blend_per_split(args) + + return GPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=blend, + blend_per_split=blend_per_split, + renormalize_blend_weights=args.renormalize_blend_weights, + split=args.split, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + s3_cache_path=args.s3_cache_path, + ) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ + args = get_args() + + config = core_gpt_dataset_config_from_args(args) + + if args.mock_data: + dataset_type = MockGPTDataset + else: + dataset_type = GPTDataset + + print_rank_0("> building train, validation, and test datasets for GPT ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + dataset_type, + train_val_test_num_samples, + is_dataset_built_on_rank, + config + ).build() + + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + + # Temporary for transition to core datasets + train_valid_test_datasets_provider.is_distributed = True + + pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + ) diff --git a/toolbox/Megatron-LM/patch/examples/mixtral/train_mixtral_8x7b_distributed.sh b/toolbox/Megatron-LM/patch/examples/mixtral/train_mixtral_8x7b_distributed.sh new file mode 100644 index 0000000000000000000000000000000000000000..1b94ef5a164d897a39d6e4a0f7e8e5ee8d3f9db0 --- /dev/null +++ b/toolbox/Megatron-LM/patch/examples/mixtral/train_mixtral_8x7b_distributed.sh @@ -0,0 +1,119 @@ +#!/bin/bash +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. + +# Runs Mixtral 8x7B model + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +GPUS_PER_NODE=16 +# Change for multinode config +MASTER_ADDR=${MASTER_ADDR:-"localhost"} +MASTER_PORT=${MASTER_PORT:-"6000"} +NNODES=${SLURM_NNODES:-"1"} +NODE_RANK=${RANK:-"0"} +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CHECKPOINT_PATH=./checkpoints/ +TOKENIZER_MODEL=../../datasets/tokenizer.model +DATA_PATH=../../datasets/gpt_small_117M_Mixtral/gpt_small_117M_text_document + +DISTRIBUTED_ARGS=( + --nproc_per_node $GPUS_PER_NODE + --nnodes $NNODES + --node_rank $NODE_RANK + --master_addr $MASTER_ADDR + --master_port $MASTER_PORT +) +TRANSFORMER_IMPL=local +MODEL_ARGS=( + --use-mcore-models + --disable-bias-linear + --seq-length 4096 + --max-position-embeddings 32768 + --num-layers 4 + --hidden-size 4096 + --ffn-hidden-size 14336 + --num-attention-heads 32 + --init-method-std 0.01 + --attention-dropout 0.0 + --hidden-dropout 0.0 + --normalization RMSNorm + --position-embedding-type rope + --swiglu + --untie-embeddings-and-output-weights + --group-query-attention + --num-query-groups 8 + --no-masked-softmax-fusion + --no-position-embedding + --rotary-base 1000000 +) + +MOE_ARGS=( + --num-experts 8 + --moe-router-topk 2 + --moe-router-load-balancing-type aux_loss + --moe-aux-loss-coeff 1e-2 + #--moe-grouped-gemm + --moe-token-dispatcher-type alltoall + --overlap-param-gather + --overlap-grad-reduce +) + +DATA_ARGS=( + --tokenizer-type Llama2Tokenizer + --tokenizer-model ${TOKENIZER_MODEL} + --data-path $DATA_PATH + --split 99990,8,2 +) + +TRAINING_ARGS=( + --micro-batch-size 1 + --transformer-impl $TRANSFORMER_IMPL\ + --global-batch-size 256 + --lr 1e-4 + --train-iters 500000 + --lr-decay-iters 320000 + --lr-decay-style cosine + --min-lr 1.0e-5 + --weight-decay 0.1 + --lr-warmup-iters 500 + --clip-grad 1.0 + --bf16 +) + +MODEL_PARALLEL_ARGS=( + --tensor-model-parallel-size 1 + --pipeline-model-parallel-size 2 + --expert-model-parallel-size 4 + --use-distributed-optimizer + --sequence-parallel +) + +LOGGING_ARGS=( + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --tensorboard-dir "${CHECKPOINT_PATH}/tensorboard" \ + --no-load-optim \ + --no-load-rng +) + +if [ -n "${WANDB_API_KEY}" ]; then + LOGGING_ARGS+=( + --wandb-project ${WANDB_PROJECT:-"Mixtral"} + --wandb-exp-name ${WANDB_NAME:-"Mixtral_8x7B"} + ) +fi + + +torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \ + ${MODEL_ARGS[@]} \ + ${MOE_ARGS[@]} \ + ${DATA_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${LOGGING_ARGS[@]} diff --git a/toolbox/Megatron-LM/patch/megatron/core/extensions/transformer_engine.py b/toolbox/Megatron-LM/patch/megatron/core/extensions/transformer_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..c3fbf35ced6bfc43d9abcfddb9b4434f7714e46e --- /dev/null +++ b/toolbox/Megatron-LM/patch/megatron/core/extensions/transformer_engine.py @@ -0,0 +1,1264 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. + +import dataclasses +import io +import os +import pickle +import warnings +from typing import Callable + +import torch +import transformer_engine as te +from packaging.version import Version as PkgVersion +from torch import Tensor +from torch.nn.parameter import Parameter + +from megatron.core import ModelParallelConfig +from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import ( + get_context_parallel_global_ranks, + get_context_parallel_group, + get_expert_data_parallel_rank, + get_expert_model_parallel_rank, + get_expert_model_parallel_world_size, + get_expert_tensor_parallel_group, + get_expert_tensor_parallel_rank, + get_expert_tensor_parallel_world_size, + get_hierarchical_context_parallel_groups, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name +from megatron.core.tensor_parallel.layers import ( + _initialize_affine_weight_cpu, + set_tensor_model_parallel_attributes, +) +from megatron.core.tensor_parallel.utils import divide +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint +from megatron.core.utils import get_te_version, is_te_min_version + + +def _get_extra_te_kwargs(config: TransformerConfig): + extra_transformer_engine_kwargs = {"params_dtype": config.params_dtype} + + if is_te_min_version("0.12.0"): + if config.use_cpu_initialization: + extra_transformer_engine_kwargs["device"] = 'cpu' + else: + extra_transformer_engine_kwargs["device"] = torch.cuda.current_device() + return extra_transformer_engine_kwargs + + +def condition_init_method(config, init_method): + """Condition TE init_method on config.perform_initialization.""" + return init_method if config.perform_initialization else (lambda w: None) + + +class TENorm: + """ + A conditional wrapper to initialize an instance of Transformer-Engine's + `LayerNorm` or `RMSNorm` based on input + """ + + # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? + def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5): + if config.normalization == "LayerNorm": + instance = te.pytorch.LayerNorm( + hidden_size=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **_get_extra_te_kwargs(config), + ) + elif config.normalization == "RMSNorm": + assert hasattr( + te.pytorch, "RMSNorm" + ), "Transformer-Engine >= v0.11 required to use this feature" + instance = te.pytorch.RMSNorm( + hidden_size=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **_get_extra_te_kwargs(config), + ) + else: + raise Exception('Only LayerNorm and RMSNorm are curently supported') + + return instance + + +class TELinear(te.pytorch.Linear): + """ + Wrapper for the Transformer-Engine's `Linear` layer. + + Note that if Megatron's parallel_state has not been initialized + yet, the tp_group passed to TE will be None and must be set later + via set_tensor_parallel_group(). + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + parallel_mode: str, + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + skip_bias_add: bool, + skip_weight_param_allocation: bool, + tp_comm_buffer_name: str = None, + is_expert: bool = False, + ): + self.config = config + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache + if skip_weight_param_allocation: + raise ValueError( + 'Transformer Engine linear layers do not support skip_weight_param_allocation' + ) + + extra_kwargs = _get_extra_te_kwargs(config) + + if is_te_min_version("0.8.0"): + if self.config.tp_comm_overlap: + if is_te_min_version("1.5.0"): + # Use old overlap flags if they were supplied instead + extra_kwargs["ub_overlap_ag"] = ( + self.config.tp_comm_overlap_ag + if hasattr(self.config, "tp_comm_overlap_ag") + else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag + ) + extra_kwargs["ub_overlap_rs"] = ( + self.config.tp_comm_overlap_rs + if hasattr(self.config, "tp_comm_overlap_rs") + else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs + ) + # Disable ub overlap for experts. + if is_expert: + extra_kwargs["ub_overlap_ag"] = False + extra_kwargs["ub_overlap_rs"] = False + else: + extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag + extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag + extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs + extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs + # Disable ub overlap for experts. + if is_expert: + extra_kwargs["ub_split_ag"] = False + extra_kwargs["ub_atomic_gemm_ag"] = False + extra_kwargs["ub_split_rs"] = False + extra_kwargs["ub_atomic_gemm_rs"] = False + if is_te_min_version("1.0.0", check_equality=False): + assert ( + tp_comm_buffer_name is not None + ), "Buffer name should be set to configure communication overlap settings" + extra_kwargs["ub_name"] = tp_comm_buffer_name + + self.expert_parallel = self.config.expert_model_parallel_size > 1 + if is_expert: + rng_tracker_name = get_expert_parallel_rng_tracker_name() + else: + rng_tracker_name = None + if is_te_min_version("1.7.0"): + extra_kwargs["rng_tracker_name"] = rng_tracker_name + + # Disable communications in TE when using TP or EP by making TE agnostic of model parallel. + if is_expert: + tp_group = get_expert_tensor_parallel_group(check_initialized=False) + tp_size = get_expert_tensor_parallel_world_size() + else: + tp_group = get_tensor_model_parallel_group(check_initialized=False) + tp_size = get_tensor_model_parallel_world_size() + explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) + + if explicit_expert_comm: + if parallel_mode == "column": + output_size = divide(output_size, tp_size) + elif parallel_mode == "row": + input_size = divide(input_size, tp_size) + parallel_mode = None + tp_size = 1 + tp_group = None + + super().__init__( + in_features=input_size, + out_features=output_size, + sequence_parallel=self.config.sequence_parallel, + fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, + tp_group=tp_group, + tp_size=tp_size, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + init_method=condition_init_method(config, init_method), + bias=bias, + return_bias=self.te_return_bias, + parallel_mode=parallel_mode, + **extra_kwargs, + ) + + for param in self.parameters(): + setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) + + def forward(self, x): + """Forward.""" + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) + out = super().forward(x, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + + # TE only returns a tuple when return_bias is True, otherwise + # it returns a single Tensor, we always want to return two + # values regardless of the arguments. + if self.te_return_bias: + return out + return out, None + + +class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): + """ + Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines + layernorm and linear layers + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: TransformerConfig, + init_method: Callable, + gather_output: bool, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + skip_weight_param_allocation: bool = False, + tp_comm_buffer_name: str = None, + ): + self.config = config + + if gather_output: + raise ValueError('Transformer Engine linear layers do not support gather_output = True') + + if is_expert: + raise ValueError('Transformer Engine linear layers do not yet support MoE') + + if skip_weight_param_allocation: + raise ValueError( + 'Transformer Engine linear layers do not support skip_weight_param_allocation' + ) + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache + extra_kwargs = _get_extra_te_kwargs(config) + + # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` + if is_te_min_version("0.11.0"): + extra_kwargs["normalization"] = self.config.normalization + elif self.config.normalization != "LayerNorm": + te_version = get_te_version() + raise ValueError( + f"Transformer Engine v{te_version} does not support {self.config.normalization}." + ) + + if is_te_min_version("0.8.0"): + if self.config.tp_comm_overlap: + extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad + extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad + if is_te_min_version("1.5.0", check_equality=False): + # Use old overlap flags if they were supplied instead + extra_kwargs["ub_overlap_ag"] = ( + self.config.tp_comm_overlap_ag + if hasattr(self.config, "tp_comm_overlap_ag") + else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag + ) + if is_te_min_version("1.6.0.dev0", check_equality=False): + extra_kwargs["ub_overlap_rs_dgrad"] = ( + self.config.tp_comm_overlap_rs_dgrad + if hasattr(self.config, "tp_comm_overlap_rs_dgrad") + else False + ) + if tp_comm_buffer_name == 'qkv' and self.config.tp_comm_overlap_disable_qkv: + extra_kwargs["ub_overlap_ag"] = False + extra_kwargs["ub_overlap_rs_dgrad"] = False + + if tp_comm_buffer_name == 'fc1' and self.config.tp_comm_overlap_disable_fc1: + extra_kwargs["ub_overlap_ag"] = False + extra_kwargs["ub_overlap_rs_dgrad"] = False + else: + extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag + extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag + if is_te_min_version("1.0.0", check_equality=False): + assert ( + tp_comm_buffer_name is not None + ), "Buffer name should be set to configure communication overlap settings" + extra_kwargs["ub_name"] = tp_comm_buffer_name + + super().__init__( + in_features=input_size, + out_features=output_size, + eps=self.config.layernorm_epsilon, + sequence_parallel=self.config.sequence_parallel, + fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, + tp_group=get_tensor_model_parallel_group(check_initialized=False), + tp_size=self.config.tensor_model_parallel_size, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + init_method=( + condition_init_method(config, init_method) + if not config.use_cpu_initialization + else lambda w: None + ), + bias=bias, + return_bias=self.te_return_bias, + parallel_mode="column", + return_layernorm_output=False, + zero_centered_gamma=self.config.layernorm_zero_centered_gamma, + **extra_kwargs, + ) + + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + + if config.use_cpu_initialization: + output_size_per_partition = divide(output_size, world_size) + _ = _initialize_affine_weight_cpu( + self.weight, + output_size, + input_size, + output_size_per_partition, + 0, + init_method=condition_init_method(config, init_method), + stride=1, + return_master_weight=False, + rank=rank, + world_size=world_size, + skip_set_tensor_parallel_attributes=True, + ) + if bias: + self.bias = Parameter( + torch.empty(output_size_per_partition, dtype=config.params_dtype) + ) + set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + with torch.no_grad(): + self.bias.zero_() + setattr(self.bias, 'allreduce', True) + + def forward(self, x): + """Forward.""" + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) + out = super().forward(x, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + + # TE only returns a tuple when return_bias is True, otherwise + # it returns a single Tensor, we always want to return two + # values regardless of the arguments. + if self.te_return_bias: + return out + return out, None + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 0, bias sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets + ) + + +class TEColumnParallelLinear(TELinear): + """ + Wrapper for the Transformer-Engine's `Linear` layer but specialized similar + to megatron's `ColumnParallelLinear` layer. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + gather_output: bool, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + skip_weight_param_allocation: bool = False, + tp_comm_buffer_name: str = None, + ): + if gather_output: + raise ValueError('Transformer Engine linear layers do not support gather_output = True') + + super().__init__( + input_size=input_size, + output_size=output_size, + parallel_mode="column", + config=config, + init_method=( + condition_init_method(config, init_method) + if not config.use_cpu_initialization + else lambda w: None + ), + bias=bias, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + skip_weight_param_allocation=skip_weight_param_allocation, + tp_comm_buffer_name=tp_comm_buffer_name, + ) + + if config.use_cpu_initialization: + if is_expert: + world_size = get_expert_tensor_parallel_world_size() + rank = get_expert_tensor_parallel_rank() + else: + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + output_size_per_partition = divide(output_size, world_size) + _ = _initialize_affine_weight_cpu( + self.weight, + output_size, + input_size, + output_size_per_partition, + 0, + init_method=condition_init_method(config, init_method), + stride=1, + return_master_weight=False, + rank=rank, + world_size=world_size, + skip_set_tensor_parallel_attributes=True, + ) + if bias: + self.bias = Parameter( + torch.empty(output_size_per_partition, dtype=config.params_dtype) + ) + set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + with torch.no_grad(): + self.bias.zero_() + setattr(self.bias, 'allreduce', True) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 0, bias sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets + ) + + +class TERowParallelLinear(TELinear): + """ + Wrapper for the Transformer-Engine's `Linear` layer but specialized similar + to megatron's `RowParallelLinear` layer. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + input_is_parallel: bool, + skip_bias_add: bool, + is_expert: bool, + tp_comm_buffer_name: str = None, + ): + if not input_is_parallel: + raise ValueError( + "Transformer Engine linear layers do not support input_is_parallel = False" + ) + + super().__init__( + input_size=input_size, + output_size=output_size, + parallel_mode="row", + config=config, + init_method=( + condition_init_method(config, init_method) + if not config.use_cpu_initialization + else lambda w: None + ), + bias=bias, + skip_bias_add=skip_bias_add, + skip_weight_param_allocation=False, # We don't currently use this for row parallel layers # pylint: disable=line-too-long + is_expert=is_expert, + tp_comm_buffer_name=tp_comm_buffer_name, + ) + if config.use_cpu_initialization: + if is_expert: + world_size = get_expert_tensor_parallel_world_size() + rank = get_expert_tensor_parallel_rank() + else: + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + input_size_per_partition = divide(input_size, world_size) + self.master_weight = _initialize_affine_weight_cpu( + self.weight, + output_size, + input_size, + input_size_per_partition, + 1, + init_method=condition_init_method(config, init_method), + stride=1, + return_master_weight=False, + params_dtype=config.params_dtype, + rank=rank, + world_size=world_size, + skip_set_tensor_parallel_attributes=True, + ) + if bias: + self.bias = Parameter(torch.empty(output_size, dtype=config.params_dtype)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + setattr(self.bias, 'allreduce', True) + setattr(self.bias, 'sequence_parallel', config.sequence_parallel) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 1, bias not sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, prefix, {'weight': 1}, sharded_offsets + ) + + +class TEDotProductAttention(te.pytorch.DotProductAttention): + """ + Wrapper for the Transformer-Engine's `DotProductAttention` layer that also + has "flash attention" enabled. + + Note that if Megatron's parallel_state has not been initialized yet, the + tp_group and cp_group passed to TE will be None and must be set later + via set_tensor_parallel_group() and set_context_parallel_group(). + """ + + cp_stream: torch.cuda.Stream = None + + def __init__( + self, + config: TransformerConfig, + layer_number: int, + attn_mask_type: AttnMaskType, + attention_type: str, + attention_dropout: float = None, + softmax_scale: float = None, + k_channels: int = None, + v_channels: int = None, + cp_comm_type: str = "p2p", + ): + self.config = config + self.te_forward_mask_type = False + self.qkv_format: str = 'sbhd' + + if self.config.apply_query_key_layer_scaling != bool( + int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0')) + ): + raise ValueError( + f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} " + f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is " + f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support " + f"setting query key layer scaling via argument, so these two must match." + ) + + extra_kwargs = {} + if is_te_min_version("0.11.0"): + extra_kwargs["num_gqa_groups"] = self.config.num_query_groups + elif self.config.num_query_groups != self.config.num_attention_heads: + raise ValueError( + f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, " + f"use a newer version of Transformer Engine. " + f"(num_query_groups ({self.config.num_query_groups}) != " + f"num_attention_heads ({self.config.num_attention_heads}))" + ) + + if is_te_min_version("0.10.0"): + extra_kwargs["attention_type"] = attention_type + # older version don't need attention_type + + if is_te_min_version("0.12.0", check_equality=False): + self.te_forward_mask_type = True + + # This check is important as CP config can be disabled while having a valid CP group + # Example - Disabling CP for encoder while a valid CP group exists for decoder + if self.config.context_parallel_size > 1: + assert is_te_min_version( + "1.0.0" + ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!" + if getattr(TEDotProductAttention, "cp_stream") is None: + TEDotProductAttention.cp_stream = torch.cuda.Stream() + extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False) + extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks( + check_initialized=False + ) + extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream + if is_te_min_version("1.10.0"): + if cp_comm_type is None: + extra_kwargs["cp_comm_type"] = "p2p" + elif cp_comm_type == "a2a+p2p": + assert is_te_min_version("1.12.0"), ( + f"Transformer-Engine v{get_te_version()} must be >= 1.12.0 to support" + "hierarchical cp commucation." + ) + extra_kwargs["cp_comm_type"] = "a2a+p2p" + extra_kwargs["cp_group"] = get_hierarchical_context_parallel_groups( + check_initialized=False + ) + else: + extra_kwargs["cp_comm_type"] = cp_comm_type + + if self.config.deterministic_mode: + if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0: + raise RuntimeError( + "deterministic_mode is on and we are using DotProductAttention from " + "Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. " + f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}." + ) + + if config.window_size is not None: + # Check version + assert is_te_min_version("1.2.0"), ( + f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" + "sliding window attention." + ) + extra_kwargs['window_size'] = config.window_size + + if is_te_min_version("1.10.0"): + # TE 1.10.0 introduces the ability to set the different k and v channels + kv_channels = ( + (k_channels, v_channels) + if k_channels is not None and v_channels is not None + else self.config.kv_channels + ) + extra_kwargs['softmax_scale'] = softmax_scale + else: + kv_channels = self.config.kv_channels + + super().__init__( + num_attention_heads=self.config.num_attention_heads, + kv_channels=kv_channels, + attention_dropout=( + self.config.attention_dropout if attention_dropout is None else attention_dropout + ), + attn_mask_type=attn_mask_type.name, + sequence_parallel=self.config.sequence_parallel, + tp_size=self.config.tensor_model_parallel_size, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + tp_group=get_tensor_model_parallel_group(check_initialized=False), + layer_number=layer_number, + **extra_kwargs, + ) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attention_mask: Tensor, + attn_mask_type: AttnMaskType, + attention_bias: Tensor = None, + packed_seq_params: PackedSeqParams = None, + ): + """Forward.""" + packed_seq_kwargs = ( + dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {} + ) + # overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set + # after init + if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False): + self.qkv_format = 'bshd' + + qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format) + + if get_te_version() < PkgVersion("1.3.0"): + # TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H + # copies (#555) + # These two arguments did not exist prior to 1.3.0 + packed_seq_kwargs.pop("max_seqlen_q", None) + packed_seq_kwargs.pop("max_seqlen_kv", None) + + if get_te_version() < PkgVersion("1.10.0"): + # TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted + # in each individual sequence in THD format dataset + # These two arguments did not exist prior to 1.8.0.Full support added in 1.10.0 (#1012) + packed_seq_kwargs.pop("cu_seqlens_q_padded", None) + packed_seq_kwargs.pop("cu_seqlens_kv_padded", None) + + # WAR for peak memory usage. + # See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388 + if self.config.apply_rope_fusion and qkv_format == 'bshd': + query, key, value = [x.contiguous().transpose(0, 1) for x in (query, key, value)] + # In PyTorch, the following two tensors are in fact the same: + # Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1) + # Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1) + # Stride for a dimension that is 1 has no meaning, so tensors created two different ways + # can have same shape but different strides. + # We unify them to the first one to pass the stride check in TE + if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride(): + value = value.as_strided(value.shape, key.stride()) + + attention_bias_kwargs = {} + if attention_bias is not None: + assert is_te_min_version("1.2.0"), ( + f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" + "`attention_bias`." + ) + attention_bias_kwargs = dict( + core_attention_bias_type='post_scale_bias', core_attention_bias=attention_bias + ) + + if self.te_forward_mask_type: + if qkv_format == 'thd' and is_te_min_version("1.7.0"): + # thd format uses flash attention with cuDNN kernel which requires is_padding=True, + # so the only acceptable mask types are `padding_causal` and `padding`. These do not + # necessarily indicate there are padded tokens in the sequence. + if attn_mask_type == AttnMaskType.causal: + attn_mask_type = AttnMaskType.padding_causal + elif attn_mask_type == AttnMaskType.no_mask: + attn_mask_type = AttnMaskType.padding + core_attn_out = super().forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type.name, + **attention_bias_kwargs, + **packed_seq_kwargs, + ) + else: + core_attn_out = super().forward( + query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs + ) + + if self.config.apply_rope_fusion and qkv_format == 'bshd': + return core_attn_out.transpose(0, 1) + else: + return core_attn_out + + +if is_te_min_version("1.9.0.dev0"): + + class TEGroupedLinear(te.pytorch.GroupedLinear): + """ + Wrapper for the Transformer-Engine's `GroupedLinear` layer. + + Note that if Megatron's parallel_state has not been initialized + yet, the tp_group passed to TE will be None and must be set later + via set_tensor_parallel_group(). + """ + + def __init__( + self, + num_gemms: int, + input_size: int, + output_size: int, + *, + parallel_mode: str, + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + skip_bias_add: bool, + is_expert: bool = False, + tp_comm_buffer_name: str = None, + ): + self.config = config + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache + + extra_kwargs = _get_extra_te_kwargs(config) + extra_kwargs["ub_name"] = tp_comm_buffer_name + + self.expert_parallel = self.config.expert_model_parallel_size > 1 + if is_expert: + extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name() + + # The comms between TP and EP group is explicitly handled by MoE token dispatcher. + # So we disable comms by making TE agnostic of model parallel. + if is_expert: + tp_group = get_expert_tensor_parallel_group(check_initialized=False) + tp_size = get_expert_tensor_parallel_world_size() + else: + tp_group = get_tensor_model_parallel_group(check_initialized=False) + tp_size = get_tensor_model_parallel_world_size() + self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) + + if self.explicit_expert_comm: + if parallel_mode == "column": + output_size = divide(output_size, tp_size) + elif parallel_mode == "row": + input_size = divide(input_size, tp_size) + parallel_mode = None + tp_size = 1 + tp_group = None + + super().__init__( + num_gemms=num_gemms, + in_features=input_size, + out_features=output_size, + sequence_parallel=self.config.sequence_parallel, + fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, + tp_group=tp_group, + tp_size=tp_size, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + init_method=condition_init_method(config, init_method), + bias=bias, + return_bias=self.te_return_bias, + parallel_mode=parallel_mode, + **extra_kwargs, + ) + + for param in self.parameters(): + setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) + + def merge_extra_states( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """ + Merge multiple "_extra_state" into one. + """ + self.init_fp8_metadata(num_gemms=self.num_gemms) + fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration + + try: + state_list = [ + state_dict.pop(f"{prefix}_extra_state{i}") for i in range(1, self.num_gemms) + ] + except KeyError: + # "_extra_state{i}" only exists for dist-ckpt. Return for torch native ckpt. + return + + if not fp8_checkpoint: + return + state_list = [state_dict.pop(f"{prefix}_extra_state")] + state_list + state_list = [self._decode_extra_state(state) for state in state_list] + extra_fp8_variables = state_list[0]['extra_fp8_variables'] + extra_fp8_variables['num_gemms'] = self.num_gemms + extra_state = { + "scale_fwd": torch.cat( + [state['scale_fwd'].view(-1, 1) for state in state_list], dim=1 + ).view(-1), + "scale_inv_fwd": torch.cat( + [state['scale_inv_fwd'].view(-1, 1) for state in state_list], dim=1 + ).view(-1), + "amax_history_fwd": torch.cat( + [state['amax_history_fwd'].view(-1, 1) for state in state_list], dim=1 + ).view(self.fp8_meta["recipe"].amax_history_len, -1), + "scale_bwd": torch.cat( + [state['scale_bwd'].view(-1, 1) for state in state_list], dim=1 + ).view(-1), + "scale_inv_bwd": torch.cat( + [state['scale_inv_bwd'].view(-1, 1) for state in state_list], dim=1 + ).view(-1), + "amax_history_bwd": torch.cat( + [state['amax_history_bwd'].view(-1, 1) for state in state_list], dim=1 + ).view(self.fp8_meta["recipe"].amax_history_len, -1), + "extra_fp8_variables": extra_fp8_variables, + } + state_dict[f"{prefix}_extra_state"] = self._encode_extra_state(extra_state) + + self._register_load_state_dict_pre_hook(merge_extra_states, with_module=True) + + def forward(self, x, m_splits): + """Forward.""" + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) + out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + + # TE only returns a tuple when return_bias is True, otherwise + # it returns a single Tensor, we always want to return two + # values regardless of the arguments. + if self.te_return_bias: + return out + return out, None + + def _encode_extra_state(self, state): + state_serialized = io.BytesIO() + torch.save(state, state_serialized) + return state_serialized + + def _decode_extra_state(self, state): + if isinstance(state, torch.Tensor): + return pickle.loads(state.detach().cpu().numpy().tobytes()) + elif isinstance(state, io.BytesIO): + state.seek(0) + return torch.load(state, map_location="cuda") + else: + raise RuntimeError("Unsupported checkpoint format.") + + def _split_extra_state(self, state): + fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration + + if not fp8_checkpoint: + return [state] * self.num_gemms + + state = self._decode_extra_state(state) + extra_states = [] + extra_fp8_variables = state['extra_fp8_variables'] + extra_fp8_variables['num_gemms'] = 1 + for gemm_idx in range(self.num_gemms): + tmp_state = { + "scale_fwd": state['scale_fwd'].view(3, -1)[:, gemm_idx], + "scale_inv_fwd": state['scale_inv_fwd'].view(3, -1)[:, gemm_idx], + "amax_history_fwd": state['amax_history_fwd'].view( + self.fp8_meta["recipe"].amax_history_len, 3, -1 + )[:, :, gemm_idx], + "scale_bwd": state['scale_bwd'].view(2, -1)[:, gemm_idx], + "scale_inv_bwd": state['scale_inv_bwd'].view(2, -1)[:, gemm_idx], + "amax_history_bwd": state['amax_history_bwd'].view( + self.fp8_meta["recipe"].amax_history_len, 2, -1 + )[:, :, gemm_idx], + "extra_fp8_variables": extra_fp8_variables, + } + extra_states.append(self._encode_extra_state(tmp_state)) + return extra_states + + def _sharded_state_dict_grouped( + self, tp_axis_map, prefix='', sharded_offsets=(), metadata=None + ): + """ + prefix should be module_name to make keys identical to sequetial ones. + """ + sharded_state_dict = {} + full_state_dict = self.state_dict(prefix='', keep_vars=True) + num_global_experts = get_expert_model_parallel_world_size() * self.num_gemms + local_expert_indices_offset = get_expert_model_parallel_rank() * self.num_gemms + ep_axis = len(sharded_offsets) + extra_states = self._split_extra_state(full_state_dict['_extra_state']) + for gemm_idx in range(self.num_gemms): + state_dict = { + f'{gemm_idx}.weight': full_state_dict[f'weight{gemm_idx}'], + f'{gemm_idx}._extra_state': extra_states[gemm_idx], + } + if self.use_bias: + state_dict[f'{gemm_idx}.bias'] = full_state_dict[f'bias{gemm_idx}'] + sub_sd = make_sharded_tensors_for_checkpoint( + state_dict, + '', + tp_axis_map, + ( + *sharded_offsets, + (ep_axis, local_expert_indices_offset + gemm_idx, num_global_experts), + ), + ) + # Remove expert layers indexing from sharded keys + replace_prefix_for_sharding(sub_sd, f'{gemm_idx}.', prefix) + sharded_state_dict.update( + { + f'{prefix}weight{gemm_idx}': sub_sd[f'{gemm_idx}.weight'], + f'{prefix}_extra_state{"" if gemm_idx == 0 else gemm_idx}': sub_sd[ + f'{gemm_idx}._extra_state' + ], + } + ) + if self.use_bias: + sharded_state_dict[f'{prefix}bias{gemm_idx}'] = sub_sd[f'{gemm_idx}.bias'] + # Adjust replica ids - replication along DP modulo EP + for k, sh_ten in sharded_state_dict.items(): + replica_id = sh_ten.replica_id + assert ( + len(replica_id) == 3 + ), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}' + sh_ten.replica_id = (*replica_id[:2], get_expert_data_parallel_rank()) + return sharded_state_dict + + class TEColumnParallelGroupedLinear(TEGroupedLinear): + """ + Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized + to column-parallel style. + """ + + def __init__( + self, + num_gemms: int, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + tp_comm_buffer_name: str = None, + ): + + super().__init__( + num_gemms=num_gemms, + input_size=input_size, + output_size=output_size, + parallel_mode="column", + config=config, + init_method=condition_init_method(config, init_method), + bias=bias, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + tp_comm_buffer_name=tp_comm_buffer_name, + ) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """ + For each gemm, sharding along axis 0, bias sharded. + Assume sharded_offsets[-1] is the expert parallel offset. + """ + tp_axis_map = {} + for gemm_idx in range(self.num_gemms): + tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0}) + return super()._sharded_state_dict_grouped( + tp_axis_map, prefix, sharded_offsets, metadata + ) + + class TERowParallelGroupedLinear(TEGroupedLinear): + """ + Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized + to row-parallel style. + """ + + def __init__( + self, + num_gemms: int, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + tp_comm_buffer_name: str = None, + ): + + super().__init__( + num_gemms=num_gemms, + input_size=input_size, + output_size=output_size, + parallel_mode="row", + config=config, + init_method=condition_init_method(config, init_method), + bias=bias, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + tp_comm_buffer_name=tp_comm_buffer_name, + ) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """ + For each gemm, sharding along axis 1, bias not sharded. + Assume sharded_offsets[-1] is the expert parallel offset. + """ + tp_axis_map = {f'{gemm_idx}.weight': 1 for gemm_idx in range(self.num_gemms)} + return super()._sharded_state_dict_grouped( + tp_axis_map, prefix, sharded_offsets, metadata + ) + +else: + + TEGroupedLinear = None + TEColumnParallelGroupedLinear = None + TERowParallelGroupedLinear = None + + +class TEDelayedScaling(te.common.recipe.DelayedScaling): + """ + Wrapper for the Transformer-Engine's `DelayedScaling` layer. + """ + + def __init__( + self, + config: ModelParallelConfig, + fp8_format: int, + override_linear_precision: tuple = (False, False, False), + ): + extra_kwargs = _get_extra_te_kwargs(config) + if is_te_min_version("1.6.0.dev0"): + extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention + extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention + if get_te_version() < PkgVersion("1.8.0"): + extra_kwargs["interval"] = config.fp8_interval + elif config.fp8_interval != 1: + warnings.warn("fp8_interval is deprecated and ignored from Transformer-Engine v1.8.0.") + + super().__init__( + margin=config.fp8_margin, + fp8_format=fp8_format, + amax_compute_algo=config.fp8_amax_compute_algo, + amax_history_len=config.fp8_amax_history_len, + override_linear_precision=override_linear_precision, + **extra_kwargs, + ) + + +class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTracker): + """Wraps TransformerEngine's CudaRNGStatesTracker so that it is + interchangeable with Megatron's RNG tracker""" + + def is_initialized(self): + """Checks if the internal RNG state has been set wirth set_states().""" + return self._is_initialized + + def reset(self): + """Reset the internal RNG state.""" + super().reset() + self._is_initialized = False + + def set_states(self, states): + """Set the internal RNG state.""" + super().set_states(states) + self._is_initialized = True + + def add(self, name, seed): + """Track the rng state.""" + super().add(name, seed) + self._is_initialized = True + + +def te_checkpoint( + forward_func, + distribute_saved_activations, + get_rng_state_tracker, + tp_group, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, +): + """Checkpointing with Transformer-Engine.""" + from transformer_engine.pytorch.distributed import checkpoint + + if is_te_min_version("1.5.0"): + return checkpoint( + forward_func, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + distribute_saved_activations=distribute_saved_activations, + get_rng_state_tracker=get_rng_state_tracker, + tp_group=tp_group, + ) + else: + return checkpoint( + forward_func, + distribute_saved_activations, + get_rng_state_tracker, + tp_group, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + + +try: + + from transformer_engine.pytorch.attention import _SplitAlongDim + + SplitAlongDim = _SplitAlongDim.apply + +except ImportError: + + SplitAlongDim = None + +try: + + from transformer_engine.pytorch.cpu_offload import ( + get_cpu_offload_context as _get_cpu_offload_context, + ) + + def get_cpu_offload_context( + enabled, num_layers, model_layers, activation_offloading, weight_offloading + ): + """Get CPU offload context and sync function.""" + if is_te_min_version("1.10.0.dev0"): + context, sync_func = _get_cpu_offload_context( + enabled, num_layers, model_layers, activation_offloading, weight_offloading + ) + else: + context, sync_func = _get_cpu_offload_context( + enabled, num_layers, activation_offloading, weight_offloading + ) + + return context, sync_func + +except ImportError: + + get_cpu_offload_context = None + +try: + + from transformer_engine.pytorch.attention import FusedRoPEFunc + + def fused_apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + """Apply rotary positional embedding to input tensor T in `sbhd` format.""" + return FusedRoPEFunc.apply(t, freqs, "sbhd") + + def fused_apply_rotary_pos_emb_thd( + t: torch.Tensor, + cu_seqlens: torch.Tensor, + freqs: torch.Tensor, + cp_size: int = 1, + cp_rank: int = 0, + ) -> torch.Tensor: + """ + Apply rotary positional embedding to input tensor T in `thd` format with CP support. + """ + if is_te_min_version("1.11.0", check_equality=False): + return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens, cp_size, cp_rank) + else: + return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens) + +except ImportError: + + pass + +try: + + from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding # pylint: disable=unused-import + +except ImportError: + + Fp8Padding = None + Fp8Unpadding = None diff --git a/toolbox/Megatron-LM/patch/megatron/core/models/common/embeddings/rope_utils.py b/toolbox/Megatron-LM/patch/megatron/core/models/common/embeddings/rope_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a1dad801319072cd198502718326dbe579720e4 --- /dev/null +++ b/toolbox/Megatron-LM/patch/megatron/core/models/common/embeddings/rope_utils.py @@ -0,0 +1,260 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from megatron.core.transformer.transformer_config import TransformerConfig + +import logging + +import torch +from torch import Tensor + +from megatron.core import parallel_state +from megatron.core.utils import is_te_min_version + +logger = logging.getLogger(__name__) + +try: + from megatron.core.extensions.transformer_engine import ( + fused_apply_rotary_pos_emb, + fused_apply_rotary_pos_emb_thd, + ) + + HAVE_APPLY_ROPE_FUSION = True +except ImportError: + try: + from apex.transformer.functional import ( + fused_apply_rotary_pos_emb, + #fused_apply_rotary_pos_emb_thd, + ) + + HAVE_APPLY_ROPE_FUSION = True + except ImportError: + HAVE_APPLY_ROPE_FUSION = False + + +try: + from flash_attn.layers.rotary import apply_rotary_emb as apply_rotary_emb_flash +except ImportError: + apply_rotary_emb_flash = None + + +__all__ = ['apply_rotary_emb_flash'] + + +def get_pos_emb_on_this_cp_rank(pos_emb: Tensor, seq_dim: int) -> Tensor: + """Get the position embedding on the current context parallel rank. + + Args: + pos_emb (Tensor): Positional embedding tensor + seq_dim (int): Sequence dimension + """ + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + cp_idx = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True + ).cuda(non_blocking=True) + pos_emb = pos_emb.view( + *pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :] + ) + pos_emb = pos_emb.index_select(seq_dim, cp_idx) + pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :]) + return pos_emb + + +def _rotate_half(x: Tensor, rotary_interleaved: bool) -> Tensor: + """Change sign so the last dimension becomes [-odd, +even] + + Args: + x (Tensor): Input tensor + + Returns: + Tensor: Tensor rotated half + """ + if not rotary_interleaved: + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x_new = torch.stack((-x2, x1), dim=-1) + return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1) + + +def _apply_rotary_pos_emb_bshd( + t: Tensor, + freqs: Tensor, + rotary_interleaved: bool = False, + multi_latent_attention: bool = False, + mscale: float = 1.0, +) -> Tensor: + """Apply rotary positional embedding to input tensor T. + + check https://kexue.fm/archives/8265 for detailed formulas + + Args: + t (Tensor): Input tensor T is of shape [seq_length, ... , dim] + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] + + Returns: + Tensor: The input tensor after applying RoPE + """ + rot_dim = freqs.shape[-1] + + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + if multi_latent_attention: + x1 = t[..., 0::2] + x2 = t[..., 1::2] + t = torch.cat((x1, x2), dim=-1) + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + cos_ = (torch.cos(freqs) * mscale).to(t.dtype) + sin_ = (torch.sin(freqs) * mscale).to(t.dtype) + + t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_) + return torch.cat((t, t_pass), dim=-1) + + +def _get_thd_freqs_on_this_cp_rank(cp_rank: int, cp_size: int, x: Tensor, freqs: Tensor) -> Tensor: + if cp_size > 1: + cp_seg = x.size(0) // 2 + full_seqlen = cp_size * x.size(0) + return torch.cat( + [ + freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg], + freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg], + ] + ) + else: + return freqs[: x.size(0)] + + +def _apply_rotary_pos_emb_thd( + t: Tensor, + cu_seqlens: Tensor, + freqs: Tensor, + rotary_interleaved: bool = False, + multi_latent_attention: bool = False, + mscale: float = 1.0, +) -> Tensor: + """A baseline implementation of applying RoPE for `thd` format. + + Args: + t (Tensor): Input tensor T is of shape [t, h, d] + cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, + with shape [b + 1] and dtype torch.int32. + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + + Returns: + Tensor: Shape [t, h, d]. The input tensor after applying RoPE. + """ + + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + cu_seqlens = cu_seqlens // cp_size + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + + return torch.cat( + [ + _apply_rotary_pos_emb_bshd( + x.unsqueeze(1), + _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs), + rotary_interleaved=rotary_interleaved, + multi_latent_attention=multi_latent_attention, + mscale=mscale, + ) + for x in torch.split(t, seqlens) + ] + ).squeeze(1) + + +def apply_rotary_pos_emb( + t: Tensor, + freqs: Tensor, + config: TransformerConfig, + cu_seqlens: Optional[Tensor] = None, + mscale: float = 1.0, +): + """ + Reroute to the appropriate apply_rotary_pos_emb function depending on + fused/unfused kernels, or bshd (conventional) / thd (packed seq) format + """ + + if config.apply_rope_fusion: + if cu_seqlens is None: + return fused_apply_rotary_pos_emb(t, freqs) + else: + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + if not is_te_min_version("1.11.0", check_equality=False): + raise ValueError("Only TE >= 1.12 supports RoPE fusion for THD format with CP.") + return fused_apply_rotary_pos_emb_thd( + t, + cu_seqlens, + freqs, + cp_size=cp_size, + cp_rank=parallel_state.get_context_parallel_rank(), + ) + else: + return fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs) + else: + if cu_seqlens is None: + return _apply_rotary_pos_emb_bshd( + t, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=config.multi_latent_attention, + mscale=mscale, + ) + else: + return _apply_rotary_pos_emb_thd( + t, + cu_seqlens, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=config.multi_latent_attention, + mscale=mscale, + ) + + +def apply_rotary_pos_emb_with_cos_sin( + t: Tensor, cos: Tensor, sin: Tensor, rotary_interleaved: bool = False +) -> Tensor: + """ + This function applies rotary positional embedding to the target tensor t + using precomputed cos and sin of size (seq_len, d_rot / 2) + """ + cos = cos.to(t.dtype) + sin = sin.to(t.dtype) + + if apply_rotary_emb_flash is None: + # Combine cos and sin into freqs + freqs = torch.stack([cos, sin], dim=-1).flatten(start_dim=-2) + + # Expand freqs to match t's shape + while freqs.dim() < t.dim(): + freqs = freqs.unsqueeze(1) + freqs = freqs.expand(t.shape[:-1] + (-1,)) + + y = _apply_rotary_pos_emb_bshd( + t, + freqs, + rotary_interleaved=rotary_interleaved, + multi_latent_attention=False, + mscale=1.0, + ) + else: + # Use Flash Attention's optimized kernel for rotary embedding + t = t.permute(1, 0, 2, 3) + y = apply_rotary_emb_flash(t, cos, sin, rotary_interleaved) + y = y.permute(1, 0, 2, 3) + + return y diff --git a/toolbox/Megatron-LM/patch/megatron/core/models/multimodal/llava_model.py b/toolbox/Megatron-LM/patch/megatron/core/models/multimodal/llava_model.py new file mode 100644 index 0000000000000000000000000000000000000000..29e6b5e2d6df9b4166adebc9651f901b86760769 --- /dev/null +++ b/toolbox/Megatron-LM/patch/megatron/core/models/multimodal/llava_model.py @@ -0,0 +1,924 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. + +import logging +from collections import namedtuple +from functools import partial +from typing import List, Optional + +import torch + +from megatron.core import InferenceParams, tensor_parallel +from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.models.gpt import GPTModel +from megatron.core.models.vision.clip_vit_model import CLIPViTModel, get_num_image_embeddings +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import get_context_parallel_group, get_context_parallel_world_size +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import get_batch_on_this_cp_rank, log_single_rank + +try: + import transformer_engine # pylint: disable=unused-import + from transformer_engine.pytorch.distributed import gather_along_first_dim + + from megatron.core.extensions.transformer_engine import TEDotProductAttention + from megatron.core.utils import is_te_min_version + + HAVE_TE = True +except: + HAVE_TE = False + if get_context_parallel_world_size() > 1: + raise RuntimeError("ContextParallelism requires TransformerEngine support, but not found.") + + +IGNORE_INDEX = -100 # ID for labels that should be ignored. +# Image token index can be tokenizer dependent so the default value does not work in all cases. +DEFAULT_IMAGE_TOKEN_INDEX = -200 +IMAGE_TOKEN = "" + + +# Note: This is under development and may be missing features. +class LLaVAModel(MegatronModule): + """LLaVA multi-modal model. + + Args: + language_transformer_config (TransformerConfig): Transformer config for the language model. + language_transformer_layer_spec (ModuleSpec): Language model spec. + language_vocab_size (int): Language model vocabulary size. + language_max_sequence_length (int): Language model maximum sequence length. + vision_transformer_config (TransformerConfig): Transformer config for the vision model. + vision_transformer_layer_spec (ModuleSpec): Vision model spec. + drop_vision_class_token (bool): Drop vision class token(s) before the language model. + vision_projection_config (TransformerConfig): Vision projection config. + vision_projection_layer_spec (ModuleSpec): Vision projection spec. + vision_projection_type (str): Type of the vision projection. Default: 2-layer MLP. + allow_missing_vision_projection_checkpoint (bool): Allow vision projection weights to be + missing when loading a checkpoint. Default False. + parallel_output (bool): Keep outputs split across tensor parallel ranks. + This is typically True for training and False for inference. + language_position_embedding_type (str): Language model position embedding type. + language_rotary_percent (float): RoPE percent. Defaults to 1.0. + pre_process (bool): Include embedding layer in the decoder (used with pipeline parallel). + post_process (bool): Include output layer in the decoder (used with pipeline parallel). + add_encoder (bool): Construct the encoder (used with pipeline parallel). + When we use pipelining, the encoder will live on only the first stage + add_decoder (bool): Construct the decoder (used with pipeline parallel). + When we use pipelining, the decoder will live on every stage after the first one. + img_h (int): Input image height. + img_w (int): Input image width. + patch_dim (int): The size of each image patch side. + language_rotary_base (int): RoPE base. + language_rope_scaling (bool): Toggle RoPE scaling. + image_token_index (int): Token ID for image token such as . + pixel_shuffle (bool): Enable pixel shuffle. + tile_tags (list): Optional tile tags. + """ + + def __init__( + self, + language_transformer_config: TransformerConfig, + language_transformer_layer_spec: ModuleSpec, + language_vocab_size: int, + language_max_sequence_length: int, + vision_transformer_config: TransformerConfig, + vision_transformer_layer_spec: ModuleSpec, + drop_vision_class_token: bool, + vision_projection_config: TransformerConfig, + vision_projection_layer_spec: ModuleSpec, + vision_projection_type: str = "mlp", + allow_missing_vision_projection_checkpoint: bool = False, + parallel_output: bool = True, + language_position_embedding_type: str = 'learned_absolute', + language_rotary_percent: float = 1.0, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + img_h: int = 336, + img_w: int = 336, + patch_dim: int = 14, + language_rotary_base: int = 10000, + language_rope_scaling: bool = False, + image_token_index: int = DEFAULT_IMAGE_TOKEN_INDEX, + pixel_shuffle: bool = False, + tile_tags: Optional[list] = None, + ) -> None: + super().__init__(config=language_transformer_config) + + if has_config_logger_enabled(language_transformer_config): + log_config_to_disk(language_transformer_config, locals(), prefix=type(self).__name__) + + log_single_rank( + logging.getLogger(__name__), + logging.WARNING, + "LLaVA is work in progress. Features are missing and methods can change.", + ) + + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder + + self.encoder_hidden_state = None + self.vision_model = None + self.vision_projection = None + self.language_model = None + + self.sequence_parallel_lm = language_transformer_config.sequence_parallel + self.tp_comm_overlap_lm = language_transformer_config.tp_comm_overlap + self.context_parallel_lm = language_transformer_config.context_parallel_size + if self.sequence_parallel_lm or self.context_parallel_lm > 1: + assert ( + language_transformer_layer_spec.submodules.self_attention.submodules.core_attention + == TEDotProductAttention + and HAVE_TE + ), "Sequence/Context Parallelism is supported only with TE DotProductAttention." + if self.context_parallel_lm > 1: + assert is_te_min_version( + "1.10.0" + ), "Context Parallelism in LLaVA requires TE v1.10 or higher" + self.tensor_model_parallel_size_lm = language_transformer_config.tensor_model_parallel_size + + # This attribute is needed to check if an all-reduce is required + # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. + self.share_embeddings_and_output_weights = False + if self.add_decoder: + self.language_model = GPTModel( + config=language_transformer_config, + transformer_layer_spec=language_transformer_layer_spec, + vocab_size=language_vocab_size, + max_sequence_length=language_max_sequence_length, + parallel_output=parallel_output, + position_embedding_type=language_position_embedding_type, + rotary_percent=language_rotary_percent, + pre_process=self.pre_process, + post_process=self.post_process, + rotary_base=language_rotary_base, + rope_scaling=language_rope_scaling, + scatter_embedding_sequence_parallel=False, + ) + self.share_embeddings_and_output_weights = ( + self.language_model.share_embeddings_and_output_weights + ) + self._language_max_sequence_length = language_max_sequence_length + self._language_is_pipeline_parallel = ( + language_transformer_config.pipeline_model_parallel_size > 1 + ) + + class_token_len = 1 + if self.add_encoder: + self._drop_vision_class_token = drop_vision_class_token + add_class_token = True + if vision_transformer_config.vision_model_type == "siglip": + class_token_len = 0 + add_class_token = False + error_msg = ( + "Siglip does not support vision class token, " + "set disable-vision-class-token to False." + ) + assert not self._drop_vision_class_token, error_msg + self.vision_model = CLIPViTModel( + vision_transformer_config, + vision_transformer_layer_spec, + img_h=img_h, + img_w=img_w, + class_token_len=class_token_len, + patch_dim=patch_dim, + model_subtype=vision_transformer_config.vision_model_type, + add_class_token=add_class_token, + ) + + vision_projection_input_size = vision_transformer_config.hidden_size + vision_projection_input_size *= 4 if pixel_shuffle else 1 + + # Map (intermediate) vision model outputs to the language model input dimension. + self.vision_projection = MultimodalProjector( + vision_projection_config, + vision_projection_layer_spec, + vision_projection_type, + vision_projection_input_size, + ) + # Ignore missing weights for the vision projection during checkpoint loading. + # This should be disabled by default but can be enabled if your checkpoint contains + # pretrained vision and language models but not the projection from vision model + # outputs to language model inputs. + if allow_missing_vision_projection_checkpoint: + vision_projection_param_names = [ + f"vision_projection.{name}" + for name in self.vision_projection.state_dict().keys() + ] + self.vision_projection.register_load_state_dict_post_hook( + partial(_load_state_dict_hook_ignore_param_names, vision_projection_param_names) + ) + + self._img_seq_len = get_num_image_embeddings( + img_h, + img_w, + patch_dim, + vision_transformer_config.vision_model_type, + drop_vision_class_token, + class_token_len, + pixel_shuffle, + tile_tags is not None, # Tile tags enabled/disabled. + ) + + self.image_token_index = image_token_index + self._pixel_shuffle = pixel_shuffle + self._tile_tags = tile_tags + + def shared_embedding_or_output_weight(self): + """This is a convenience method to surface the language model's word embeddings, which is + necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" + if self.add_decoder: + return self.language_model.shared_embedding_or_output_weight() + return None + + def set_input_tensor(self, input_tensor) -> None: + """Set model chunk input tensor.""" + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1, 'input_tensor should only be length 1 for llava' + + if self.add_encoder and self.add_decoder: + self.vision_model.set_input_tensor(input_tensor[0]) + elif self.add_encoder: + self.vision_model.set_input_tensor(input_tensor[0]) + elif self.pre_process: + self.encoder_hidden_state = input_tensor[0] + else: + self.language_model.set_input_tensor(input_tensor[0]) + + def freeze( + self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool + ): + """Freeze model modules. + + Make specific modules non-trainable by setting requires_grad to False. + + Args: + freeze_language_model (bool): Freeze the language model module. + freeze_vision_model (bool): Freeze the vision model module. + freeze_vision_projection (bool): Freeze the vision projection module. + """ + modules = [] + if freeze_language_model and self.language_model is not None: + modules.append(self.language_model) + if freeze_vision_model and self.vision_model is not None: + modules.append(self.vision_model) + if freeze_vision_projection and self.vision_projection is not None: + modules.append(self.vision_projection) + + for module in modules: + for param in module.parameters(): + param.requires_grad = False + + def _preprocess_data( + self, + image_embeddings, + language_embeddings, + input_ids, + loss_mask, + labels, + use_inference_kv_cache, + inference_params, + image_token_index, + num_image_tiles, + image_token_mask=None, + ): + """Preprocess input data before input to language model. + + This function is adopted from + https://github.com/huggingface/transformers/blob/85817d98fb60977c97e3014196a462b732d2ed1a/src/transformers/models/llava_next/modeling_llava_next.py#L409 + for our input data conventions. + + image_token_index = -200 indicates the image position in the input_ids = [0, 1, -200, 2, 3] + and labels = [1, -200, 2, 3, 4], for example. + We want to replace the image position (-200) with image_embeddings and return the following: + - final_embeddings = [0, 1, image_embeddings, 2, 3], + - final_labels = [1, -100, 2, 3, 4] + - final_loss_mask = [1, 0, 0, 1, 1] + + This function handles samples without images (text-only sample). It also handles samples + with images that are split into multiples tiles. + + If pipeline parallelism is not used, then self.pre_process and self.post_process + are both True and we update both input embeddings, labels and loss masks (if available). + + If pipeline parallelism is used, then we do the following + - the first language model chunk has self.pre_process = True and + self.post_process = False. We update input embeddings. + - the middle language model chunk(s) has self.pre_process = False and + self.post_process = False. We don't need to update anything. + - the last language model chunk has self.pre_process = False and + self.post_process = True. We update labels and loss mask. + + TODO: This function should adjust the attention mask too. + Currently, we assume the language model uses a causal mask. + + Returns: + final_embedding (torch.Tensor): image and text embeddings [combined_seq_len, b, h]. + final_labels (torch.Tensor): labels for image and text positions [b, combined_seq_len]. + final_loss_mask (torch.Tensor): loss mask [b, combined_seq_len]. + """ + assert self.add_decoder, "input text preprocessing is only needed for the language model" + + # No pre- or postprocessing needed. + # With pipeline parallel > 2, this means a chunk in the middle of the model. + if not self.pre_process and not self.post_process: + return None, None, None + + # If using the inference KV cache, the image tokens are already computed. + if use_inference_kv_cache: + return language_embeddings, loss_mask, labels + + img_seq_len = self._img_seq_len + batch_size, text_seq_len = input_ids.shape + # input_ids seq len is expected to be sharded by CP size + if self.context_parallel_lm: + text_seq_len *= self.context_parallel_lm + + has_labels = labels is not None + if has_labels: + assert ( + labels.shape == loss_mask.shape + ), f"mismatching labels shape {labels.shape} and loss mask shape {loss_mask.shape}" + + # Create indices for new text and label positions. + with torch.no_grad(): + if image_token_mask is None: + assert ( + self.context_parallel_lm <= 1 + ), "image_token_mask cannot be inferred from input_ids if using \ + Context Parallelism. Please provide in forward_step" + image_token_mask = input_ids == image_token_index + num_images_per_sample = torch.sum(image_token_mask, dim=-1) + + # Number of tiles per sample. + num_image_tiles_batch = num_image_tiles.split(num_images_per_sample.tolist(), dim=0) + num_image_tiles_batch = torch.tensor( + [x.sum() for x in num_image_tiles_batch], device=input_ids.device + ) + + # Sequence length for each sample is the image sequence length multiplied by + # the number of tiles for that image, minus image token indices, + # plus text sequence length. + seq_lens = num_image_tiles_batch * img_seq_len - num_images_per_sample + text_seq_len + max_seq_len = seq_lens.max() + # Pipeline parallel expects fixed input size. Check if we need to pad. + if ( + self._language_is_pipeline_parallel + and max_seq_len < self._language_max_sequence_length + and inference_params is None + ): + max_seq_len = self._language_max_sequence_length + + batch_indices, non_image_indices = torch.where(image_token_mask != True) + + # New position ids for the text tokens, shifted by the image sequence length. + # E.g. for input_ids = [-200, 1, 2, 3] and img_seq_len = 576, we get + # new_position_ids = [576, 577, 578, 579]. text_position_ids are then [577, 578, 579]. + image_token_mask_lens = image_token_mask.int().clone() + # -1 is for the removed image token index. + image_token_mask_lens[image_token_mask] = num_image_tiles * img_seq_len - 1 + # +1 is needed here for the cumulative sum. -1 is adjusting for zero-based indexing. + new_position_ids = torch.cumsum((image_token_mask_lens + 1), dim=-1) - 1 + text_position_ids = new_position_ids[batch_indices, non_image_indices] + + # Labels are shifted to left by one. + # So, shift text position ids and non-image indices to left by one. + if has_labels: + label_text_position_ids = text_position_ids - 1 + valid_label_text_position_ids = label_text_position_ids >= 0 + label_text_position_ids = label_text_position_ids[valid_label_text_position_ids] + + label_batch_indices = batch_indices[valid_label_text_position_ids] + + label_non_image_indices = non_image_indices - 1 + valid_label_non_image_indices = label_non_image_indices >= 0 + label_non_image_indices = label_non_image_indices[valid_label_non_image_indices] + + # Create a mask for the image embedding positions. + images_mask = torch.full( + (batch_size, max_seq_len), True, dtype=torch.bool, device=input_ids.device + ) + # No images in the text positions. + images_mask[batch_indices, text_position_ids] = False + # Samples can have different amount of images tokens. + # new_position_ids[:, -1] gives the last text position id for each sample. + # Padding is needed when the number of image tokens differs. + first_padding_idx = new_position_ids[:, -1] + 1 + images_mask[ + torch.arange(max_seq_len, device=first_padding_idx.device).repeat(batch_size, 1) + >= first_padding_idx.unsqueeze(1) + ] = False + + # Create the final input embedding (if this is the first language model stage). + final_embedding = None + if self.pre_process: + embed_dim = language_embeddings.shape[-1] + final_embedding = torch.zeros( + batch_size, + max_seq_len, + embed_dim, + dtype=language_embeddings.dtype, + device=language_embeddings.device, + ) + + # Put text embeddings to the text positions in the result tensor. + final_embedding[batch_indices, text_position_ids] = language_embeddings[ + batch_indices, non_image_indices + ] + + # Put image embeddings to image positions. + final_embedding[images_mask] = ( + image_embeddings.permute(1, 0, 2).reshape(-1, embed_dim).contiguous() + ) + + # Create the final labels and loss mask (if this is the last language model stage). + final_labels, final_loss_mask = None, None + if self.post_process and has_labels: + final_labels = torch.full( + (batch_size, max_seq_len), IGNORE_INDEX, dtype=labels.dtype, device=labels.device + ) + final_loss_mask = torch.full( + (batch_size, max_seq_len), 0, dtype=loss_mask.dtype, device=loss_mask.device + ) + + # Put text labels and loss mask to the text positions. + final_labels[label_batch_indices, label_text_position_ids] = labels[ + label_batch_indices, label_non_image_indices + ] + + final_loss_mask[batch_indices, text_position_ids] = loss_mask[ + batch_indices, non_image_indices + ] + + # For labels, pick the last label index that got dropped by the shift to left. + label_extra_text_position_ids = seq_lens - 1 + batch_range = torch.arange(len(label_extra_text_position_ids)) + final_labels[batch_range, label_extra_text_position_ids] = labels[batch_range, -1] + + # Loss mask the image positions. + final_loss_mask[images_mask] = 0 + + # Loss mask last text position just before an image + # so that text token does not need to predict the first image token. + batch_image_indices, image_indices = torch.where(image_token_mask) + # Indices just before image tokens. If it's -1, skip it. + before_image_indices = image_indices - 1 + valid = before_image_indices >= 0 + valid_batch_image_indices = batch_image_indices[valid] + valid_before_image_indices = before_image_indices[valid] + # Map those indices those position ids. + valid_before_image_indices = new_position_ids[ + valid_batch_image_indices, valid_before_image_indices + ] + + final_loss_mask[valid_batch_image_indices, valid_before_image_indices] = 0 + + if final_embedding is not None and final_labels is not None: + assert ( + final_embedding.shape[:2] == final_labels.shape == final_loss_mask.shape + ), "unexpected shapes after data preprocessing" + + if final_embedding is not None: + # Truncate if exceeding the language model's max sequence length. + if final_embedding.shape[1] > self._language_max_sequence_length: + final_embedding = final_embedding[:, : self._language_max_sequence_length] + # Transpose to [s,b,h] if not using CP because CP Sharding expects seq in dim=1 + if self.context_parallel_lm == 1: + final_embedding = final_embedding.transpose(1, 0).contiguous() + + truncate_labels = ( + final_labels is not None and final_labels.shape[1] > self._language_max_sequence_length + ) + if truncate_labels: + final_labels = final_labels[:, : self._language_max_sequence_length] + final_loss_mask = final_loss_mask[:, : self._language_max_sequence_length] + + return final_embedding, final_labels, final_loss_mask + + def _process_embedding_token_parallel( + self, combined_embeddings, new_labels, new_loss_mask, packed_seq_params + ): + """Processes the input data for model parallelism support. + + When using sequence parallelism (SP) or context parallelism (CP), the sequence is sharded + across different GPUs. This function helps ensure that the sharding is done correctly by + 1. Calculates `padding_factor` which determines based on how many chunks we expect to shard + the sequence + 2. Calculates and pads the inputs to necessary length to ensure equal sized chunks + 3. Creates/Modifies PackedSeqParams which helps mask padded tokens during calculations + 4. Performs any layout changes if necessary + 5. Distributes the sequence across GPUs for SP and CP + + Context Parallelism is a feature that helps improve memory efficiency for + long sequence training by distributing sequence across CP ranks. + It requires token length to be divisible by (CP size *2) to ensure proper load balance. + Please refer to `get_batch_on_this_cp_rank` function for more details. + + Sequence Parallelism is a feature that helps improve memory efficiency for + long sequence training by distributing sequence across TP ranks. + It requires token length to be divisible by TP size. + + Returns: + combined_embeddings (torch.Tensor): image and text embeddings combined and distributed. + new_labels (torch.Tensor): Distributed labels for image and text positions. + new_loss_mask (torch.Tensor): Distributed loss mask. + packed_seq_params (PackedSeqParams): Dict with padded token information. + + """ + # combined_embeddings - `s,b,h` if not using CP, `b,s,h` if using CP + batch_size = ( + combined_embeddings.shape[0] + if self.context_parallel_lm > 1 + else combined_embeddings.shape[1] + ) + seq_dim = 1 if self.context_parallel_lm > 1 else 0 + + padding_mask_type = 'padding' in str( + self.language_model.transformer_layer_spec.submodules.self_attention.params.get( + 'attn_mask_type', '' + ) + ) + if self.sequence_parallel_lm and self.tp_comm_overlap_lm: + assert ( + combined_embeddings.shape[seq_dim] == self._language_max_sequence_length + ) or padding_mask_type, f"TP Comm overlap either requires Vision+Text token length \ + == language_max_sequence_length or mask type to be set to padding/padding_causal" + + if padding_mask_type: + # Calculate the padded sequence length needed to support SP and CP + # SP and CP are used to distributed the sequence across GPUs to improve + # memory efficiency and enable very long context training. + # To distribute workload equally, we need to ensure that the sequence is + # divisible by the appropriate padding factor calculated below. + padding_factor = None + padded_seq_len = None + mp_padding_needed = 0 + if self.context_parallel_lm > 1 and self.sequence_parallel_lm: + padding_factor = self.tensor_model_parallel_size_lm * self.context_parallel_lm * 2 + elif self.context_parallel_lm > 1: + padding_factor = self.context_parallel_lm * 2 + elif self.sequence_parallel_lm: + padding_factor = self.tensor_model_parallel_size_lm + + padded_seq_len = int( + (combined_embeddings.shape[seq_dim] + (padding_factor - 1)) + // padding_factor + * padding_factor + ) + + assert ( + padded_seq_len <= self._language_max_sequence_length + ), f"Sequence length after padding {padded_seq_len} for SP/CP has exceeded \ + language_max_sequence_length. Ensure language_max_sequence_length is \ + divisible by SP/CP factor: {padding_factor}" + + if self.sequence_parallel_lm and self.tp_comm_overlap_lm: + # TP Comm overlap initializes the user buffer shape used for communication + # at the beginning of training run and the same shape is expected to be + # used throughout the training. + # Pad to language_max_sequence_length to use TP Comm overlap. + assert ( + self._language_max_sequence_length % padding_factor == 0 + ), f"TP Comm overlap uses language_max_sequence_length \ + which needs to be divisible by SP/CP factor {padding_factor}" + padded_seq_len = self._language_max_sequence_length + + assert ( + packed_seq_params is not None + ), "Please provide PackedSeqParams dict when using SP or CP with padding" + valid_seqlens = packed_seq_params.cu_seqlens_q[1:] - packed_seq_params.cu_seqlens_q[:-1] + valid_seq_len = max(valid_seqlens) + assert ( + padded_seq_len >= valid_seq_len + ), f"Padded Seq Len calculated for model parallelism: {padded_seq_len} \ + is shorter than expected valid token len {valid_seq_len} provided." + + mp_padding_needed = padded_seq_len - combined_embeddings.shape[seq_dim] + if mp_padding_needed > 0: + new_labels = torch.nn.functional.pad( + new_labels, (0, mp_padding_needed), value=IGNORE_INDEX + ) + new_loss_mask = torch.nn.functional.pad(new_loss_mask, (0, mp_padding_needed)) + if self.context_parallel_lm > 1: + combined_embeddings = torch.nn.functional.pad( + combined_embeddings, (0, 0, 0, mp_padding_needed) + ) + else: + combined_embeddings = torch.nn.functional.pad( + combined_embeddings, (0, 0, 0, 0, 0, mp_padding_needed) + ) + + # Update PackedSeqParams if padding needed beyond user provided PackedSeqParams + packed_seq_params.max_seqlen_q = padded_seq_len + packed_seq_params.max_seqlen_kv = padded_seq_len + cu_seqlens_padded = None + # We need cu_seqlens_q_padded/cu_seqlens_kv_padded when doing + # CP+Padding to support accurate Attention with THD format. + if self.context_parallel_lm > 1: + cu_seqlens_padded = torch.arange( + 0, + (batch_size + 1) * (padded_seq_len), + step=(padded_seq_len), + dtype=torch.int32, + device=combined_embeddings.device, + ) + packed_seq_params.cu_seqlens_q_padded = cu_seqlens_padded + packed_seq_params.cu_seqlens_kv_padded = cu_seqlens_padded + packed_seq_params.qkv_format = 'thd' + else: + packed_seq_params.qkv_format = 'sbhd' + + if self.context_parallel_lm > 1: + # Distribute sequence across CP ranks + batch = get_batch_on_this_cp_rank( + { + "combined_embeddings": combined_embeddings, + "new_labels": new_labels, + "new_loss_mask": new_loss_mask, + } + ) + + combined_embeddings = batch["combined_embeddings"] # [B, S/CP, H] + new_labels = batch["new_labels"] + new_loss_mask = batch["new_loss_mask"] + + if getattr(packed_seq_params, 'qkv_format', None) == 'thd': + # If PackedSeqParams requires THD format, + # reshape embedding from [B,S,H] to [T,1,H] where T=B*S + combined_embeddings = ( + combined_embeddings.contiguous() + .view(combined_embeddings.shape[0] * combined_embeddings.shape[1], -1) + .unsqueeze(1) + ) + new_labels = new_labels.view(new_labels.shape[0] * new_labels.shape[1]).unsqueeze(0) + new_loss_mask = new_loss_mask.view( + new_loss_mask.shape[0] * new_loss_mask.shape[1] + ).unsqueeze(0) + else: + combined_embeddings = combined_embeddings.transpose( + 1, 0 + ).contiguous() # [B,S/CP,H] -> [S/CP,B,H] + + if self.sequence_parallel_lm: + combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region( + combined_embeddings + ) # [S/(CP*TP),B,H] + + return combined_embeddings, new_labels, new_loss_mask, packed_seq_params + + def _apply_tile_tagging(self, image_embeddings, num_image_tiles): + """Apply tile tagging. + + The image embeddings of multiple tiles are prepended with tile tags such as . + This implements the method used in NVLM https://arxiv.org/pdf/2409.11402. + + Args: + image_embeddings (torch.Tensor): [img_seq_len, num_tiles, h_language]. + num_image_tiles (torch.Tensor): Number of tiles for each input image [num_images]. + + Returns: + torch.Tensor: Tile tags prepended to image embeddings. + [tile_seq_len (=5) + img_seq_len, num_tiles, h_language] + """ + assert ( + num_image_tiles.shape[0] == 1 and len(num_image_tiles) == 1 + ), "multiple input images are not supported yet." + + num_tiles = num_image_tiles[0].item() + tile_tags = self._tile_tags[: num_tiles - 1] + [self._tile_tags[-1]] + + # [num_tiles, tile_seq_len (=5)] + tile_tag_input_ids = torch.tensor( + tile_tags, dtype=torch.int64, device=num_image_tiles.device + ) + + # [tile_seq_len, num_tiles, h_language] + tile_tag_embeds = self.language_model.embedding(tile_tag_input_ids, position_ids=None) + + # [num_tiles, dim] should be the same same + assert tile_tag_embeds.shape[1:] == image_embeddings.shape[1:] + + image_embeddings = torch.cat([tile_tag_embeds, image_embeddings]) + + return image_embeddings # [tile_seq_len + img_seq_len, num_tiles, h_language] + + def forward( + self, + images: torch.Tensor, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + labels: Optional[torch.Tensor] = None, + loss_mask: Optional[torch.Tensor] = None, + inference_params: Optional[InferenceParams] = None, + num_image_tiles: Optional[List[int]] = None, + image_token_index: Optional[int] = None, + runtime_gather_output: Optional[bool] = None, + image_token_mask: Optional[torch.Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + ) -> torch.Tensor: + """Forward function of the LLaVA model. + + Args: + images (torch.Tensor): input images of shape [num_tiles, img_h, img_w]. + num_tiles means the number of image tiles in this batch. + num_tiles = 0 if the batch doesn't contain images. + input_ids (torch.Tensor): input text ids [batch, text_seq_len]. + position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. + attention_mask (torch.Tensor): Language model attention mask + [batch, 1, 1, combined_seq_len]. NOTE: attention_mask is typically None and + attn_mask_type in layer specs determines the attention mask used. + labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. + loss_mask (torch.Tensor): Text loss mask [batch, text_seq_len]. + inference_params (InferenceParams): Inference-time parameters including KV cache. + num_image_tiles (list of int): Number of tiles per image. Default 1 tile per image. + image_token_index (int): ID for input images. Default None means `image_token_index` + arg in the constructor will be used. + runtime_gather_output (bool): Gather output at runtime. Default None means + `parallel_output` arg in the constructor will be used. + image_token_mask (torch.Tensor): Tensor indicating the location of + image token index in input_ids. + packed_seq_params (PackedSeqParams): 1) If using sequence packing, must contain + subsample length information. 2) If using SP/CP with padding mask type, + must contain padded token information. + + Returns: + output (torch.Tensor): Loss of shape [b, s] if labels are provided, + otherwise logits of shape [b, s, vocab_size]. + loss_mask (torch.Tensor): Loss mask expanded to combined sequence length. Shape [b, s]. + """ + use_inference_kv_cache = ( + inference_params is not None + and "image_tokens_count" in inference_params.key_value_memory_dict + ) + has_images = images is not None and images.shape[0] > 0 + + # If running inference, we can skip image token computation + # if they were computed already earlier for this sample. + if use_inference_kv_cache: + image_embeddings = None + elif self.add_encoder and not has_images: + # If no images provided, use an empty image embeddings tensor. + image_embeddings = torch.tensor([], dtype=images.dtype, device=images.device).reshape( + 0, 0, 0 + ) + elif self.add_encoder and has_images: + image_embeddings = self.vision_model(images) # [num_tiles, img_seq_len, h_vision] + if self._drop_vision_class_token: + image_embeddings = image_embeddings[:, self.vision_model.class_token_len :, :] + + if self._pixel_shuffle: + image_embeddings = pixel_shuffle( + image_embeddings + ) # [num_tiles, img_seq_len_shuffled, h_vision_shuffled] + + # contiguous() required as `permute` can sparsify the tensor and this breaks pipelining + image_embeddings = image_embeddings.permute( + 1, 0, 2 + ).contiguous() # [img_seq_len, num_tiles, h_vision] + + # map vision model output size to language model input size. + image_embeddings = self.vision_projection( + image_embeddings + ) # [img_seq_len, num_tiles, h_language] + + # Apply tile tagging if enabled and an image token is present. + if self._tile_tags is not None and torch.any(input_ids == self.image_token_index): + image_embeddings = self._apply_tile_tagging(image_embeddings, num_image_tiles) + + # TODO: Support batched inference. + # In inference, the language model KV cache will be updated for image token positions. + # Store the image tokens sequence length to be used as an offset to the KV cache later. + if inference_params is not None: + inference_params.key_value_memory_dict["image_tokens_count"] = ( + image_embeddings.shape[0] * image_embeddings.shape[1] + ) + else: + image_embeddings = self.encoder_hidden_state + + if not self.add_decoder: + return image_embeddings, loss_mask + + language_embeddings = None + if self.pre_process: + input_ids_text = input_ids.clone() + input_ids_text[input_ids_text == self.image_token_index] = 0 + # Note: This adds absolute position embedding but not RoPE. + # Each image is counted as one position. + # RoPE is added in language_model forward. Each image embedding is one position. + language_embeddings = self.language_model.embedding( + input_ids=input_ids_text, position_ids=position_ids + ) # [text_seq_len, b, h_language] + # Gather the language embeddings back. We need the full embedding to insert + # image embeddings and then scatter again to avoid load imbalance. + if self.context_parallel_lm > 1: + cp_group = get_context_parallel_group() + language_embeddings, _ = gather_along_first_dim(language_embeddings, cp_group) + + language_embeddings = language_embeddings.transpose( + 1, 0 + ).contiguous() # [b, text_seq_len, h_language] + + # Assume 1 tile per image if the number of tiles is not provided. + if num_image_tiles is None: + num_image_tiles = torch.ones(images.shape[0], dtype=torch.int, device=input_ids.device) + + combined_embeddings, new_labels, new_loss_mask = self._preprocess_data( + image_embeddings, + language_embeddings, + input_ids, + loss_mask, + labels, + use_inference_kv_cache, + inference_params, + image_token_index if image_token_index is not None else self.image_token_index, + num_image_tiles, + image_token_mask, + ) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len] + + if self.context_parallel_lm > 1 or self.sequence_parallel_lm: + combined_embeddings, new_labels, new_loss_mask, packed_seq_params = ( + self._process_embedding_token_parallel( + combined_embeddings, new_labels, new_loss_mask, packed_seq_params + ) + ) + + output = self.language_model( + input_ids=None, + position_ids=None, + attention_mask=attention_mask, + decoder_input=combined_embeddings, + labels=new_labels, + inference_params=inference_params, + runtime_gather_output=runtime_gather_output, + packed_seq_params=packed_seq_params, + ) + + return output, new_loss_mask + + +def _load_state_dict_hook_ignore_param_names( + param_names: List[str], module: torch.nn.Module, incompatible_keys: namedtuple +): + """Hook to ignore missing keys during checkpoint loading. + + By default, this should not be used to avoid accidentally missing weights in checkpoint loading. + + Example use case: Use this if you want to load a checkpoint that contains vision and language + model weights but not the vision projection weights. + + Args: + param_names (list str): Parameter names allowed to be missing when calling load_state_dict. + module (torch.nn.Module): The torch module this hook applies to. Required by the torch API. + incompatible_keys (namedtuple): Namedtuple with fields missing_keys and unexpected_keys, + which collect the missing and unexpected keys, respectively. + """ + for param_name in param_names: + if param_name in incompatible_keys.missing_keys: + logging.getLogger(__name__).warning( + f"{param_name} being removed from incompatible_keys.missing_keys in LlavaModel" + ) + incompatible_keys.missing_keys.remove(param_name) + + +# pylint: disable-next=line-too-long +# Based on https://github.com/OpenGVLab/InternVL/blob/c7c5af1a8930b4862afe8ed14672307082ef61fa/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py#L218 +# Copyright (c) 2023 OpenGVLab. +def pixel_shuffle(x, scale_factor=0.5, version=2): + """Pixel shuffle based on InternVL but adapted for our use case. + + Args: + x (torch.Tensor): Vision model outputs [num_tiles, img_seq_len, h_vision] + version (int): Implementation version. + + Returns: + Shuffled vision model outputs [num_tiles, (sq ** 2) * (scale ** 2), h_vision / (scale ** 2)] + """ + h = w = int(x.shape[1] ** 0.5) # sq + x = x.reshape(x.shape[0], h, w, -1) # [num_tiles, sq, sq, h_vision] + + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) + x = x.view( + n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)) + ) + + if version == 2: + x = x.permute(0, 2, 1, 3).contiguous() + + x = x.reshape(x.shape[0], -1, x.shape[-1]) + + return x diff --git a/toolbox/Megatron-LM/patch/megatron/core/transformer/transformer_block.py b/toolbox/Megatron-LM/patch/megatron/core/transformer/transformer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc0919e88b83d6fedb33d5539de11f7ead16f35 --- /dev/null +++ b/toolbox/Megatron-LM/patch/megatron/core/transformer/transformer_block.py @@ -0,0 +1,617 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. + +from contextlib import nullcontext +from dataclasses import dataclass +from typing import List, Optional, Union + +import torch +from torch import Tensor + +from megatron.core import InferenceParams, parallel_state, tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding +from megatron.core.fusions.fused_layer_norm import FusedLayerNorm +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import BaseTransformerLayer, TransformerLayer +from megatron.core.transformer.utils import sharded_state_dict_default +from megatron.core.utils import is_te_min_version, make_viewless_tensor + +try: + from megatron.core.extensions.transformer_engine import ( + TEDelayedScaling, + TENorm, + get_cpu_offload_context, + te_checkpoint, + ) + + HAVE_TE = True + LayerNormImpl = TENorm +except ImportError: + HAVE_TE = False + get_cpu_offload_context = None + + try: + import apex # pylint: disable=unused-import + + LayerNormImpl = FusedLayerNorm + + except ImportError: + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + LayerNormImpl = WrappedTorchNorm + + +def get_num_layers_to_build(config: TransformerConfig) -> int: + """ + Determine the number of transformer layers to build for the current pipeline stage. + Args: + config (TransformerConfig): Configuration object containing transformer model parameters. + + Returns: + int: The number of layers to be built for the current pipeline stage. + """ + if config.first_pipeline_num_layers is not None or config.last_pipeline_num_layers is not None: + assert ( + parallel_state.get_virtual_pipeline_model_parallel_world_size() is None + ), "Uneven number of layer not compatible with interleaved pipeline schedule" + + # Number of layers to distribute over rest of pipeline stages + layers_to_distribute = config.num_layers + # Number of pipeline stages left for distributing transformer layers + pipeline_stages_left = parallel_state.get_pipeline_model_parallel_world_size() + + if config.first_pipeline_num_layers is not None: + layers_to_distribute -= config.first_pipeline_num_layers + pipeline_stages_left -= 1 + if parallel_state.is_pipeline_first_stage(): + return config.first_pipeline_num_layers + + if config.last_pipeline_num_layers is not None: + layers_to_distribute -= config.last_pipeline_num_layers + pipeline_stages_left -= 1 + if parallel_state.is_pipeline_last_stage(): + return config.last_pipeline_num_layers + + assert ( + layers_to_distribute % pipeline_stages_left == 0 + ), "With uneven pipelineing the left over layers must be divisible by left over stages" + num_layers_per_pipeline_rank = layers_to_distribute // pipeline_stages_left + else: + pipeline_ranks = config.pipeline_model_parallel_size + num_layers_per_pipeline_rank = config.num_layers // pipeline_ranks + + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + # Interleaved pipeline parallelism: + # Number of layers in each model chunk is the number of layers in the stage, + # divided by the number of model chunks in a stage. + # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0] [2] [4] [6] + # Stage 1: [1] [3] [5] [7] + # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0, 1] [4, 5] + # Stage 1: [2, 3] [6, 7] + + vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + + num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size + + num_layers_to_build = num_layers_per_virtual_rank + + else: + # Non-interleaved pipeline parallelism: + # Each stage gets a contiguous set of layers. + + num_layers_to_build = num_layers_per_pipeline_rank + + return num_layers_to_build + + +@dataclass +class TransformerBlockSubmodules: + """ + Dataclass for specifying the submodules of a transformer block. + + This class defines the structure for configuring the layers and normalization + within a transformer block, allowing for flexible and customizable architecture designs. + + Args: + layer_specs (List[ModuleSpec], optional): A list of module specifications for + the layers within the transformer block. Each specification typically + defines a complete transformer layer (e.g., self-attention, feed-forward network). + layer_norm (Optional[Union[ModuleSpec, torch.nn.Module]], optional): Specification + or instance of the layer normalization to be applied. + """ + + layer_specs: List[ModuleSpec] = None + layer_norm: Optional[Union[ModuleSpec, torch.nn.Module]] = None + + +def _get_block_submodules( + config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec] +) -> TransformerBlockSubmodules: + """ + Retrieve or construct TransformerBlockSubmodules based on the provided specification. + + Args: + config (TransformerConfig): Configuration object for the transformer model. + spec (Union[TransformerBlockSubmodules, ModuleSpec]): Specification for the + transformer block submodules. Can be either a TransformerBlockSubmodules + instance or a ModuleSpec. + + Returns: + TransformerBlockSubmodules: The submodules for the transformer block. + """ + + # Transformer block submodules. + if isinstance(spec, TransformerBlockSubmodules): + return spec + + # ModuleSpec here is generally assumed to be for a transformer layer that + # is implemented in `transformer_layer.py` or if it subclasses + # `BaseTransformerLayer` from the `transformer_layer.py` file. + elif isinstance(spec, ModuleSpec): + if issubclass(spec.module, TransformerBlock): + return spec.submodules + elif issubclass(spec.module, BaseTransformerLayer): + num_layers = get_num_layers_to_build(config) + return TransformerBlockSubmodules( + layer_specs=[spec] * num_layers, layer_norm=LayerNormImpl + ) + else: + raise Exception(f"specialize for {spec.module.__name__}.") + else: + raise Exception(f"specialize for {type(spec).__name__}.") + + +class TransformerBlock(MegatronModule): + """Transformer class.""" + + def __init__( + self, + config: TransformerConfig, + spec: Union[TransformerBlockSubmodules, ModuleSpec], + post_layer_norm: bool = True, + pre_process: bool = True, + post_process: bool = True, + ): + super().__init__(config=config) + + self.submodules = _get_block_submodules(config, spec) + self.post_layer_norm = post_layer_norm + self.pre_process = pre_process + self.post_process = post_process + # Dictionary to store CUDA graphs. Number of items in the dictionary = len(self.layers). + # Item `i` in the dictionary is a list of `N` CUDA graphs for layer 'i' where N is the + # number of microbatches. Multiple CUDA graphs per layer is required to support + # pipelining which requires running FWD graph of multiple microbatches before BWD graph. + # To enable CUDA graph, this dictionary should be populated in the model training script + # with the graphs returned by make_graphed_callables API before the first trainng step. + self.cuda_graphs = {} + self.current_microbatch = -1 + + # required for pipeline parallel schedules + self.input_tensor = None + + self.checkpoint_core_attention = self.config.recompute_granularity == 'selective' + + if get_cpu_offload_context is not None: + (self.offload_context, self.group_prefetch_offload_commit_async) = ( + get_cpu_offload_context( + self.config.cpu_offloading, + self.config.cpu_offloading_num_layers, + self.config.num_layers, + self.config.cpu_offloading_activations, + self.config.cpu_offloading_weights, + ) + ) + self.config._cpu_offloading_context = ( + self.offload_context if self.config.cpu_offloading else None + ) + else: + assert ( + self.config.cpu_offloading is False + ), "CPU Offloading is enabled when TE is not present" + + self.offload_context, self.group_prefetch_offload_commit_async = nullcontext(), None + self.config._cpu_offloading_context = None + + self._build_layers() + self.num_layers_per_pipeline_rank = len(self.layers) + self.tp_only_amax_red = config.tp_only_amax_red + + def _build_layers(self): + # Transformer layers. + # @jcasper can we improve how we deal with layer_number? + # currently it's only used in CoreAttention? + # if self.apply_query_key_layer_scaling: + # coeff = self.layer_number + # self.norm_factor *= coeff + def build_layer(layer_spec, layer_number): + return build_module(layer_spec, config=self.config, layer_number=layer_number) + + # offset is implicit in TransformerLayer + self.layers = torch.nn.ModuleList( + [ + build_layer(layer_spec, i + 1) + for i, layer_spec in enumerate(self.submodules.layer_specs) + ] + ) + + # @TODO: add back standalone_embedding_stage (see issue #293) + # In pipeline parallelism, we want to add this LN only to the last stage of the pipeline + # self.post_process and self.post_layer_norm guide this behavior + if self.submodules.layer_norm and self.post_process and self.post_layer_norm: + self.final_layernorm = build_module( + self.submodules.layer_norm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + else: + self.final_layernorm = None # Either this or nn.Identity + + def _get_layer(self, layer_number: int): + return self.layers[layer_number] + + def _checkpointed_forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor, + context_mask: Tensor, + rotary_pos_emb: Tensor, + attention_bias: Tensor, + packed_seq_params: PackedSeqParams, + ): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward( + hidden_states, attention_mask, context, context_mask, rotary_pos_emb + ): + for index in range(start, end): + layer = self._get_layer(index) + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + inference_params=None, + packed_seq_params=packed_seq_params, + ) + return hidden_states, context + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + + if self.config.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + hidden_states, context = checkpoint_handler( + custom(layer_idx, layer_idx + self.config.recompute_num_layers) + ) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + if self.config.fp8 and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if ( + layer_idx >= recompute_skip_num_layers + and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers + ): + hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + else: + hidden_states, context = custom(layer_idx, layer_idx + 1)( + hidden_states, attention_mask, context, context_mask, rotary_pos_emb + ) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def set_input_tensor(self, input_tensor: Tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def get_cuda_graph_optional_args( + self, + attention_mask: Tensor, + context: Tensor, + context_mask: Tensor, + rotary_pos_emb: Tensor, + attention_bias: Tensor, + inference_params: InferenceParams, + packed_seq_params: PackedSeqParams, + ): + """Get optional tensor arguments for CUDA graph.""" + + optional_inputs = {} + optional_inputs['is_first_microbatch'] = self.current_microbatch == 0 + try: + import transformer_engine.pytorch as te # pylint: disable=unused-import + + if is_te_min_version("1.10.0", check_equality=False): + assert not any( + [attention_mask, context, context_mask, rotary_pos_emb] + ), "Keyword Arguments not supported with CUDA graph." + else: + optional_inputs['attention_mask'] = attention_mask + optional_inputs['context'] = context + optional_inputs['context_mask'] = context_mask + optional_inputs['rotary_pos_emb'] = rotary_pos_emb + except ImportError: + raise RuntimeError("CUDAGraph requires TransformerEngine, but not installed") + return optional_inputs + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor = None, + context_mask: Tensor = None, + rotary_pos_emb: Tensor = None, + rotary_pos_cos: Tensor = None, + rotary_pos_sin: Tensor = None, + attention_bias: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + ): + """ + Perform the forward pass through the transformer block. + + This method handles the core computation of the transformer, including + self-attention, optional cross-attention, and feed-forward operations. + + Args: + hidden_states (Tensor): Input tensor of shape [s, b, h] where s is the + sequence length, b is the batch size, and h is the hidden size. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_params (InferenceParams, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + if self.config.fp8: + import transformer_engine # To keep out TE dependency when not training in fp8 + + if self.config.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif self.config.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + fp8_recipe = TEDelayedScaling( + config=self.config, + fp8_format=fp8_format, + override_linear_precision=(False, False, not self.config.fp8_wgrad), + ) + fp8_group = None + if parallel_state.model_parallel_is_initialized(): + fp8_group = parallel_state.get_amax_reduction_group( + with_context_parallel=True, tp_only_amax_red=self.tp_only_amax_red + ) + fp8_context = transformer_engine.pytorch.fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group + ) + else: + fp8_context = nullcontext() + + with rng_context, fp8_context: + # Forward pass. + if self.config.recompute_granularity == 'full' and self.training: + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + else: + for l_no, layer in enumerate(self.layers): + with self.offload_context: + layer.use_cudagraph = True + if (len(self.cuda_graphs) == 0) or (not self.training): + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + ) + else: + # CUDA graph replay for layer `l_no` and microbatch + # `self.current_microbatch`. TransformerEngine versions>=1.10 + # allow keyword arguments with CUDA graph. However, CUDA graph + # acccepts only Tensor inputs and Tensor outputs. Hence, + # `inference_params` and `packed_seq_params` are excluded from + # input list while output is limited to `hidden_states`. + cg_index = self.current_microbatch % len(self.cuda_graphs[l_no]) + assert not any( + [inference_params, packed_seq_params] + ), "CUDA graph accepts only Tensor inputs." + optional_inputs = self.get_cuda_graph_optional_args( + attention_mask, + context, + context_mask, + rotary_pos_emb, + attention_bias, + inference_params, + packed_seq_params, + ) + hidden_states = self.cuda_graphs[l_no][cg_index]( + hidden_states, **optional_inputs + ) + + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + # Final layer norm. + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) + + return hidden_states + + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: dict = None + ) -> ShardedStateDict: + """ + Generate a sharded state dictionary for the transformer block. + + Args: + prefix (str, optional): Prefix to be added to all keys in the state dict. + Defaults to an empty string. + sharded_offsets (tuple, optional): Tuple of sharding offsets. + metadata (dict, optional): Additional metadata for sharding. + Can specify if layers are non-homogeneous. Defaults to None. + + Returns: + ShardedStateDict: A dictionary containing the sharded state of the model. + """ + assert not sharded_offsets, "Unexpected sharded offsets" + non_homogeneous_layers = metadata is not None and metadata.get( + 'non_homogeneous_layers', False + ) + if self.config.num_moe_experts is not None: + non_homogeneous_layers = True + + sharded_state_dict = {} + + layer_prefix = f'{prefix}layers.' + num_layers = self.config.num_layers + for layer in self.layers: + offset = TransformerLayer._get_layer_offset(self.config) + + global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 + state_dict_prefix = f'{layer_prefix}{global_layer_offset - offset}.' # module list index in TransformerBlock # pylint: disable=line-too-long + if non_homogeneous_layers: + sharded_prefix = f'{layer_prefix}{global_layer_offset}.' + sharded_pp_offset = [] + else: + sharded_prefix = layer_prefix + sharded_pp_offset = [ + (0, global_layer_offset, num_layers) + ] # PP sharding offset for ShardedTensors + layer_sharded_state_dict = layer.sharded_state_dict( + state_dict_prefix, sharded_pp_offset, metadata + ) + replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) + + sharded_state_dict.update(layer_sharded_state_dict) + + # Add modules other than self.layers + for name, module in self.named_children(): + if not module is self.layers: + sharded_state_dict.update( + sharded_state_dict_default( + module, f'{prefix}{name}.', sharded_offsets, metadata + ) + ) + + return sharded_state_dict diff --git a/toolbox/Megatron-LM/patch/megatron/legacy/model/transformer.py b/toolbox/Megatron-LM/patch/megatron/legacy/model/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c807d99c4dc4d09c947499d37a18de67112390 --- /dev/null +++ b/toolbox/Megatron-LM/patch/megatron/legacy/model/transformer.py @@ -0,0 +1,1806 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. + +"""Transformer.""" +import math +import os +from contextlib import nullcontext +from typing import Optional + +import numpy as np +import torch +import torch.nn.functional as F + +from megatron import core +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.legacy.model.enums import AttnMaskType, LayerType, AttnType +from megatron.legacy.model.fused_softmax import FusedScaleMaskSoftmax +from megatron.legacy.model.fused_bias_gelu import bias_gelu_impl +from megatron.core.models.common.embeddings import apply_rotary_pos_emb +from megatron.core.jit import jit_fuser +from megatron.core.num_microbatches_calculator import get_num_microbatches +from megatron.core.parallel_state import ( + get_expert_tensor_and_model_parallel_group, + get_tensor_model_parallel_group, +) +from megatron.core.tensor_parallel import ( + gather_from_sequence_parallel_region, + reduce_scatter_to_sequence_parallel_region, + get_cuda_rng_tracker, + get_data_parallel_rng_tracker_name, +) +from megatron.legacy.model.enums import AttnMaskType, AttnType, LayerType +from megatron.legacy.model.fused_bias_gelu import bias_gelu_impl +from megatron.legacy.model.fused_softmax import FusedScaleMaskSoftmax +from megatron.legacy.model.utils import ( + attention_mask_func, + erf_gelu, + get_norm, + openai_gelu, +) +from megatron.training import get_args, get_timers + +from .module import MegatronModule + +try: + from einops import rearrange +except ImportError: + rearrange = None + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func +except ImportError: + try: + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_func as flash_attn_unpadded_func, + ) + except ImportError: + flash_attn_unpadded_func = None + +""" We use the following notation throughout this file: + h: hidden size + n: number of attention heads + p: number of model parallel partitions + np: n/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + l: number of layers + Transformer takes input of size [s, b, h] and returns a + tensor of the same size. We use the following arguments: + hyperparameters: transformer hyperparameters +""" + +class DropPath(MegatronModule): + """Drop paths (Stochastic Depth) per sample + (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=0.): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_state): + if self.drop_prob == 0. or not self.training: + return hidden_state + keep_prob = 1 - self.drop_prob + # work with diff dim tensors, not just 2D ConvNets + # hidden_state: [s, b, h] + shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2) + random_tensor = keep_prob + \ + torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device) + random_tensor.floor_() # binarize + output = hidden_state.div(keep_prob) * random_tensor + return output + +class ParallelMLP(MegatronModule): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config, is_expert=False): + super(ParallelMLP, self).__init__() + args = get_args() + + self.add_bias = config.add_bias_linear + + ffn_hidden_size = config.ffn_hidden_size + if config.gated_linear_unit: + ffn_hidden_size *= 2 + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + ffn_hidden_size, + config=config, + init_method=config.init_method, + bias=self.add_bias, + gather_output=False, + skip_bias_add=True, + is_expert=is_expert, + ) + + self.bias_gelu_fusion = False + self.activation_func = None + self.swiglu = args.swiglu + + if args.openai_gelu: + self.activation_func = openai_gelu + elif args.onnx_safe: + self.activation_func = erf_gelu + elif args.swiglu: + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + self.activation_func = swiglu + elif args.squared_relu: + def squared_relu(x): + return torch.pow(F.relu(x), 2) + self.activation_func = squared_relu + else: + self.bias_gelu_fusion = args.bias_gelu_fusion + self.activation_func = F.gelu + + # Project back to h. + self.dense_4h_to_h = tensor_parallel.RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + config=config, + init_method=config.output_layer_init_method, + bias=self.add_bias, + skip_bias_add=True, + input_is_parallel=True, + is_expert=is_expert, + ) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) + + if self.bias_gelu_fusion: + assert self.add_bias is True + assert self.activation_func == F.gelu + intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + if bias_parallel is not None: + intermediate_parallel = intermediate_parallel + bias_parallel + intermediate_parallel = self.activation_func(intermediate_parallel) + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias + +def sinkhorn(cost, tol=0.0001): + cost = torch.exp(cost) + d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype) + d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype) + + eps = 0.00000001 + error = 1e9 + d1_old = d1 + while error > tol: + d0 = (1/d0.size(0))*1/(torch.sum(d1*cost,1) + eps) + d1 = (1/d1.size(0))*1/(torch.sum(d0.unsqueeze(1)*cost,0)+eps) + error = torch.mean(torch.abs(d1_old-d1)) + d1_old = d1 + return d1*cost*d0.unsqueeze(1) + + +def get_router_linear_layer(config): + args = get_args() + router = torch.nn.Linear(args.hidden_size, args.num_experts, bias=False) + with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()): + config.init_method(router.weight) + setattr(router.weight, 'sequence_parallel',config.sequence_parallel) + return router + + +class SwitchMLP(MegatronModule): + """ + Routes input to one of N MLP "experts" + """ + def __init__(self, config): + super(SwitchMLP, self).__init__() + args = get_args() + self.router = get_router_linear_layer(config) + self.expert_parallel_size = mpu.get_expert_model_parallel_world_size() + self.sequence_parallel = config.sequence_parallel + self.add_bias = config.add_bias_linear + + assert args.num_experts % self.expert_parallel_size == 0 + self.num_local_experts = args.num_experts // self.expert_parallel_size + local_expert_indices_offset = mpu.get_expert_model_parallel_rank() * self.num_local_experts + self.local_expert_indices = [local_expert_indices_offset + i for i in range(self.num_local_experts)] + + self.local_experts = torch.nn.ModuleList() + for i in range(self.num_local_experts): + self.local_experts.append(ParallelMLP(config, is_expert=True)) + + self.tp_ep_group = get_expert_tensor_and_model_parallel_group() + + def gather_indices(self, local_indices): + """ Gather tensors and concatinate along the first dimension.""" + world_size = torch.distributed.get_world_size(group=self.tp_ep_group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return local_indices + + dim_size = list(local_indices.size()) + dim_size[0] = dim_size[0] * world_size + + # TODO pre allocate memory + output = torch.empty(dim_size, dtype=local_indices.dtype, + device=torch.cuda.current_device()) + torch.distributed._all_gather_base( + output, local_indices.contiguous(), group=self.tp_ep_group + ) + return output + + def forward(self, hidden_states): + # hidden_states: [b, s, h] + args = get_args() + s = hidden_states.size(0) + b = hidden_states.size(1) + h = hidden_states.size(2) + route = self.router(hidden_states).view(-1, args.num_experts) + + # TODO (rprenger) Right now we're just using the sinkhorn algorithm + # for load balancing. There should be an option to do no load balancing + # and the algorithm and parametets should be further tested + if self.training: + with torch.no_grad(): + sinkroute = sinkhorn(route.detach().to(dtype=torch.float32)) + _, max_ind = torch.max(sinkroute, dim=1) + route = torch.sigmoid(route) + max_prob = route[torch.arange(route.size(0)), max_ind] + else: + route = torch.sigmoid(route) + max_prob, max_ind = torch.max(route, dim=1) + + max_prob = torch.unsqueeze(max_prob, 1) + hidden_states = hidden_states.view(-1, hidden_states.size(2)) + + # TODO (rprenger) TODO this could be made easier to read + # Converting [s, b, h] to [s*b, h]. + # Each vector could be routed differently + if self.sequence_parallel or (self.expert_parallel_size > 1): + global_hidden_states = \ + gather_from_sequence_parallel_region(hidden_states, group=self.tp_ep_group) + global_indices = self.gather_indices(max_ind) + else: + global_hidden_states = hidden_states + global_indices = max_ind + + output_total = torch.zeros_like(global_hidden_states) + if self.add_bias: + output_bias_total = torch.zeros_like(global_hidden_states) + + for expert_num, expert in enumerate(self.local_experts): + local_expert_index = self.local_expert_indices[expert_num] + local_indices = (global_indices == local_expert_index).nonzero() + hidden = global_hidden_states[local_indices, :] + output, output_bias = expert(hidden) + output_total[local_indices, :] = output + if self.add_bias: + output_bias = output_bias.expand_as(output) + output_bias_total[local_indices, :] = output_bias + + if self.sequence_parallel or (self.expert_parallel_size > 1): + output_total = \ + reduce_scatter_to_sequence_parallel_region(output_total, group=self.tp_ep_group) + if self.add_bias: + output_bias_total = \ + reduce_scatter_to_sequence_parallel_region(output_bias_total, group=self.tp_ep_group) + + # bias is duplicated across tensor parallelism ranks; + # reduce scatter reduces bias across tensor parallel_ranks + output_bias_total = \ + output_bias_total/mpu.get_tensor_model_parallel_world_size() + + output_total = output_total*max_prob + output_total = output_total.view(s, b, h) + if self.add_bias: + output_bias_total = output_bias_total*max_prob + output_bias_total = output_bias_total.view(s, b, h) + else: + output_bias_total = None + + return output_total, output_bias_total + + +class CoreAttention(MegatronModule): + + def __init__(self, layer_number, config, + attn_mask_type=AttnMaskType.padding): + super(CoreAttention, self).__init__() + self.fp16 = config.fp16 + self.bf16 = config.bf16 + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type + self.sequence_parallel = config.sequence_parallel + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_partition = core.utils.divide(projection_size, + world_size) + self.hidden_size_per_attention_head = core.utils.divide( + projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = core.utils.divide( + config.num_attention_heads, world_size) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + self.fp16, self.bf16, + self.attn_mask_type, + config.masked_softmax_fusion, + attention_mask_func, + self.attention_softmax_in_fp32, + coeff) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, + value_layer, attention_mask): + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.reshape(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor( + (output_size[0]*output_size[1], output_size[2], output_size[3]), + query_layer.dtype, "mpu") + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, alpha=(1.0/self.norm_factor)) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + if not self.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), + output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class FlashSelfAttention(torch.nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, + device=None, dtype=None): + super().__init__() + assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, ' + 'e.g., with pip install flash-attn') + assert rearrange is not None, 'Please install einops first, e.g., with pip install einops' + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, q, k, v): + """Implements the multihead softmax attention. + Arguments + --------- + q, k, v: The tensor containing the query, key, and value. (B, S, H, D) + """ + + assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v))) + assert all((i.is_cuda for i in (q,k,v))) + + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = k.shape[1] + + q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, + device=q.device) + + if self.training: + # during training q,k,v always have same seqlen + assert seqlen_k == seqlen_q + + is_causal = self.causal + cu_seqlens_k = cu_seqlens_q + dropout_p = self.dropout_p + else: + # turn off FA causal mask after first inference autoregressive iteration + # only on first autoregressive step q,k,v have same seqlen + is_causal = seqlen_q == seqlen_k + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, + device=q.device) + dropout_p = 0 + + output = flash_attn_unpadded_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, + dropout_p, + softmax_scale=self.softmax_scale, causal=is_causal + ) + + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + return output + + +class ParallelAttention(MegatronModule): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config, layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=AttnMaskType.padding): + super(ParallelAttention, self).__init__() + args = get_args() + self.layer_number = max(1, layer_number) + self.attention_type = attention_type + self.attn_mask_type = attn_mask_type + self.params_dtype = config.params_dtype + self.sequence_parallel = config.sequence_parallel + self.config = config + self.group_query_attention = args.group_query_attention + self.num_query_groups = args.num_query_groups + + query_projection_size = config.kv_channels * config.num_attention_heads + if self.group_query_attention: + kv_projection_size = args.kv_channels * args.num_query_groups + else: + kv_projection_size = args.kv_channels * args.num_attention_heads + + self.use_flash_attn = args.use_flash_attn \ + and attention_type == AttnType.self_attn \ + and self.attn_mask_type == AttnMaskType.causal + if self.use_flash_attn: + if flash_attn_unpadded_func is None: + raise ImportError('FlashAttention is not installed, please install with ' + 'pip install flash-attn') + assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports ' + 'self-attention for now') + assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only ' + 'supports causal mask for now') + if rearrange is None: + raise ImportError('einops is not installed, please install with pip install einops') + + # Per attention head and per partition values. + world_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_attention_head = core.utils.divide( + query_projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = core.utils.divide( + config.num_attention_heads, world_size) + + if self.group_query_attention: + if args.num_query_groups % world_size != 0: + raise NotImplementedError('Currently the num_query_groups should be ' + 'a multiple of the tensor parallel size') + self.num_query_groups_per_partition = core.utils.divide( + args.num_query_groups, world_size) + else: + self.num_query_groups_per_partition = self.num_attention_heads_per_partition + + # Strided linear layer. + if attention_type == AttnType.self_attn: + self.query_key_value = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + query_projection_size + 2 * kv_projection_size, + config=config, + init_method=config.init_method, + bias=args.add_bias_linear or args.add_qkv_bias, + gather_output=False) + else: + assert attention_type == AttnType.cross_attn + + if self.group_query_attention: + raise NotImplementedError("Grouped query attention not implemented for cross-attention.") + assert query_projection_size == kv_projection_size + + self.query = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + query_projection_size, + config=config, + init_method=config.init_method, + bias=config.add_bias_linear, + gather_output=False) + + self.key_value = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + 2 * kv_projection_size, + config=config, + init_method=config.init_method, + bias=config.add_bias_linear, + gather_output=False) + + self.core_attention = CoreAttention(self.layer_number, config, + self.attn_mask_type) + self.checkpoint_core_attention = config.recompute_granularity == 'selective' + + if self.use_flash_attn: + self.core_attention_flash = FlashSelfAttention( + causal=True, attention_dropout=config.attention_dropout + ) + + # Output. + self.dense = tensor_parallel.RowParallelLinear( + query_projection_size, + config.hidden_size, + config=config, + init_method=config.output_layer_init_method, + bias=args.add_bias_linear, + input_is_parallel=True, + skip_bias_add=True) + + def _checkpointed_attention_forward(self, query_layer, key_layer, + value_layer, attention_mask, + rotary_pos_emb=None): + """Forward method with activation checkpointing.""" + def custom_forward(*inputs): + query_layer = inputs[0] + key_layer = inputs[1] + value_layer = inputs[2] + attention_mask = inputs[3] + output_ = self.core_attention(query_layer, key_layer, + value_layer, attention_mask) + return output_ + + q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \ + else rotary_pos_emb + + hidden_states = tensor_parallel.checkpoint( + custom_forward, + False, query_layer, key_layer, value_layer, attention_mask, + q_pos_emb, k_pos_emb) + + return hidden_states + + def _allocate_memory(self, inference_max_sequence_len, batch_size, num_attention_heads): + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=self.params_dtype, + device=torch.cuda.current_device()) + + def forward(self, hidden_states, attention_mask, + encoder_output=None, inference_params=None, + rotary_pos_emb=None): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + is_first_step = False + if inference_params: + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_len = inference_params.max_sequence_length + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size, + self.num_query_groups_per_partition) + inference_value_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size, + self.num_query_groups_per_partition) + + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, inference_value_memory) + is_first_step = True + else: + inference_key_memory, inference_value_memory = \ + inference_params.key_value_memory_dict[self.layer_number] + + # ===================== + # Query, Key, and Value + # ===================== + if self.attention_type == AttnType.self_attn: + + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query_layer, + key_layer, + value_layer) = torch.split( + mixed_x_layer, + [ + ( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head + ], + dim=3) + + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] - + query_layer = query_layer.reshape(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) + else: + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + mixed_kv_layer, _ = self.key_value(encoder_output) + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head) + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key_layer, + value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) + + # Attention head [sq, b, h] --> [sq, b, hp] + query_layer, _ = self.query(hidden_states) + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + query_layer = query_layer.view(*new_tensor_shape) + + # ================================== + # Adjust key and value for inference + # ================================== + + # duplicate the pos_emb for self attention + if rotary_pos_emb is not None: + if isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = rotary_pos_emb + else: + rotary_pos_emb = ((rotary_pos_emb,) * 2) + + if inference_params: + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + # Copy key and values. + inference_key_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = key_layer + inference_value_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = value_layer + key_layer = inference_key_memory[ + :sequence_end, batch_start:batch_end, ...] + value_layer = inference_value_memory[ + :sequence_end, batch_start:batch_end, ...] + + + # adjust the key rotary positional embedding + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + # need to cross check this condition during inference + # if not set_inference_key_value_memory: + if not is_first_step: + # In inference, we compute one token at a time. + # Select the correct positional embedding + # (only the last token in the sequence) + q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end] + else: + # In the first forward pass of inference, + # we use the entire provided prefix. + # q_pos_emb here has the rope embeddings of the entire + # prefix + to-be-generated output so + # we slice to just the prefix. + q_pos_emb = q_pos_emb[:sequence_end, :, :, :] + k_pos_emb = k_pos_emb[:sequence_end, :, :, :] + rotary_pos_emb = (q_pos_emb, k_pos_emb) + + # ================================== + # core attention computation + # ================================== + + # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn] + if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1: + key_layer = key_layer.repeat_interleave( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition, + dim = 2 + ) + value_layer = value_layer.repeat_interleave( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition, + dim = 2 + ) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb,self.config) + key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb,self.config) + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + + if not self.use_flash_attn: + if self.checkpoint_core_attention: + context_layer = self._checkpointed_attention_forward( + query_layer, key_layer, value_layer, attention_mask) + else: + context_layer = self.core_attention( + query_layer, key_layer, value_layer, attention_mask) + else: + q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() + for x in (query_layer, key_layer, value_layer)] + if not self.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): + context_layer = self.core_attention_flash(q, k, v) + else: + context_layer = self.core_attention_flash(q, k, v) + context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + return output, bias + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor + if bias is not None: + x = x + bias + out = torch.nn.functional.dropout(x, p=prob, training=training) + out = residual + out + return out + + +def get_bias_dropout_add(training): + def _bias_dropout_add(x, bias, residual, prob): + return bias_dropout_add(x, bias, residual, prob, training) + return _bias_dropout_add + + +@jit_fuser +def bias_dropout_add_fused_train(x: torch.Tensor, + bias: Optional[torch.Tensor], + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + + +@jit_fuser +def bias_dropout_add_fused_inference(x: torch.Tensor, + bias: Optional[torch.Tensor], + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, False) + + +class ParallelTransformerLayer(MegatronModule): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config, + layer_number, layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + drop_path_rate=0.): + args = get_args() + + super(ParallelTransformerLayer, self).__init__() + self.layer_number = layer_number + self.layer_type = layer_type + + self.apply_residual_connection_post_norm \ + = config.apply_residual_connection_post_layernorm + + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + + # Normalize the input data. + self.input_norm = get_norm(config) + + # Self attention. + self.self_attention = ParallelAttention( + config, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=self_attn_mask_type) + self.hidden_dropout = config.hidden_dropout + self.bias_dropout_fusion = config.bias_dropout_fusion + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None + + # Normalize the attention output + self.post_attention_norm = get_norm(config) + + # Cross attention. + if self.layer_type in (LayerType.decoder, + LayerType.retro_decoder, + LayerType.retro_decoder_with_retriever, + LayerType.retro_encoder): + self.inter_attention = ParallelAttention( + config, + layer_number, + attention_type=AttnType.cross_attn) + # Normalize the attention output. + self.post_inter_attention_norm = get_norm(config) + + # MLP + if args.num_experts is not None: + self.mlp = SwitchMLP(config) + else: + self.mlp = ParallelMLP(config) + + # Set bias+dropout+add fusion grad_enable execution handler. + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + self.bias_dropout_add_exec_handler = \ + nullcontext if use_nvfuser else torch.enable_grad + + if args.retro_add_retriever: + self.retro_num_neighbors = args.retro_num_neighbors + self.retro_chunk_length = args.retro_chunk_length + self.retro_retrieved_length = \ + args.retro_num_retrieved_chunks * args.retro_chunk_length + + # Retriever (bi-directional transformer with cross attention) + if layer_type == LayerType.retro_decoder_with_retriever: + self.retriever = ParallelTransformer( + config=config, + model_type=ModelType.retro_encoder, + self_attn_mask_type=AttnMaskType.padding, + pre_process=True, + post_process=False, + ) + self._retriever_key = 'retriever' + else: + self.retriever = None + + def default_decoder_cross_attention(self, + encoder_output, + enc_dec_attn_mask, + norm_input, + norm_output, + bias_dropout_add_func): + '''Cross attention for a standard encoder-decoder model.''' + + # Attention. + attention_output, attention_bias = \ + self.inter_attention(norm_output, + enc_dec_attn_mask, + encoder_output=encoder_output) + + # Residual connection. + if self.apply_residual_connection_post_norm: + residual = norm_output + else: + residual = norm_input + + if attention_bias is not None: + attention_bias = attention_bias.expand_as(residual) + + # Bias-dropout-add. + with self.bias_dropout_add_exec_handler(): + norm_input = bias_dropout_add_func( + attention_output, + attention_bias, + residual, + self.hidden_dropout) + + # Normalize. + norm_output = self.post_inter_attention_norm(norm_input) + + return norm_input, norm_output + + def retro_encoder_cross_attention(self, + retriever_output, + norm_input, + norm_output, + bias_dropout_add_func): + """Cross attention for Retro encoder. + + Notation: + ns : Sequence length. + bs : Batch size. + d : Hidden size. + l : Number of chunks per sample (i.e., seq_length/chunk_length). + k : Number of neighbors. + r : Number of retrieved tokens (neighbors + continuation). + """ + + ns, bs, d = norm_output.shape # [r, bs * l * k, d] + + # Divide sequence dimension into chunks. + chunked_outputs = norm_output.reshape(self.retro_retrieved_length, + -1, + self.retro_num_neighbors, + d) + chunked_outputs_before_norm = \ + norm_input.reshape(self.retro_retrieved_length, -1, + self.retro_num_neighbors, d) # [r, bs*l, k, d] + + # Per-chunk attention. + norm_inputs = [] + norm_outputs = [] + for k in range(self.retro_num_neighbors): + + # Attention. + chunked_output = chunked_outputs[:,:,k].contiguous() + attention_output, attention_bias = \ + self.inter_attention( + chunked_output, # Q (neighbor embedding) + None, + encoder_output=retriever_output) # K, V (hidden act) + + # Residual connection. + if self.apply_residual_connection_post_norm: + residual = chunked_output + else: + residual = chunked_outputs_before_norm[:,:,k] + + # Re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + norm_input = bias_dropout_add_func( + attention_output, + None if attention_bias is None else attention_bias.expand_as(residual), + residual, + self.hidden_dropout) + norm_inputs.append(norm_input) + + # Layer norm. + norm_output = self.post_inter_attention_norm(norm_input) + norm_outputs.append(norm_output) + + # Concatenate layer norms. + # norm_input : [r, k * bs * l, d] + # norm_output : [r, k * bs * l, d] + norm_input = torch.stack(norm_inputs, dim=1).reshape(ns, bs, d) + norm_output = torch.stack(norm_outputs, dim=1).reshape(ns, bs, d) + + return norm_input, norm_output + + def retro_decoder_cross_attention(self, + retriever_input, + retriever_output, + retriever_attn_mask, + norm_input, + norm_output, + inference_params, + bias_dropout_add_func): + """Cross attention for Retro decoder. + + Notation: + ns : Sequence length. + bs : Batch size. + d : Hidden size. + l : Number of chunks per sample (i.e., seq_length/chunk_length). + m : Number of tokens per chunk. + k : Number of neighbors. + r : Number of retrieved tokens (neighbors + continuation). + """ + + ns, bs, d = norm_output.shape + l = int(np.ceil(ns / self.retro_chunk_length)) + + # Retrieve neighbors. + if self.layer_type == LayerType.retro_decoder_with_retriever: + first_ns = ns % self.retro_chunk_length + if first_ns > 0: + first_chunk, rest_chunk = \ + norm_output[:first_ns], norm_output[first_ns:] + first_chunk = torch.nn.functional.pad( + first_chunk, + (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), + 'constant', + 0) + chunked_output = \ + torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d] + else: + chunked_output = norm_output # [l * m, bs, d] + chunked_output = chunked_output \ + .reshape(l, self.retro_chunk_length, bs, d) \ + .permute(1, 2, 0, 3) \ + .reshape(self.retro_chunk_length, bs * l, d) \ + .contiguous() + + # Get Encoder Output + retriever_output = self.retriever( + hidden_states=retriever_input, + attention_mask=retriever_attn_mask, + retriever_output=chunked_output, + retriever_attn_mask=retriever_attn_mask, + inference_params=inference_params) # [r, k * bs * l , d] + retriever_output = retriever_output.reshape( + self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d) # [r * k, bs * l, d] + + # Chunks. + pad = (ns - 1) % self.retro_chunk_length + attending_chunks = norm_output[pad:] + padded_chunks = torch.nn.functional.pad( + attending_chunks, + (0, 0, 0, 0, 0, self.retro_chunk_length - 1), + 'constant', 0) + padded_chunked_output = padded_chunks \ + .reshape(l, self.retro_chunk_length, bs, d) \ + .permute(1, 2, 0, 3) + padded_chunked_output = padded_chunked_output.reshape( + self.retro_chunk_length, bs * l, d).contiguous() + + # Encoder output. + attention_output, attention_bias = \ + self.inter_attention(padded_chunked_output, + None, + encoder_output=retriever_output) + + # Residual connection. + if self.apply_residual_connection_post_norm: + residual = norm_output + else: + residual = norm_input + + # Re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + norm_input = bias_dropout_add_func( + attention_output, + None if attention_bias is None else attention_bias.expand_as(attention_output), + torch.zeros_like(attention_output), + self.hidden_dropout) + norm_input = norm_input \ + .reshape(self.retro_chunk_length, bs, l, d) \ + .permute(2, 0, 1, 3) # [l, m, bs, d] + norm_input = norm_input.reshape(self.retro_chunk_length * l, bs, d) + norm_input = torch.nn.functional.pad( + norm_input, + (0, 0, 0, 0, pad, 0), + 'constant', 0)[:ns] # [ns, b, d] + # TODO: better redesign with inference param + args = get_args() + norm_input = args.retro_attention_gate * norm_input + residual + + # Layer norm post the decoder attention + norm_output = self.post_inter_attention_norm(norm_input) + + return retriever_output, norm_input, norm_output + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + retriever_input=None, + retriever_output=None, + retriever_attn_mask=None, + inference_params=None, + rotary_pos_emb=None): + + # Update the params in case the retro param changes during inference + # TODO: better redesign with inference param + args = get_args() + if args.retro_add_retriever: + self.retro_num_neighbors = args.retro_num_neighbors + self.retro_chunk_length = args.retro_chunk_length + self.retro_retrieved_length = \ + args.retro_num_retrieved_chunks * args.retro_chunk_length + + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + norm_output = self.input_norm(hidden_states) + + # Self attention. + attention_output, attention_bias = \ + self.self_attention( + norm_output, + attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb) + + # Residual connection. + if self.apply_residual_connection_post_norm: + residual = norm_output + else: + residual = hidden_states + + if self.drop_path is None: + # jit scripting for a nn.module (with dropout) is not + # trigerring the fusion kernel. For now, we use two + # different nn.functional routines to account for varying + # dropout semantics during training and inference phases. + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = bias_dropout_add_fused_train + else: + bias_dropout_add_func = bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) + + if attention_bias is not None: + attention_bias = attention_bias.expand_as(residual) + with self.bias_dropout_add_exec_handler(): + norm_input = bias_dropout_add_func( + attention_output, + attention_bias, + residual, + self.hidden_dropout) + else: + out = torch.nn.functional.dropout(attention_output + attention_bias, + p=self.hidden_dropout, + training=self.training) + norm_input = residual + self.drop_path(out) + + # Layer norm post the self attention. + norm_output = self.post_attention_norm(norm_input) + + # Cross attention. + if self.layer_type == LayerType.encoder: + pass + elif self.layer_type == LayerType.decoder: + norm_input, norm_output = \ + self.default_decoder_cross_attention( + encoder_output, + enc_dec_attn_mask, + norm_input, + norm_output, + bias_dropout_add_func) + elif self.layer_type == LayerType.retro_encoder: + norm_input, norm_output = \ + self.retro_encoder_cross_attention( + retriever_output, + norm_input, + norm_output, + bias_dropout_add_func) + elif self.layer_type in (LayerType.retro_decoder, + LayerType.retro_decoder_with_retriever): + retriever_output, norm_input, norm_output = \ + self.retro_decoder_cross_attention( + retriever_input, + retriever_output, + retriever_attn_mask, + norm_input, + norm_output, + inference_params, + bias_dropout_add_func) + else: + raise Exception("Unsupported layer type, '%s'." % + self.layer_type.name) + + # MLP. + mlp_output, mlp_bias = self.mlp(norm_output) + + # Second residual connection. + if self.apply_residual_connection_post_norm: + residual = norm_output + else: + residual = norm_input + + if self.drop_path is None: + if mlp_bias is not None: + mlp_bias = mlp_bias.expand_as(residual) + with self.bias_dropout_add_exec_handler(): + output = bias_dropout_add_func( + mlp_output, + mlp_bias, + residual, + self.hidden_dropout) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = core.utils.make_viewless_tensor(inp = output, + requires_grad = output.requires_grad, + keep_graph = True) + + else: + if mlp_bias is not None: + mlp_output = mlp_output + mlp_bias + out = torch.nn.functional.dropout(mlp_output, + p=self.hidden_dropout, + training=self.training) + output = residual + self.drop_path(out) + + if self.layer_type == LayerType.retro_decoder_with_retriever: + return output, retriever_output + else: + return output + + +class NoopTransformerLayer(MegatronModule): + """A single 'no-op' transformer layer. + + The sole purpose of this layer is for when a standalone embedding layer + is used (i.e., args.standalone_embedding_stage == True). In this case, + zero transformer layers are assigned when pipeline rank == 0. Additionally, + when virtual pipeline rank >= 1, zero total model parameters are created + (virtual rank 0 contains the input embedding). This results in the model's + input and output tensors being the same, which causes an error when + performing certain memory optimiations on the output tensor (e.g., + deallocating it). Thus, this layer disconnects the input from the output + via a clone. Since ranks containing a no-op layer are generally under- + utilized (both compute and memory), there's no worry of any performance + degredation. + """ + + def __init__(self, layer_number): + super().__init__() + self.layer_number = layer_number + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + inference_params=None): + return hidden_states.clone() + + +def _get_num_layers(args, model_type, is_decoder=False): + """Compute the number of transformer layers resident on the current rank.""" + is_encoder_and_decoder_model = (model_type == ModelType.encoder_and_decoder) + if model_type == ModelType.retro_encoder: + num_layers = args.retro_encoder_layers + elif mpu.get_pipeline_model_parallel_world_size() > 1: + assert not is_encoder_and_decoder_model, "This is no longer supported." + assert args.num_layers == args.encoder_num_layers + assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ + 'num_layers must be divisible by transformer_pipeline_model_parallel_size' + + # When a standalone embedding stage is used, all transformer layers + # are divided among pipeline rank >= 1, while on pipeline rank 0, + # ranks either contain the input embedding layer (virtual pp rank 0), + # or no layers at all (virtual pp rank >= 1). + num_layers = ( + 0 + if args.standalone_embedding_stage + and mpu.get_pipeline_model_parallel_rank() == 0 else + args.num_layers // args.transformer_pipeline_model_parallel_size + ) + else: + if not is_decoder: + num_layers = args.encoder_num_layers + else: + num_layers = args.decoder_num_layers + return num_layers + + +def _get_layer_type(model_type, default_layer_type, retro_layer_numbers, + layer_number): + args = get_args() + if args.retro_add_retriever and layer_number in retro_layer_numbers: + if model_type == ModelType.retro_decoder: + return LayerType.retro_decoder_with_retriever \ + if layer_number == retro_layer_numbers[0] \ + else LayerType.retro_decoder + elif model_type == ModelType.retro_encoder: + return LayerType.retro_encoder + else: + raise Exception("Unsupported model type, '%s'." % model_type) + else: + return default_layer_type + + +class ParallelTransformer(MegatronModule): + """Transformer class.""" + + def __init__(self, config, + model_type, layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + post_norm=True, + pre_process=True, + post_process=True, + drop_path_rate=0.0): + super(ParallelTransformer, self).__init__() + args = get_args() + + self.layer_type = layer_type + self.model_type = model_type + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + self.post_norm = post_norm + self.pre_process = pre_process + self.post_process = post_process + self.input_tensor = None + self.drop_path_rate = drop_path_rate + self.transformer_impl = args.transformer_impl + self.retro_add_retriever = args.retro_add_retriever + + # Store activation checkpoiting flag. + self.recompute_granularity = config.recompute_granularity + self.recompute_method = config.recompute_method + self.recompute_num_layers = config.recompute_num_layers + self.distribute_saved_activations = \ + config.distribute_saved_activations and not config.sequence_parallel + + self.sequence_parallel = config.sequence_parallel + + # Transformer Engine Init. + self.transformer_engine_v_0_10 = False + self.transformer_engine_v_0_11 = False + self.transformer_engine_v_0_8 = False + if self.transformer_impl == 'transformer_engine': + global transformer_engine + import transformer_engine + + if core.utils.is_te_min_version("0.8.0"): + self.transformer_engine_v_0_8 = True + if core.utils.is_te_min_version("0.10.0"): + self.transformer_engine_v_0_10 = True + if core.utils.is_te_min_version("0.11.0"): + self.transformer_engine_v_0_11 = True + + assert not args.squared_relu, ("TransformerEngine does not support squared " + "relu activation.") + + self.use_fp8 = args.fp8 is not None + self.fp8_recipe = None + self.fp8_group = None + if self.use_fp8: + assert args.transformer_impl == 'transformer_engine', \ + 'transformer-engine required for fp8 training and inference' + self.fp8_group = mpu.get_amax_reduction_group(tp_only_amax_red=config.tp_only_amax_red) + if args.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif args.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.") + self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=args.fp8_margin, + interval=args.fp8_interval, + fp8_format=fp8_format, + amax_history_len=args.fp8_amax_history_len, + amax_compute_algo=args.fp8_amax_compute_algo, + override_linear_precision=(False, False, not args.fp8_wgrad), + ) + + self.num_microbatches_in_previous_step = -1 + self.microbatch_count = 0 + self.checkpoint_core_attention = config.recompute_granularity == 'selective' + + # Number of layers. + self.num_layers = _get_num_layers(args, model_type, + layer_type==LayerType.decoder) + + self.drop_path_rates = [ + rate.item() for rate in + torch.linspace(0, self.drop_path_rate, config.num_layers)] + + self.retro_layer_numbers = None + if model_type == ModelType.retro_decoder: + retro_layer_start = 6 if config.num_layers <= 15 else 9 + self.retro_layer_numbers = \ + np.arange(retro_layer_start, args.num_layers + 1, 3).tolist() + if model_type == ModelType.retro_encoder: + self.retro_layer_numbers = [1] + + # Transformer layers. + if args.retro_add_retriever: + assert self.recompute_granularity != 'full', \ + "Full recompute not supported for Retro." + assert args.transformer_impl == 'local', \ + "Transformer engine does not support Retro layers." + def build_layer(layer_number): + if args.transformer_impl == 'local': + current_layer_type = _get_layer_type( + model_type, layer_type, self.retro_layer_numbers, + layer_number) + return ParallelTransformerLayer( + config, + layer_number, + layer_type=current_layer_type, + self_attn_mask_type=self_attn_mask_type, + drop_path_rate=self.drop_path_rates[layer_number - 1]) + else: + # This argument is only available from TE v0.10 onwards. + extra_transformer_engine_kwargs = {} + if self.transformer_engine_v_0_8: + extra_transformer_engine_kwargs["bias"] = args.add_bias_linear + if self.transformer_engine_v_0_10: + extra_transformer_engine_kwargs["activation"] = "swiglu" if args.swiglu else "gelu" + if self.transformer_engine_v_0_11: + extra_transformer_engine_kwargs["normalization"] = args.normalization + assert config.attention_softmax_in_fp32, "TransformerEngine only supports softmax compute in FP32." + assert ( + (bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and args.fp16) == config.apply_query_key_layer_scaling + ), ("Unsupported config for apply_query_key_layer_scaling in TransformerEngine. If --apply-query-key-layer-scaling is " + "provided, set env-var NVTE_APPLY_QK_LAYER_SCALING=1 and you must be using fp16.") + return transformer_engine.pytorch.TransformerLayer( + config.hidden_size, + config.ffn_hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.layernorm_epsilon, + hidden_dropout=config.hidden_dropout, + attention_dropout=config.attention_dropout, + init_method=config.init_method, + output_layer_init_method=config.output_layer_init_method, + layer_number=layer_number, + kv_channels=config.kv_channels, + self_attn_mask_type=self_attn_mask_type.name, + tp_group=mpu.get_tensor_model_parallel_group() if mpu.is_initialized() else None, + tp_size=mpu.get_tensor_model_parallel_world_size(), + get_rng_state_tracker=get_cuda_rng_tracker + if get_cuda_rng_tracker().is_initialized() + else None, + fuse_wgrad_accumulation=config.gradient_accumulation_fusion, + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + sequence_parallel=config.sequence_parallel, + params_dtype=config.params_dtype, + apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm, + output_layernorm=False, + layer_type="encoder", + drop_path_rate=self.drop_path_rates[layer_number - 1], + set_parallel_mode=True, + fuse_qkv_params=True, + **extra_transformer_engine_kwargs) + + if config.virtual_pipeline_model_parallel_size is not None: + assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \ + 'num_layers_per_stage must be divisible by ' \ + 'virtual_pipeline_model_parallel_size' + assert args.model_type != ModelType.encoder_and_decoder + # Number of layers in each model chunk is the number of layers in the stage, + # divided by the number of model chunks in a stage. + self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size + # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0] [2] [4] [6] + # Stage 1: [1] [3] [5] [7] + # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0, 1] [4, 5] + # Stage 1: [2, 3] [6, 7] + offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( + config.num_layers // config.virtual_pipeline_model_parallel_size) + \ + (mpu.get_pipeline_model_parallel_rank() * self.num_layers) + else: + # Each stage gets a contiguous set of layers. + if args.model_type == ModelType.encoder_and_decoder and \ + mpu.get_pipeline_model_parallel_world_size() > 1: + pipeline_rank = mpu.get_pipeline_model_parallel_rank() + if layer_type == LayerType.encoder: + offset = pipeline_rank * self.num_layers + else: + num_ranks_in_enc = args.pipeline_model_parallel_split_rank + offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers + else: + offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers + + if self.num_layers == 0: + # When a standalone embedding stage is used (e.g., + # args.standalone_embedding_stage == True), virtual pipeline ranks + # on pipeline rank 0 will have zero transformer layers assigned to + # them. This results in the model's input and output tensors to be + # the same, which will cause failure for certain output tensor + # optimizations (e.g., pipeline output deallocation). To remedy + # this, we assign a 'no-op' layer on these ranks, which will + # disconnect the input tensor from the output tensor. + self.num_layers = 1 + self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ]) + else: + self.layers = torch.nn.ModuleList( + [build_layer(i + 1 + offset) for i in range(self.num_layers)]) + + # Update dropout rate for Retro encoder. + if model_type == ModelType.retro_encoder: + for layer in self.layers: + if layer.self_attention.use_flash_attn: + layer.self_attention.core_attention_flash.dropout_p = \ + torch.nn.Dropout(args.retro_encoder_attention_dropout) + else: + layer.self_attention.core_attention.attention_dropout.p =\ + args.retro_encoder_attention_dropout + layer.hidden_dropout = args.retro_encoder_hidden_dropout + + if self.post_process and self.post_norm: + # Final layer norm before output. + self.final_norm = get_norm(config) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def _checkpointed_forward(self, hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + rotary_pos_emb, is_first_microbatch): + """Forward method with activation checkpointing.""" + def custom(start, end): + def custom_forward(*args, **kwargs): + x_, *args = args + for index in range(start, end): + layer = self._get_layer(index) + x_ = layer(x_, *args, **kwargs) + return x_ + return custom_forward + + te_forward_kwargs = {} + if self.transformer_impl == 'transformer_engine': + te_forward_kwargs['is_first_microbatch'] = is_first_microbatch + if self.transformer_engine_v_0_10: + te_forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + + if self.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and + # checkpoint the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + l = 0 + while l < self.num_layers: + if self.transformer_impl == 'transformer_engine': + hidden_states = transformer_engine.pytorch.checkpoint( + custom(l, l + self.recompute_num_layers), + self.distribute_saved_activations, + tensor_parallel.get_cuda_rng_tracker, + mpu.get_tensor_model_parallel_group(), + hidden_states, attention_mask, encoder_output, + enc_dec_attn_mask, **te_forward_kwargs) + else: + hidden_states = tensor_parallel.checkpoint( + custom(l, l + self.recompute_num_layers), + self.distribute_saved_activations, + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb) + + l += self.recompute_num_layers + + elif self.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + for l in range(self.num_layers): + if l < self.recompute_num_layers: + if self.transformer_impl == 'transformer_engine': + hidden_states = transformer_engine.pytorch.checkpoint( + custom(l, l + 1), + self.distribute_saved_activations, + tensor_parallel.get_cuda_rng_tracker, + mpu.get_tensor_model_parallel_group(), + hidden_states, attention_mask, encoder_output, + enc_dec_attn_mask, **te_forward_kwargs) + else: + hidden_states = tensor_parallel.checkpoint( + custom(l, l + 1), + self.distribute_saved_activations, + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb) + else: + if self.transformer_impl == 'transformer_engine': + hidden_states = custom(l, l + 1)( + hidden_states, attention_mask, encoder_output, + enc_dec_attn_mask, **te_forward_kwargs) + else: + hidden_states = custom(l, l + 1)( + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + retriever_input=None, + retriever_output=None, + retriever_attn_mask=None, + inference_params=None, + rotary_pos_emb=None): + # hidden_states: [s, b, h] + + # Checks. + if inference_params: + assert self.recompute_granularity is None, \ + 'inference does not work with activation checkpointing' + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = core.utils.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + # RNG context. + if self.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # Forward layers. + with rng_context: + # The fp8_autocast context manager is a no-op when enabled=True + # The if...else serves to short circuit name resolution for fp8_autocast + with transformer_engine.pytorch.fp8_autocast( + enabled=self.use_fp8, + fp8_recipe=self.fp8_recipe, + fp8_group=self.fp8_group + ) if self.use_fp8 else nullcontext(): + # Determine if the current iteration is first microbatch + if self.num_microbatches_in_previous_step != get_num_microbatches(): + self.microbatch_count = 0 # Reset count on new batch size rampup interval + self.num_microbatches_in_previous_step = get_num_microbatches() + is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 + + # Forward pass. + if self.recompute_granularity == 'full': + hidden_states = self._checkpointed_forward(hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + rotary_pos_emb, + is_first_microbatch) + else: + forward_kwargs = { + 'encoder_output': encoder_output, + 'enc_dec_attn_mask': enc_dec_attn_mask, + 'inference_params': inference_params, + } + + if self.transformer_impl == 'transformer_engine': + forward_kwargs['is_first_microbatch'] = is_first_microbatch + forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention + if self.transformer_engine_v_0_10: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + else: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + forward_kwargs['retriever_input'] = retriever_input + forward_kwargs['retriever_output'] = retriever_output + forward_kwargs['retriever_attn_mask'] = retriever_attn_mask + + for index in range(self.num_layers): + layer = self._get_layer(index) + + hidden_states = layer( + hidden_states, + attention_mask, + **forward_kwargs) + + # First Retro decoder layer returns both hidden_states + # and retriever_output. Make retriever_output available + # to subsequence Retro layers. + if isinstance(hidden_states, tuple): + assert len(hidden_states) == 2 + hidden_states, retriever_output = hidden_states + forward_kwargs["retriever_output"] = retriever_output + + # Skip counter update for eval and activation checkpointing + if torch.is_grad_enabled() and self.training: + self.microbatch_count += 1 + + # Final layer norm. + if self.post_process and self.post_norm: + hidden_states = self.final_norm(hidden_states) + + return hidden_states + + def load_state_dict(self, state_dict, strict=True): + """Customize load.""" + + # Handle renaming layernorm -> norm in component names + state_dict_ = {} + for key in state_dict.keys(): + # Bypass TransformerEngine module parameters. + if "layernorm_qkv" in key or "layernorm_mlp" in key: + state_dict_[key] = state_dict[key] + continue + newkey = key.replace("layernorm", "norm") + state_dict_[newkey] = state_dict[key] + + super().load_state_dict(state_dict_, strict)