From 0ae2005d2ffd48a1cc78276e3592bb386fb7a990 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E4=BC=9F=E6=A0=B9?= <1101204667@qq.com> Date: Fri, 8 Apr 2022 07:32:55 +0000 Subject: [PATCH] update /train_ctc.py. --- .../nlp/LSTM_ID0468_for_PyTorch/timit/steps/train_ctc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/PyTorch/built-in/nlp/LSTM_ID0468_for_PyTorch/timit/steps/train_ctc.py b/PyTorch/built-in/nlp/LSTM_ID0468_for_PyTorch/timit/steps/train_ctc.py index 98e3f7ae16..d3508aaf13 100644 --- a/PyTorch/built-in/nlp/LSTM_ID0468_for_PyTorch/timit/steps/train_ctc.py +++ b/PyTorch/built-in/nlp/LSTM_ID0468_for_PyTorch/timit/steps/train_ctc.py @@ -41,6 +41,7 @@ import time import yaml import argparse import numpy as np +import apex from apex import amp import torch import torch.nn as nn @@ -188,9 +189,9 @@ def main(args,conf): print(params) loss_fn = nn.CTCLoss(reduction='sum') - optimizer = torch.optim.Adam(model.parameters(), lr=init_lr, weight_decay=weight_decay) + optimizer = apex.optimizers.NpuFusedAdam(model.parameters(), lr=init_lr, weight_decay=weight_decay) if args.apex: - model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale) + model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale,combine_grad=True) #visualization for training # from visdom import Visdom # viz = Visdom() -- Gitee