diff --git a/cv/classification/resnet50/pytorch/train.py b/cv/classification/resnet50/pytorch/train.py index eb61e7d9b8fe0dc151310101ace8266747fed420..a19d1f7534ba4912d00c907218510bfbb3ac384c 100644 --- a/cv/classification/resnet50/pytorch/train.py +++ b/cv/classification/resnet50/pytorch/train.py @@ -342,7 +342,7 @@ def get_args_parser(add_help=True): parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help) parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset') - parser.add_argument('--model', default='resnet18', help='model') + parser.add_argument('--model', default='resnet50', help='model') parser.add_argument('--device', default='cuda', help='device') parser.add_argument('-b', '--batch-size', default=32, type=int) parser.add_argument('--epochs', default=90, type=int, metavar='N',