diff --git a/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/DataLoader.py b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/DataLoader.py new file mode 100644 index 0000000000000000000000000000000000000000..c79b9ce0db3efdb5629e676c6dcc0cc5a12ca4bb --- /dev/null +++ b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/DataLoader.py @@ -0,0 +1,118 @@ +# +# BSD 3-Clause License +# +# Copyright (c) 2017 xxxx +# All rights reserved. +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ +# +import torch +import torch.npu +import os +NPU_CALCULATE_DEVICE = 0 +if os.getenv('NPU_CALCULATE_DEVICE') and str.isdigit(os.getenv('NPU_CALCULATE_DEVICE')): + NPU_CALCULATE_DEVICE = int(os.getenv('NPU_CALCULATE_DEVICE')) +if torch.npu.current_device() != NPU_CALCULATE_DEVICE: + torch.npu.set_device(f'npu:{NPU_CALCULATE_DEVICE}') + +class MyDataSet( torch.utils.data.Dataset): + def __init__( self, data_path, user_map, material_map, category_map, max_length): + + user = []; material = []; category = [] + material_historical = []; category_historical = [] + material_historical_neg = []; category_historical_nge = [] + mask = []; sequential_length = [] + target = [] + + with open( data_path, 'r') as fin: + + for line in fin: + item = line.strip('\n').split('\t') + if not item: continue + + user.append( user_map.get( item[1], 0 ) ) + material.append( material_map.get( item[2], 0 ) ) + category.append( category_map.get( item[3], 0 ) ) + + material_historical_item = [0] * max_length + temp = item[4].split("") + if( len( temp) >= max_length): temp = temp[ -max_length:] + for i, m in enumerate( temp): + material_historical_item[i] = material_map.get( m, 0 ) + material_historical.append( material_historical_item) + + category_historical_item = [0] * max_length + temp = item[5].split("") + if( len( temp) >= max_length): temp = temp[ -max_length:] + for i, c in enumerate( temp): + category_historical_item[i] = category_map.get( c, 0 ) + category_historical.append( category_historical_item) + + temp = min( len(temp), max_length) + mask_item = [1] * temp + [0] * ( max_length - temp) + + mask.append( mask_item) + sequential_length.append( temp) + + target.append( int( item[0])) + + self.user = torch.tensor( user) + + self.material = torch.tensor( material) + self.catetory = torch.tensor( category) + + self.material_historical = torch.tensor( material_historical) + self.catetory_historical = torch.tensor( category_historical) + + self.mask = torch.tensor( mask).type( torch.FloatTensor) + self.sequential_length = torch.tensor( sequential_length).type( torch.FloatTensor) + + self.target = torch.tensor( target).type( torch.FloatTensor) + + + def __len__( self): + return len( self.user) + + def __getitem__(self, index): + if torch.is_tensor( index): + index = index.tolist() + + user = self.user[ index] + + material_historical = self.material_historical[ index, :] + category_historical = self.catetory_historical[ index, :] + mask = self.mask[ index, :] + sequential_length = self.sequential_length[ index] + + material = self.material[ index] + category = self.catetory[ index] + + target = self.target[ index] + + return user, material_historical, category_historical, mask, sequential_length , \ + material, category, 0, 0, target \ No newline at end of file diff --git a/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/LICENSE b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..09d493bf1fc257505c1336f3f87425568ab9da3c --- /dev/null +++ b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2017, +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/README.md b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/README.md new file mode 100644 index 0000000000000000000000000000000000000000..09d4601fb61104bf96425847b391e307719e17e4 --- /dev/null +++ b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/README.md @@ -0,0 +1,45 @@ +# DIN/DIEN + +Implementation based on pytorch for DIN recommendation algorithm + + +## Attention + +1. For convenience, referring to authors tensorflow implementation, feature-embedding dimension is identical. +2. Without any L1/L2 normalization or dropout strategy, it's supposed to choose suitable model according to the evaluation stage manually. + +## File description +|file name|description| +|--|----| +|main.ipynb|Session for training and evaluation| +|model.py|Defination of target models| +|DataLoader.py|Self-defined data loader| +|environment.yml|Conda envrionment yaml| + +## Original paper +[Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf) + +[Deep Interest Evolution Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1809.03672.pdf) + +## Source data +[meta_Books.json.gz](http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Books.json.gz) + +[reviews_Books.json.gz](http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Books.json.gz) + +Preprocessed data wrapped within `data.tar.gz` came from [mouna99/dien](https://github.com/mouna99/dien) + +## Reference + +[mouna99/dien](https://github.com/mouna99/dien) + +[alibaba/x-deeplearning](https://github.com/alibaba/x-deeplearning) + +[shenweichen/DeepCTR-Torch](https://github.com/shenweichen/DeepCTR-Torch) + + +## To do list + +- [x] DIN +- [x] AUGRU +- [ ] DICE activation layer +- [ ] Auxialary loss with neg_sample \ No newline at end of file diff --git a/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/environment.yml b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..b21a399bff2773027ba05bc77a296671fa7cd331 --- /dev/null +++ b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/environment.yml @@ -0,0 +1,183 @@ +name: my_env +channels: + - pytorch + - https://conda.anaconda.org/anaconda + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - astroid=2.4.1=py38_0 + - attrs=19.3.0=py_0 + - backcall=0.1.0=py38_0 + - blas=1.0=mkl + - bleach=3.1.4=py_0 + - bzip2=1.0.8=h516909a_2 + - ca-certificates=2020.6.24=0 + - cairo=1.16.0=hcf35c78_1003 + - certifi=2020.6.20=py38_0 + - cloudpickle=1.4.1=py_0 + - cudatoolkit=10.2.89=hfd86e86_1 + - cycler=0.10.0=py_2 + - cytoolz=0.10.1=py38h7b6447c_0 + - dask-core=2.17.2=py_0 + - dbus=1.13.14=hb2f20db_0 + - decorator=4.4.2=py_0 + - defusedxml=0.6.0=py_0 + - entrypoints=0.3=py38_0 + - expat=2.2.9=he1b5a44_2 + - ffmpeg=4.2=h167e202_0 + - fontconfig=2.13.1=h86ecdb6_1001 + - freetype=2.9.1=h8a8886c_1 + - gettext=0.19.8.1=hc5be6a0_1002 + - giflib=5.2.1=h516909a_2 + - glib=2.64.3=h6f030ca_0 + - gmp=6.1.2=h6c8ec71_1 + - gnutls=3.6.13=h79a8f9a_0 + - graphite2=1.3.13=he1b5a44_1001 + - gst-plugins-base=1.14.5=h0935bb2_2 + - gstreamer=1.14.5=h36ae1b5_2 + - harfbuzz=2.4.0=h9f30f68_3 + - hdf5=1.10.6=nompi_h3c11f04_100 + - icu=64.2=he1b5a44_1 + - imageio=2.8.0=py_0 + - importlib_metadata=1.5.0=py38_0 + - intel-openmp=2020.0=166 + - ipykernel=5.1.4=py38h39e3cac_0 + - ipython=7.13.0=py38h5ca1d4c_0 + - ipython_genutils=0.2.0=py38_0 + - isort=4.3.21=py38_0 + - jasper=1.900.1=h07fcdf6_1006 + - jedi=0.17.0=py38_0 + - jinja2=2.11.2=py_0 + - joblib=0.16.0=py_0 + - jpeg=9d=h516909a_0 + - jsonschema=3.2.0=py38_0 + - jupyter_client=6.1.3=py_0 + - jupyter_core=4.6.3=py38_0 + - kiwisolver=1.2.0=py38hbf85e49_0 + - lame=3.100=h14c3975_1001 + - lazy-object-proxy=1.4.3=py38h7b6447c_0 + - ld_impl_linux-64=2.34=h53a641e_4 + - libblas=3.8.0=15_mkl + - libcblas=3.8.0=15_mkl + - libclang=9.0.1=default_hde54327_0 + - libedit=3.1.20181209=hc058e9b_0 + - libffi=3.2.1=he1b5a44_1007 + - libgcc-ng=9.1.0=hdf63c60_0 + - libgfortran-ng=7.3.0=hdf63c60_0 + - libiconv=1.15=h516909a_1006 + - liblapack=3.8.0=15_mkl + - liblapacke=3.8.0=15_mkl + - libllvm9=9.0.1=he513fc3_1 + - libopencv=4.2.0=py38_6 + - libpng=1.6.37=hbc83047_0 + - libsodium=1.0.16=h1bed415_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.1.0=h2733197_0 + - libuuid=2.32.1=h14c3975_1000 + - libwebp=1.0.2=h56121f0_5 + - libxcb=1.13=h14c3975_1002 + - libxkbcommon=0.10.0=he1b5a44_0 + - libxml2=2.9.9=hea5a465_1 + - markupsafe=1.1.1=py38h7b6447c_0 + - matplotlib=3.2.1=0 + - matplotlib-base=3.2.1=py38h2af1d28_0 + - mccabe=0.6.1=py38_1 + - mistune=0.8.4=py38h7b6447c_1000 + - mkl=2020.0=166 + - mkl-service=2.3.0=py38he904b0f_0 + - mkl_fft=1.0.15=py38ha843d7b_0 + - mkl_random=1.1.0=py38h962f231_0 + - nb_conda_kernels=2.2.3=py38_0 + - nbconvert=5.6.1=py38_0 + - nbformat=5.0.6=py_0 + - ncurses=6.2=he6710b0_1 + - nettle=3.4.1=h1bed415_1002 + - networkx=2.4=py_0 + - ninja=1.9.0=py38hfd86e86_0 + - notebook=6.0.3=py38_0 + - nspr=4.25=he1b5a44_0 + - nss=3.47=he751ad9_0 + - numpy=1.18.1=py38h4f9e942_0 + - numpy-base=1.18.1=py38hde5b4d6_1 + - olefile=0.46=py_0 + - opencv=4.2.0=py38_6 + - openh264=1.8.0=hdbcaa40_1000 + - openssl=1.1.1g=h7b6447c_0 + - pandas=1.0.4=py38hcb8c335_0 + - pandoc=2.2.3.2=0 + - pandocfilters=1.4.2=py38_1 + - parso=0.7.0=py_0 + - pcre=8.44=he1b5a44_0 + - pexpect=4.8.0=py38_0 + - pickleshare=0.7.5=py38_1000 + - pillow=7.1.2=py38hb39fc2d_0 + - pip=20.0.2=py38_1 + - pixman=0.38.0=h516909a_1003 + - prometheus_client=0.7.1=py_0 + - prompt-toolkit=3.0.4=py_0 + - prompt_toolkit=3.0.4=0 + - pthread-stubs=0.4=h14c3975_1001 + - ptyprocess=0.6.0=py38_0 + - py-opencv=4.2.0=py38h23f93f0_6 + - pygments=2.6.1=py_0 + - pylint=2.5.2=py38_0 + - pyparsing=2.4.7=pyh9f0ad1d_0 + - pyqt=5.12.3=py38ha8c2ead_3 + - pyrsistent=0.16.0=py38h7b6447c_0 + - python=3.8.3=cpython_he5300dc_0 + - python-dateutil=2.8.1=py_0 + - python_abi=3.8=1_cp38 + - pytorch=1.5.0=py3.8_cuda10.2.89_cudnn7.6.5_0 + - pytz=2020.1=pyh9f0ad1d_0 + - pywavelets=1.1.1=py38h7b6447c_0 + - pyyaml=5.3.1=py38h7b6447c_0 + - pyzmq=18.1.1=py38he6710b0_0 + - qt=5.12.5=hd8c4c69_1 + - readline=8.0=h7b6447c_0 + - redis=5.0.3=h7b6447c_0 + - scikit-image=0.16.2=py38h0573a6f_0 + - scikit-learn=0.23.1=py38h423224d_0 + - scipy=1.4.1=py38h0b6359f_0 + - send2trash=1.5.0=py38_0 + - setuptools=46.1.3=py38_0 + - sip=4.19.13=py38he6710b0_0 + - six=1.14.0=py38_0 + - sqlite=3.31.1=h62c20be_1 + - terminado=0.8.3=py38_0 + - testpath=0.4.4=py_0 + - threadpoolctl=2.1.0=pyh5ca1d4c_0 + - tk=8.6.10=hed695b0_0 + - toml=0.10.0=pyh91ea838_0 + - toolz=0.10.0=py_0 + - torchvision=0.6.0=py38_cu102 + - tornado=6.0.4=py38h7b6447c_1 + - traitlets=4.3.3=py38_0 + - wcwidth=0.1.9=py_0 + - webencodings=0.5.1=py38_1 + - wheel=0.34.2=py38_0 + - wrapt=1.11.2=py38h7b6447c_0 + - x264=1!152.20180806=h14c3975_0 + - xorg-kbproto=1.0.7=h14c3975_1002 + - xorg-libice=1.0.10=h516909a_0 + - xorg-libsm=1.2.3=h84519dc_1000 + - xorg-libx11=1.6.9=h516909a_0 + - xorg-libxau=1.0.9=h14c3975_0 + - xorg-libxdmcp=1.1.3=h516909a_0 + - xorg-libxext=1.3.4=h516909a_0 + - xorg-libxrender=0.9.10=h516909a_1002 + - xorg-renderproto=0.11.1=h14c3975_1002 + - xorg-xextproto=7.3.0=h14c3975_1002 + - xorg-xproto=7.0.31=h14c3975_1007 + - xz=5.2.5=h7b6447c_0 + - yaml=0.1.7=h96e3832_1 + - zeromq=4.3.1=he6710b0_3 + - zipp=3.1.0=py_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.3.7=h0b5b093_0 + - pip: + - pyqt5-sip==4.19.18 + - pyqtchart==5.12 + - pyqtwebengine==5.12.1 +prefix: /home/juboge/opt/anaconda3/envs/my_env + diff --git a/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/layer.py b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..ca24a1b63bdff8c460286ef6a7c1a7cfb8a33c1a --- /dev/null +++ b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/layer.py @@ -0,0 +1,130 @@ +# +# BSD 3-Clause License +# +# Copyright (c) 2017 xxxx +# All rights reserved. +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ +# +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.npu +import os +NPU_CALCULATE_DEVICE = 0 +if os.getenv('NPU_CALCULATE_DEVICE') and str.isdigit(os.getenv('NPU_CALCULATE_DEVICE')): + NPU_CALCULATE_DEVICE = int(os.getenv('NPU_CALCULATE_DEVICE')) +if torch.npu.current_device() != NPU_CALCULATE_DEVICE: + torch.npu.set_device(f'npu:{NPU_CALCULATE_DEVICE}') + +class MLP( nn.Module): + + def __init__(self, input_dimension, hidden_size , target_dimension = 1, activation_layer = 'LeakyReLU'): + super().__init__() + + Activation = nn.LeakyReLU + + # if activation_layer == 'DICE': pass + # elif activation_layer == 'LeakyReLU': pass + + def _dense( in_dim, out_dim, bias = False): + return nn.Sequential( + nn.Linear( in_dim, out_dim, bias = bias), + nn.BatchNorm1d( out_dim), + Activation( 0.1 )) + + dimension_pair = [input_dimension] + hidden_size + layers = [ _dense( dimension_pair[i], dimension_pair[i+1]) for i in range( len( hidden_size))] + + layers.append( nn.Linear( hidden_size[-1], target_dimension)) + layers.insert( 0, nn.BatchNorm1d( input_dimension) ) + + self.model = nn.Sequential( *layers ) + + def forward( self, X): return self.model( X) + + +class InputEmbedding( nn.Module): + + def __init__(self, n_uid, n_mid, n_cid, embedding_dim ): + super().__init__() + self.user_embedding_unit = nn.Embedding( n_uid, embedding_dim) + self.material_embedding_unit = nn.Embedding( n_mid, embedding_dim) + self.category_embedding_unit = nn.Embedding( n_cid, embedding_dim) + + def forward( self, user, material, category, material_historical, category_historical, + material_historical_neg, category_historical_neg, neg_smaple = False ): + + user_embedding = self.user_embedding_unit( user) + + material_embedding = self.material_embedding_unit( material) + material_historical_embedding = self.material_embedding_unit( material_historical) + + category_embedding = self.category_embedding_unit( category) + category_historical_embedding = self.category_embedding_unit( category_historical) + + material_historical_neg_embedding = self.material_embedding_unit( material_historical_neg) if neg_smaple else None + category_historical_neg_embedding = self.category_embedding_unit( category_historical_neg) if neg_smaple else None + + ans = [ user_embedding, material_historical_embedding, category_historical_embedding, + material_embedding, category_embedding, material_historical_neg_embedding, category_historical_neg_embedding ] + return tuple( map( lambda x: x.squeeze() if x != None else None , ans) ) + + + +class AttentionLayer( nn.Module): + + def __init__(self, embedding_dim, hidden_size, activation_layer = 'sigmoid'): + super().__init__() + + Activation = nn.Sigmoid + if activation_layer == 'Dice': pass + + def _dense( in_dim, out_dim): + return nn.Sequential( nn.Linear( in_dim, out_dim), Activation() ) + + dimension_pair = [embedding_dim * 8] + hidden_size + layers = [ _dense( dimension_pair[i], dimension_pair[i+1]) for i in range( len( hidden_size))] + layers.append( nn.Linear( hidden_size[-1], 1) ) + self.model = nn.Sequential( *layers) + + def forward( self, query, fact, mask, return_scores = False): + B, T, D = fact.shape + + query = torch.ones((B, T, 1), device=f'npu:{NPU_CALCULATE_DEVICE}', dtype=torch.float16) * query.view( (B, 1, D)) + # query = query.view(-1).expand( T, -1).view( T, B, D).permute( 1, 0, 2) + + combination = torch.cat( [ fact, query, fact * query, query - fact ], dim = 2) + + scores = self.model( combination).squeeze() + scores = torch.where( mask == 1, scores, torch.ones_like( scores) * ( -2 ** 31 ) ) + + scores = ( scores.softmax( dim = -1) * mask ).view( (B , 1, T)) + + if return_scores: return scores.squeeze() + return torch.matmul( scores, fact).squeeze() \ No newline at end of file diff --git a/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/main.py b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/main.py new file mode 100644 index 0000000000000000000000000000000000000000..3f465f4b2f7f6086a85b38511ce3a2b9c19d3513 --- /dev/null +++ b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/main.py @@ -0,0 +1,211 @@ +# +# BSD 3-Clause License +# +# Copyright (c) 2017 xxxx +# All rights reserved. +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ +# +import torch +import torch.nn as nn +import os +import sys +import pickle as pk +import numpy as np +import random + +from sklearn.metrics import roc_auc_score +import time +import torch.npu +import os +import apex +from apex import amp +NPU_CALCULATE_DEVICE = 0 +if os.getenv('NPU_CALCULATE_DEVICE') and str.isdigit(os.getenv('NPU_CALCULATE_DEVICE')): + NPU_CALCULATE_DEVICE = int(os.getenv('NPU_CALCULATE_DEVICE')) +if torch.npu.current_device() != NPU_CALCULATE_DEVICE: + torch.npu.set_device(f'npu:{NPU_CALCULATE_DEVICE}') + +workspace_dir = '.' +try: + from google.colab import drive + drive.mount( '/content/drive/' ) + + workspace_dir = os.path.join( '.' , 'drive', 'My Drive', 'DIN-pytorch') + sys.path.append( workspace_dir) + #! rm -rf data + #! tar zxf "{workspace_dir/npu/traindata/ID2837_CarPeting_Pytorch_DIN.tar.gz" -C ./ + #! tar zxf "{workspace_dir}/loader.tar.gz" -C ./ + #! ls -al data +except ImportError: + pass + +from model import DIN, DIEN, DynamicGRU +from DataLoader import MyDataSet + + +#Model hyper parameter +MAX_LEN = 100 +EMBEDDING_DIM = 32 +# HIDDEN_SIZE_ATTENTION = [80, 40] +# HIDDEN_SIZE_FC = [200, 80] +# ACTIVATION_LAYER = 'LeakyReLU' # lr = 0.01 + + +# Adam +LR = 1e-3 +BETA1 = 0.5 +BETA2 = 0.99 + +# Train +BATCH_SIZE = 128 +EPOCH_TIME = 20 +TEST_ITER = 1000 + +RANDOM_SEED = 19940808 + +USE_CUDA = True + +train_file = os.path.join( '/npu/traindata/ID2837_CarPeting_Pytorch_DIN', "local_train_splitByUser") +test_file = os.path.join( '/npu/traindata/ID2837_CarPeting_Pytorch_DIN', "local_test_splitByUser") +uid_voc = os.path.join( '/npu/traindata/ID2837_CarPeting_Pytorch_DIN', "uid_voc.pkl") +mid_voc = os.path.join( '/npu/traindata/ID2837_CarPeting_Pytorch_DIN', "mid_voc.pkl") +cat_voc = os.path.join( '/npu/traindata/ID2837_CarPeting_Pytorch_DIN', "cat_voc.pkl") + +if USE_CUDA and torch.npu.is_available(): + print( "Cuda is avialable" ) + device = torch.device(f'npu:{NPU_CALCULATE_DEVICE}') + dtype = torch.npu.FloatTensor +else: + device = torch.device( f'npu:{NPU_CALCULATE_DEVICE}') + dtype = torch.FloatTensor + +# Stable the random seed +def same_seeds(seed = RANDOM_SEED): + torch.manual_seed(seed) + if torch.npu.is_available(): + torch.npu.manual_seed(seed) + torch.npu.manual_seed_all(seed) # if you are using multi-GPU. + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + +# Initilize parameters +def weights_init( m): + try: + classname = m.__class__.__name__ + if classname.find( 'BatchNorm') != -1: + nn.init.normal_( m.weight.data, 1.0, 0.02) + nn.init.constant_( m.bias.data, 0) + elif classname.find( 'Linear') != -1: + nn.init.normal_( m.weight.data, 0.0, 0.02) + elif classname.find( 'Embedding') != -1: + m.weight.data.uniform_(-1, 1) + except AttributeError: + print( "AttributeError:", classname) + + + +def eval_output( scores, target, loss_function = torch.nn.functional.binary_cross_entropy_with_logits): + loss = loss_function( scores, target) + + y_pred = scores.sigmoid().round() + + accuracy = ( y_pred == target).type( dtype).mean() + + auc = roc_auc_score( target.cpu().detach(), scores.cpu().detach() ) + return loss, accuracy, auc + +# The dict mapping description(string) to type index(int) +# A more graceful api https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html#sklearn.preprocessing.LabelEncoder not used in this project + +user_map = pk.load( open( uid_voc, 'rb')); n_uid = len( user_map) +material_map = pk.load( open( mid_voc, 'rb')); n_mid = len( material_map) +category_map = pk.load( open( cat_voc, 'rb')); n_cat = len( category_map) + +same_seeds( RANDOM_SEED) + +dataset_train = MyDataSet( train_file, user_map, material_map, category_map, max_length = MAX_LEN) +dataset_test = MyDataSet( test_file, user_map, material_map, category_map, max_length = MAX_LEN) + +loader_train = torch.utils.data.DataLoader( dataset_train, batch_size = BATCH_SIZE, shuffle = True, pin_memory=True) +loader_test = torch.utils.data.DataLoader( dataset_test, batch_size = BATCH_SIZE, shuffle = False) + +# with open( 'data/loader.pk', 'rb') as fin: +# loader_train, loader_test = pk.load(fin) + +# Get model and initialize it +# model = DIEN( n_uid, n_mid, n_cat, EMBEDDING_DIM).to( device) +model = DIN( n_uid, n_mid, n_cat, EMBEDDING_DIM ).to( f'npu:{NPU_CALCULATE_DEVICE}') +model.apply( weights_init) + +# Set loss function and optimizer +optimizer = apex.optimizers.NpuFusedAdam(model.parameters(), LR, (BETA1, BETA2)) +model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=128.0, combine_grad=True) +model.train(); iter = 0 +for epoch in range( EPOCH_TIME): + for i, data in enumerate( loader_train): + if i >= 1000:pass + start_time = time.time() + iter += 1 + + # transform data to target device + + data = [ item.to( f'npu:{NPU_CALCULATE_DEVICE}', non_blocking=True) if item != None else None for item in data] + target = data.pop(-1) + + model.zero_grad() + + scores = model( data, neg_sample = False) + + loss, accuracy, auc = eval_output( scores, target) + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + optimizer.step( ) + step_time = time.time() - start_time + FPS = BATCH_SIZE / step_time + print( "Epoch:{}, step:{}, loss:{:.4f}, Acc:{:.4f},Auc:{:.4f}, time/step(s):{:.4f},FPS:{:.3f}".format( epoch + 1, i + 1, loss.item(), accuracy.item(), auc.item(), step_time, FPS)) + + if iter % TEST_ITER == 0: + model.eval() + with torch.no_grad(): + score_list = []; target_list = [] + for data in loader_test: + data = [ item.to( f'npu:{NPU_CALCULATE_DEVICE}') if item != None else None for item in data] + + target = data.pop(-1) + + scores = model( data, neg_sample = False) + score_list.append( scores) + target_list.append( target) + scores = torch.cat( score_list, dim = -1) + target = torch.cat( target_list, dim = -1) + loss, accuracy, auc = eval_output( scores, target) + print( "\tTest Set\tloss:%.5f\tacc:%.5f\tauc:%.5f"%( loss.item(), accuracy.item(), auc.item() ) ) + model.train() diff --git a/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/model.py b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d84812b873531d87d31147478f16e36b1db0a751 --- /dev/null +++ b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/model.py @@ -0,0 +1,119 @@ +# +# BSD 3-Clause License +# +# Copyright (c) 2017 xxxx +# All rights reserved. +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ +# +import torch +import torch.nn as nn +import torch.nn.functional as F + +from layer import * +from rnn import * +import torch.npu +import os +NPU_CALCULATE_DEVICE = 0 +if os.getenv('NPU_CALCULATE_DEVICE') and str.isdigit(os.getenv('NPU_CALCULATE_DEVICE')): + NPU_CALCULATE_DEVICE = int(os.getenv('NPU_CALCULATE_DEVICE')) +if torch.npu.current_device() != NPU_CALCULATE_DEVICE: + torch.npu.set_device(f'npu:{NPU_CALCULATE_DEVICE}') + +class DIN( nn.Module): + def __init__(self, n_uid, n_mid, n_cid, embedding_dim, ): + super().__init__() + + self.embedding_layer = InputEmbedding( n_uid, n_mid, n_cid, embedding_dim ) + self.attention_layer = AttentionLayer( embedding_dim, hidden_size = [ 80, 40], activation_layer='sigmoid') + # self.output_layer = MLP( embedding_dim * 9, [ 200, 80], 1, 'ReLU') + self.output_layer = MLP( embedding_dim * 7, [ 200, 80], 1, 'ReLU') + + def forward( self, data, neg_sample = False): + + user, material_historical, category_historical, mask, sequential_length , material, category, \ + material_historical_neg, category_historical_neg = data + + user_embedding, material_historical_embedding, category_historical_embedding, \ + material_embedding, category_embedding, material_historical_neg_embedding, category_historical_neg_embedding = \ + self.embedding_layer( user, material, category, material_historical, category_historical, material_historical_neg, category_historical_neg, neg_sample) + + item_embedding = torch.cat( [ material_embedding, category_embedding], dim = 1) + item_historical_embedding = torch.cat( [ material_historical_embedding, category_historical_embedding], dim = 2 ) + + item_historical_embedding_sum = torch.matmul( mask.unsqueeze( dim = 1), item_historical_embedding).squeeze() / sequential_length.unsqueeze( dim = 1) + + + attention_feature = self.attention_layer( item_embedding, item_historical_embedding, mask) + + # combination = torch.cat( [ user_embedding, item_embedding, item_historical_embedding_sum, attention_feature ], dim = 1) + combination = torch.cat( [ user_embedding, item_embedding, item_historical_embedding_sum, + # item_embedding * item_historical_embedding_sum, + attention_feature ], dim = 1) + + scores = self.output_layer( combination) + + return scores.squeeze() + +class DIEN( nn.Module): + def __init__(self, n_uid, n_mid, n_cid, embedding_dim): + super().__init__() + + self.embedding_layer = InputEmbedding( n_uid, n_mid, n_cid, embedding_dim ) + self.gru_based_layer = nn.GRU( embedding_dim * 2 , embedding_dim * 2, batch_first = True) + self.attention_layer = AttentionLayer( embedding_dim, hidden_size = [ 80, 40], activation_layer='sigmoid') + self.gru_customized_layer = DynamicGRU( embedding_dim * 2, embedding_dim * 2) + + self.output_layer = MLP( embedding_dim * 9, [ 200, 80], 1, 'ReLU') + # self.output_layer = MLP( embedding_dim * 9, [ 200, 80], 1, 'ReLU') + + def forward( self, data, neg_sample = False): + + user, material_historical, category_historical, mask, sequential_length , material, category, \ + material_historical_neg, category_historical_neg = data + + user_embedding, material_historical_embedding, category_historical_embedding, \ + material_embedding, category_embedding, material_historical_neg_embedding, category_historical_neg_embedding = \ + self.embedding_layer( user, material, category, material_historical, category_historical, material_historical_neg, category_historical_neg, neg_sample) + + item_embedding = torch.cat( [ material_embedding, category_embedding], dim = 1) + item_historical_embedding = torch.cat( [ material_historical_embedding, category_historical_embedding], dim = 2 ) + + item_historical_embedding_sum = torch.matmul( mask.unsqueeze( dim = 1), item_historical_embedding).squeeze() / sequential_length.unsqueeze( dim = 1) + + output_based_gru, _ = self.gru_based_layer( item_historical_embedding) + attention_scores = self.attention_layer( item_embedding, output_based_gru, mask, return_scores = True) + output_customized_gru = self.gru_customized_layer( output_based_gru, attention_scores) + + attention_feature = output_customized_gru[ range( len( sequential_length)), sequential_length - 1] + + combination = torch.cat( [ user_embedding, item_embedding, item_historical_embedding_sum, item_embedding * item_historical_embedding_sum, attention_feature ], dim = 1) + + scores = self.output_layer( combination) + + return scores.squeeze() \ No newline at end of file diff --git a/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/modelzoo_level.txt b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/modelzoo_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..c45626e398eabe6022fe7b2e148f0ffce6400d6e --- /dev/null +++ b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/modelzoo_level.txt @@ -0,0 +1,3 @@ +FuncStatus:OK +PerfStatus:POK +PrecisionStatus:OK \ No newline at end of file diff --git a/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/requirements.txt b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/rnn.py b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5b30b43b5049b2366edb87d1c41ad6f2740996aa --- /dev/null +++ b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/rnn.py @@ -0,0 +1,82 @@ +# +# BSD 3-Clause License +# +# Copyright (c) 2017 xxxx +# All rights reserved. +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ +# +import torch +import torch.nn as nn +import torch.npu +import os +NPU_CALCULATE_DEVICE = 0 +if os.getenv('NPU_CALCULATE_DEVICE') and str.isdigit(os.getenv('NPU_CALCULATE_DEVICE')): + NPU_CALCULATE_DEVICE = int(os.getenv('NPU_CALCULATE_DEVICE')) +if torch.npu.current_device() != NPU_CALCULATE_DEVICE: + torch.npu.set_device(f'npu:{NPU_CALCULATE_DEVICE}') + +class AUGRUCell(nn.Module): + def __init__(self, input_dim, hidden_dim, bias = True): + super(AUGRUCell, self).__init__() + + in_dim = input_dim + hidden_dim + self.reset_gate = nn.Sequential( nn.Linear( in_dim, hidden_dim, bias = bias), nn.Sigmoid()) + self.update_gate = nn.Sequential( nn.Linear( in_dim, hidden_dim, bias = bias), nn.Sigmoid()) + self.h_hat_gate = nn.Sequential( nn.Linear( in_dim, hidden_dim, bias = bias), nn.Tanh()) + + + def forward(self, X, h_prev, attention_score): + temp_input = torch.cat( [ h_prev, X ] , dim = -1) + r = self.reset_gate( temp_input) + u = self.update_gate( temp_input) + + h_hat = self.h_hat_gate( torch.cat( [ h_prev * r, X], dim = -1) ) + + u = attention_score.unsqueeze(1) * u + h_cur = (1. - u) * h_prev + u * h_hat + + return h_cur + + +class DynamicGRU(nn.Module): + def __init__(self, input_dim, hidden_dim, bias=True): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.rnn_cell = AUGRUCell( input_dim, hidden_dim, bias = True) + + def forward(self, X, attenion_scores , h0 = None ): + B, T, D = X.shape + H = self.hidden_dim + + output = torch.zeros( B, T, H ).type( X.type() ) + h_prev = torch.zeros( B, H ).type( X.type() ) if h0 == None else h0 + for t in range( T): + h_prev = output[ : , t, :] = self.rnn_cell( X[ : , t, :], h_prev, attenion_scores[ :, t] ) + return output diff --git a/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/test/train_full_1p.sh b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/test/train_full_1p.sh new file mode 100644 index 0000000000000000000000000000000000000000..c47f7bc0e5ffd0a3bb309b7ad9348f1c6d37ed17 --- /dev/null +++ b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/test/train_full_1p.sh @@ -0,0 +1,186 @@ +#!/bin/bash + +#当前路径,不需要修改 +cur_path=`pwd` +#export ASCEND_SLOG_PRINT_TO_STDOUT=1 +export NPU_CALCULATE_DEVICE=$ASCEND_DEVICE_ID + +#集合通信参数,不需要修改 +export RANK_SIZE=1 +export JOB_ID=10087 +RANK_ID_START=0 + + +# 数据集路径,保持为空,不需要修改 +data_path="" + +#基础参数,需要模型审视修改 +#网络名称,同目录名称 +Network="DIN_ID2837_for_PyTorch" +#训练epoch +train_epochs=20 +#训练batch_size +batch_size=128 +#训练step +#train_steps=`expr 1281167 / ${batch_size}` +#学习率 +learning_rate=0.495 + +#TF2.X独有,不需要修改 +#export NPU_LOOP_SIZE=${train_steps} + +#维测参数,precision_mode需要模型审视修改 +precision_mode="allow_mix_precision" +#维持参数,以下不需要修改 +over_dump=False +data_dump_flag=False +data_dump_step="10" +profiling=False +autotune=False + +# 帮助信息,不h需要修改 +if [[ $1 == --help || $1 == -h ]];then + echo"usage:./train_full_1p.sh " + echo " " + echo "parameter explain: + --precision_mode precision mode(allow_fp32_to_fp16/force_fp16/must_keep_origin_dtype/allow_mix_precision) + --over_dump if or not over detection, default is False + --data_dump_flag data dump flag, default is False + --data_dump_step data dump step, default is 10 + --profiling if or not profiling for performance debug, default is False + --data_path source data of training + -h/--help show help message + " + exit 1 +fi + +#参数校验,不需要修改 +for para in $* +do + if [[ $para == --precision_mode* ]];then + precision_mode=`echo ${para#*=}` + elif [[ $para == --over_dump* ]];then + over_dump=`echo ${para#*=}` + over_dump_path=${cur_path}/output/overflow_dump + mkdir -p ${over_dump_path} + elif [[ $para == --data_dump_flag* ]];then + data_dump_flag=`echo ${para#*=}` + data_dump_path=${cur_path}/output/data_dump + mkdir -p ${data_dump_path} + elif [[ $para == --data_dump_step* ]];then + data_dump_step=`echo ${para#*=}` + elif [[ $para == --profiling* ]];then + profiling=`echo ${para#*=}` + profiling_dump_path=${cur_path}/output/profiling + mkdir -p ${profiling_dump_path} + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + fi +done + +#校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + +#训练开始时间,不需要修改 +start_time=$(date +%s) + +#进入训练脚本目录,需要模型审视修改 +cd $cur_path/../ + + +sed -i "s|./data|$data_path|g" main.py +#sed -i "s|EPOCH_TIME = 20|EPOCH_TIME = 1|g" main.py +#sed -i "s|continue|break|g" main.py + +#python3 setup.py install +#mkdir -p checkpoints +#mkdir -p /root/.cache/torch/hub/checkpoints +#cp $data_path/fcn_* /root/.cache/torch/hub/checkpoints + +for((RANK_ID=$RANK_ID_START;RANK_ID<$((RANK_SIZE+RANK_ID_START));RANK_ID++)); +do + #设置环境变量,不需要修改 + echo "Device ID: $ASCEND_DEVICE_ID" + export RANK_ID=$RANK_ID + + + + #创建DeviceID输出目录,不需要修改 + if [ -d ${cur_path}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${cur_path}/output/${ASCEND_DEVICE_ID} + mkdir -p ${cur_path}/output/$ASCEND_DEVICE_ID/ckpt + else + mkdir -p ${cur_path}/output/$ASCEND_DEVICE_ID/ckpt + fi + + #绑核,不需要绑核的模型删除,需要绑核的模型根据实际修改 + #cpucount=`lscpu | grep "CPU(s):" | head -n 1 | awk '{print $2}'` + #cpustep=`expr $cpucount / 8` + #echo "taskset c steps:" $cpustep + #let a=RANK_ID*$cpustep + #let b=RANK_ID+1 + #let c=b*$cpustep-1 + + #执行训练脚本,以下传参不需要修改,其他需要模型审视修改 + nohup python3 main.py > ${cur_path}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & +done +wait + +#恢复参数 +sed -i "s|$data_path|./data|g" main.py +#sed -i "s|EPOCH_TIME = 1|EPOCH_TIME = 20|g" main.py +#sed -i "s|break|continue|g" main.py + +#conda deactivate +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +FPS=`grep FPS $cur_path/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F "FPS:" '{print $2}'|tail -n +4|awk '{sum+=$1} END {print"",sum/NR}'|sed s/[[:space:]]//g` +#FPS=`awk 'BEGIN{printf "%.2f\n",'${batch_size}'*'${perf}'}'` + + +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +#输出训练精度,需要模型审视修改 +train_accuracy=`grep "Test Set" $cur_path/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F ':' '{print $3}' | awk -F 'auc' '{print $1}' | awk 'NR==1{max=$1;next}{max=max>$1?max:$1}END{print max}'|sed s/[[:space:]]//g` +#打印,不需要修改 +echo "Final Train Accuracy : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" + +#稳定性精度看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'acc' + +##获取性能数据 +#吞吐量,不需要修改 +ActualFPS=${FPS} +#单迭代训练时长,不需要修改 +TrainingTime=`awk 'BEGIN{printf "%.2f\n",'${BatchSize}'*1000/'${FPS}'}'` + +#从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep "FPS" $cur_path/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|awk -F "," '{print $3}'| awk -F "loss:" '{print $2}' >> $cur_path/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +#最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' $cur_path/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = ${train_accuracy}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log diff --git a/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/test/train_performance_1p.sh b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/test/train_performance_1p.sh new file mode 100644 index 0000000000000000000000000000000000000000..b3c181b9b8dd7510dba822295b00c4ec61512a9d --- /dev/null +++ b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/test/train_performance_1p.sh @@ -0,0 +1,186 @@ +#!/bin/bash + +#当前路径,不需要修改 +cur_path=`pwd` +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +export NPU_CALCULATE_DEVICE=$ASCEND_DEVICE_ID + +#集合通信参数,不需要修改 +export RANK_SIZE=1 +export JOB_ID=10087 +RANK_ID_START=0 + + +# 数据集路径,保持为空,不需要修改 +data_path="" + +#基础参数,需要模型审视修改 +#网络名称,同目录名称 +Network="DIN_ID2837_for_PyTorch" +#训练epoch +train_epochs=1 +#训练batch_size +batch_size=128 +#训练step +#train_steps=`expr 1281167 / ${batch_size}` +#学习率 +learning_rate=0.495 + +#TF2.X独有,不需要修改 +#export NPU_LOOP_SIZE=${train_steps} + +#维测参数,precision_mode需要模型审视修改 +precision_mode="allow_mix_precision" +#维持参数,以下不需要修改 +over_dump=False +data_dump_flag=False +data_dump_step="10" +profiling=False +autotune=False + +# 帮助信息,不h需要修改 +if [[ $1 == --help || $1 == -h ]];then + echo"usage:./train_full_1p.sh " + echo " " + echo "parameter explain: + --precision_mode precision mode(allow_fp32_to_fp16/force_fp16/must_keep_origin_dtype/allow_mix_precision) + --over_dump if or not over detection, default is False + --data_dump_flag data dump flag, default is False + --data_dump_step data dump step, default is 10 + --profiling if or not profiling for performance debug, default is False + --data_path source data of training + -h/--help show help message + " + exit 1 +fi + +#参数校验,不需要修改 +for para in $* +do + if [[ $para == --precision_mode* ]];then + precision_mode=`echo ${para#*=}` + elif [[ $para == --over_dump* ]];then + over_dump=`echo ${para#*=}` + over_dump_path=${cur_path}/output/overflow_dump + mkdir -p ${over_dump_path} + elif [[ $para == --data_dump_flag* ]];then + data_dump_flag=`echo ${para#*=}` + data_dump_path=${cur_path}/output/data_dump + mkdir -p ${data_dump_path} + elif [[ $para == --data_dump_step* ]];then + data_dump_step=`echo ${para#*=}` + elif [[ $para == --profiling* ]];then + profiling=`echo ${para#*=}` + profiling_dump_path=${cur_path}/output/profiling + mkdir -p ${profiling_dump_path} + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + fi +done + +#校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + +#训练开始时间,不需要修改 +start_time=$(date +%s) + +#进入训练脚本目录,需要模型审视修改 +cd $cur_path/../ + + +sed -i "s|./data|$data_path|g" main.py +sed -i "s|EPOCH_TIME = 20|EPOCH_TIME = 1|g" main.py +sed -i "s|if i >= 1000:pass|if i >= 1000:break|g" main.py + +#python3 setup.py install +#mkdir -p checkpoints +#mkdir -p /root/.cache/torch/hub/checkpoints +#cp $data_path/fcn_* /root/.cache/torch/hub/checkpoints + +for((RANK_ID=$RANK_ID_START;RANK_ID<$((RANK_SIZE+RANK_ID_START));RANK_ID++)); +do + #设置环境变量,不需要修改 + echo "Device ID: $ASCEND_DEVICE_ID" + export RANK_ID=$RANK_ID + + + + #创建DeviceID输出目录,不需要修改 + if [ -d ${cur_path}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${cur_path}/output/${ASCEND_DEVICE_ID} + mkdir -p ${cur_path}/output/$ASCEND_DEVICE_ID/ckpt + else + mkdir -p ${cur_path}/output/$ASCEND_DEVICE_ID/ckpt + fi + + #绑核,不需要绑核的模型删除,需要绑核的模型根据实际修改 + #cpucount=`lscpu | grep "CPU(s):" | head -n 1 | awk '{print $2}'` + #cpustep=`expr $cpucount / 8` + #echo "taskset c steps:" $cpustep + #let a=RANK_ID*$cpustep + #let b=RANK_ID+1 + #let c=b*$cpustep-1 + + #执行训练脚本,以下传参不需要修改,其他需要模型审视修改 + nohup python3 main.py > ${cur_path}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & +done +wait + +#恢复参数 +sed -i "s|$data_path|./data|g" main.py +sed -i "s|EPOCH_TIME = 1|EPOCH_TIME = 20|g" main.py +sed -i "s|if i >= 1000:break|if i >= 1000:pass|g" main.py + +#conda deactivate +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +FPS=`grep FPS $cur_path/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F "FPS:" '{print $2}'|tail -n +4|awk '{sum+=$1} END {print"",sum/NR}'|sed s/[[:space:]]//g` +#FPS=`awk 'BEGIN{printf "%.2f\n",'${batch_size}'*'${perf}'}'` + + +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +#输出训练精度,需要模型审视修改 +#train_accuracy=`grep eval_accuracy $cur_path/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|grep -v mlp_log|awk 'END {print $5}'| sed 's/,//g' |cut -c 1-5` +#打印,不需要修改 +#echo "Final Train Accuracy : ${train_accuracy}" +#echo "E2E Training Duration sec : $e2e_time" + +#稳定性精度看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'perf' + +##获取性能数据 +#吞吐量,不需要修改 +ActualFPS=${FPS} +#单迭代训练时长,不需要修改 +TrainingTime=`awk 'BEGIN{printf "%.2f\n",'${BatchSize}'*1000/'${FPS}'}'` + +#从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep "FPS" $cur_path/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|awk -F "," '{print $3}'| awk -F "loss:" '{print $2}' >> $cur_path/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +#最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' $cur_path/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +#echo "TrainAccuracy = ${train_accuracy}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log diff --git a/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/test/train_performance_1p_success.sh b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/test/train_performance_1p_success.sh new file mode 100644 index 0000000000000000000000000000000000000000..b0c942106dc664528d46c9b46a86910bae0500ce --- /dev/null +++ b/PyTorch/dev/cv/image_classification/DIN_ID2837_for_PyTorch/test/train_performance_1p_success.sh @@ -0,0 +1,186 @@ +#!/bin/bash + +#当前路径,不需要修改 +cur_path=`pwd` +export ASCEND_SLOG_PRINT_TO_STDOUT=1 +export NPU_CALCULATE_DEVICE=$ASCEND_DEVICE_ID + +#集合通信参数,不需要修改 +export RANK_SIZE=1 +export JOB_ID=10087 +RANK_ID_START=0 + + +# 数据集路径,保持为空,不需要修改 +data_path="" + +#基础参数,需要模型审视修改 +#网络名称,同目录名称 +Network="DIN_ID2837_for_PyTorch" +#训练epoch +train_epochs=1 +#训练batch_size +batch_size=128 +#训练step +#train_steps=`expr 1281167 / ${batch_size}` +#学习率 +learning_rate=0.495 + +#TF2.X独有,不需要修改 +#export NPU_LOOP_SIZE=${train_steps} + +#维测参数,precision_mode需要模型审视修改 +precision_mode="allow_mix_precision" +#维持参数,以下不需要修改 +over_dump=False +data_dump_flag=False +data_dump_step="10" +profiling=False +autotune=False + +# 帮助信息,不h需要修改 +if [[ $1 == --help || $1 == -h ]];then + echo"usage:./train_full_1p.sh " + echo " " + echo "parameter explain: + --precision_mode precision mode(allow_fp32_to_fp16/force_fp16/must_keep_origin_dtype/allow_mix_precision) + --over_dump if or not over detection, default is False + --data_dump_flag data dump flag, default is False + --data_dump_step data dump step, default is 10 + --profiling if or not profiling for performance debug, default is False + --data_path source data of training + -h/--help show help message + " + exit 1 +fi + +#参数校验,不需要修改 +for para in $* +do + if [[ $para == --precision_mode* ]];then + precision_mode=`echo ${para#*=}` + elif [[ $para == --over_dump* ]];then + over_dump=`echo ${para#*=}` + over_dump_path=${cur_path}/output/overflow_dump + mkdir -p ${over_dump_path} + elif [[ $para == --data_dump_flag* ]];then + data_dump_flag=`echo ${para#*=}` + data_dump_path=${cur_path}/output/data_dump + mkdir -p ${data_dump_path} + elif [[ $para == --data_dump_step* ]];then + data_dump_step=`echo ${para#*=}` + elif [[ $para == --profiling* ]];then + profiling=`echo ${para#*=}` + profiling_dump_path=${cur_path}/output/profiling + mkdir -p ${profiling_dump_path} + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + fi +done + +#校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + +#训练开始时间,不需要修改 +start_time=$(date +%s) + +#进入训练脚本目录,需要模型审视修改 +cd $cur_path/../ + + +sed -i "s|./data|$data_path|g" main.py +sed -i "s|EPOCH_TIME = 20|EPOCH_TIME = 1|g" main.py +sed -i "s|continue|break|g" main.py + +#python3 setup.py install +#mkdir -p checkpoints +#mkdir -p /root/.cache/torch/hub/checkpoints +#cp $data_path/fcn_* /root/.cache/torch/hub/checkpoints + +for((RANK_ID=$RANK_ID_START;RANK_ID<$((RANK_SIZE+RANK_ID_START));RANK_ID++)); +do + #设置环境变量,不需要修改 + echo "Device ID: $ASCEND_DEVICE_ID" + export RANK_ID=$RANK_ID + + + + #创建DeviceID输出目录,不需要修改 + if [ -d ${cur_path}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${cur_path}/output/${ASCEND_DEVICE_ID} + mkdir -p ${cur_path}/output/$ASCEND_DEVICE_ID/ckpt + else + mkdir -p ${cur_path}/output/$ASCEND_DEVICE_ID/ckpt + fi + + #绑核,不需要绑核的模型删除,需要绑核的模型根据实际修改 + #cpucount=`lscpu | grep "CPU(s):" | head -n 1 | awk '{print $2}'` + #cpustep=`expr $cpucount / 8` + #echo "taskset c steps:" $cpustep + #let a=RANK_ID*$cpustep + #let b=RANK_ID+1 + #let c=b*$cpustep-1 + + #执行训练脚本,以下传参不需要修改,其他需要模型审视修改 + nohup python3 main.py > ${cur_path}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & +done +wait + +#恢复参数 +sed -i "s|$data_path|./data|g" main.py +sed -i "s|EPOCH_TIME = 1|EPOCH_TIME = 20|g" main.py +sed -i "s|break|continue|g" main.py + +#conda deactivate +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +FPS=`grep FPS $cur_path/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F "FPS:" '{print $2}'|tail -n +2|awk '{sum+=$1} END {print"",sum/NR}'|sed s/[[:space:]]//g` +#FPS=`awk 'BEGIN{printf "%.2f\n",'${batch_size}'*'${perf}'}'` + + +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +#输出训练精度,需要模型审视修改 +#train_accuracy=`grep eval_accuracy $cur_path/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|grep -v mlp_log|awk 'END {print $5}'| sed 's/,//g' |cut -c 1-5` +#打印,不需要修改 +#echo "Final Train Accuracy : ${train_accuracy}" +#echo "E2E Training Duration sec : $e2e_time" + +#稳定性精度看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'perf' + +##获取性能数据 +#吞吐量,不需要修改 +ActualFPS=${FPS} +#单迭代训练时长,不需要修改 +TrainingTime=`awk 'BEGIN{printf "%.2f\n",'${BatchSize}'*1000/'${FPS}'}'` + +#从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep "FPS" $cur_path/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|awk -F "Loss :" '{print $2}'|awk -F "," '{print $1}' >> $cur_path/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +#最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' $cur_path/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +#echo "TrainAccuracy = ${train_accuracy}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log