diff --git a/cv/semantic_segmentation/torchvision/pytorch/dataloader/segmentation.py b/cv/semantic_segmentation/torchvision/pytorch/dataloader/segmentation.py index 3791b8b631369a5e3e72de02a4bc49fd9ee1e906..990b7e2153343faa5a4f6a49c4feb71308c879b8 100644 --- a/cv/semantic_segmentation/torchvision/pytorch/dataloader/segmentation.py +++ b/cv/semantic_segmentation/torchvision/pytorch/dataloader/segmentation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -13,8 +13,6 @@ # License for the specific language governing permissions and limitations # under the License. - - import torchvision from .utils.coco_seg_utils import get_coco @@ -30,12 +28,14 @@ Examples: """ -def get_transform(train, base_size, crop_size): - return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(crop_size) +def get_transform(train): + base_size = 520 + crop_size = 480 + return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size) -def get_dataset(dir_path, name, image_set, base_size=540, crop_size=512): - transform = get_transform(image_set == 'train', base_size, crop_size) +def get_dataset(dir_path, name, image_set): + transform = get_transform(image_set == 'train') # name = 'camvid' def sbd(*args, **kwargs): return torchvision.datasets.SBDataset(*args, mode='segmentation', **kwargs) @@ -48,4 +48,4 @@ def get_dataset(dir_path, name, image_set, base_size=540, crop_size=512): p, ds_fn, num_classes = paths[name] ds = ds_fn(p, image_set=image_set, transforms=transform) - return ds, num_classes \ No newline at end of file + return ds, num_classes diff --git a/cv/semantic_segmentation/torchvision/pytorch/train.py b/cv/semantic_segmentation/torchvision/pytorch/train.py index 9f997a3abcb4a67e138bade68592dbcc5508c520..4046319665cb28a4caabc82cf808fa4a680a3bfb 100644 --- a/cv/semantic_segmentation/torchvision/pytorch/train.py +++ b/cv/semantic_segmentation/torchvision/pytorch/train.py @@ -1,21 +1,16 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. import datetime import os +import sys import time - +import math import torch +from torch import nn import torch.utils.data import torchvision -from torch import nn -import torch.nn.functional as TF - -try: - from apex import amp as apex_amp -except: - apex_amp = None import utils from dataloader.segmentation import get_dataset @@ -26,16 +21,25 @@ try: except: autocast = None scaler = None +import ssl +ssl._create_default_https_context = ssl._create_unverified_context + +import torchvision.models.resnet +print("WARN: Using pretrained weights from torchvision-0.9.") +torchvision.models.resnet.model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} def criterion(inputs, target): - if isinstance(inputs, (tuple, list)): - inputs = {str(i): x for i, x in enumerate(inputs)} - inputs["out"] = inputs.pop("0") - - if not isinstance(inputs, dict): - inputs = dict(out=inputs) - losses = {} for name, x in inputs.items(): losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255) @@ -43,8 +47,7 @@ def criterion(inputs, target): if len(losses) == 1: return losses['out'] - loss = losses.pop("out") - return loss + 0.5 * sum(losses.values()) + return losses['out'] + 0.5 * losses['aux'] def evaluate(model, data_loader, device, num_classes): @@ -56,13 +59,7 @@ def evaluate(model, data_loader, device, num_classes): for image, target in metric_logger.log_every(data_loader, 100, header): image, target = image.to(device), target.to(device) output = model(image) - if isinstance(output, dict): - output = output['out'] - if isinstance(output, (tuple, list)): - output = output[0] - - if output.shape[2:] != image.shape[2:]: - output = TF.upsample(output, image.shape[2:], mode="bilinear") + output = output['out'] confmat.update(target.flatten(), output.argmax(1).flatten()) @@ -71,10 +68,7 @@ def evaluate(model, data_loader, device, num_classes): return confmat -def train_one_epoch(model, criterion, optimizer, - data_loader, lr_scheduler, - device, epoch, print_freq, - use_amp=False, use_nhwc=False): +def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, amp=False): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) @@ -84,22 +78,33 @@ def train_one_epoch(model, criterion, optimizer, all_fps = [] for image, target in metric_logger.log_every(data_loader, print_freq, header): start_time = time.time() - image, target = image.to(device, non_blocking=True), target.to(device, non_blocking=True) - - output = model(image) - loss = criterion(output, target) + image, target = image.to(device), target.to(device) - if use_amp: - with apex_amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() + if autocast is None or not amp: + output = model(image) + loss = criterion(output, target) else: - loss.backward() + with autocast(): + output = model(image) + loss = criterion(output, target) - optimizer.step() optimizer.zero_grad() + if scaler is not None and amp: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + torch.cuda.synchronize() end_time = time.time() + lr_scheduler.step() + loss_value = loss.item() + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) fps = image.shape[0] / (end_time - start_time) * utils.get_world_size() @@ -108,7 +113,6 @@ def train_one_epoch(model, criterion, optimizer, print(header, 'Avg img/s:', sum(all_fps) / len(all_fps)) - def main(args): if args.output_dir: utils.mkdir(args.output_dir) @@ -120,16 +124,8 @@ def main(args): torch.backends.cudnn.benchmark = True - dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", - crop_size=args.crop_size, base_size=args.base_size) - dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", - crop_size=args.crop_size, base_size=args.base_size) - args.num_classes = num_classes - - if args.nhwc: - collate_fn = utils.nhwc_collate_fn(fp16=args.amp, padding_channel=args.padding_channel) - else: - collate_fn = utils.collate_fn + dataset, num_classes = get_dataset(args.data_path, args.dataset, "train") + dataset_test, _ = get_dataset(args.data_path, args.dataset, "val") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) @@ -141,52 +137,36 @@ def main(args): data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, - collate_fn=collate_fn, drop_last=True) + collate_fn=utils.collate_fn, drop_last=True) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, - collate_fn=collate_fn) + collate_fn=utils.collate_fn) - if hasattr(args, "model_cls"): - model = args.model_cls(args) - else: - model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, - aux_loss=args.aux_loss, - pretrained=args.pretrained) - if args.padding_channel: - if hasattr(model, "backbone") and hasattr(model.backbone, "conv1"): - model.backbone.conv1 = utils.padding_conv_channel_to_4(model.backbone.conv1) - else: - print("WARN: Cannot convert first conv to N4HW.") + model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, + aux_loss=args.aux_loss, + pretrained=args.pretrained) model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - if args.nhwc: - model = model.cuda().to(memory_format=torch.channels_last) + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module params_to_optimize = [ - {"params": [p for p in model.parameters() if p.requires_grad]}, + {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]}, + {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]}, ] - + if args.aux_loss: + params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad] + params_to_optimize.append({"params": params, "lr": args.lr * 10}) optimizer = torch.optim.SGD( params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - if args.amp: - model, optimizer = apex_amp.initialize(model, optimizer, opt_level="O2", - loss_scale=args.loss_scale, - master_weights=True) - - model_without_ddp = model - if args.distributed: - model = torch.nn.parallel.DistributedDataParallel( - model, device_ids=[args.gpu], - find_unused_parameters=args.find_unused_parameters - ) - model_without_ddp = model.module - lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) @@ -209,24 +189,22 @@ def main(args): epoch_start_time = time.time() if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, - args.amp, args.nhwc) + train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, args.amp) confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) print(confmat) - if args.output_dir is not None: - checkpoint = { - 'model': model_without_ddp.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'epoch': epoch, - 'args': args - } - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'checkpoint.pth')) + checkpoint = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'args': args + } + utils.save_on_master( + checkpoint, + os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) + utils.save_on_master( + checkpoint, + os.path.join(args.output_dir, 'checkpoint.pth')) epoch_total_time = time.time() - epoch_start_time epoch_total_time_str = str(datetime.timedelta(seconds=int(epoch_total_time))) print('epoch time {}'.format(epoch_total_time_str)) @@ -242,7 +220,7 @@ def get_args_parser(add_help=True): parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset path') parser.add_argument('--dataset', default='camvid', help='dataset name') - parser.add_argument('--model', default='deeplabv3_resnet50', help='model') + parser.add_argument('--model', default='fcn_resnet101', help='model') parser.add_argument('--aux-loss', action='store_true', help='auxiliar loss') parser.add_argument('--device', default='cuda', help='device') parser.add_argument('-b', '--batch-size', default=8, type=int) @@ -258,7 +236,7 @@ def get_args_parser(add_help=True): metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') - parser.add_argument('--output-dir', default=None, help='path where to save') + parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') @@ -275,45 +253,20 @@ def get_args_parser(add_help=True): action="store_true", ) # distributed training parameters - parser.add_argument('--local_rank', default=-1, type=int, + parser.add_argument('--local_rank', '--local-rank', default=-1, type=int, help='Local rank') parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') parser.add_argument('--amp', action='store_true', help='Automatic Mixed Precision training') - parser.add_argument('--padding-channel', action='store_true', help='Padding the channels of image to 4') - parser.add_argument('--loss_scale', default="dynamic", type=str) - parser.add_argument('--nhwc', action='store_true', help='Use NHWC') - parser.add_argument('--find_unused_parameters', action='store_true') - parser.add_argument('--crop-size', default=512, type=int) - parser.add_argument('--base-size', default=540, type=int) return parser -def check_agrs(args): - try: - args.loss_scale = float(args.loss_scale) - except: pass - - if args.padding_channel: - if not args.nhwc: - print("Turning nhwc when padding the channel of image.") - args.nhwc = True - - if args.amp: - if apex_amp is None: - raise RuntimeError("Not found apex in installed packages, cannot enable amp.") - - -def train_model(model_cls=None): - args = get_args_parser().parse_args() - check_agrs(args) - if model_cls is not None: - args.model_cls = model_cls - main(args) - - if __name__ == "__main__": args = get_args_parser().parse_args() - check_agrs(args) + try: + from dltest import show_training_arguments + show_training_arguments(args) + except: + pass main(args) diff --git a/cv/semantic_segmentation/torchvision/pytorch/utils.py b/cv/semantic_segmentation/torchvision/pytorch/utils.py index 48b348ba61f8af69ff5456178630c301c6a50b9e..bab74dfd6a27edeeed5dd381886c6466b81dcf6b 100644 --- a/cv/semantic_segmentation/torchvision/pytorch/utils.py +++ b/cv/semantic_segmentation/torchvision/pytorch/utils.py @@ -1,4 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. from collections import defaultdict, deque import datetime @@ -70,38 +72,3 @@ def collate_fn(batch): batched_imgs = cat_list(images, fill_value=0) batched_targets = cat_list(targets, fill_value=255) return batched_imgs, batched_targets - - -def nhwc_collate_fn(fp16=False, padding_channel=False): - dtype = torch.float32 - if fp16: - dtype = torch.float16 - def _collect_fn(batch): - batch = collate_fn(batch) - if not padding_channel: - return batch - batch = list(batch) - image = batch[0] - zeros = image.new_zeros(image.shape[0], image.shape[2], image.shape[3], 1) - image = torch.cat([image.permute(0, 2, 3, 1), zeros], dim=-1).permute(0, 3, 1, 2) - image = image.to(memory_format=torch.channels_last, dtype=dtype) - batch[0] = image - return batch - - return _collect_fn - - -def padding_conv_channel_to_4(conv: torch.nn.Conv2d): - new_conv = torch.nn.Conv2d( - 4, conv.out_channels, - kernel_size=conv.kernel_size, - stride=conv.stride, - padding=conv.padding, - dilation=conv.dilation, - bias=conv.bias is not None - ) - weight_shape = conv.weight.shape - padding_weight = conv.weight.new_zeros(weight_shape[0], 1, *weight_shape[2:]) - new_conv.weight = torch.nn.Parameter(torch.cat([conv.weight, padding_weight], dim=1)) - new_conv.bias = conv.bias - return new_conv