diff --git a/README.md b/README.md index d2118ca6a7a36c505909cabcf298e950fff7e902..6d673da028a24ea8719dbeb31280f714c192d1ab 100644 --- a/README.md +++ b/README.md @@ -440,6 +440,7 @@ DeepSparkHub甄选上百个应用算法和模型,覆盖AI和通用计算各领 [Bloom-7B1](nlp/llm/bloom-7b1/firefly/README.md) | PyTorch (Firefly) | school_math_0.25M & bloom-7b1 [ChatGLM-6B](nlp/llm/chatglm-6b/deepspeed/README.md) | PyTorch (DeepSpeed) | ADGEN & chatglm-6b [ChatGLM2-6B SFT](nlp/llm/ChatGLM2-6b-sft/README.md) | PyTorch (DeepSpeed) | ADGEN & chatglm2-6b +[ChatGLM3-6B SFT](nlp/llm/chatglm3-6b/deepspeed/finetune_demo/README.md) | PyTorch (DeepSpeed) | ADGEN & chatglm3-6b [Llama-7B](nlp/llm/llama-7b/colossalai/README.md) | PyTorch (Colossal-AI) | llama-7b-hf [Llama2-7B](nlp/llm/llama2-7b/megatron-deepspeed/README.md) | PyTorch (Megatron-DeepSpeed) | Bookcorpus [Llama2-7B Reward Model Finetuning](nlp/llm/llama2-7b_reward_sft/deepspeed/README.md) | PyTorch (DeepSpeed) | Dahoas/rm-static diff --git a/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/README.md b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..38d2d06f5eda97381172b7fdebc23f05ebad49de --- /dev/null +++ b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/README.md @@ -0,0 +1,64 @@ +# ChatGLM3-6B + +## Model description + +ChatGLM3 is a generation of pre-trained dialogue models jointly released by Zhipu AI and Tsinghua KEG. ChatGLM3-6B is the open-source model in the ChatGLM3 series, maintaining many excellent features of the first two generations such as smooth dialogue and low deployment threshold. + +## Step 1: Installation + +```bash +cd finetune_demo +pip3 install -r requirements.txt +``` + +## Step 2: Preparing datasets and checkpoints + +```bash +# Get AdvertiseGen.tar.gz +mkdir -p data + +pushd data +wget -O AdvertiseGen.tar.gz https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1 +tar xf AdvertiseGen.tar.gz +popd + +python3 process_data.py +``` + +```bash +# Get chatglm3-6b from https://modelscope.cn/models/ZhipuAI/chatglm3-6b or huggingface. +mkdir -p checkpoint + +pushd checkpoint +tar -zxvf chatglm3-6b.tar.gz +popd +``` + +## Step 3: Training + +```bash +bash run.sh {config_file} {num_gpus} + +# 1 GPU +bash run.sh configs/lora.yaml 1 +bash run.sh configs/ptuning_v2.yaml 1 + +# Multi GPUs +bash run.sh configs/lora.yaml 16 +bash run.sh configs/ptuning_v2.yaml 16 +bash run.sh configs/sft.yaml 16 +``` + +## Results + +| GPUs | model | peft | num_gpus | train_samples_per_second | +| ------- | ---------- | ---------- | -------- | ------------------------ | +| BI-V150 | ChatGLM-6B | Lora | 1 | 2.11 | +| BI-V150 | ChatGLM-6B | ptuning_v2 | 1 | 8.889 | +| BI-V150 | ChatGLM-6B | Lora | 16 | 32.639 | +| BI-V150 | ChatGLM-6B | ptuning_v2 | 16 | 115.763 | +| BI-V150 | ChatGLM-6B | sft | 16 | 5.99 | + +## Reference + +- [ChatGLM3](https://github.com/THUDM/ChatGLM3) diff --git a/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/ds_zero_2.json b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/ds_zero_2.json new file mode 100644 index 0000000000000000000000000000000000000000..1f560ac23fde080f58342812136bf5eca78058cc --- /dev/null +++ b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/ds_zero_2.json @@ -0,0 +1,29 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/ds_zero_3.json b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/ds_zero_3.json new file mode 100644 index 0000000000000000000000000000000000000000..db21ea0916fab69d8eb58a35e235e3bf8df28091 --- /dev/null +++ b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/ds_zero_3.json @@ -0,0 +1,39 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_allow_untested_optimizer": true, + "bf16": { + "enabled": "auto" + }, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 +}, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "zero_optimization": { + "stage": 3, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "reduce_scatter": true, + "contiguous_gradients": true, + "overlap_comm": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + } +} diff --git a/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/lora.yaml b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..efcc91e0125a5872db6df9e77622b1ac96fbd34e --- /dev/null +++ b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/lora.yaml @@ -0,0 +1,48 @@ +data_config: + train_file: train.json + val_file: dev.json + test_file: dev.json + num_proc: 16 + constant_length: true +max_input_length: 256 +max_output_length: 512 +training_args: + # see `transformers.Seq2SeqTrainingArguments` + output_dir: ./output + max_steps: 100 + ddp_find_unused_parameters: False + # needed to be fit for the dataset + learning_rate: 5e-5 + # settings for data loading + per_device_train_batch_size: 2 + dataloader_num_workers: 16 + remove_unused_columns: false + # settings for saving checkpoints + save_strategy: steps + save_steps: 500 + # settings for logging + log_level: info + logging_strategy: steps + logging_steps: 10 + # settings for evaluation + per_device_eval_batch_size: 16 + evaluation_strategy: steps + eval_steps: 500 + # settings for optimizer + # adam_epsilon: 1e-6 + # uncomment the following line to detect nan or inf values + # debug: underflow_overflow + predict_with_generate: true + # see `transformers.GenerationConfig` + generation_config: + max_new_tokens: 512 + # set your absolute deepspeed path here + #deepspeed: ds_zero_2.json + # set to true if train with cpu. + use_cpu: false +peft_config: + peft_type: LORA + task_type: CAUSAL_LM + r: 8 + lora_alpha: 32 + lora_dropout: 0.1 diff --git a/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/ptuning_v2.yaml b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/ptuning_v2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5cd35fe83747ec432dd74dc4850f0194b2f18a65 --- /dev/null +++ b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/ptuning_v2.yaml @@ -0,0 +1,43 @@ +data_config: + train_file: train.json + val_file: dev.json + test_file: dev.json + num_proc: 16 +max_input_length: 256 +max_output_length: 512 +training_args: + # see `transformers.Seq2SeqTrainingArguments` + output_dir: ./output + max_steps: 100 + # needed to be fit for the dataset + learning_rate: 5e-5 + # settings for data loading + per_device_train_batch_size: 4 + dataloader_num_workers: 16 + remove_unused_columns: false + # settings for saving checkpoints + save_strategy: steps + save_steps: 500 + # settings for logging + log_level: info + logging_strategy: steps + logging_steps: 10 + # settings for evaluation + per_device_eval_batch_size: 16 + evaluation_strategy: steps + eval_steps: 500 + # settings for optimizer + # adam_epsilon: 1e-6 + # uncomment the following line to detect nan or inf values + # debug: underflow_overflow + predict_with_generate: true + # see `transformers.GenerationConfig` + generation_config: + max_new_tokens: 512 + # set your absolute deepspeed path here + #deepspeed: ds_zero_3.json + use_cpu: false +peft_config: + peft_type: PREFIX_TUNING + task_type: CAUSAL_LM + num_virtual_tokens: 128 diff --git a/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/sft.yaml b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/sft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ded47f80218f2430c99731f4ba5a40d490d402b7 --- /dev/null +++ b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/configs/sft.yaml @@ -0,0 +1,39 @@ +data_config: + train_file: train.json + val_file: dev.json + test_file: dev.json + num_proc: 16 + constant_length: true +max_input_length: 256 +max_output_length: 512 +training_args: + # see `transformers.Seq2SeqTrainingArguments` + output_dir: ./output + max_steps: 100 + fp16: true + # needed to be fit for the dataset + learning_rate: 5e-5 + # settings for data loading + per_device_train_batch_size: 1 + dataloader_num_workers: 16 + remove_unused_columns: false + # settings for saving checkpoints + save_strategy: steps + save_steps: 500 + # settings for logging + log_level: info + logging_strategy: steps + logging_steps: 1 + # settings for evaluation + per_device_eval_batch_size: 16 + evaluation_strategy: steps + eval_steps: 500 + # settings for optimizer + # adam_epsilon: 1e-6 + # uncomment the following line to detect nan or inf values + # debug: underflow_overflow + predict_with_generate: true + generation_config: + max_new_tokens: 512 + # set your absolute deepspeed path here + deepspeed: configs/ds_zero_2.json diff --git a/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/finetune_hf.py b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/finetune_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..4301771559c79ed8a4b17b6ca45ea1fb62da36ca --- /dev/null +++ b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/finetune_hf.py @@ -0,0 +1,709 @@ +# -*- coding: utf-8 -*- +import os +import jieba +import warnings +import random +import dataclasses as dc +import functools +from collections.abc import Callable, Mapping, Sequence +from pathlib import Path +from typing import Annotated, Any, Optional, Union +import numpy as np +import ruamel.yaml as yaml +import torch +from torch.utils.data import IterableDataset +import typer +from datasets import Dataset, DatasetDict, NamedSplit, Split, load_dataset +from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu +from peft import ( + PeftConfig, + PeftModelForCausalLM, + get_peft_config, + get_peft_model +) +from rouge_chinese import Rouge +from torch import nn +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + EvalPrediction, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + Seq2SeqTrainingArguments, AutoConfig, +) +from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq + +from transformers import Seq2SeqTrainer as _Seq2SeqTrainer + + +ModelType = Union[PreTrainedModel, PeftModelForCausalLM] +TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +app = typer.Typer(pretty_exceptions_show_locals=False) + + +class ClosedToConstantLengthSplicedDataset(IterableDataset): + """ + Define an iterable dataset that returns a (close to) constant length data point spliced from multiple + original independent (pre-tokenized) data points. + """ + + def __init__( + self, + dataset, + max_length = 4096, + num_packed_sequences = 8, + fetch_sequence_func = None, + input_ids_field = "input_ids", + labels_field = "labels", + infinite = False, + shuffle = True, + error_strict = False, + ) -> None: + self.dataset = dataset + self.max_length = max_length + self.infinite = infinite + self.max_buffer_size = max_length * num_packed_sequences # e.g., 4096 * 16 + self.shuffle = shuffle + + # Callable[[Dict[str, Any]], Tuple[List[int], List[int]]], + # A function that fetch sequence input_ids and labels from the original data point + if fetch_sequence_func is None: + self.fetch_sequence_func = lambda data_point: (data_point[input_ids_field], data_point[labels_field]) + else: + self.fetch_sequence_func = fetch_sequence_func + self.input_ids_field = input_ids_field + self.labels_field = labels_field + + self.error_strict = error_strict + self.current_size = 0 # `int`, current packed data size. + + def __len__(self) -> int: + return len(self.dataset) + + def __iter__(self): + iterator = iter(self.dataset) + more_data_points = True + while more_data_points is True: + buffer, buffer_len = [], 0 + while True: + # ending condition. + if buffer_len >= self.max_buffer_size: + break + try: + # `Tuple[List[int], List[int]]` + seq_input_ids, seq_labels = self.fetch_sequence_func(next(iterator)) + buffer.append({self.input_ids_field: seq_input_ids, self.labels_field: seq_labels}) + buffer_len += len(buffer[-1][self.input_ids_field]) + except StopIteration: + if self.infinite is True: + iterator = iter(self.dataset) + warnings.warn("The dataset reached end and the iterator is reset to the start.") + else: + more_data_points = False + break + examples = [] # `List[Dict[str, List[int]]]`, save buffered spliced data points. + spliced_input_ids, spliced_labels = [], [] # `List[int]`, `List[int]` + for i, data_point in enumerate(buffer): + # TODO(2023-09-18) check errors for each unspliced tokenized data point + seq_input_ids = data_point[self.input_ids_field] + seq_labels = data_point[self.labels_field] + # Handle special case: + # If the length of an original data point (i.e., input_ids length of a data point before splicing) + # exceeds `max_length`, truncate it. + if len(seq_input_ids) > self.max_length: + truncated_seq_input_ids = seq_input_ids[: self.max_length] + truncated_seq_labels = seq_labels[: self.max_length] + + spliced_data_point = { + self.input_ids_field: truncated_seq_input_ids, + self.labels_field: truncated_seq_labels, + } + examples.append(spliced_data_point) + warnings.warn("Find a data point to be truncated.") + continue + + # Pre action judgment. + if len(spliced_input_ids) + len(seq_input_ids) > self.max_length: + spliced_input_ids.extend(seq_input_ids) + spliced_labels.extend(seq_labels) + + spliced_data_point = { + self.input_ids_field: spliced_input_ids[:self.max_length], + self.labels_field: spliced_labels[:self.max_length], + } # `Dict[str, List[int]]` + # Update. + spliced_input_ids, spliced_labels = [], [] + examples.append(spliced_data_point) + else: + spliced_input_ids.extend(seq_input_ids) + spliced_labels.extend(seq_labels) + # For residual spliced data point at the end of the data set + if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0: + examples.append({self.input_ids_field: spliced_input_ids, self.labels_field: spliced_labels}) + if self.shuffle: + random.shuffle(examples) + for spliced_data_point in examples: + # TODO(2023-09-18): check errors for each spliced tokenized data point. + self.current_size += 1 + yield spliced_data_point + +class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq): + def __call__(self, features, return_tensors=None): + output_ids = ( + [feature['output_ids'] for feature in features] + if 'output_ids' in features[0].keys() + else None + ) + if output_ids is not None: + max_output_length = max(len(out) for out in output_ids) + if self.pad_to_multiple_of is not None: + max_output_length = ( + ( + max_output_length + self.pad_to_multiple_of - 1) // + self.pad_to_multiple_of * self.pad_to_multiple_of + ) + for feature in features: + remainder = [self.tokenizer.pad_token_id] * ( + max_output_length - len(feature['output_ids']) + ) + if isinstance(feature['output_ids'], list): + feature['output_ids'] = feature['output_ids'] + remainder + else: + feature['output_ids'] = np.concatenate( + [feature['output_ids'], remainder] + ).astype(np.int64) + return super().__call__(features, return_tensors) + + +class Seq2SeqTrainer(_Seq2SeqTrainer): + def prediction_step( + self, + model: nn.Module, + inputs: dict[str, Any], + prediction_loss_only: bool, + ignore_keys=None, + **gen_kwargs, + ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + if self.args.predict_with_generate: + output_ids = inputs.pop('output_ids') + input_ids = inputs['input_ids'] + loss, generated_tokens, labels = super().prediction_step( + model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs + ) + generated_tokens = generated_tokens[:, input_ids.size()[1]:] + if self.args.predict_with_generate: + labels = output_ids + return loss, generated_tokens, labels + # For P-Tuning a new save_model function is fine for the prefix_encoder model + # but may cost problems for the whole model loading + + # def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + # if output_dir is None: + # output_dir = self.args.output_dir + # os.makedirs(output_dir, exist_ok=True) + # ptuning_params = {k: v for k, v in self.model.transformer.prefix_encoder.state_dict().items()} + # + # torch.save(ptuning_params, os.path.join(output_dir, 'pytorch_model.bin')) + # + # print(f"P-Tuning model weights saved in {output_dir}") + # + # if self.tokenizer is not None: + # self.tokenizer.save_pretrained(output_dir) + + +def _resolve_path(path: Union[str, Path]) -> Path: + return Path(path).expanduser().resolve() + + +def _sanity_check( + input_ids: Sequence[int], + output_ids: Sequence[int], + tokenizer: PreTrainedTokenizer, +): + print('--> Sanity check') + for in_id, out_id in zip(input_ids, output_ids): + if in_id == 0: + continue + if in_id in tokenizer.tokenizer.index_special_tokens: + in_text = tokenizer.tokenizer.index_special_tokens[in_id] + else: + in_text = tokenizer.decode([in_id]) + print(f'{repr(in_text):>20}: {in_id} -> {out_id}') + + +@functools.cache +def _get_yaml_parser() -> yaml.YAML: + parser = yaml.YAML(typ='safe', pure=True) + parser.indent(mapping=2, offset=2, sequence=4) + parser.default_flow_style = False + return parser + + +@dc.dataclass +class DataConfig(object): + train_file: str + val_file: Optional[str] = None + test_file: Optional[str] = None + + num_proc: Optional[int] = None + constant_length: Optional[bool] = False + + @property + def data_format(self) -> str: + return Path(self.train_file).suffix + + @property + def data_files(self) -> dict[NamedSplit, str]: + return { + split: data_file + for split, data_file in zip( + [Split.TRAIN, Split.VALIDATION, Split.TEST], + [self.train_file, self.val_file, self.test_file], + ) + if data_file is not None + } + + +@dc.dataclass +class FinetuningConfig(object): + data_config: DataConfig + + max_input_length: int + max_output_length: int + + training_args: Seq2SeqTrainingArguments = dc.field( + default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output') + ) + peft_config: Optional[PeftConfig] = None + + def __post_init__(self): + if not self.training_args.do_eval or self.data_config.val_file is None: + # skips the evaluation stage when `do_eval` or `eval_file` is not provided + self.training_args.do_eval = False + self.training_args.evaluation_strategy = 'no' + self.data_config.val_file = None + else: + self.training_args.per_device_eval_batch_size = ( + self.training_args.per_device_eval_batch_size + or self.training_args.per_device_train_batch_size + ) + + @classmethod + def from_dict(cls, **kwargs) -> 'FinetuningConfig': + training_args = kwargs.get('training_args', None) + if training_args is not None and not isinstance( + training_args, Seq2SeqTrainingArguments + ): + gen_config = training_args.get('generation_config') + # TODO: a bit hacky + if not isinstance(gen_config, GenerationConfig): + training_args['generation_config'] = GenerationConfig( + **gen_config + ) + kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args) + + data_config = kwargs.get('data_config') + if not isinstance(data_config, DataConfig): + kwargs['data_config'] = DataConfig(**data_config) + + peft_config = kwargs.get('peft_config', None) + if peft_config is not None and not isinstance(peft_config, PeftConfig): + kwargs['peft_config'] = get_peft_config(peft_config) + return cls(**kwargs) + + @classmethod + def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig': + path = _resolve_path(path) + kwargs = _get_yaml_parser().load(path) + return cls.from_dict(**kwargs) + + +def _load_datasets( + data_dir: Path, + data_format: str, + data_files: dict[NamedSplit, str], + num_proc: Optional[int], +) -> DatasetDict: + if data_format in ('.csv', '.json', '.jsonl'): + dataset_dct = load_dataset( + data_format[1:], + data_dir=data_dir, + data_files=data_files, + num_proc=num_proc, + ) + else: + err_msg = f"Cannot load dataset in the '{data_format}' format." + raise NotImplementedError(err_msg) + + return dataset_dct + + +class DataManager(object): + def __init__(self, data_dir: str, data_config: DataConfig): + self._num_proc = data_config.num_proc + + self._dataset_dct = _load_datasets( + _resolve_path(data_dir), + data_config.data_format, + data_config.data_files, + self._num_proc, + ) + + def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]: + return self._dataset_dct.get(split, None) + + def get_dataset( + self, + split: NamedSplit, + process_fn: Callable[[dict[str, Any]], dict[str, Any]], + batched: bool = True, + remove_orig_columns: bool = True, + constant_length: bool = False, + max_length: int = 4096, + ) -> Optional[Dataset]: + orig_dataset = self._get_dataset(split) + if orig_dataset is None: + return + + if remove_orig_columns: + remove_columns = orig_dataset.column_names + else: + remove_columns = None + + res = orig_dataset.map( + process_fn, + batched=batched, + remove_columns=remove_columns, + num_proc=self._num_proc, + ) + if constant_length: + res = self.constant_length_dataset(res, max_length=max_length) + return res + + def constant_length_dataset(self, dataset, max_length): + res = ClosedToConstantLengthSplicedDataset(dataset, max_length=max_length) + return res + + +def print_model_size(model: PreTrainedModel): + print("--> Model") + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"\n--> model has {total_params / 1e6}M params\n") + + +def process_batch( + batch: Mapping[str, Sequence], + tokenizer: PreTrainedTokenizer, + max_input_length: int, + max_output_length: int, +) -> dict[str, list]: + batched_tools = batch.get('tools', None) + batched_conv = batch['conversations'] + batched_input_ids = [] + batched_labels = [] + + if batched_tools is None: + batched_tools = [None] * len(batched_conv) + + for tools, conv in zip(batched_tools, batched_conv): + input_ids, loss_masks = [ + tokenizer.get_command('[gMASK]'), + tokenizer.get_command('sop'), + ], [False, False] + + if tools is not None: + raise NotImplementedError() + + for message in conv: + if message['role'] in ('system', 'user'): + loss_mask_val = False + else: + loss_mask_val = True + + if message['role'] == 'tool': + raise NotImplementedError() + else: + new_input_ids = tokenizer.build_single_message( + message['role'], '', message['content'] + ) + new_loss_masks = [loss_mask_val] * len(new_input_ids) + + input_ids += new_input_ids + loss_masks += new_loss_masks + + input_ids.append(tokenizer.eos_token_id) + loss_masks = [False, *loss_masks] + labels = [] + for input_id, mask in zip(input_ids, loss_masks): + if mask: + labels.append(input_id) + else: + labels.append(-100) + max_length = max_input_length + max_output_length + 1 + batched_input_ids.append(input_ids[:max_length]) + batched_labels.append(labels[:max_length]) + return {'input_ids': batched_input_ids, 'labels': batched_labels} + + +def process_batch_eval( + batch: Mapping[str, Sequence], + tokenizer: PreTrainedTokenizer, + max_input_length: int, + max_output_length: int, +) -> dict[str, list]: + batched_tools = batch.get('tools', None) + batched_conv = batch['conversations'] + batched_input_ids = [] + # To avoid computing loss, we do not provide the `labels` field in the input dictionary. + batched_output_ids = [] + + if batched_tools is None: + batched_tools = [None] * len(batched_conv) + + for tools, conv in zip(batched_tools, batched_conv): + input_ids = [ + tokenizer.get_command('[gMASK]'), + tokenizer.get_command('sop'), + ] + + if tools is not None: + raise NotImplementedError() + + for message in conv: + if len(input_ids) >= max_input_length: + break + if message['role'] == 'tool': + raise NotImplementedError() + else: + new_input_ids = tokenizer.build_single_message( + message['role'], '', message['content'] + ) + if message['role'] == 'assistant': + output_prompt, output_ids = ( + new_input_ids[:1], + new_input_ids[1:], + ) + output_ids.append(tokenizer.eos_token_id) + batched_input_ids.append( + input_ids[:max_input_length] + output_prompt[:1] + ) + batched_output_ids.append(output_ids[:max_output_length]) + input_ids += new_input_ids + return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids} + + +# Not sure if this is necessary, can set it to half. +# If train with cpu, cast all params to fp32 instead of trainable ones. +def _prepare_model_for_training(model: nn.Module, use_cpu: bool): + for param in model.parameters(): + if param.requires_grad or use_cpu: + param.data = param.data.to(torch.float32) + + +def load_tokenizer_and_model( + model_dir: str, + peft_config: Optional[PeftConfig] = None, +) -> tuple[PreTrainedTokenizer, nn.Module]: + tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + if peft_config is not None: + if peft_config.peft_type.name == "PREFIX_TUNING": + config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + config.pre_seq_len = peft_config.num_virtual_tokens + config.use_cache = False + model = AutoModelForCausalLM.from_pretrained( + model_dir, + trust_remote_code=True, + config=config, + ) + + if peft_config.peft_type.name == "LORA": + model = AutoModelForCausalLM.from_pretrained( + model_dir, + trust_remote_code=True, + empty_init=False, + use_cache=False + ) + + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + else: + model = AutoModelForCausalLM.from_pretrained( + model_dir, + trust_remote_code=True, + empty_init=False, + use_cache=False + ) + + print_model_size(model) + return tokenizer, model + + +def compute_metrics(eval_preds: EvalPrediction, tokenizer: PreTrainedTokenizer): + batched_pred_ids, batched_label_ids = eval_preds + + metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []} + for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids): + pred_txt = tokenizer.decode(pred_ids).strip() + label_txt = tokenizer.decode(label_ids).strip() + pred_tokens = list(jieba.cut(pred_txt)) + label_tokens = list(jieba.cut(label_txt)) + rouge = Rouge() + scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens)) + for k, v in scores[0].items(): + metrics_dct[k].append(round(v['f'] * 100, 4)) + metrics_dct['bleu-4'].append( + sentence_bleu( + [label_tokens], + pred_tokens, + smoothing_function=SmoothingFunction().method3, + ) + ) + return {k: np.mean(v) for k, v in metrics_dct.items()} + + +@app.command() +def main( + data_dir: Annotated[str, typer.Argument(help='')], + model_dir: Annotated[ + str, + typer.Argument( + help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.' + ), + ], + config_file: Annotated[str, typer.Argument(help='')], + auto_resume_from_checkpoint: str = typer.Argument( + default='', + help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training' + ), + +): + ft_config = FinetuningConfig.from_file(config_file) + tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config) + data_manager = DataManager(data_dir, ft_config.data_config) + + train_dataset = data_manager.get_dataset( + Split.TRAIN, + functools.partial( + process_batch, + tokenizer=tokenizer, + max_input_length=ft_config.max_input_length, + max_output_length=ft_config.max_output_length, + ), + batched=True, + constant_length=ft_config.data_config.constant_length, + max_length=ft_config.max_input_length+ft_config.max_output_length, + ) + print('train_dataset:', train_dataset) + val_dataset = data_manager.get_dataset( + Split.VALIDATION, + functools.partial( + process_batch_eval, + tokenizer=tokenizer, + max_input_length=ft_config.max_input_length, + max_output_length=ft_config.max_output_length, + ), + batched=True, + ) + if val_dataset is not None: + print('val_dataset:', val_dataset) + test_dataset = data_manager.get_dataset( + Split.TEST, + functools.partial( + process_batch_eval, + tokenizer=tokenizer, + max_input_length=ft_config.max_input_length, + max_output_length=ft_config.max_output_length, + ), + batched=True, + ) + if test_dataset is not None: + print('test_dataset:', test_dataset) + + # checks encoded dataset + if isinstance(train_dataset, Dataset): + _sanity_check( + train_dataset[0]["input_ids"], train_dataset[0]["labels"], tokenizer + ) + elif isinstance(train_dataset, IterableDataset): + example = next(iter(train_dataset)) + input_ids, labels = example["input_ids"], example["labels"] + _sanity_check( + input_ids, labels, tokenizer + ) + else: + raise KeyError + # turn model to fp32 + # _prepare_model_for_training(model, ft_config.training_args.use_cpu) + + ft_config.training_args.generation_config.pad_token_id = ( + tokenizer.pad_token_id + ) + ft_config.training_args.generation_config.eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.get_command('<|user|>'), + tokenizer.get_command('<|observation|>'), + ] + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + + use_tokenizer = True + if ft_config.peft_config is not None: + use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True + + trainer = Seq2SeqTrainer( + model=model, + args=ft_config.training_args, + data_collator=DataCollatorForSeq2Seq( + tokenizer=tokenizer, + padding='longest', + return_tensors='pt', + ), + train_dataset=train_dataset, + eval_dataset=val_dataset.select(list(range(50))), + tokenizer=tokenizer if use_tokenizer else None, # LORA does not need tokenizer + compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer), + ) + + if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None: + trainer.train() + else: + def do_rf_checkpoint(sn): + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + checkpoint_directory = os.path.join(output_dir, "checkpoint-" + sn) + print("resume checkpoint from checkpoint-" + sn) + trainer.train(resume_from_checkpoint=checkpoint_directory) + + output_dir = ft_config.training_args.output_dir + + # resume from latest checkpoint + if auto_resume_from_checkpoint.upper() == "YES": + dirlist = os.listdir(output_dir) + checkpoint_sn = 0 + # get latest checkpoint + for checkpoint_str in dirlist: + if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1: + checkpoint = int(checkpoint_str.replace("checkpoint-", "")) + if checkpoint > checkpoint_sn: + checkpoint_sn = checkpoint + if checkpoint_sn > 0: + do_rf_checkpoint(str(checkpoint_sn)) + else: + trainer.train() + else: + # resume from specific checkpoint + if auto_resume_from_checkpoint.isdigit() and int(auto_resume_from_checkpoint) > 0: + do_rf_checkpoint(auto_resume_from_checkpoint) + else: + print(auto_resume_from_checkpoint, + "The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct chkeckpoint in the model output directory") + + # test stage + if test_dataset is not None: + trainer.predict(test_dataset) + + +if __name__ == '__main__': + app() diff --git a/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/inference_hf.py b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/inference_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd004262042621e0cfea387bf68440b3c466ab7 --- /dev/null +++ b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/inference_hf.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from pathlib import Path +from typing import Annotated, Union + +import typer +from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) + +ModelType = Union[PreTrainedModel, PeftModelForCausalLM] +TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + +app = typer.Typer(pretty_exceptions_show_locals=False) + + +def _resolve_path(path: Union[str, Path]) -> Path: + return Path(path).expanduser().resolve() + + +def load_model_and_tokenizer(model_dir: Union[str, Path]) -> tuple[ModelType, TokenizerType]: + model_dir = _resolve_path(model_dir) + if (model_dir / 'adapter_config.json').exists(): + model = AutoPeftModelForCausalLM.from_pretrained( + model_dir, trust_remote_code=True, device_map='auto' + ) + tokenizer_dir = model.peft_config['default'].base_model_name_or_path + else: + model = AutoModelForCausalLM.from_pretrained( + model_dir, trust_remote_code=True, device_map='auto' + ) + tokenizer_dir = model_dir + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_dir, trust_remote_code=True + ) + return model, tokenizer + + +@app.command() +def main( + model_dir: Annotated[str, typer.Argument(help='')], + prompt: Annotated[str, typer.Option(help='')], +): + model, tokenizer = load_model_and_tokenizer(model_dir) + response, _ = model.chat(tokenizer, prompt) + print(response) + + +if __name__ == '__main__': + app() diff --git a/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/models/modeling_chatglm.py b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/models/modeling_chatglm.py new file mode 100644 index 0000000000000000000000000000000000000000..513d0419f3af915dad44a4e8a33381e442d9ff74 --- /dev/null +++ b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/models/modeling_chatglm.py @@ -0,0 +1,1304 @@ +# Copyright (c) 2023, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. + +""" PyTorch ChatGLM model. """ + +import math +import copy +import warnings +import re +import sys + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any +from copy import deepcopy + +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM" +_CONFIG_FOR_DOC = "ChatGLMConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm3-6b", + # See all ChatGLM models at https://huggingface.co/models?filter=chatglm +] + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 + self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, + config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl( + self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 + ): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl( + max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device + ) + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split('.')[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, + device=query_layer.device + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: + attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], + device=attention_scores.device, dtype=torch.bool) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + self.config = config + self.projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) + self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, **_config_to_kwargs(config) + ) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, + device=device, **_config_to_kwargs(config) + ) + + def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True + ): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def _config_to_kwargs(args): + common_kwargs = { + "dtype": args.torch_dtype, + } + return common_kwargs + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + self.gradient_checkpointing = False + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache, + use_reentrant=False + ) + else: + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, + device=input_ids.device), full_attention_mask), dim=-1) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + return position_ids + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device + ) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class ChatGLMModel(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels + ) + + self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device, + dtype=config.torch_dtype) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, + dtype=config.torch_dtype, **init_kwargs) + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + def get_input_embeddings(self): + return self.embedding.word_embeddings + + def set_input_embeddings(self, value): + self.embedding.word_embeddings = value + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.multi_query_group_num, + self.kv_channels + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, + dtype=inputs_embeds.dtype) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask], dim=-1) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def quantize(self, weight_bit_width: int): + from .quantization import quantize + quantize(self.encoder, weight_bit_width) + return self + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + self.config = config + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) + + model_kwargs["is_first_forward"] = False + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + is_first_forward: bool = True, + **kwargs + ) -> dict: + # only last token for input_ids if past is not None + if position_ids is None: + position_ids = self.get_position_ids(input_ids, device=input_ids.device) + if not is_first_forward: + if past_key_values is not None: + position_ids = position_ids[..., -1:] + input_ids = input_ids[:, -1:] + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "position_ids": position_ids, + "attention_mask": attention_mask, + "return_last_logit": True, + "use_cache": use_cache + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def process_response(self, output, history): + content = "" + history = deepcopy(history) + for response in output.split("<|assistant|>"): + if "\n" in response: + metadata, content = response.split("\n", maxsplit=1) + else: + metadata, content = "", response + if not metadata.strip(): + content = content.strip() + history.append({"role": "assistant", "metadata": metadata, "content": content}) + content = content.replace("[[训练时间]]", "2023年") + else: + history.append({"role": "assistant", "metadata": metadata, "content": content}) + if history[0]["role"] == "system" and "tools" in history[0]: + content = "\n".join(content.split("\n")[1:-1]) + def tool_call(**kwargs): + return kwargs + parameters = eval(content) + content = {"name": metadata.strip(), "parameters": parameters} + else: + content = {"name": metadata.strip(), "content": content} + return content, history + + @torch.inference_mode() + def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", + max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, + **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + inputs = tokenizer.build_chat_input(query, history=history, role=role) + inputs = inputs.to(self.device) + eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), + tokenizer.get_command("<|observation|>")] + outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] + response = tokenizer.decode(outputs) + history.append({"role": role, "content": query}) + response, history = self.process_response(response, history) + return response, history + + @torch.inference_mode() + def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", + past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, + logits_processor=None, return_past_key_values=False, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), + tokenizer.get_command("<|observation|>")] + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if past_key_values is None: + inputs = tokenizer.build_chat_input(query, history=history, role=role) + else: + inputs = tokenizer.build_chat_input(query, role=role) + inputs = inputs.to(self.device) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[0] + if self.transformer.pre_seq_len is not None: + past_length -= self.transformer.pre_seq_len + inputs.position_ids += past_length + attention_mask = inputs.attention_mask + attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) + inputs['attention_mask'] = attention_mask + history.append({"role": role, "content": query}) + for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, + eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, + **gen_kwargs): + if return_past_key_values: + outputs, past_key_values = outputs + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] + response = tokenizer.decode(outputs) + if response and response[-1] != "�": + response, new_history = self.process_response(response, history) + if return_past_key_values: + yield response, new_history, past_key_values + else: + yield response, new_history + + @torch.inference_mode() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + return_past_key_values=False, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + model_kwargs["use_cache"] = generation_config.use_cache + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + + def quantize(self, bits: int, empty_init=False, device=None, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, + **kwargs) + return self + + +class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.num_labels = config.num_labels + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + + self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half) + if config.classifier_dropout is not None: + self.dropout = nn.Dropout(config.classifier_dropout) + else: + self.dropout = None + self.config = config + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + full_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + full_attention_mask=full_attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + pooled_hidden_states = hidden_states[-1] + if self.dropout is not None: + pooled_hidden_states = self.dropout(pooled_hidden_states) + logits = self.classifier_head(pooled_hidden_states) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze().float(), labels.squeeze()) + else: + loss = loss_fct(logits.float(), labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits.float(), labels.view(-1, self.num_labels)) + + if not return_dict: + output = (logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/process_data.py b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/process_data.py new file mode 100644 index 0000000000000000000000000000000000000000..08da8a916064c93a9b15a0a67f529f4ed316793d --- /dev/null +++ b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/process_data.py @@ -0,0 +1,46 @@ +import json +import os +import argparse + +def process(json_path, save_path, max_samples=None): + parsed_data = [] + data = [] + with open(json_path, 'r', encoding='utf-8') as file: + for line in file: + try: + data.append(json.loads(line)) + except json.JSONDecodeError as e: + print(f"JSONDecodeError: {e.msg} in line {line}") + + for item in data: + parsed_item = dict() + parsed_item["conversations"] = [{"role": "user", "content": item["content"]}] + parsed_item["conversations"].append({"role": "assistant", "content": item["summary"]}) + parsed_data.append(parsed_item) + + with open(save_path, 'w', encoding='utf-8') as outfile: + for i, item in enumerate(parsed_data): + if max_samples and i >= max_samples: + print(f"note: save just {max_samples} max_samples to outfile") + break + json.dump(item, outfile, ensure_ascii=False) + outfile.write('\n') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--eval_max_samples', type=int, default=None, help='can get a small eval file if just for test function') + args = parser.parse_args() + + train_json_path = "data/AdvertiseGen/train.json" + process_train_json = "data/AdvertiseGen_process/train.json" + eval_json_path = "data/AdvertiseGen/dev.json" + process_eval_path = "data/AdvertiseGen_process/dev.json" + if not os.path.exists(os.path.dirname(os.path.abspath(process_train_json))): + os.mkdir(os.path.dirname(os.path.abspath(process_train_json))) + + print("process train datasets ...") + process(train_json_path, process_train_json) + + print("process eval datasets ... ") + process(eval_json_path, process_eval_path, max_samples=args.eval_max_samples) diff --git a/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/requirements.txt b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..918c116c988cac4b37832c560d318873c16d77e3 --- /dev/null +++ b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/requirements.txt @@ -0,0 +1,13 @@ +jieba>=0.42.1 +ruamel_yaml>=0.18.6 +rouge_chinese>=1.0.3 +jupyter>=1.0.0 +datasets>=2.18.0 +peft>=0.10.0 +mpi4py>=3.1.5 +transformers==4.40.0 +accelerate>=0.33.0 +typer>=0.9.0 +nltk +sentencepiece + diff --git a/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/run.sh b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..9378fca288ad788d84b98096669f7f4edc26e1cf --- /dev/null +++ b/nlp/llm/chatglm3-6b/deepspeed/finetune_demo/run.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Copyright (c) 2023, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +NUM_GPUS=16 +CONFIG_FILE=$1 +MASTER_PORT=$(shuf -n 1 -i 10000-65535) + +# Check if a second argument is provided for NUM_GPUS +if [ ! -z "$2" ]; then + NUM_GPUS=$2 +fi + +echo "start training with num_gpus=$NUM_GPUS | and config_file=$CONFIG_FILE" + +# modeling_chatglm.py 有一点修改,用原生的会导致sft训练报错 +cp -r models/* checkpoint/chatglm3-6b + +torchrun --nnodes=1 --master_port=$MASTER_PORT --nproc_per_node=$NUM_GPUS finetune_hf.py \ + data/AdvertiseGen_process/ \ + checkpoint/chatglm3-6b \ + $CONFIG_FILE