From 5ba7f3af682720f25472aedf7697120facfd2100 Mon Sep 17 00:00:00 2001 From: "hui.sang" Date: Thu, 10 Nov 2022 19:51:00 +0800 Subject: [PATCH 1/2] cpm model add checkpoint link #I60KVO cpm model add checkpoint Signed-off-by: hui.sang --- nlp/dialogue_generation/cpm/pytorch/base/run_pretraining.py | 3 +++ nlp/dialogue_generation/cpm/pytorch/base/train/trainer.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/nlp/dialogue_generation/cpm/pytorch/base/run_pretraining.py b/nlp/dialogue_generation/cpm/pytorch/base/run_pretraining.py index 9770a2314..c05eb1a4d 100755 --- a/nlp/dialogue_generation/cpm/pytorch/base/run_pretraining.py +++ b/nlp/dialogue_generation/cpm/pytorch/base/run_pretraining.py @@ -105,6 +105,9 @@ def main(): training_event.on_train_end() raw_train_end_time = logger.previous_log_time training_state.raw_train_time = (raw_train_end_time - raw_train_start_time) / 1e+3 + + trainer.save_checkpoint() + return config, training_state if __name__ == "__main__": diff --git a/nlp/dialogue_generation/cpm/pytorch/base/train/trainer.py b/nlp/dialogue_generation/cpm/pytorch/base/train/trainer.py index cf201efe9..20139a72c 100755 --- a/nlp/dialogue_generation/cpm/pytorch/base/train/trainer.py +++ b/nlp/dialogue_generation/cpm/pytorch/base/train/trainer.py @@ -150,3 +150,8 @@ class Trainer(): ]) return do_eval or state.global_steps >= self.config.max_steps + + def save_checkpoint(self): + if self.config.n_gpu == 1 or (self.config.n_gpu > 1 and self.config.device == 0): + print("save checkpoint...") + torch.save(self.model.module.state_dict(), "cpm_model_states_medium_end2end.pt") \ No newline at end of file -- Gitee From 03f5d17f252be7a2993a0064e36949c398c4a0a5 Mon Sep 17 00:00:00 2001 From: "hui.sang" Date: Thu, 10 Nov 2022 20:30:58 +0800 Subject: [PATCH 2/2] dlrm model add checkpoint lind #I60L07 dlrm model add checkpoint Signed-off-by: hui.sang --- .../cpm/pytorch/base/run_pretraining.py | 15 +++++++++++++++ .../cpm/pytorch/base/train/trainer.py | 15 +++++++++++++++ .../ctr/dlrm/pytorch/dlrm/dist_model.py | 7 +++++++ recommendation/ctr/dlrm/pytorch/scripts/train.py | 2 +- 4 files changed, 38 insertions(+), 1 deletion(-) diff --git a/nlp/dialogue_generation/cpm/pytorch/base/run_pretraining.py b/nlp/dialogue_generation/cpm/pytorch/base/run_pretraining.py index c05eb1a4d..c2e9a30a0 100755 --- a/nlp/dialogue_generation/cpm/pytorch/base/run_pretraining.py +++ b/nlp/dialogue_generation/cpm/pytorch/base/run_pretraining.py @@ -1,3 +1,18 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + """BERT Pretraining""" from __future__ import absolute_import diff --git a/nlp/dialogue_generation/cpm/pytorch/base/train/trainer.py b/nlp/dialogue_generation/cpm/pytorch/base/train/trainer.py index 20139a72c..4a2098732 100755 --- a/nlp/dialogue_generation/cpm/pytorch/base/train/trainer.py +++ b/nlp/dialogue_generation/cpm/pytorch/base/train/trainer.py @@ -1,3 +1,18 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + import math import time import os diff --git a/recommendation/ctr/dlrm/pytorch/dlrm/dist_model.py b/recommendation/ctr/dlrm/pytorch/dlrm/dist_model.py index 467fb9ba2..706a53e98 100644 --- a/recommendation/ctr/dlrm/pytorch/dlrm/dist_model.py +++ b/recommendation/ctr/dlrm/pytorch/dlrm/dist_model.py @@ -427,3 +427,10 @@ class DistDlrm(): def to(self, *args, **kwargs): self.bottom_model.to(*args, **kwargs) self.top_model.to(*args, **kwargs) + + def state_dict(self): + dlrm_state_dic = {} + dlrm_state_dic.update(self.bottom_model.state_dict()) + dlrm_state_dic.update(self.top_model.state_dict()) + + return dlrm_state_dic diff --git a/recommendation/ctr/dlrm/pytorch/scripts/train.py b/recommendation/ctr/dlrm/pytorch/scripts/train.py index 24eed49c7..682ddb0a8 100644 --- a/recommendation/ctr/dlrm/pytorch/scripts/train.py +++ b/recommendation/ctr/dlrm/pytorch/scripts/train.py @@ -53,7 +53,7 @@ flags.DEFINE_enum("dataset_type", "memmap", ["bin", "memmap", "dist"], "Which da flags.DEFINE_boolean("use_embedding_ext", True, "Use embedding cuda extension. If False, use Pytorch embedding") # Saving and logging flags -flags.DEFINE_string("output_dir", "/tmp", "path where to save") +flags.DEFINE_string("output_dir", ".", "path where to save") flags.DEFINE_integer("test_freq", None, "#steps test. If None, 20 tests per epoch per MLperf rule.") flags.DEFINE_float("test_after", 0, "Don't test the model unless this many epochs has been completed") flags.DEFINE_integer("print_freq", None, "#steps per pring") -- Gitee