diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/.keep" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/.keep" new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/a.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/a.py" new file mode 100644 index 0000000000000000000000000000000000000000..64810c46b9dfda07e80b572e8a0c08bf1b08a89a --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/a.py" @@ -0,0 +1,42 @@ +import torch + +from data.datautils import get_class +from config import load_config +from utils import get_logger + +logger = get_logger('TrainingSetup') +""" +""" + +def main(): + """加载和读取config文件,设置随机种子和cudnn参数使得达到可复现效果,设置训练模型类别为3DUNET网络, + 调用config文件中的builder参数作为指定的训练类别===》默认为UNet3DTrainerBuilder + 从trainer这个module中加载到UNet3DTrainerBuilder这个函数 + 通过UNet3DTrainerBuilder函数中的bulider来建立模型,将config参数集传到bulider函数中 + builder函数调用create_trainer函数 + 继而调用UNET3Dtrainer函数,在此函数中进行训练和验证 + """ + # Load and log experiment configuration + path = r".\home\zyyang\concile\segmentation\config_training.yaml" + config = load_config(path) + logger.info(config) + + manual_seed = config.get('manual_seed', None) + if manual_seed is not None: + logger.info(f'Seed the RNG for all devices with {manual_seed}') + torch.manual_seed(manual_seed) + # see https://pytorch.org/docs/stable/notes/randomness.html + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # create trainer + default_trainer_builder_class = 'UNet3DTrainerBuilder' + trainer_builder_class = config['trainer'].get('builder', default_trainer_builder_class) + trainer_builder = get_class(trainer_builder_class, modules=['pytorch3dunet.unet3d.trainer']) + trainer = trainer_builder.build(config) + # Start training + trainer.fit() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/config.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/config.py" new file mode 100644 index 0000000000000000000000000000000000000000..4e9b77806aa13385e64226a67070da4fc5ecb8bc --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/config.py" @@ -0,0 +1,37 @@ +import argparse + +import torch +import yaml + +import utils + +logger = utils.get_logger('ConfigLoader') + + +def load_config(path): + """ + + + """ + #parser = argparse.ArgumentParser(description='UNet3D') + #parser.add_argument('--config',default=r"D:\concile\segmentation\config_training.yaml", type=str, help='Path to the YAML config file', required=True) + #args = parser.parse_args() + config = _load_config_yaml(path) + # Get a device to train on + device_str = config.get('device', None) + if device_str is not None: + logger.info(f"Device specified in config: '{device_str}'") + if device_str.startswith('cuda') and not torch.cuda.is_available(): + logger.warn('CUDA not available, using CPU') + device_str = 'cpu' + else: + device_str = "cuda:0" if torch.cuda.is_available() else 'cpu' + logger.info(f"Using '{device_str}' device") + + device = torch.device(device_str) + config['device'] = device + return config + + +def _load_config_yaml(config_file): + return yaml.safe_load(open(config_file, 'r')) diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/config_training.yaml" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/config_training.yaml" new file mode 100644 index 0000000000000000000000000000000000000000..528786cae10adc212eab737f83f3b3d9fb238fba --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/config_training.yaml" @@ -0,0 +1,156 @@ +# use a fixed random seed to guarantee that when you run the code twice you will get the same outcome +manual_seed: 0 +# model configuration +model: + # model class, e.g. UNet3D, ResidualUNet3D + name: UNet3D + # number of input channels to the model + in_channels: 1 + # number of output channels + out_channels: 1 + # determines the order of operators in a single layer (gcr - GroupNorm+Conv3d+ReLU) + layer_order: gcr + # feature maps scale factor + f_maps: 32 + # number of groups in the groupnorm + num_groups: 8 + # apply element-wise nn.Sigmoid after the final 1x1 convolution, otherwise apply nn.Softmax + final_sigmoid: true + # if True applies the final normalization layer (sigmoid or softmax), otherwise the networks returns the output from the final convolution layer; use False for regression problems, e.g. de-noising + is_segmentation: true +# trainer configuration +trainer: + # path to the checkpoint directory + checkpoint_dir: "./checkpoint" + # path to latest checkpoint; if provided the training will be resumed from that checkpoint + resume: null + # how many iterations between validations + validate_after_iters: 20 + # how many iterations between tensorboard logging + log_after_iters: 20 + # max number of epochs + max_num_epochs: 50 + # max number of iterations + max_num_iterations: 100000 + # model with higher eval score is considered better + eval_score_higher_is_better: True +# optimizer configuration +optimizer: + # initial learning rate + learning_rate: 0.0002 + # weight decay + weight_decay: 0.0001 +# loss function configuration +loss: + # loss function to be used during training + name: DiceLoss + # A manual rescaling weight given to each class. + weight: null + # a target value that is ignored and does not contribute to the input gradient + ignore_index: null +# evaluation metric configuration +eval_metric: + name: MeanIoU + # a target label that is ignored during metric evaluation + ignore_index: null +# learning rate scheduler configuration +lr_scheduler: + name: MultiStepLR + milestones: [10, 30, 60] + gamma: 0.2 +# data loaders configuration +loaders: + # class of the HDF5 dataset, currently StandardHDF5Dataset and LazyHDF5Dataset are supported. + # When using LazyHDF5Dataset make sure to set `num_workers = 1`, due to a bug in h5py which corrupts the data + # when reading from multiple threads. + dataset: MyDataset + # batch dimension; if number of GPUs is N > 1, then a batch_size of N * batch_size will automatically be taken for DataParallel + batch_size: 1 + # how many subprocesses to use for data loading + num_workers: 4 + # path to the raw data within the H5 + raw_internal_path: raw + # path to the the label data withtin the H5 + label_internal_path: label + # path to the pixel-wise weight map withing the H5 if present + weight_internal_path: null + # configuration of the train loader + train: + # absolute paths to the training datasets; if a given path is a directory all H5 files ('*.h5', '*.hdf', '*.hdf5', '*.hd5') + # inside this this directory will be included as well (non-recursively) + file_paths: + - "../train" + + # SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch + slice_builder: + # SliceBuilder class + name: SliceBuilder + # train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better) + patch_shape: [32, 64, 64] + # train stride between patches + stride_shape: [8, 16, 16] + + # data transformations/augmentations + transformer: + raw: + # re-scale the values to be 0-mean and 1-std + - name: Standardize + # randomly flips an image across randomly chosen axis + - name: RandomFlip + # rotate an image by 90 degrees around a randomly chosen plane + - name: RandomRotate90 + # rotate an image by a random degrees from taken from (-angle_spectrum, angle_spectrum) interval + - name: RandomRotate + # rotate only in ZY only since most volumetric data is anisotropic + axes: [[2, 1]] + angle_spectrum: 15 + mode: reflect + # apply elasitc deformations of 3D patches on a per-voxel mesh + - name: ElasticDeformation + spline_order: 3 + # randomly adjust contrast + - name: RandomContrast + # apply additive Gaussian noise + - name: AdditiveGaussianNoise + # apply additive Poisson noise + - name: AdditivePoissonNoise + # convert to torch tensor + - name: ToTensor + # add additional 'channel' axis when the input data is 3D + expand_dims: true + label: + - name: RandomFlip + - name: RandomRotate90 + - name: RandomRotate + # rotate only in ZY only since most volumetric data is anisotropic + axes: [[2, 1]] + angle_spectrum: 15 + mode: reflect + - name: ElasticDeformation + spline_order: 0 + - name: ToTensor + expand_dims: true + + # configuration of the validation loaders + val: + # paths to the validation datasets; if a given path is a directory all H5 files ('*.h5', '*.hdf', '*.hdf5', '*.hd5') + # inside this this directory will be included as well (non-recursively) + file_paths: + - "../valid" + + # SliceBuilder configuration + slice_builder: + # SliceBuilder class + name: SliceBuilder + # validation patch (can be bigger than train patch since there is no backprop) + patch_shape: [32, 64, 64] + # validation stride (validation patches doesn't need to overlap) + stride_shape: [32, 64, 64] + transformer: + raw: + - name: Standardize + - name: ToTensor + expand_dims: true + label: + - name: ToTensor + expand_dims: true \ No newline at end of file diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/data/.keep" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/data/.keep" new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/data/datautils.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/data/datautils.py" new file mode 100644 index 0000000000000000000000000000000000000000..d92dad55de87548165f7d7dd1b236ff40783cf14 --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/data/datautils.py" @@ -0,0 +1,419 @@ +import collections +import importlib + +import numpy as np +import torch +from torch.utils.data import DataLoader, ConcatDataset, Dataset + +from utils import get_logger + +logger = get_logger('Dataset') + + +class ConfigDataset(Dataset): + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + @classmethod + def create_datasets(cls, dataset_config, phase): + """ + Factory method for creating a list of datasets based on the provided config. + + Args: + dataset_config (dict): dataset configuration + phase (str): one of ['train', 'val', 'test'] + + Returns: + list of `Dataset` instances + """ + raise NotImplementedError + + @classmethod + def prediction_collate(cls, batch): + """Default collate_fn. Override in child class for non-standard datasets.""" + return default_prediction_collate(batch) + + +class SliceBuilder: + """ + Builds the position of the patches in a given raw/label/weight ndarray based on the the patch and stride shape + """ + + def __init__(self, raw_datasets, label_datasets, weight_dataset, patch_shape, stride_shape, **kwargs): + """ + :param raw_datasets: ndarray of raw data + :param label_datasets: ndarray of ground truth labels + :param weight_dataset: ndarray of weights for the labels + :param patch_shape: the shape of the patch DxHxW + :param stride_shape: the shape of the stride DxHxW + :param kwargs: additional metadata + """ + + patch_shape = tuple(patch_shape) + stride_shape = tuple(stride_shape) + skip_shape_check = kwargs.get('skip_shape_check', False) + if not skip_shape_check: + self._check_patch_shape(patch_shape) + + self._raw_slices = self._build_slices(raw_datasets[0], patch_shape, stride_shape) + if label_datasets is None: + self._label_slices = None + else: + # take the first element in the label_datasets to build slices + self._label_slices = self._build_slices(label_datasets[0], patch_shape, stride_shape) + assert len(self._raw_slices) == len(self._label_slices) + if weight_dataset is None: + self._weight_slices = None + else: + self._weight_slices = self._build_slices(weight_dataset[0], patch_shape, stride_shape) + assert len(self.raw_slices) == len(self._weight_slices) + + @property + def raw_slices(self): + return self._raw_slices + + @property + def label_slices(self): + return self._label_slices + + @property + def weight_slices(self): + return self._weight_slices + + @staticmethod + def _build_slices(dataset, patch_shape, stride_shape): + """Iterates over a given n-dim dataset patch-by-patch with a given stride + and builds an array of slice positions. + + Returns: + list of slices, i.e. + [(slice, slice, slice, slice), ...] if len(shape) == 4 + [(slice, slice, slice), ...] if len(shape) == 3 + """ + slices = [] + if dataset.ndim == 4: + in_channels, i_z, i_y, i_x = dataset.shape + else: + i_z, i_y, i_x = dataset.shape + + k_z, k_y, k_x = patch_shape + s_z, s_y, s_x = stride_shape + z_steps = SliceBuilder._gen_indices(i_z, k_z, s_z) + for z in z_steps: + y_steps = SliceBuilder._gen_indices(i_y, k_y, s_y) + for y in y_steps: + x_steps = SliceBuilder._gen_indices(i_x, k_x, s_x) + for x in x_steps: + slice_idx = ( + slice(z, z + k_z), + slice(y, y + k_y), + slice(x, x + k_x) + ) + if dataset.ndim == 4: + slice_idx = (slice(0, in_channels),) + slice_idx + slices.append(slice_idx) + return slices + + @staticmethod + def _gen_indices(i, k, s): + assert i >= k, 'Sample size has to be bigger than the patch size' + for j in range(0, i - k + 1, s): + yield j + if j + k < i: + yield i - k + + @staticmethod + def _check_patch_shape(patch_shape): + assert len(patch_shape) == 3, 'patch_shape must be a 3D tuple' + assert patch_shape[1] >= 64 and patch_shape[2] >= 64, 'Height and Width must be greater or equal 64' + + +class FilterSliceBuilder(SliceBuilder): + """ + Filter patches containing more than `1 - threshold` of ignore_index label + """ + + def __init__(self, raw_datasets, label_datasets, weight_datasets, patch_shape, stride_shape, ignore_index=(0,), + threshold=0.6, slack_acceptance=0.01, **kwargs): + super().__init__(raw_datasets, label_datasets, weight_datasets, patch_shape, stride_shape, **kwargs) + if label_datasets is None: + return + + rand_state = np.random.RandomState(47) + + def ignore_predicate(raw_label_idx): + label_idx = raw_label_idx[1] + patch = np.copy(label_datasets[0][label_idx]) + for ii in ignore_index: + patch[patch == ii] = 0 + non_ignore_counts = np.count_nonzero(patch != 0) + non_ignore_counts = non_ignore_counts / patch.size + return non_ignore_counts > threshold or rand_state.rand() < slack_acceptance + + zipped_slices = zip(self.raw_slices, self.label_slices) + # ignore slices containing too much ignore_index + filtered_slices = list(filter(ignore_predicate, zipped_slices)) + # unzip and save slices + raw_slices, label_slices = zip(*filtered_slices) + self._raw_slices = list(raw_slices) + self._label_slices = list(label_slices) + + +class EmbeddingsSliceBuilder(FilterSliceBuilder): + """ + Filter patches containing more than `1 - threshold` of ignore_index label and patches containing more than + `patch_max_instances` labels + """ + + def __init__(self, raw_datasets, label_datasets, weight_datasets, patch_shape, stride_shape, ignore_index=(0,), + threshold=0.8, slack_acceptance=0.01, patch_max_instances=48, patch_min_instances=5, **kwargs): + super().__init__(raw_datasets, label_datasets, weight_datasets, patch_shape, stride_shape, ignore_index, + threshold, slack_acceptance, **kwargs) + + if label_datasets is None: + return + + rand_state = np.random.RandomState(47) + + def ignore_predicate(raw_label_idx): + label_idx = raw_label_idx[1] + patch = label_datasets[0][label_idx] + num_instances = np.unique(patch).size + + # patch_max_instances is a hard constraint + if num_instances <= patch_max_instances: + # make sure that we have at least patch_min_instances in the batch and allow some slack + return num_instances >= patch_min_instances or rand_state.rand() < slack_acceptance + + return False + + zipped_slices = zip(self.raw_slices, self.label_slices) + # ignore slices containing too much ignore_index + filtered_slices = list(filter(ignore_predicate, zipped_slices)) + # unzip and save slices + raw_slices, label_slices = zip(*filtered_slices) + self._raw_slices = list(raw_slices) + self._label_slices = list(label_slices) + + +class RandomFilterSliceBuilder(EmbeddingsSliceBuilder): + """ + Filter patches containing more than `1 - threshold` of ignore_index label and return only random sample of those. + """ + + def __init__(self, raw_datasets, label_datasets, weight_datasets, patch_shape, stride_shape, ignore_index=(0,), + threshold=0.8, slack_acceptance=0.01, patch_max_instances=48, patch_acceptance_probab=0.1, + max_num_patches=25, **kwargs): + super().__init__(raw_datasets, label_datasets, weight_datasets, patch_shape, stride_shape, + ignore_index=ignore_index, threshold=threshold, slack_acceptance=slack_acceptance, + patch_max_instances=patch_max_instances, **kwargs) + + self.max_num_patches = max_num_patches + + if label_datasets is None: + return + + rand_state = np.random.RandomState(47) + + def ignore_predicate(raw_label_idx): + result = rand_state.rand() < patch_acceptance_probab + if result: + self.max_num_patches -= 1 + + return result and self.max_num_patches > 0 + + zipped_slices = zip(self.raw_slices, self.label_slices) + # ignore slices containing too much ignore_index + filtered_slices = list(filter(ignore_predicate, zipped_slices)) + # unzip and save slices + raw_slices, label_slices = zip(*filtered_slices) + self._raw_slices = list(raw_slices) + self._label_slices = list(label_slices) + + +def get_class(class_name, modules): + for module in modules: + m = importlib.import_module(module) + clazz = getattr(m, class_name, None) + if clazz is not None: + return clazz + raise RuntimeError(f'Unsupported dataset class: {class_name}') + + +def _loader_classes(class_name): + modules = [ + + 'data.mydata', + 'data.datautils' + ] + return get_class(class_name, modules) + + +def get_slice_builder(raws, labels, weight_maps, config): + assert 'name' in config + logger.info(f"Slice builder config: {config}") + slice_builder_cls = _loader_classes(config['name']) + return slice_builder_cls(raws, labels, weight_maps, **config) + + +def get_train_loaders(config): + """ + Returns dictionary containing the training and validation loaders (torch.utils.data.DataLoader). + + :param config: a top level configuration object containing the 'loaders' key + :return: dict { + 'train': + 'val': + } + 读取config中的loaders参数 + loaders参数很多, 包含关于训练和验证的参数 + 通过dataset的参数,我们获取dataset对应的对象,并且调用其中的create_dataset函数创建两个数据集 + 基于dataloader()来生成最终可以被调用的两个loader对象,并将二者封装为字典 + + """ + assert 'loaders' in config, 'Could not find data loaders configuration' + loaders_config = config['loaders'] + + logger.info('Creating training and validation set loaders...') + + # get dataset class + dataset_cls_str = loaders_config.get('dataset', None) + if dataset_cls_str is None: + dataset_cls_str = 'MyDataset' + logger.warn(f"Cannot find dataset class in the config. Using default '{dataset_cls_str}'.") + dataset_class = _loader_classes(dataset_cls_str) + + assert set(loaders_config['train']['file_paths']).isdisjoint(loaders_config['val']['file_paths']), \ + "Train and validation 'file_paths' overlap. One cannot use validation data for training!" + + train_datasets = dataset_class.create_datasets(loaders_config, phase='train') + + val_datasets = dataset_class.create_datasets(loaders_config, phase='val') + + num_workers = loaders_config.get('num_workers', 1) + logger.info(f'Number of workers for train/val dataloader: {num_workers}') + batch_size = loaders_config.get('batch_size', 1) + if torch.cuda.device_count() > 1 and not config['device'].type == 'cpu': + logger.info( + f'{torch.cuda.device_count()} GPUs available. Using batch_size = {torch.cuda.device_count()} * {batch_size}') + batch_size = batch_size * torch.cuda.device_count() + + logger.info(f'Batch size for train/val loader: {batch_size}') + # when training with volumetric data use batch_size of 1 due to GPU memory constraints + return { + 'train': DataLoader(ConcatDataset(train_datasets), batch_size=batch_size, shuffle=True, + num_workers=num_workers), + # don't shuffle during validation: useful when showing how predictions for a given batch get better over time + 'val': DataLoader(ConcatDataset(val_datasets), batch_size=batch_size, shuffle=False, num_workers=num_workers) + } + + +def get_test_loaders(config): + """ + Returns test DataLoader. + + :return: generator of DataLoader objects + """ + + assert 'loaders' in config, 'Could not find data loaders configuration' + loaders_config = config['loaders'] + + logger.info('Creating test set loaders...') + + # get dataset class + dataset_cls_str = loaders_config.get('dataset', None) + if dataset_cls_str is None: + dataset_cls_str = 'StandardHDF5Dataset' + logger.warn(f"Cannot find dataset class in the config. Using default '{dataset_cls_str}'.") + dataset_class = _loader_classes(dataset_cls_str) + + test_datasets = dataset_class.create_datasets(loaders_config, phase='test') + + num_workers = loaders_config.get('num_workers', 1) + logger.info(f'Number of workers for the dataloader: {num_workers}') + + batch_size = loaders_config.get('batch_size', 1) + if torch.cuda.device_count() > 1 and not config['device'].type == 'cpu': + logger.info( + f'{torch.cuda.device_count()} GPUs available. Using batch_size = {torch.cuda.device_count()} * {batch_size}') + batch_size = batch_size * torch.cuda.device_count() + + logger.info(f'Batch size for dataloader: {batch_size}') + + # use generator in order to create data loaders lazily one by one + for test_dataset in test_datasets: + logger.info(f'Loading test set from: {test_dataset.file_path}...') + if hasattr(test_dataset, 'prediction_collate'): + collate_fn = test_dataset.prediction_collate + else: + collate_fn = default_prediction_collate + + yield DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, + collate_fn=collate_fn) + + +def default_prediction_collate(batch): + """ + Default collate_fn to form a mini-batch of Tensor(s) for HDF5 based datasets + """ + error_msg = "batch must contain tensors or slice; found {}" + if isinstance(batch[0], torch.Tensor): + return torch.stack(batch, 0) + elif isinstance(batch[0], tuple) and isinstance(batch[0][0], slice): + return batch + elif isinstance(batch[0], collections.Sequence): + transposed = zip(*batch) + return [default_prediction_collate(samples) for samples in transposed] + + raise TypeError((error_msg.format(type(batch[0])))) + + +def calculate_stats(images): + """ + Calculates min, max, mean, std given a list of ndarrays + """ + # flatten first since the images might not be the same size + flat = np.concatenate( + [img.ravel() for img in images] + ) + return np.min(flat), np.max(flat), np.mean(flat), np.std(flat) + + +def sample_instances(label_img, instance_ratio, random_state, ignore_labels=(0,)): + """ + Given the labelled volume `label_img`, this function takes a random subset of object instances specified by `instance_ratio` + and zeros out the remaining labels. + + Args: + label_img(nd.array): labelled image + instance_ratio(float): a number from (0, 1] + random_state: RNG state + ignore_labels: labels to be ignored during sampling + + Returns: + labelled volume of the same size as `label_img` with a random subset of object instances. + """ + unique = np.unique(label_img) + for il in ignore_labels: + unique = np.setdiff1d(unique, il) + + # shuffle labels + random_state.shuffle(unique) + # pick instance_ratio objects + num_objects = round(instance_ratio * len(unique)) + if num_objects == 0: + # if there are no objects left, just return an empty patch + return np.zeros_like(label_img) + + # sample the labels + sampled_instances = unique[:num_objects] + + result = np.zeros_like(label_img) + # keep only the sampled_instances + for si in sampled_instances: + result[label_img == si] = si + + return result diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/data/mydata.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/data/mydata.py" new file mode 100644 index 0000000000000000000000000000000000000000..03245578a6560cf4ec2667e69b229aee37c54906 --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/data/mydata.py" @@ -0,0 +1,170 @@ +import collections +import os +import nrrd +import imageio +import numpy as np +import torch +import nibabel as nib +import data.transforms as transforms +from data.datautils import ConfigDataset, calculate_stats, sample_instances +from utils import get_logger +from skimage.transform import resize +logger = get_logger('MyDataset') + + +def dsb_prediction_collate(batch): + """ + Forms a mini-batch of (images, paths) during test time for the DSB-like datasets. + """ + error_msg = "batch must contain tensors or str; found {}" + if isinstance(batch[0], torch.Tensor): + return torch.stack(batch, 0) + elif isinstance(batch[0], str): + return list(batch) + elif isinstance(batch[0], collections.Sequence): + # transpose tuples, i.e. [[1, 2], ['a', 'b']] to be [[1, 'a'], [2, 'b']] + transposed = zip(*batch) + return [dsb_prediction_collate(samples) for samples in transposed] + + raise TypeError((error_msg.format(type(batch[0])))) + + +class MyDataset(ConfigDataset): + """ + 根据代码,datautils中的get_train_loaders通过调用create_datasets生成了两个dsb类下的对象 + cls(file_paths[0], train, transformer_config, mirror_padding=none, expand_dims=true, instance_ratio=none, random_seed=0) + cls(file_paths[0], valid, transformer_config, mirror_padding=none, expand_dims=true, instance_ratio=none, random_seed=0) + 对于训练过程来说,mirror_padding为none,设置phase,root_dir中包含images和masks + 首先读取images_dir中的数据,需要改写load_files文件 + 计算全体数据集数据的均值,最小值,最大值,方差 + 利用以上的参数生成针对数据的transform函数 + 基于load_files函数,在mask_dir中提取关于数据的mask数据 + 同样基于以上参数生成关于mask的transform函数 + + """ + def __init__(self, root_dir, phase, transformer_config, mirror_padding=(0, 32, 32), expand_dims=True, + instance_ratio=None, random_seed=0): + assert os.path.isdir(root_dir), f'{root_dir} is not a directory' + assert phase in ['train', 'val', 'test'] + + # use mirror padding only during the 'test' phase + if phase in ['train', 'val']: + mirror_padding = None + if mirror_padding is not None: + assert len(mirror_padding) == 3, f"Invalid mirror_padding: {mirror_padding}" + self.mirror_padding = mirror_padding + + self.phase = phase + + # load raw images + images_dir = os.path.join(root_dir, 'images') + assert os.path.isdir(images_dir) + self.images, self.paths = self._load_files(images_dir, expand_dims) + self.file_path = images_dir + self.instance_ratio = instance_ratio + + min_value, max_value, mean, std = calculate_stats(self.images) + logger.info(f'Input stats: min={min_value}, max={max_value}, mean={mean}, std={std}') + + transformer = transforms.get_transformer(transformer_config, min_value=min_value, max_value=max_value, + mean=mean, std=std) + + # load raw images transformer + self.raw_transform = transformer.raw_transform() + + if phase != 'test': + # load labeled images + masks_dir = os.path.join(root_dir, 'masks') + assert os.path.isdir(masks_dir) + #self.masks, _ = self._load_files(masks_dir, expand_dims) + self.masks, self.pa = self._load_files(masks_dir, expand_dims) + # prepare for training with sparse object supervision (allow sparse objects only in training phase) + if self.instance_ratio is not None and phase == 'train': + assert 0 < self.instance_ratio <= 1 + rs = np.random.RandomState(random_seed) + self.masks = [sample_instances(m, self.instance_ratio, rs) for m in self.masks] + assert len(self.images) == len(self.masks) + # load label images transformer + self.masks_transform = transformer.label_transform() + else: + self.masks = None + self.masks_transform = None + + # add mirror padding if needed + if self.mirror_padding is not None: + z, y, x = self.mirror_padding + pad_width = ((z, z), (y, y), (x, x)) + padded_imgs = [] + for img in self.images: + padded_img = np.pad(img, pad_width=pad_width, mode='reflect') + padded_imgs.append(padded_img) + + self.images = padded_imgs + + def __getitem__(self, idx): + if idx >= len(self): + raise StopIteration + + img = self.images[idx] + if self.phase != 'test': + mask = self.masks[idx] + #print(self.paths[idx]) + #print(self.pa[idx]) + return self.raw_transform(img), self.masks_transform(mask) + else: + return self.raw_transform(img), self.paths[idx] + + def __len__(self): + return len(self.images) + + @classmethod + def prediction_collate(cls, batch): + return dsb_prediction_collate(batch) + + @classmethod + def create_datasets(cls, dataset_config, phase): + """ + 本方法为类方法,可以在不进行对象的定义时直接对类调用 + phase分别为训练和验证 + 基于config中的训练和测试下的参数生成相关的变换要求(transform_config) + 分别获取测试和验证数据集对应的文件夹路径 + 返回结果为这个数据集类(包含了本类方法获取的一些参数===》数据集路径,验证集还是训练集,是否进行mirror_padding,expand_dims,实例比率,随机种子) + 调用类方法,返回特定参数的类来生成具体的对象 + """ + phase_config = dataset_config[phase] + # load data augmentation configuration + transformer_config = phase_config['transformer'] + # load files to process + file_paths = phase_config['file_paths'] + # mirror padding conf + mirror_padding = dataset_config.get('mirror_padding', None) + expand_dims = dataset_config.get('expand_dims', True) + instance_ratio = phase_config.get('instance_ratio', None) + random_seed = phase_config.get('random_seed', 0) + return [cls(file_paths[0], phase, transformer_config, mirror_padding, expand_dims, instance_ratio, random_seed)] + + @staticmethod + def _load_files(dir, expand_dims): + size = (256,256,16) + files_data = [] + paths = [] + ss = os.listdir(dir) + ss.sort() + for file in enumerate(ss): + path = os.path.join(dir, file[1]) + #img = np.asarray(imageio.imread(path)) + if path.endswith(".nii.gz"): + img = nib.load(path).get_fdata() + elif path.endswith(".nrrd"): + img , options = nrrd.read(path) + img = resize(img , size , mode='reflect', anti_aliasing=False, preserve_range=True) + if expand_dims: + dims = img.ndim + img = np.expand_dims(img, axis=0) + if dims == 3: + img = np.transpose(img, (0, 3, 1, 2)) + #print(img.shape) + files_data.append(img) + paths.append(path) + + return files_data, paths diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/data/transforms.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/data/transforms.py" new file mode 100644 index 0000000000000000000000000000000000000000..8d34fdfbf087af0ae54514dcb386b60dcfd251e7 --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/data/transforms.py" @@ -0,0 +1,777 @@ +import importlib + +import numpy as np +import torch +from scipy.ndimage import rotate, map_coordinates, gaussian_filter +from scipy.ndimage.filters import convolve +from skimage import measure +from skimage.filters import gaussian +from skimage.segmentation import find_boundaries +from torchvision.transforms import Compose +import torch.nn.functional as F +# WARN: use fixed random state for reproducibility; if you want to randomize on each run seed with `time.time()` e.g. +GLOBAL_RANDOM_STATE = np.random.RandomState(47) + + +class RandomFlip: + """ + Randomly flips the image across the given axes. Image can be either 3D (DxHxW) or 4D (CxDxHxW). + + When creating make sure that the provided RandomStates are consistent between raw and labeled datasets, + otherwise the models won't converge. + """ + + def __init__(self, random_state, axis_prob=0.5, **kwargs): + assert random_state is not None, 'RandomState cannot be None' + self.random_state = random_state + self.axes = (0, 1, 2) + self.axis_prob = axis_prob + + def __call__(self, m): + assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images' + + for axis in self.axes: + if self.random_state.uniform() > self.axis_prob: + if m.ndim == 3: + m = np.flip(m, axis) + else: + channels = [np.flip(m[c], axis) for c in range(m.shape[0])] + m = np.stack(channels, axis=0) + + return m + + +class RandomRotate90: + """ + Rotate an array by 90 degrees around a randomly chosen plane. Image can be either 3D (DxHxW) or 4D (CxDxHxW). + + When creating make sure that the provided RandomStates are consistent between raw and labeled datasets, + otherwise the models won't converge. + + IMPORTANT: assumes DHW axis order (that's why rotation is performed across (1,2) axis) + """ + + def __init__(self, random_state, **kwargs): + self.random_state = random_state + # always rotate around z-axis + self.axis = (1, 2) + + def __call__(self, m): + assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images' + + # pick number of rotations at random + k = self.random_state.randint(0, 4) + # rotate k times around a given plane + if m.ndim == 3: + m = np.rot90(m, k, self.axis) + else: + channels = [np.rot90(m[c], k, self.axis) for c in range(m.shape[0])] + m = np.stack(channels, axis=0) + + return m + + +class RandomRotate: + """ + Rotate an array by a random degrees from taken from (-angle_spectrum, angle_spectrum) interval. + Rotation axis is picked at random from the list of provided axes. + """ + + def __init__(self, random_state, angle_spectrum=30, axes=None, mode='reflect', order=0, **kwargs): + if axes is None: + axes = [(1, 0), (2, 1), (2, 0)] + else: + assert isinstance(axes, list) and len(axes) > 0 + + self.random_state = random_state + self.angle_spectrum = angle_spectrum + self.axes = axes + self.mode = mode + self.order = order + + def __call__(self, m): + axis = self.axes[self.random_state.randint(len(self.axes))] + angle = self.random_state.randint(-self.angle_spectrum, self.angle_spectrum) + + if m.ndim == 3: + m = rotate(m, angle, axes=axis, reshape=False, order=self.order, mode=self.mode, cval=-1) + else: + channels = [rotate(m[c], angle, axes=axis, reshape=False, order=self.order, mode=self.mode, cval=-1) for c + in range(m.shape[0])] + m = np.stack(channels, axis=0) + + return m + + +class RandomContrast: + """ + Adjust contrast by scaling each voxel to `mean + alpha * (v - mean)`. + """ + + def __init__(self, random_state, alpha=(0.5, 1.5), mean=0.0, execution_probability=0.1, **kwargs): + self.random_state = random_state + assert len(alpha) == 2 + self.alpha = alpha + self.mean = mean + self.execution_probability = execution_probability + + def __call__(self, m): + if self.random_state.uniform() < self.execution_probability: + alpha = self.random_state.uniform(self.alpha[0], self.alpha[1]) + result = self.mean + alpha * (m - self.mean) + return np.clip(result, -1, 1) + + return m + + +# it's relatively slow, i.e. ~1s per patch of size 64x200x200, so use multiple workers in the DataLoader +# remember to use spline_order=0 when transforming the labels +class ElasticDeformation: + """ + Apply elasitc deformations of 3D patches on a per-voxel mesh. Assumes ZYX axis order (or CZYX if the data is 4D). + Based on: https://github.com/fcalvet/image_tools/blob/master/image_augmentation.py#L62 + """ + + def __init__(self, random_state, spline_order, alpha=2000, sigma=50, execution_probability=0.1, apply_3d=True, + **kwargs): + """ + :param spline_order: the order of spline interpolation (use 0 for labeled images) + :param alpha: scaling factor for deformations + :param sigma: smoothing factor for Gaussian filter + :param execution_probability: probability of executing this transform + :param apply_3d: if True apply deformations in each axis + """ + self.random_state = random_state + self.spline_order = spline_order + self.alpha = alpha + self.sigma = sigma + self.execution_probability = execution_probability + self.apply_3d = apply_3d + + def __call__(self, m): + if self.random_state.uniform() < self.execution_probability: + assert m.ndim in [3, 4] + + if m.ndim == 3: + volume_shape = m.shape + else: + volume_shape = m[0].shape + + if self.apply_3d: + dz = gaussian_filter(self.random_state.randn(*volume_shape), self.sigma, mode="reflect") * self.alpha + else: + dz = np.zeros_like(m) + + dy, dx = [ + gaussian_filter( + self.random_state.randn(*volume_shape), + self.sigma, mode="reflect" + ) * self.alpha for _ in range(2) + ] + + z_dim, y_dim, x_dim = volume_shape + z, y, x = np.meshgrid(np.arange(z_dim), np.arange(y_dim), np.arange(x_dim), indexing='ij') + indices = z + dz, y + dy, x + dx + + if m.ndim == 3: + return map_coordinates(m, indices, order=self.spline_order, mode='reflect') + else: + channels = [map_coordinates(c, indices, order=self.spline_order, mode='reflect') for c in m] + return np.stack(channels, axis=0) + + return m + + +def blur_boundary(boundary, sigma): + boundary = gaussian(boundary, sigma=sigma) + boundary[boundary >= 0.5] = 1 + boundary[boundary < 0.5] = 0 + return boundary + + +class CropToFixed: + def __init__(self, random_state, size=(256, 256), centered=False, **kwargs): + self.random_state = random_state + self.crop_y, self.crop_x = size + self.centered = centered + + def __call__(self, m): + def _padding(pad_total): + half_total = pad_total // 2 + return (half_total, pad_total - half_total) + + def _rand_range_and_pad(crop_size, max_size): + """ + Returns a tuple: + max_value (int) for the corner dimension. The corner dimension is chosen as `self.random_state(max_value)` + pad (int): padding in both directions; if crop_size is lt max_size the pad is 0 + """ + if crop_size < max_size: + return max_size - crop_size, (0, 0) + else: + return 1, _padding(crop_size - max_size) + + def _start_and_pad(crop_size, max_size): + if crop_size < max_size: + return (max_size - crop_size) // 2, (0, 0) + else: + return 0, _padding(crop_size - max_size) + + assert m.ndim in (3, 4) + if m.ndim == 3: + _, y, x = m.shape + else: + _, _, y, x = m.shape + + if not self.centered: + y_range, y_pad = _rand_range_and_pad(self.crop_y, y) + x_range, x_pad = _rand_range_and_pad(self.crop_x, x) + + y_start = self.random_state.randint(y_range) + x_start = self.random_state.randint(x_range) + + else: + y_start, y_pad = _start_and_pad(self.crop_y, y) + x_start, x_pad = _start_and_pad(self.crop_x, x) + + if m.ndim == 3: + result = m[:, y_start:y_start + self.crop_y, x_start:x_start + self.crop_x] + return np.pad(result, pad_width=((0, 0), y_pad, x_pad), mode='reflect') + else: + channels = [] + for c in range(m.shape[0]): + result = m[c][:, y_start:y_start + self.crop_y, x_start:x_start + self.crop_x] + channels.append(np.pad(result, pad_width=((0, 0), y_pad, x_pad), mode='reflect')) + return np.stack(channels, axis=0) + + +class AbstractLabelToBoundary: + AXES_TRANSPOSE = [ + (0, 1, 2), # X + (0, 2, 1), # Y + (2, 0, 1) # Z + ] + + def __init__(self, ignore_index=None, aggregate_affinities=False, append_label=False, **kwargs): + """ + :param ignore_index: label to be ignored in the output, i.e. after computing the boundary the label ignore_index + will be restored where is was in the patch originally + :param aggregate_affinities: aggregate affinities with the same offset across Z,Y,X axes + :param append_label: if True append the orignal ground truth labels to the last channel + :param blur: Gaussian blur the boundaries + :param sigma: standard deviation for Gaussian kernel + """ + self.ignore_index = ignore_index + self.aggregate_affinities = aggregate_affinities + self.append_label = append_label + + def __call__(self, m): + """ + Extract boundaries from a given 3D label tensor. + :param m: input 3D tensor + :return: binary mask, with 1-label corresponding to the boundary and 0-label corresponding to the background + """ + assert m.ndim == 3 + + kernels = self.get_kernels() + boundary_arr = [np.where(np.abs(convolve(m, kernel)) > 0, 1, 0) for kernel in kernels] + channels = np.stack(boundary_arr) + results = [] + if self.aggregate_affinities: + assert len(kernels) % 3 == 0, "Number of kernels must be divided by 3 (one kernel per offset per Z,Y,X axes" + # aggregate affinities with the same offset + for i in range(0, len(kernels), 3): + # merge across X,Y,Z axes (logical OR) + xyz_aggregated_affinities = np.logical_or.reduce(channels[i:i + 3, ...]).astype(np.int) + # recover ignore index + xyz_aggregated_affinities = _recover_ignore_index(xyz_aggregated_affinities, m, self.ignore_index) + results.append(xyz_aggregated_affinities) + else: + results = [_recover_ignore_index(channels[i], m, self.ignore_index) for i in range(channels.shape[0])] + + if self.append_label: + # append original input data + results.append(m) + + # stack across channel dim + return np.stack(results, axis=0) + + @staticmethod + def create_kernel(axis, offset): + # create conv kernel + k_size = offset + 1 + k = np.zeros((1, 1, k_size), dtype=np.int) + k[0, 0, 0] = 1 + k[0, 0, offset] = -1 + return np.transpose(k, axis) + + def get_kernels(self): + raise NotImplementedError + + +class StandardLabelToBoundary: + def __init__(self, ignore_index=None, append_label=False, blur=False, sigma=1, mode='thick', foreground=False, + **kwargs): + self.ignore_index = ignore_index + self.append_label = append_label + self.blur = blur + self.sigma = sigma + self.mode = mode + self.foreground = foreground + + def __call__(self, m): + assert m.ndim == 3 + + boundaries = find_boundaries(m, connectivity=2, mode=self.mode) + if self.blur: + boundaries = blur_boundary(boundaries, self.sigma) + + results = [] + if self.foreground: + foreground = (m > 0).astype('uint8') + results.append(_recover_ignore_index(foreground, m, self.ignore_index)) + + results.append(_recover_ignore_index(boundaries, m, self.ignore_index)) + + if self.append_label: + # append original input data + results.append(m) + + return np.stack(results, axis=0) + + +class BlobsWithBoundary: + def __init__(self, mode=None, append_label=False, blur=False, sigma=1, **kwargs): + if mode is None: + mode = ['thick', 'inner', 'outer'] + self.mode = mode + self.append_label = append_label + self.blur = blur + self.sigma = sigma + + def __call__(self, m): + assert m.ndim == 3 + + # get the segmentation mask + results = [(m > 0).astype('uint8')] + + for bm in self.mode: + boundary = find_boundaries(m, connectivity=2, mode=bm) + if self.blur: + boundary = blur_boundary(boundary, self.sigma) + results.append(boundary) + + if self.append_label: + results.append(m) + + return np.stack(results, axis=0) + + +class BlobsToMask: + """ + Returns binary mask from labeled image, i.e. every label greater than 0 is treated as foreground. + + """ + + def __init__(self, append_label=False, boundary=False, cross_entropy=False, **kwargs): + self.cross_entropy = cross_entropy + self.boundary = boundary + self.append_label = append_label + + def __call__(self, m): + assert m.ndim == 3 + + # get the segmentation mask + mask = (m > 0).astype('uint8') + results = [mask] + + if self.boundary: + outer = find_boundaries(m, connectivity=2, mode='outer') + if self.cross_entropy: + # boundary is class 2 + mask[outer > 0] = 2 + results = [mask] + else: + results.append(outer) + + if self.append_label: + results.append(m) + + return np.stack(results, axis=0) + + +class RandomLabelToAffinities(AbstractLabelToBoundary): + """ + Converts a given volumetric label array to binary mask corresponding to borders between labels. + One specify the max_offset (thickness) of the border. Then the offset is picked at random every time you call + the transformer (offset is picked form the range 1:max_offset) for each axis and the boundary computed. + One may use this scheme in order to make the network more robust against various thickness of borders in the ground + truth (think of it as a boundary denoising scheme). + """ + + def __init__(self, random_state, max_offset=10, ignore_index=None, append_label=False, z_offset_scale=2, **kwargs): + super().__init__(ignore_index=ignore_index, append_label=append_label, aggregate_affinities=False) + self.random_state = random_state + self.offsets = tuple(range(1, max_offset + 1)) + self.z_offset_scale = z_offset_scale + + def get_kernels(self): + rand_offset = self.random_state.choice(self.offsets) + axis_ind = self.random_state.randint(3) + # scale down z-affinities due to anisotropy + if axis_ind == 2: + rand_offset = max(1, rand_offset // self.z_offset_scale) + + rand_axis = self.AXES_TRANSPOSE[axis_ind] + # return a single kernel + return [self.create_kernel(rand_axis, rand_offset)] + + +class LabelToAffinities(AbstractLabelToBoundary): + """ + Converts a given volumetric label array to binary mask corresponding to borders between labels (which can be seen + as an affinity graph: https://arxiv.org/pdf/1706.00120.pdf) + One specify the offsets (thickness) of the border. The boundary will be computed via the convolution operator. + """ + + def __init__(self, offsets, ignore_index=None, append_label=False, aggregate_affinities=False, z_offsets=None, + **kwargs): + super().__init__(ignore_index=ignore_index, append_label=append_label, + aggregate_affinities=aggregate_affinities) + + assert isinstance(offsets, list) or isinstance(offsets, tuple), 'offsets must be a list or a tuple' + assert all(a > 0 for a in offsets), "'offsets must be positive" + assert len(set(offsets)) == len(offsets), "'offsets' must be unique" + if z_offsets is not None: + assert len(offsets) == len(z_offsets), 'z_offsets length must be the same as the length of offsets' + else: + # if z_offsets is None just use the offsets for z-affinities + z_offsets = list(offsets) + self.z_offsets = z_offsets + + self.kernels = [] + # create kernel for every axis-offset pair + for xy_offset, z_offset in zip(offsets, z_offsets): + for axis_ind, axis in enumerate(self.AXES_TRANSPOSE): + final_offset = xy_offset + if axis_ind == 2: + final_offset = z_offset + # create kernels for a given offset in every direction + self.kernels.append(self.create_kernel(axis, final_offset)) + + def get_kernels(self): + return self.kernels + + +class LabelToZAffinities(AbstractLabelToBoundary): + """ + Converts a given volumetric label array to binary mask corresponding to borders between labels (which can be seen + as an affinity graph: https://arxiv.org/pdf/1706.00120.pdf) + One specify the offsets (thickness) of the border. The boundary will be computed via the convolution operator. + """ + + def __init__(self, offsets, ignore_index=None, append_label=False, **kwargs): + super().__init__(ignore_index=ignore_index, append_label=append_label) + + assert isinstance(offsets, list) or isinstance(offsets, tuple), 'offsets must be a list or a tuple' + assert all(a > 0 for a in offsets), "'offsets must be positive" + assert len(set(offsets)) == len(offsets), "'offsets' must be unique" + + self.kernels = [] + z_axis = self.AXES_TRANSPOSE[2] + # create kernels + for z_offset in offsets: + self.kernels.append(self.create_kernel(z_axis, z_offset)) + + def get_kernels(self): + return self.kernels + + +class LabelToBoundaryAndAffinities: + """ + Combines the StandardLabelToBoundary and LabelToAffinities in the hope + that that training the network to predict both would improve the main task: boundary prediction. + """ + + def __init__(self, xy_offsets, z_offsets, append_label=False, blur=False, sigma=1, ignore_index=None, mode='thick', + foreground=False, **kwargs): + # blur only StandardLabelToBoundary results; we don't want to blur the affinities + self.l2b = StandardLabelToBoundary(blur=blur, sigma=sigma, ignore_index=ignore_index, mode=mode, + foreground=foreground) + self.l2a = LabelToAffinities(offsets=xy_offsets, z_offsets=z_offsets, append_label=append_label, + ignore_index=ignore_index) + + def __call__(self, m): + boundary = self.l2b(m) + affinities = self.l2a(m) + return np.concatenate((boundary, affinities), axis=0) + + +class FlyWingBoundary: + """ + Use if the volume contains a single pixel boundaries between labels. Gives the single pixel boundary in the 1st + channel and the 'thick' boundary in the 2nd channel and optional z-affinities + """ + + def __init__(self, append_label=False, thick_boundary=True, ignore_index=None, z_offsets=None, **kwargs): + self.append_label = append_label + self.thick_boundary = thick_boundary + self.ignore_index = ignore_index + self.lta = None + if z_offsets is not None: + self.lta = LabelToZAffinities(z_offsets, ignore_index=ignore_index) + + def __call__(self, m): + boundary = (m == 0).astype('uint8') + results = [boundary] + + if self.thick_boundary: + t_boundary = find_boundaries(m, connectivity=1, mode='outer', background=0) + results.append(t_boundary) + + if self.lta is not None: + z_affs = self.lta(m) + for z_aff in z_affs: + results.append(z_aff) + + if self.ignore_index is not None: + for b in results: + b[m == self.ignore_index] = self.ignore_index + + if self.append_label: + # append original input data + results.append(m) + + return np.stack(results, axis=0) + + +class LabelToMaskAndAffinities: + def __init__(self, xy_offsets, z_offsets, append_label=False, background=0, ignore_index=None, **kwargs): + self.background = background + self.l2a = LabelToAffinities(offsets=xy_offsets, z_offsets=z_offsets, append_label=append_label, + ignore_index=ignore_index) + + def __call__(self, m): + mask = m > self.background + mask = np.expand_dims(mask.astype(np.uint8), axis=0) + affinities = self.l2a(m) + return np.concatenate((mask, affinities), axis=0) + + +class Standardize: + """ + Apply Z-score normalization to a given input tensor, i.e. re-scaling the values to be 0-mean and 1-std. + """ + + def __init__(self, eps=1e-10, mean=None, std=None, channelwise=False, **kwargs): + if mean is not None or std is not None: + assert mean is not None and std is not None + self.mean = mean + self.std = std + self.eps = eps + self.channelwise = channelwise + + def __call__(self, m): + if self.mean is not None: + mean, std = self.mean, self.std + else: + if self.channelwise: + # normalize per-channel + axes = list(range(m.ndim)) + # average across channels + axes = tuple(axes[1:]) + mean = np.mean(m, axis=axes, keepdims=True) + std = np.std(m, axis=axes, keepdims=True) + else: + mean = np.mean(m) + std = np.std(m) + + return (m - mean) / np.clip(std, a_min=self.eps, a_max=None) + + +class PercentileNormalizer: + def __init__(self, pmin, pmax, channelwise=False, eps=1e-10, **kwargs): + self.eps = eps + self.pmin = pmin + self.pmax = pmax + self.channelwise = channelwise + + def __call__(self, m): + if self.channelwise: + axes = list(range(m.ndim)) + # average across channels + axes = tuple(axes[1:]) + pmin = np.percentile(m, self.pmin, axis=axes, keepdims=True) + pmax = np.percentile(m, self.pmax, axis=axes, keepdims=True) + else: + pmin = np.percentile(m, self.pmin) + pmax = np.percentile(m, self.pmax) + + return (m - pmin) / (pmax - pmin + self.eps) + + +class Normalize: + """ + Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data in a fixed range of [-1, 1]. + """ + + def __init__(self, min_value, max_value, **kwargs): + assert max_value > min_value + self.min_value = min_value + self.value_range = max_value - min_value + + def __call__(self, m): + norm_0_1 = (m - self.min_value) / self.value_range + return np.clip(2 * norm_0_1 - 1, -1, 1) + + +class AdditiveGaussianNoise: + def __init__(self, random_state, scale=(0.0, 1.0), execution_probability=0.1, **kwargs): + self.execution_probability = execution_probability + self.random_state = random_state + self.scale = scale + + def __call__(self, m): + if self.random_state.uniform() < self.execution_probability: + std = self.random_state.uniform(self.scale[0], self.scale[1]) + gaussian_noise = self.random_state.normal(0, std, size=m.shape) + return m + gaussian_noise + return m + + +class AdditivePoissonNoise: + def __init__(self, random_state, lam=(0.0, 1.0), execution_probability=0.1, **kwargs): + self.execution_probability = execution_probability + self.random_state = random_state + self.lam = lam + + def __call__(self, m): + if self.random_state.uniform() < self.execution_probability: + lam = self.random_state.uniform(self.lam[0], self.lam[1]) + poisson_noise = self.random_state.poisson(lam, size=m.shape) + return m + poisson_noise + return m + + +class ToTensor: + """ + Converts a given input numpy.ndarray into torch.Tensor. Adds additional 'channel' axis when the input is 3D + and expand_dims=True (use for raw data of the shape (D, H, W)). + """ + + def __init__(self, expand_dims, dtype=np.float32, **kwargs): + self.expand_dims = expand_dims + self.dtype = dtype + + def __call__(self, m): + assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images' + # add channel dimension + if self.expand_dims and m.ndim == 3: + m = np.expand_dims(m, axis=0) + + return torch.from_numpy(m.astype(dtype=self.dtype)) + + +class Relabel: + """ + Relabel a numpy array of labels into a consecutive numbers, e.g. + [10, 10, 0, 6, 6] -> [2, 2, 0, 1, 1]. Useful when one has an instance segmentation volume + at hand and would like to create a one-hot-encoding for it. Without a consecutive labeling the task would be harder. + """ + + def __init__(self, append_original=False, run_cc=True, ignore_label=None, **kwargs): + self.append_original = append_original + self.ignore_label = ignore_label + self.run_cc = run_cc + + if ignore_label is not None: + assert append_original, "ignore_label present, so append_original must be true, so that one can localize the ignore region" + + def __call__(self, m): + orig = m + if self.run_cc: + # assign 0 to the ignore region + m = measure.label(m, background=self.ignore_label) + + _, unique_labels = np.unique(m, return_inverse=True) + result = unique_labels.reshape(m.shape) + if self.append_original: + result = np.stack([result, orig]) + return result + + +class Identity: + def __init__(self, **kwargs): + pass + + def __call__(self, m): + return m + + +class RgbToLabel: + def __call__(self, img): + img = np.array(img) + assert img.ndim == 3 and img.shape[2] == 3 + result = img[..., 0] * 65536 + img[..., 1] * 256 + img[..., 2] + return result + + +class LabelToTensor: + def __call__(self, m): + m = np.array(m) + return torch.from_numpy(m.astype(dtype='int64')) + + +class ImgNormalize: + def __call__(self, tensor): + mean = torch.mean(tensor, dim=(1, 2)) + std = torch.std(tensor, dim=(1, 2)) + return F.normalize(tensor, mean, std) + + +def get_transformer(config, min_value, max_value, mean, std): + base_config = {'min_value': min_value, 'max_value': max_value, 'mean': mean, 'std': std} + return Transformer(config, base_config) + + +class Transformer: + def __init__(self, phase_config, base_config): + self.phase_config = phase_config + self.config_base = base_config + self.seed = GLOBAL_RANDOM_STATE.randint(10000000) + + def raw_transform(self): + return self._create_transform('raw') + + def label_transform(self): + return self._create_transform('label') + + def weight_transform(self): + return self._create_transform('weight') + + @staticmethod + def _transformer_class(class_name): + m = importlib.import_module('data.transforms') + clazz = getattr(m, class_name) + return clazz + + def _create_transform(self, name): + assert name in self.phase_config, f'Could not find {name} transform' + return Compose([ + self._create_augmentation(c) for c in self.phase_config[name] + ]) + + def _create_augmentation(self, c): + config = dict(self.config_base) + config.update(c) + config['random_state'] = np.random.RandomState(self.seed) + aug_class = self._transformer_class(config['name']) + return aug_class(**config) + + +def _recover_ignore_index(input, orig, ignore_index): + if ignore_index is not None: + mask = orig == ignore_index + input[mask] = ignore_index + + return input diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/lingwu.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/lingwu.py" new file mode 100644 index 0000000000000000000000000000000000000000..ed91c9308998800998ccda91f4f08be52c48d097 --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/lingwu.py" @@ -0,0 +1,60 @@ +import os +files_data = [] +paths = [] +paths1 = [] +paths2 = [] +paths3 = [] +dir = r"/home/zyyang/concile/segmentation/mydataset/valid/images" +dir1 = r"/home/zyyang/concile/segmentation/mydataset/valid/masks" +dir2 = r"/home/zyyang/concile/segmentation/mydataset/train/images" +dir3 = r"/home/zyyang/concile/segmentation/mydataset/train/masks" +ss = os.listdir(dir) +#print(ss) +ss.sort() +print(ss) +for file in enumerate(ss): + #print(file) + #print(file[1]) + path = os.path.join(dir, file[1]) + # img = np.asarray(imageio.imread(path)) + #if path.endswith(".nii.gz"): + #img = nib.load(path).get_fdata() + #elif path.endswith(".nrrd"): + #img, options = nrrd.read(path) + paths.append(path) +print(paths) +print(len(paths)) + + +for file in os.listdir(dir1): + path = os.path.join(dir1, file) + # img = np.asarray(imageio.imread(path)) + #if path.endswith(".nii.gz"): + #img = nib.load(path).get_fdata() + #elif path.endswith(".nrrd"): + #img, options = nrrd.read(path) + paths1.append(path) +print(paths1) +print(len(paths1)) +for file in os.listdir(dir2): + path = os.path.join(dir2, file) + # img = np.asarray(imageio.imread(path)) + # if path.endswith(".nii.gz"): + # img = nib.load(path).get_fdata() + # elif path.endswith(".nrrd"): + # img, options = nrrd.read(path) + paths2.append(path) +print(paths2) +print(len(paths2)) +for file in os.listdir(dir3): + path = os.path.join(dir3, file) + # img = np.asarray(imageio.imread(path)) + # if path.endswith(".nii.gz"): + # img = nib.load(path).get_fdata() + # elif path.endswith(".nrrd"): + # img, options = nrrd.read(path) + paths3.append(path) + +print(paths3) +print(len(paths3)) + diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/loss.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/loss.py" new file mode 100644 index 0000000000000000000000000000000000000000..33ebd3bdbd2f1f35e4209710c2d647f3c85939c2 --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/loss.py" @@ -0,0 +1,372 @@ +import torch +import torch.nn.functional as F +from torch import nn as nn +from torch.autograd import Variable +from torch.nn import MSELoss, SmoothL1Loss, L1Loss + +from utils import expand_as_one_hot + + +def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): + """ + Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. + Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. + + Args: + input (torch.Tensor): NxCxSpatial input tensor + target (torch.Tensor): NxCxSpatial target tensor + epsilon (float): prevents division by zero + weight (torch.Tensor): Cx1 tensor of weight per channel/class + 在此处判断一下target是否需要添加通道维 + """ + + # input and target shapes must match + assert input.size() == target.size(), "'input' and 'target' must have the same shape" + + input = flatten(input) + target = flatten(target) + target = target.float() + + # compute per channel Dice Coefficient + intersect = (input * target).sum(-1) + if weight is not None: + intersect = weight * intersect + + # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1) + denominator = (input * input).sum(-1) + (target * target).sum(-1) + return 2 * (intersect / denominator.clamp(min=epsilon)) + + +class _MaskingLossWrapper(nn.Module): + """ + Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`. + """ + + def __init__(self, loss, ignore_index): + super(_MaskingLossWrapper, self).__init__() + assert ignore_index is not None, 'ignore_index cannot be None' + self.loss = loss + self.ignore_index = ignore_index + + def forward(self, input, target): + mask = target.clone().ne_(self.ignore_index) + mask.requires_grad = False + + # mask out input/target so that the gradient is zero where on the mask + input = input * mask + target = target * mask + + # forward masked input and target to the loss + return self.loss(input, target) + + +class SkipLastTargetChannelWrapper(nn.Module): + """ + Loss wrapper which removes additional target channel + """ + + def __init__(self, loss, squeeze_channel=False): + super(SkipLastTargetChannelWrapper, self).__init__() + self.loss = loss + self.squeeze_channel = squeeze_channel + + def forward(self, input, target): + assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel' + + # skips last target channel if needed + target = target[:, :-1, ...] + + if self.squeeze_channel: + # squeeze channel dimension if singleton + target = torch.squeeze(target, dim=1) + return self.loss(input, target) + + +class _AbstractDiceLoss(nn.Module): + """ + Base class for different implementations of Dice loss. + """ + + def __init__(self, weight=None, normalization='sigmoid'): + super(_AbstractDiceLoss, self).__init__() + self.register_buffer('weight', weight) + # The output from the network during training is assumed to be un-normalized probabilities and we would + # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data, + # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems. + # However if one would like to apply Softmax in order to get the proper probability distribution from the + # output, just specify `normalization=Softmax` + assert normalization in ['sigmoid', 'softmax', 'none'] + if normalization == 'sigmoid': + self.normalization = nn.Sigmoid() + elif normalization == 'softmax': + self.normalization = nn.Softmax(dim=1) + else: + self.normalization = lambda x: x + + def dice(self, input, target, weight): + # actual Dice score computation; to be implemented by the subclass + raise NotImplementedError + + def forward(self, input, target): + # get probabilities from logits + input = self.normalization(input) + + # compute per channel Dice coefficient + per_channel_dice = self.dice(input, target, weight=self.weight) + + # average Dice score across all channels/classes + return 1. - torch.mean(per_channel_dice) + + +class DiceLoss(_AbstractDiceLoss): + """Computes Dice Loss according to https://arxiv.org/abs/1606.04797. + For multi-class segmentation `weight` parameter can be used to assign different weights per class. + The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function. + weight参数是对于像素分类时设置的权重参数 + diceloss继承_AbstractDiceLoss + normalization='sigmoid' + 结合abstractdiceloss和diceloss两个函数,我们可以发现,两个函数的核心就是 + 将输入(也就是model的输出output通过sigmoid处理后,这儿我很迷惑,怎么还需要sigmoid处理一下,但说不定就是diceloss的特点) + 和目标(也就是我们的mask)进行per_channel_dice的计算==》 + 即调用compute_per_channel_dice函数==》 + 先将output和target进行flatten操作==>将(N , C , D , H , W)变化为(C , N*D*H*W) + 基于以上,调用compute_per_channel_dice生成diceloss + 因为我们的分割只有2分类,所以通道数只有1 + + + """ + + def __init__(self, weight=None, normalization='sigmoid'): + super().__init__(weight, normalization) + + def dice(self, input, target, weight): + return compute_per_channel_dice(input, target, weight=self.weight) + + +class GeneralizedDiceLoss(_AbstractDiceLoss): + """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf. + """ + + def __init__(self, normalization='sigmoid', epsilon=1e-6): + super().__init__(weight=None, normalization=normalization) + self.epsilon = epsilon + + def dice(self, input, target, weight): + assert input.size() == target.size(), "'input' and 'target' must have the same shape" + + input = flatten(input) + target = flatten(target) + target = target.float() + + if input.size(0) == 1: + # for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf) + # put foreground and background voxels in separate channels + input = torch.cat((input, 1 - input), dim=0) + target = torch.cat((target, 1 - target), dim=0) + + # GDL weighting: the contribution of each label is corrected by the inverse of its volume + w_l = target.sum(-1) + w_l = 1 / (w_l * w_l).clamp(min=self.epsilon) + w_l.requires_grad = False + + intersect = (input * target).sum(-1) + intersect = intersect * w_l + + denominator = (input + target).sum(-1) + denominator = (denominator * w_l).clamp(min=self.epsilon) + + return 2 * (intersect.sum() / denominator.sum()) + + +class BCEDiceLoss(nn.Module): + """Linear combination of BCE and Dice losses""" + + def __init__(self, alpha, beta): + super(BCEDiceLoss, self).__init__() + self.alpha = alpha + self.bce = nn.BCEWithLogitsLoss() + self.beta = beta + self.dice = DiceLoss() + + def forward(self, input, target): + return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target) + + +class WeightedCrossEntropyLoss(nn.Module): + """WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf + """ + + def __init__(self, ignore_index=-1): + super(WeightedCrossEntropyLoss, self).__init__() + self.ignore_index = ignore_index + + def forward(self, input, target): + weight = self._class_weights(input) + return F.cross_entropy(input, target, weight=weight, ignore_index=self.ignore_index) + + @staticmethod + def _class_weights(input): + # normalize the input first + input = F.softmax(input, dim=1) + flattened = flatten(input) + nominator = (1. - flattened).sum(-1) + denominator = flattened.sum(-1) + class_weights = Variable(nominator / denominator, requires_grad=False) + return class_weights + + +class PixelWiseCrossEntropyLoss(nn.Module): + def __init__(self, class_weights=None, ignore_index=None): + super(PixelWiseCrossEntropyLoss, self).__init__() + self.register_buffer('class_weights', class_weights) + self.ignore_index = ignore_index + self.log_softmax = nn.LogSoftmax(dim=1) + + def forward(self, input, target, weights): + assert target.size() == weights.size() + # normalize the input + log_probabilities = self.log_softmax(input) + # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW) + target = expand_as_one_hot(target, C=input.size()[1], ignore_index=self.ignore_index) + # expand weights + weights = weights.unsqueeze(0) + weights = weights.expand_as(input) + + # create default class_weights if None + if self.class_weights is None: + class_weights = torch.ones(input.size()[1]).float().to(input.device) + else: + class_weights = self.class_weights + + # resize class_weights to be broadcastable into the weights + class_weights = class_weights.view(1, -1, 1, 1, 1) + + # multiply weights tensor by class weights + weights = class_weights * weights + + # compute the losses + result = -weights * target * log_probabilities + # average the losses + return result.mean() + + +class WeightedSmoothL1Loss(nn.SmoothL1Loss): + def __init__(self, threshold, initial_weight, apply_below_threshold=True): + super().__init__(reduction="none") + self.threshold = threshold + self.apply_below_threshold = apply_below_threshold + self.weight = initial_weight + + def forward(self, input, target): + l1 = super().forward(input, target) + + if self.apply_below_threshold: + mask = target < self.threshold + else: + mask = target >= self.threshold + + l1[mask] = l1[mask] * self.weight + + return l1.mean() + + +def flatten(tensor): + """Flattens a given tensor such that the channel axis is first. + The shapes are transformed as follows: + (N, C, D, H, W) -> (C, N * D * H * W) + """ + # number of channels + C = tensor.size(1) + # new axis order + axis_order = (1, 0) + tuple(range(2, tensor.dim())) + # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) + transposed = tensor.permute(axis_order) + # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) + return transposed.contiguous().view(C, -1) + + +def get_loss_criterion(config): + """ + Returns the loss function based on provided configuration + :param config: (dict) a top level configuration object containing the 'loss' key + :return: an instance of the loss function + 读取config中关于损失的设置参数 + name = diceloss + ignore_index = none + skip_last_target = none + weight = none + pos_weight = none + 调用create_loss函数创建diceloss() + 按照设置,直接返回loss + + """ + assert 'loss' in config, 'Could not find loss function configuration' + loss_config = config['loss'] + name = loss_config.pop('name') + + ignore_index = loss_config.pop('ignore_index', None) + skip_last_target = loss_config.pop('skip_last_target', False) + weight = loss_config.pop('weight', None) + + if weight is not None: + # convert to cuda tensor if necessary + weight = torch.tensor(weight).to(config['device']) + + pos_weight = loss_config.pop('pos_weight', None) + if pos_weight is not None: + # convert to cuda tensor if necessary + pos_weight = torch.tensor(pos_weight).to(config['device']) + + loss = _create_loss(name, loss_config, weight, ignore_index, pos_weight) + + if not (ignore_index is None or name in ['CrossEntropyLoss', 'WeightedCrossEntropyLoss']): + # use MaskingLossWrapper only for non-cross-entropy losses, since CE losses allow specifying 'ignore_index' directly + loss = _MaskingLossWrapper(loss, ignore_index) + + if skip_last_target: + loss = SkipLastTargetChannelWrapper(loss, loss_config.get('squeeze_channel', False)) + + return loss + + +####################################################################################################################### + +def _create_loss(name, loss_config, weight, ignore_index, pos_weight): + """ + 返回函数diceloss(),其中参数设置为(weight = none, normalization = sigmoid) + """ + if name == 'BCEWithLogitsLoss': + return nn.BCEWithLogitsLoss(pos_weight=pos_weight) + elif name == 'BCEDiceLoss': + alpha = loss_config.get('alphs', 1.) + beta = loss_config.get('beta', 1.) + return BCEDiceLoss(alpha, beta) + elif name == 'CrossEntropyLoss': + if ignore_index is None: + ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss + return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) + elif name == 'WeightedCrossEntropyLoss': + if ignore_index is None: + ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss + return WeightedCrossEntropyLoss(ignore_index=ignore_index) + elif name == 'PixelWiseCrossEntropyLoss': + return PixelWiseCrossEntropyLoss(class_weights=weight, ignore_index=ignore_index) + elif name == 'GeneralizedDiceLoss': + normalization = loss_config.get('normalization', 'sigmoid') + return GeneralizedDiceLoss(normalization=normalization) + elif name == 'DiceLoss': + normalization = loss_config.get('normalization', 'sigmoid') + return DiceLoss(weight=weight, normalization=normalization) + elif name == 'MSELoss': + return MSELoss() + elif name == 'SmoothL1Loss': + return SmoothL1Loss() + elif name == 'L1Loss': + return L1Loss() + elif name == 'WeightedSmoothL1Loss': + return WeightedSmoothL1Loss(threshold=loss_config['threshold'], + initial_weight=loss_config['initial_weight'], + apply_below_threshold=loss_config.get('apply_below_threshold', True)) + else: + raise RuntimeError(f"Unsupported loss function: '{name}'") + diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/main.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/main.py" new file mode 100644 index 0000000000000000000000000000000000000000..14eb1b0a58fcebe3056966165718ff0d96ad039b --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/main.py" @@ -0,0 +1,42 @@ +import torch + +from data.datautils import get_class +from config import load_config +from utils import get_logger + +logger = get_logger('TrainingSetup') +""" +""" + +def main(): + """加载和读取config文件,设置随机种子和cudnn参数使得达到可复现效果,设置训练模型类别为3DUNET网络, + 调用config文件中的builder参数作为指定的训练类别===》默认为UNet3DTrainerBuilder + 从trainer这个module中加载到UNet3DTrainerBuilder这个函数 + 通过UNet3DTrainerBuilder函数中的bulider来建立模型,将config参数集传到bulider函数中 + builder函数调用create_trainer函数 + 继而调用UNET3Dtrainer函数,在此函数中进行训练和验证 + """ + # Load and log experiment configuration + path = r"./config_training.yaml" + config = load_config(path) + logger.info(config) + + manual_seed = config.get('manual_seed', None) + if manual_seed is not None: + logger.info(f'Seed the RNG for all devices with {manual_seed}') + torch.manual_seed(manual_seed) + # see https://pytorch.org/docs/stable/notes/randomness.html + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # create trainer + default_trainer_builder_class = 'UNet3DTrainerBuilder' + trainer_builder_class = config['trainer'].get('builder', default_trainer_builder_class) + trainer_builder = get_class(trainer_builder_class, modules=['train']) + trainer = trainer_builder.build(config) + # Start training + trainer.fit() + + +if __name__ == '__main__': + main() diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/metrics.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/metrics.py" new file mode 100644 index 0000000000000000000000000000000000000000..a6998dffb59e3c8c70500b5bd600326bcf342540 --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/metrics.py" @@ -0,0 +1,444 @@ +import importlib + +import numpy as np +import torch +from skimage import measure +from skimage.metrics import adapted_rand_error, peak_signal_noise_ratio + +from loss import compute_per_channel_dice +from seg_metrics import AveragePrecision, Accuracy +from utils import get_logger, expand_as_one_hot, convert_to_numpy + +logger = get_logger('EvalMetric') + + +class DiceCoefficient: + """Computes Dice Coefficient. + Generalized to multiple channels by computing per-channel Dice Score + (as described in https://arxiv.org/pdf/1707.03237.pdf) and theTn simply taking the average. + Input is expected to be probabilities instead of logits. + This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets). + DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss. + """ + + def __init__(self, epsilon=1e-6, **kwargs): + self.epsilon = epsilon + + def __call__(self, input, target): + # Average across channels in order to get the final score + return torch.mean(compute_per_channel_dice(input, target, epsilon=self.epsilon)) + + +class MeanIoU: + """ + Computes IoU for each class separately and then averages over all classes. + """ + + def __init__(self, skip_channels=(), ignore_index=None, **kwargs): + """ + :param skip_channels: list/tuple of channels to be ignored from the IoU computation + :param ignore_index: id of the label to be ignored from IoU computation + """ + self.ignore_index = ignore_index + self.skip_channels = skip_channels + + def __call__(self, input, target): + """ + :param input: 5D probability maps torch float tensor (NxCxDxHxW) + :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot + :return: intersection over union averaged over all channels + 将target添加通道维,单通道时只需要对其设定阈值,返回 结果即为基于阈值判断的类似mask的格式的tensor + 对于一个批量的5维的output和target来说,基于zip函数生成N个可迭代对象 + iou计算 + 先对于每个样本的每个通道计算一次iou(iou为target和outpuyt的交集比上并集) + 接着针对全通道计算平均iou + 接着计算batch内样本的平均iou + """ + assert input.dim() == 5 + + n_classes = input.size()[1] + + if target.dim() == 4: + target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index) + + assert input.size() == target.size() + + per_batch_iou = [] + for _input, _target in zip(input, target): + binary_prediction = self._binarize_predictions(_input, n_classes) + + if self.ignore_index is not None: + # zero out ignore_index + mask = _target == self.ignore_index + binary_prediction[mask] = 0 + _target[mask] = 0 + + # convert to uint8 just in case + binary_prediction = binary_prediction.byte() + _target = _target.byte() + + per_channel_iou = [] + for c in range(n_classes): + if c in self.skip_channels: + continue + + per_channel_iou.append(self._jaccard_index(binary_prediction[c], _target[c])) + + assert per_channel_iou, "All channels were ignored from the computation" + mean_iou = torch.mean(torch.tensor(per_channel_iou)) + per_batch_iou.append(mean_iou) + + return torch.mean(torch.tensor(per_batch_iou)) + + def _binarize_predictions(self, input, n_classes): + """ + Puts 1 for the class/channel with the highest probability and 0 in other channels. Returns byte tensor of the + same size as the input tensor. + """ + if n_classes == 1: + # for single channel input just threshold the probability map + result = input > 0.5 + return result.long() + + _, max_index = torch.max(input, dim=0, keepdim=True) + return torch.zeros_like(input, dtype=torch.uint8).scatter_(0, max_index, 1) + + def _jaccard_index(self, prediction, target): + """ + Computes IoU for a given target and prediction tensors + """ + return torch.sum(prediction & target).float() / torch.clamp(torch.sum(prediction | target).float(), min=1e-8) + + +class AdaptedRandError: + """ + A functor which computes an Adapted Rand error as defined by the SNEMI3D contest + (http://brainiac2.mit.edu/SNEMI3D/evaluation). + + This is a generic implementation which takes the input, converts it to the segmentation image (see `input_to_segm()`) + and then computes the ARand between the segmentation and the ground truth target. Depending on one's use case + it's enough to extend this class and implement the `input_to_segm` method. + + Args: + use_last_target (bool): use only the last channel from the target to compute the ARand + """ + + def __init__(self, use_last_target=False, **kwargs): + self.use_last_target = use_last_target + + def __call__(self, input, target): + """ + Compute ARand Error for each input, target pair in the batch and return the mean value. + + Args: + input (torch.tensor): 5D (NCDHW) output from the network + target (torch.tensor): 4D (NDHW) ground truth segmentation + + Returns: + average ARand Error across the batch + """ + + def _arand_err(gt, seg): + n_seg = len(np.unique(seg)) + if n_seg == 1: + return 0. + return adapted_rand_error(gt, seg)[0] + + # converts input and target to numpy arrays + input, target = convert_to_numpy(input, target) + if self.use_last_target: + target = target[:, -1, ...] # 4D + else: + # use 1st target channel + target = target[:, 0, ...] # 4D + + # ensure target is of integer type + target = target.astype(np.int) + + per_batch_arand = [] + for _input, _target in zip(input, target): + n_clusters = len(np.unique(_target)) + # skip ARand eval if there is only one label in the patch due to the zero-division error in Arand impl + # xxx/skimage/metrics/_adapted_rand_error.py:70: RuntimeWarning: invalid value encountered in double_scalars + # precision = sum_p_ij2 / sum_a2 + logger.info(f'Number of ground truth clusters: {n_clusters}') + if n_clusters == 1: + logger.info('Skipping ARandError computation: only 1 label present in the ground truth') + per_batch_arand.append(0.) + continue + + # convert _input to segmentation CDHW + segm = self.input_to_segm(_input) + assert segm.ndim == 4 + + # compute per channel arand and return the minimum value + per_channel_arand = [_arand_err(_target, channel_segm) for channel_segm in segm] + logger.info(f'Min ARand for channel: {np.argmin(per_channel_arand)}') + per_batch_arand.append(np.min(per_channel_arand)) + + # return mean arand error + mean_arand = torch.mean(torch.tensor(per_batch_arand)) + logger.info(f'ARand: {mean_arand.item()}') + return mean_arand + + def input_to_segm(self, input): + """ + Converts input tensor (output from the network) to the segmentation image. E.g. if the input is the boundary + pmaps then one option would be to threshold it and run connected components in order to return the segmentation. + + :param input: 4D tensor (CDHW) + :return: segmentation volume either 4D (segmentation per channel) + """ + # by deafult assume that input is a segmentation volume itself + return input + + +class BoundaryAdaptedRandError(AdaptedRandError): + """ + Compute ARand between the input boundary map and target segmentation. + Boundary map is thresholded, and connected components is run to get the predicted segmentation + """ + + def __init__(self, thresholds=None, use_last_target=True, input_channel=None, invert_pmaps=True, + save_plots=False, plots_dir='.', **kwargs): + super().__init__(use_last_target=use_last_target, save_plots=save_plots, plots_dir=plots_dir, **kwargs) + if thresholds is None: + thresholds = [0.3, 0.4, 0.5, 0.6] + assert isinstance(thresholds, list) + self.thresholds = thresholds + self.input_channel = input_channel + self.invert_pmaps = invert_pmaps + + def input_to_segm(self, input): + if self.input_channel is not None: + input = np.expand_dims(input[self.input_channel], axis=0) + + segs = [] + for predictions in input: + for th in self.thresholds: + # threshold probability maps + predictions = predictions > th + + if self.invert_pmaps: + # for connected component analysis we need to treat boundary signal as background + # assign 0-label to boundary mask + predictions = np.logical_not(predictions) + + predictions = predictions.astype(np.uint8) + # run connected components on the predicted mask; consider only 1-connectivity + seg = measure.label(predictions, background=0, connectivity=1) + segs.append(seg) + + return np.stack(segs) + + +class GenericAdaptedRandError(AdaptedRandError): + def __init__(self, input_channels, thresholds=None, use_last_target=True, invert_channels=None, **kwargs): + + super().__init__(use_last_target=use_last_target, **kwargs) + assert isinstance(input_channels, list) or isinstance(input_channels, tuple) + self.input_channels = input_channels + if thresholds is None: + thresholds = [0.3, 0.4, 0.5, 0.6] + assert isinstance(thresholds, list) + self.thresholds = thresholds + if invert_channels is None: + invert_channels = [] + self.invert_channels = invert_channels + + def input_to_segm(self, input): + # pick only the channels specified in the input_channels + results = [] + for i in self.input_channels: + c = input[i] + # invert channel if necessary + if i in self.invert_channels: + c = 1 - c + results.append(c) + + input = np.stack(results) + + segs = [] + for predictions in input: + for th in self.thresholds: + # run connected components on the predicted mask; consider only 1-connectivity + seg = measure.label((predictions > th).astype(np.uint8), background=0, connectivity=1) + segs.append(seg) + + return np.stack(segs) + + +class GenericAveragePrecision: + def __init__(self, min_instance_size=None, use_last_target=False, metric='ap', **kwargs): + self.min_instance_size = min_instance_size + self.use_last_target = use_last_target + assert metric in ['ap', 'acc'] + if metric == 'ap': + # use AveragePrecision + self.metric = AveragePrecision() + else: + # use Accuracy at 0.5 IoU + self.metric = Accuracy(iou_threshold=0.5) + + def __call__(self, input, target): + if target.dim() == 5: + if self.use_last_target: + target = target[:, -1, ...] # 4D + else: + # use 1st target channel + target = target[:, 0, ...] # 4D + + input1 = input2 = input + multi_head = isinstance(input, tuple) + if multi_head: + input1, input2 = input + + input1, input2, target = convert_to_numpy(input1, input2, target) + + batch_aps = [] + i_batch = 0 + # iterate over the batch + for inp1, inp2, tar in zip(input1, input2, target): + if multi_head: + inp = (inp1, inp2) + else: + inp = inp1 + + segs = self.input_to_seg(inp, tar) # expects 4D + assert segs.ndim == 4 + # convert target to seg + tar = self.target_to_seg(tar) + + # filter small instances if necessary + tar = self._filter_instances(tar) + + # compute average precision per channel + segs_aps = [self.metric(self._filter_instances(seg), tar) for seg in segs] + + logger.info(f'Batch: {i_batch}. Max Average Precision for channel: {np.argmax(segs_aps)}') + # save max AP + batch_aps.append(np.max(segs_aps)) + i_batch += 1 + + return torch.tensor(batch_aps).mean() + + def _filter_instances(self, input): + """ + Filters instances smaller than 'min_instance_size' by overriding them with 0-index + :param input: input instance segmentation + """ + if self.min_instance_size is not None: + labels, counts = np.unique(input, return_counts=True) + for label, count in zip(labels, counts): + if count < self.min_instance_size: + input[input == label] = 0 + return input + + def input_to_seg(self, input, target=None): + raise NotImplementedError + + def target_to_seg(self, target): + return target + + +class BlobsAveragePrecision(GenericAveragePrecision): + """ + Computes Average Precision given foreground prediction and ground truth instance segmentation. + """ + + def __init__(self, thresholds=None, metric='ap', min_instance_size=None, input_channel=0, **kwargs): + super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric) + if thresholds is None: + thresholds = [0.4, 0.5, 0.6, 0.7, 0.8] + assert isinstance(thresholds, list) + self.thresholds = thresholds + self.input_channel = input_channel + + def input_to_seg(self, input, target=None): + input = input[self.input_channel] + segs = [] + for th in self.thresholds: + # threshold and run connected components + mask = (input > th).astype(np.uint8) + seg = measure.label(mask, background=0, connectivity=1) + segs.append(seg) + return np.stack(segs) + + +class BlobsBoundaryAveragePrecision(GenericAveragePrecision): + """ + Computes Average Precision given foreground prediction, boundary prediction and ground truth instance segmentation. + Segmentation mask is computed as (P_mask - P_boundary) > th followed by a connected component + """ + + def __init__(self, thresholds=None, metric='ap', min_instance_size=None, **kwargs): + super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric) + if thresholds is None: + thresholds = [0.3, 0.4, 0.5, 0.6, 0.7] + assert isinstance(thresholds, list) + self.thresholds = thresholds + + def input_to_seg(self, input, target=None): + # input = P_mask - P_boundary + input = input[0] - input[1] + segs = [] + for th in self.thresholds: + # threshold and run connected components + mask = (input > th).astype(np.uint8) + seg = measure.label(mask, background=0, connectivity=1) + segs.append(seg) + return np.stack(segs) + + +class BoundaryAveragePrecision(GenericAveragePrecision): + """ + Computes Average Precision given boundary prediction and ground truth instance segmentation. + """ + + def __init__(self, thresholds=None, min_instance_size=None, input_channel=0, **kwargs): + super().__init__(min_instance_size=min_instance_size, use_last_target=True) + if thresholds is None: + thresholds = [0.3, 0.4, 0.5, 0.6] + assert isinstance(thresholds, list) + self.thresholds = thresholds + self.input_channel = input_channel + + def input_to_seg(self, input, target=None): + input = input[self.input_channel] + segs = [] + for th in self.thresholds: + seg = measure.label(np.logical_not(input > th).astype(np.uint8), background=0, connectivity=1) + segs.append(seg) + return np.stack(segs) + + +class PSNR: + """ + Computes Peak Signal to Noise Ratio. Use e.g. as an eval metric for denoising task + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, input, target): + input, target = convert_to_numpy(input, target) + return peak_signal_noise_ratio(target, input) + + +def get_evaluation_metric(config): + """ + Returns the evaluation metric function based on provided configuration + :param config: (dict) a top level configuration object containing the 'eval_metric' key + :return: an instance of the evaluation metric + 获取对应的metric函数===>基于mean_IOU来衡量分割性能 + """ + + def _metric_class(class_name): + m = importlib.import_module('metrics') + clazz = getattr(m, class_name) + return clazz + + assert 'eval_metric' in config, 'Could not find evaluation metric configuration' + metric_config = config['eval_metric'] + metric_class = _metric_class(metric_config['name']) + return metric_class(**metric_config) diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/model/.keep" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/model/.keep" new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/model/buildingblocks.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/model/buildingblocks.py" new file mode 100644 index 0000000000000000000000000000000000000000..40dd8cb22162651764496141e684980a5a133eb9 --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/model/buildingblocks.py" @@ -0,0 +1,458 @@ +from functools import partial + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +def conv3d(in_channels, out_channels, kernel_size, bias, padding): + return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) + + +def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding): + """ + Create a list of modules with together constitute a single conv layer with non-linearity + and optional batchnorm/groupnorm. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + kernel_size(int or tuple): size of the convolving kernel + order (string): order of things, e.g. + 'cr' -> conv + ReLU + 'gcr' -> groupnorm + conv + ReLU + 'cl' -> conv + LeakyReLU + 'ce' -> conv + ELU + 'bcr' -> batchnorm + conv + ReLU + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding added to all three sides of the input + + Return: + list of tuple (name, module) + """ + assert 'c' in order, "Conv layer MUST be present" + assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' + + modules = [] + for i, char in enumerate(order): + if char == 'r': + modules.append(('ReLU', nn.ReLU(inplace=True))) + elif char == 'l': + modules.append(('LeakyReLU', nn.LeakyReLU(inplace=True))) + elif char == 'e': + modules.append(('ELU', nn.ELU(inplace=True))) + elif char == 'c': + # add learnable bias only in the absence of batchnorm/groupnorm + bias = not ('g' in order or 'b' in order) + modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding))) + elif char == 'g': + is_before_conv = i < order.index('c') + if is_before_conv: + num_channels = in_channels + else: + num_channels = out_channels + + # use only one group if the given number of groups is greater than the number of channels + if num_channels < num_groups: + num_groups = 1 + + assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' + modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) + elif char == 'b': + is_before_conv = i < order.index('c') + if is_before_conv: + modules.append(('batchnorm', nn.BatchNorm3d(in_channels))) + else: + modules.append(('batchnorm', nn.BatchNorm3d(out_channels))) + else: + raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']") + + return modules + + +class SingleConv(nn.Sequential): + """ + order为grc时的单层卷积模型设置: + 设置model列表==》num_channels = in_channels + 对于4个编码器和3个解码器(一共14个单层卷积模型) + in channels分别为==》[1,16,32,32,64,64,128,128][384,128,192,64,96,32] + 生成16个order为”gcr"的单层卷积网络 + 单层卷积网络=nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)+ + +conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)+ + +nn.ReLU(inplace=True) + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, padding=1): + super(SingleConv, self).__init__() + + for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding): + self.add_module(name, module) + + +class DoubleConv(nn.Sequential): + """ + doubleconv:(in_channels=xx, out_channels=xx, + encoder=Truekernel_size=3, + order=grc, + num_groups=8, + padding=1) + conv1_in = 1 , 32 , 64 , 128 + conv1_out = 16 , 32 , 64 , 128 + conv2_in = 16 , 32 , 64 , 128 + conv2_out = 32 , 64 ,128 , 256 + + """ + + def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', num_groups=8, padding=1): + super(DoubleConv, self).__init__() + if encoder: + # we're in the encoder path + conv1_in_channels = in_channels + conv1_out_channels = out_channels // 2 + if conv1_out_channels < in_channels: + conv1_out_channels = in_channels + conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels + else: + # we're in the decoder path, decrease the number of channels in the 1st convolution + conv1_in_channels, conv1_out_channels = in_channels, out_channels + conv2_in_channels, conv2_out_channels = out_channels, out_channels + + # conv1 + self.add_module('SingleConv1', + SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups, + padding=padding)) + # conv2 + self.add_module('SingleConv2', + SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups, + padding=padding)) + + + +class Encoder(nn.Module): + """ + conv1_in = 1 , 32 , 64 , 128 + conv1_out = 16 , 32 , 64 , 128 + conv2_in = 16 , 32 , 64 , 128 + conv2_out = 32 , 64 ,128 , 256 + """ + + def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, + pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr', + num_groups=8, padding=1): + super(Encoder, self).__init__() + assert pool_type in ['max', 'avg'] + if apply_pooling: + if pool_type == 'max': + self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) + else: + self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) + else: + self.pooling = None + + self.basic_module = basic_module(in_channels, out_channels, + encoder=True, + kernel_size=conv_kernel_size, + order=conv_layer_order, + num_groups=num_groups, + padding=padding) + + def forward(self, x): + if self.pooling is not None: + x = self.pooling(x) + x = self.basic_module(x) + return x + +def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, + pool_kernel_size): + """ + 设置列表为encoders,遍历f_maps=[32,64,128,256] + 生成4个encoder===>组合成为encoders + 4个encoder分别为 + 0:[doubleconv] + doubleconv:(in_channels=1, out_channels=32, + encoder=Truekernel_size=3, + order=grc, + num_groups=8, + padding=1) + 1:[maxpooling + doubleconv] + doubleconv:(in_channels=32, out_channels=64, + encoder=Truekernel_size=3, + order=grc, + num_groups=8, + padding=1) + 2:[maxpooling + doubleconv] + doubleconv:(in_channels=64, out_channels=128, + encoder=Truekernel_size=3, + order=grc, + num_groups=8, + padding=1) + 3:[maxpooling + doubleconv] + doubleconv:(in_channels=128, out_channels=256, + encoder=Truekernel_size=3, + order=grc, + num_groups=8, + padding=1) + """ + # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` + encoders = [] + for i, out_feature_num in enumerate(f_maps): + if i == 0: + encoder = Encoder(in_channels, out_feature_num, + apply_pooling=False, # skip pooling in the firs encoder + basic_module=basic_module, + conv_layer_order=layer_order, + conv_kernel_size=conv_kernel_size, + num_groups=num_groups, + padding=conv_padding) + else: + # TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations + encoder = Encoder(f_maps[i - 1], out_feature_num, + basic_module=basic_module, + conv_layer_order=layer_order, + conv_kernel_size=conv_kernel_size, + num_groups=num_groups, + pool_kernel_size=pool_kernel_size, + padding=conv_padding) + + encoders.append(encoder) + + return nn.ModuleList(encoders) + + +class Decoder(nn.Module): + """ + Decoder(in_feature_num=[384,192,96], out_feature_num=[128,64,32], + basic_module=doubleconv, + conv_layer_order=gcr, + conv_kernel_size=3, + num_groups=8, + padding=1, + upsample=true) +w我们从这段解码器中可以发现,我们需要使用的代码只有以下一段: + if upsample: + if basic_module == DoubleConv: + # if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining + self.upsampling = InterpolateUpsampling(mode=mode) + # concat joining + self.joining = partial(self._joining, concat=True) + 生成了self.sampling , self.joining , self.doubleconv + self.sampling:InterpolateUpsampling(mode=nearest) + self.joining:partial(self._joining, concat=True)===>torch.cat((encoder_features, x), dim=1) + self.doubleconv:basic_module(in_channels=[384,192,96], out_channels=[128,64,32], + encoder=False, + kernel_size=conv_kernel_size, + order=conv_layer_order, + num_groups=num_groups, + padding=padding) + 注意上述的doubleconv的encoder参数==》 + conv1_in = 384 , 192 , 96 + conv1_out = 128 , 64 , 32 + conv2_in = 128 , 64 , 32 + conv2_out = 128 , 64 ,32 + """ + + def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv, + conv_layer_order='gcr', num_groups=8, mode='nearest', padding=1, upsample=True): + super(Decoder, self).__init__() + + if upsample: + if basic_module == DoubleConv: + # if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining + self.upsampling = InterpolateUpsampling(mode=mode) + # concat joining + self.joining = partial(self._joining, concat=True) + else: + # if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining + self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels, + kernel_size=conv_kernel_size, scale_factor=scale_factor) + # sum joining + self.joining = partial(self._joining, concat=False) + # adapt the number of in_channels for the ExtResNetBlock + in_channels = out_channels + else: + # no upsampling + self.upsampling = NoUpsampling() + # concat joining + self.joining = partial(self._joining, concat=True) + + self.basic_module = basic_module(in_channels, out_channels, + encoder=False, + kernel_size=conv_kernel_size, + order=conv_layer_order, + num_groups=num_groups, + padding=padding) + + def forward(self, encoder_features, x): + x = self.upsampling(encoder_features=encoder_features, x=x) + x = self.joining(encoder_features, x) + x = self.basic_module(x) + return x + + @staticmethod + def _joining(encoder_features, x, concat): + if concat: + return torch.cat((encoder_features, x), dim=1) + else: + return encoder_features + x + + +def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, upsample): + """ + f_maps=(32,64,128,256), basic_module=doubleconv, conv_kernel_size=3, conv_padding=1, layer_order=gcr, num_groups=8, + upsample=True + 生成一个解码器列表 + 将f_maps=[32,64,128,256]倒置得到[256,128,64,32] + 基础模块为doubleconv,所以设置输入特征数量和输出特征数量为[(256+128)=>(128),(128+64)=>(64),(64+32)=>(32)] + 根据上面列表的三项,生成三个解码器 + upsample = true + 3个decoder分别为 + 0: + 1: + 2: + + + """ + # create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1` + decoders = [] + reversed_f_maps = list(reversed(f_maps)) + for i in range(len(reversed_f_maps) - 1): + if basic_module == DoubleConv: + in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] + else: + in_feature_num = reversed_f_maps[i] + + out_feature_num = reversed_f_maps[i + 1] + + # TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv + # currently strides with a constant stride: (2, 2, 2) + + _upsample = True + if i == 0: + # upsampling can be skipped only for the 1st decoder, afterwards it should always be present + _upsample = upsample + + decoder = Decoder(in_feature_num, out_feature_num, + basic_module=basic_module, + conv_layer_order=layer_order, + conv_kernel_size=conv_kernel_size, + num_groups=num_groups, + padding=conv_padding, + upsample=_upsample) + decoders.append(decoder) + return nn.ModuleList(decoders) + + +class AbstractUpsampling(nn.Module): + """ + Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either + interpolation or learned transposed convolution. + """ + + def __init__(self, upsample): + super(AbstractUpsampling, self).__init__() + self.upsample = upsample + + def forward(self, encoder_features, x): + # get the spatial dimensions of the output given the encoder_features + output_size = encoder_features.size()[2:] + # upsample the input and return + return self.upsample(x, output_size) + + +class InterpolateUpsampling(AbstractUpsampling): + """ + 按照代码要求,本代码继承AbstractUpsampling,同时mode为nearest + 假如输入为(N,1,16,128,128)==》经过4个encoder之后的输出分别为 + 0:(N,16,16,128,128)==》(N,32,16,128,128) + 1:(N,32,8,64,64)==>(N,64,8,64,64) + 2:(N,64,4,32,32)==>(N,128,4,32,32) + 3:(N,128,2,16,16)==>(N,256,2,16,16) + 前三个输出是保留在encoder_features中的 + 本代码继承abstractupsampling,就是要将x的大小上采样(和encoder_features中对应元素的维度相同) + """ + + def __init__(self, mode='nearest'): + upsample = partial(self._interpolate, mode=mode) + super().__init__(upsample) + + @staticmethod + def _interpolate(x, size, mode): + return F.interpolate(x, size=size, mode=mode) + + +class TransposeConvUpsampling(AbstractUpsampling): + """ + Args: + in_channels (int): number of input channels for transposed conv + used only if transposed_conv is True + out_channels (int): number of output channels for transpose conv + used only if transposed_conv is True + kernel_size (int or tuple): size of the convolving kernel + used only if transposed_conv is True + scale_factor (int or tuple): stride of the convolution + used only if transposed_conv is True + + """ + + def __init__(self, in_channels=None, out_channels=None, kernel_size=3, scale_factor=(2, 2, 2)): + # make sure that the output size reverses the MaxPool3d from the corresponding encoder + upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, + padding=1) + super().__init__(upsample) + + +class NoUpsampling(AbstractUpsampling): + def __init__(self): + super().__init__(self._no_upsampling) + + @staticmethod + def _no_upsampling(x, size): + return x + +class ExtResNetBlock(nn.Module): + """ + Basic UNet block consisting of a SingleConv followed by the residual block. + The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number + of output channels is compatible with the residual block that follows. + This block can be used instead of standard DoubleConv in the Encoder module. + Motivated by: https://arxiv.org/pdf/1706.00120.pdf + + Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs): + super(ExtResNetBlock, self).__init__() + + # first convolution + self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) + # residual block + self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) + # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual + n_order = order + for c in 'rel': + n_order = n_order.replace(c, '') + self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, + num_groups=num_groups) + + # create non-linearity separately + if 'l' in order: + self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif 'e' in order: + self.non_linearity = nn.ELU(inplace=True) + else: + self.non_linearity = nn.ReLU(inplace=True) + + def forward(self, x): + # apply first convolution and save the output as a residual + out = self.conv1(x) + residual = out + + # residual block + out = self.conv2(out) + out = self.conv3(out) + + out += residual + out = self.non_linearity(out) + + return out + diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/model/get_model.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/model/get_model.py" new file mode 100644 index 0000000000000000000000000000000000000000..ceddff9afc01cc3d9ecddff5a71a55dcc6bd63b4 --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/model/get_model.py" @@ -0,0 +1,12 @@ +import importlib +def get_model(model_config): + def _model_class(class_name): + modules = ['model.modell'] + for module in modules: + m = importlib.import_module(module) + clazz = getattr(m, class_name, None) + if clazz is not None: + return clazz + + model_class = _model_class(model_config['name']) + return model_class(**model_config) diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/model/modell.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/model/modell.py" new file mode 100644 index 0000000000000000000000000000000000000000..1c67634a744af0c52d75c0eef6341a30311ce107 --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/model/modell.py" @@ -0,0 +1,168 @@ +from functools import partial +import importlib +import torch +from torch import nn as nn +from torch.nn import functional as F +from model.buildingblocks import DoubleConv, ExtResNetBlock, create_encoders, \ + create_decoders + +def number_of_features_per_level(init_channel_number, num_levels): + return [init_channel_number * 2 ** k for k in range(num_levels)] + + +class Abstract3DUNet(nn.Module): + """ + 我们先看编码器和解码器的构建代码,再将其组成最终的unet3d模型 + 本代码首先将参数传递给了创建两个模组的代码=》参数为 + in_channels=1, f_maps=(32,64,128,256), basic_module=doubleconv, conv_kernel_size=3, conv_padding=1, layer_order=grc, + num_groups=8, pool_kernel_size=2 + + f_maps=(32,64,128,256), basic_module=doubleconv, conv_kernel_size=3, conv_padding=1, layer_order=gcr, num_groups=8, + upsample=True +可以看到模型中设置输出通道为1,说明mask作为label是二分类的分割任务 + + """ + + def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=4, is_segmentation=True, testing=False, + conv_kernel_size=3, pool_kernel_size=2, conv_padding=1, **kwargs): + super(Abstract3DUNet, self).__init__() + + self.testing = testing + + if isinstance(f_maps, int): + f_maps = number_of_features_per_level(f_maps, num_levels=num_levels) + + assert isinstance(f_maps, list) or isinstance(f_maps, tuple) + assert len(f_maps) > 1, "Required at least 2 levels in the U-Net" + + # create encoder path + self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, + num_groups, pool_kernel_size) + + # create decoder path + self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, + upsample=True) + + # in the last layer a 1×1 convolution reduces the number of output + # channels to the number of labels + self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) + + if is_segmentation: + # semantic segmentation problem + if final_sigmoid: + self.final_activation = nn.Sigmoid() + else: + self.final_activation = nn.Softmax(dim=1) + else: + # regression problem + self.final_activation = None + + def forward(self, x): + # encoder part + """ + 模型的加载过程 + 设置encoder_features列表 + 将输入经过4个encoder的结果分别输入encoder_features列表(从后向前输入) + 同时舍弃最后的一层encoder的输出结果(也就是只要保存前三个encoder的3个输出结果,分别为32通道,64通道输出以及128通道输出) + 将encoder_features第一个元素(128通道)伴随最后一个encoder的256维输出结合输入第一个decoder,获取一个128维的输出 + 将encoder_features第二个元素(64通道)伴随第一个decoder的128维输出结合来输入第二个decoder,获取一个64维输出 + 将encoder_features第三个元素(32通道)伴随第二个decoder的64维输出来输入第三个decoder,获取一个32维输出 + 接着输入最后一层卷积层(1*1卷积,只改变通道数量,32==》1) + 最终判断是否进行sigmoid或者softmax操作 + """ + encoders_features = [] + for encoder in self.encoders: + x = encoder(x) + # reverse the encoder outputs to be aligned with the decoder + encoders_features.insert(0, x) + + # remove the last encoder's output from the list + # !!remember: it's the 1st in the list + encoders_features = encoders_features[1:] + + # decoder part + for decoder, encoder_features in zip(self.decoders, encoders_features): + # pass the output from the corresponding encoder and the output + # of the previous decoder + x = decoder(encoder_features, x) + + x = self.final_conv(x) + + # apply final_activation (i.e. Sigmoid or Softmax) only during prediction. During training the network outputs + # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric + if self.testing and self.final_activation is not None: + x = self.final_activation(x) + + return x + + +class UNet3D(Abstract3DUNet): + """ + 将f_maps生成为(32,64,128,256) + 创建unet3d所需的结构(编码器和解码器) + self.encoders = create_encoders(in_channels=1, f_maps=(32,64,128,256), basic_module=doubleconv, conv_kernel_size=3, conv_padding=1, layer_order=grc, + num_groups=8, pool_kernel_size=2) + + self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, + upsample=True) + 创建最后一层网络 + self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) + +通过是否进行分割等参数来决定最后一层的激活层设置 + if is_segmentation: + # semantic segmentation problem + if final_sigmoid: + self.final_activation = nn.Sigmoid() + else: + self.final_activation = nn.Softmax(dim=1) + 这样的设置应该是为了达到对于每个像素得到分类概率总和为1的结果 + else: + # regression problem + self.final_activation = None + + """ + + def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, **kwargs): + super(UNet3D, self).__init__(in_channels=in_channels, + out_channels=out_channels, + final_sigmoid=final_sigmoid, + basic_module=DoubleConv, + f_maps=f_maps, + layer_order=layer_order, + num_groups=num_groups, + num_levels=num_levels, + is_segmentation=is_segmentation, + conv_padding=conv_padding, + **kwargs) + + +class ResidualUNet3D(Abstract3DUNet): + """ + Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. + Uses ExtResNetBlock as a basic building block, summation joining instead + of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts). + Since the model effectively becomes a residual net, in theory it allows for deeper UNet. + """ + + def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, **kwargs): + super(ResidualUNet3D, self).__init__(in_channels=in_channels, + out_channels=out_channels, + final_sigmoid=final_sigmoid, + basic_module=ExtResNetBlock, + f_maps=f_maps, + layer_order=layer_order, + num_groups=num_groups, + num_levels=num_levels, + is_segmentation=is_segmentation, + conv_padding=conv_padding, + **kwargs) + + + + + + + diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/seg_metrics.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/seg_metrics.py" new file mode 100644 index 0000000000000000000000000000000000000000..f9d6e03d39486c51007cd06e6f2d887ef8ac49f5 --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/seg_metrics.py" @@ -0,0 +1,123 @@ +import numpy as np +from skimage.metrics import contingency_table + + +def precision(tp, fp, fn): + return tp / (tp + fp) if tp > 0 else 0 + + +def recall(tp, fp, fn): + return tp / (tp + fn) if tp > 0 else 0 + + +def accuracy(tp, fp, fn): + return tp / (tp + fp + fn) if tp > 0 else 0 + + +def f1(tp, fp, fn): + return (2 * tp) / (2 * tp + fp + fn) if tp > 0 else 0 + + +def _relabel(input): + _, unique_labels = np.unique(input, return_inverse=True) + return unique_labels.reshape(input.shape) + + +def _iou_matrix(gt, seg): + # relabel gt and seg for smaller memory footprint of contingency table + gt = _relabel(gt) + seg = _relabel(seg) + + # get number of overlapping pixels between GT and SEG + n_inter = contingency_table(gt, seg).A + + # number of pixels for GT instances + n_gt = n_inter.sum(axis=1, keepdims=True) + # number of pixels for SEG instances + n_seg = n_inter.sum(axis=0, keepdims=True) + + # number of pixels in the union between GT and SEG instances + n_union = n_gt + n_seg - n_inter + + iou_matrix = n_inter / n_union + # make sure that the values are within [0,1] range + assert 0 <= np.min(iou_matrix) <= np.max(iou_matrix) <= 1 + + return iou_matrix + + +class SegmentationMetrics: + """ + Computes precision, recall, accuracy, f1 score for a given ground truth and predicted segmentation. + Contingency table for a given ground truth and predicted segmentation is computed eagerly upon construction + of the instance of `SegmentationMetrics`. + + Args: + gt (ndarray): ground truth segmentation + seg (ndarray): predicted segmentation + """ + + def __init__(self, gt, seg): + self.iou_matrix = _iou_matrix(gt, seg) + + def metrics(self, iou_threshold): + """ + Computes precision, recall, accuracy, f1 score at a given IoU threshold + """ + # ignore background + iou_matrix = self.iou_matrix[1:, 1:] + detection_matrix = (iou_matrix > iou_threshold).astype(np.uint8) + n_gt, n_seg = detection_matrix.shape + + # if the iou_matrix is empty or all values are 0 + trivial = min(n_gt, n_seg) == 0 or np.all(detection_matrix == 0) + if trivial: + tp = fp = fn = 0 + else: + # count non-zero rows to get the number of TP + tp = np.count_nonzero(detection_matrix.sum(axis=1)) + # count zero rows to get the number of FN + fn = n_gt - tp + # count zero columns to get the number of FP + fp = n_seg - np.count_nonzero(detection_matrix.sum(axis=0)) + + return { + 'precision': precision(tp, fp, fn), + 'recall': recall(tp, fp, fn), + 'accuracy': accuracy(tp, fp, fn), + 'f1': f1(tp, fp, fn) + } + + +class Accuracy: + """ + Computes accuracy between ground truth and predicted segmentation a a given threshold value. + Defined as: AC = TP / (TP + FP + FN). + Kaggle DSB2018 calls it Precision, see: + https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric. + """ + + def __init__(self, iou_threshold): + self.iou_threshold = iou_threshold + + def __call__(self, input_seg, gt_seg): + metrics = SegmentationMetrics(gt_seg, input_seg).metrics(self.iou_threshold) + return metrics['accuracy'] + + +class AveragePrecision: + """ + Average precision taken for the IoU range (0.5, 0.95) with a step of 0.05 as defined in: + https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric + """ + + def __init__(self): + self.iou_range = np.linspace(0.50, 0.95, 10) + + def __call__(self, input_seg, gt_seg): + # compute contingency_table + sm = SegmentationMetrics(gt_seg, input_seg) + # compute accuracy for each threshold + acc = [sm.metrics(iou)['accuracy'] for iou in self.iou_range] + # return the average + return np.mean(acc) diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/test.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/test.py" new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/train.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/train.py" new file mode 100644 index 0000000000000000000000000000000000000000..26863864d8e6982d9275fc9a7415de98c0e52f3d --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/train.py" @@ -0,0 +1,572 @@ +import os +import torch +import torch.nn as nn +from tensorboardX import SummaryWriter +from torch.optim.lr_scheduler import ReduceLROnPlateau +from data.datautils import get_train_loaders +from loss import get_loss_criterion +from metrics import get_evaluation_metric +from model.get_model import get_model +from utils import get_logger, get_tensorboard_formatter, create_sample_plotter, create_optimizer, \ + create_lr_scheduler, get_number_of_learnable_parameters +import utils + +logger = get_logger('UNet3DTrainer') + + +def _create_trainer(config, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders): + """ + training_config中除了manual_seed和使用config_load函数生成的device参数之外, + 还有7类参数字典:分别为model,trainer,optimizer,loss,eval_metric,lr_scheduler,loaders + 我们在生成模型时调用了model参数字典 + 生成优化器和lr_scheduler时调用了optimizer和lr_scheduler字典 + 生成损失函数和选择验证方法时使用了loss和eval_metric参数字典 + 生成dataset_loader时调用了loaders字典 + + 训练中的其他参数设置包含在trainer字典中 + 在trainer中生成新的四个参数:resume,pre_trained,tensorboard_formatter,sample_plotter + 参数对应默认值为none + 返回函数UNet3DTrainer(model=model===》get_model(config['model']), + optimizer=optimizer, + lr_scheduler=lr_scheduler, + loss_criterion=loss_criterion, + eval_criterion=eval_criterion, + device=config['device'], + loaders=loaders, + tensorboard_formatter=tensorboard_formatter, + sample_plotter=sample_plotter, + **trainer_config) + 生成tensorboard_formatter的过程:因为trainers中没有tensorboard_formatter的相关参数,所以直接引用DefaultTensorboardFormatter类返回一个对象 + 生成sample_plotter的过程:none + """ + assert 'trainer' in config, 'Could not find trainer configuration' + trainer_config = config['trainer'] + + resume = trainer_config.get('resume', None) + pre_trained = trainer_config.get('pre_trained', None) + + # get tensorboard formatter + tensorboard_formatter = get_tensorboard_formatter(trainer_config.pop('tensorboard_formatter', None)) + # get sample plotter + sample_plotter = create_sample_plotter(trainer_config.pop('sample_plotter', None)) + + if resume is not None: + # continue training from a given checkpoint + return UNet3DTrainer.from_checkpoint(model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + loss_criterion=loss_criterion, + eval_criterion=eval_criterion, + loaders=loaders, + tensorboard_formatter=tensorboard_formatter, + sample_plotter=sample_plotter, + **trainer_config) + elif pre_trained is not None: + # fine-tune a given pre-trained model + return UNet3DTrainer.from_pretrained(model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + loss_criterion=loss_criterion, + eval_criterion=eval_criterion, + tensorboard_formatter=tensorboard_formatter, + sample_plotter=sample_plotter, + device=config['device'], + loaders=loaders, + **trainer_config) + else: + # start training from scratch + return UNet3DTrainer(model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + loss_criterion=loss_criterion, + eval_criterion=eval_criterion, + device=config['device'], + loaders=loaders, + tensorboard_formatter=tensorboard_formatter, + sample_plotter=sample_plotter, + **trainer_config) + + +class UNet3DTrainerBuilder: + """main函数中的引用,先将config函数传给builder函数,接着将config.model传递给model.get_model()来生成3DUNET + 首先将config.model传递给get_model() + 接着get_model将config.model参数传递给unet3d类 + unet3d类通过继承abstract3dunet来构建一个3dunet==》参数设置为in_channels, out_channels, final_sigmoid=True, f_maps=32, layer_order='gcr', + num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, ==>basic_module=doubleconv, testing=False, + conv_kernel_size=3, pool_kernel_size=2, conv_padding=1 + device = torch.device(cuda:0) + config['device'] = device + + loss criterion定义:diceloss + metric定义:mean_iou + data_loader定义:config_training 中的loaders里还有一些关于训练集和验证集的参数用途未知 + optimizer定义:torch.optim.Adam() + lr_scheduler定义:torch.optim.lr_scheduler.MultiStepLR()===> + 将optimizer作为config中的lr_config中的参数optimizer,将lr_config传递给torch.optim.lr_scheduler.MultiStepLR() + + 根据已定义的以上参数和模型来调用_create_trainer函数 + _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, + loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders) + + + """ + + @staticmethod + def build(config): + # Create the model + model = get_model(config['model']) + # use DataParallel if more than 1 GPU available + device = config['device'] + if torch.cuda.device_count() > 1 and not device.type == 'cpu': + model = nn.DataParallel(model) + logger.info(f'Using {torch.cuda.device_count()} GPUs for training') + + # put the model on GPUs + logger.info(f"Sending the model to '{config['device']}'") + model = model.to(device) + + # Log the number of learnable parameters + logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}') + + # Create loss criterion + loss_criterion = get_loss_criterion(config) + # Create evaluation metric + eval_criterion = get_evaluation_metric(config) + + # Create data loaders + loaders = get_train_loaders(config) + + # Create the optimizer + optimizer = create_optimizer(config['optimizer'], model) + + # Create learning rate adjustment strategy + lr_scheduler = create_lr_scheduler(config.get('lr_scheduler', None), optimizer) + + # Create model trainer + trainer = _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, + loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders) + + return trainer + + +class UNet3DTrainer: + """3D UNet trainer. + + Args: + model (Unet3D): UNet 3D model to be trained + optimizer (nn.optim.Optimizer): optimizer used for training + lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler + WARN: bear in mind that lr_scheduler.step() is invoked after every validation step + (i.e. validate_after_iters) not after every epoch. So e.g. if one uses StepLR with step_size=30 + the learning rate will be adjusted after every 30 * validate_after_iters iterations. + loss_criterion (callable): loss function + eval_criterion (callable): used to compute training/validation metric (such as Dice, IoU, AP or Rand score) + saving the best checkpoint is based on the result of this function on the validation set + device (torch.device): device to train on + loaders (dict): 'train' and 'val' loaders + checkpoint_dir (string): dir for saving checkpoints and tensorboard logs + max_num_epochs (int): maximum number of epochs + max_num_iterations (int): maximum number of iterations + validate_after_iters (int): validate after that many iterations + log_after_iters (int): number of iterations before logging to tensorboard + validate_iters (int): number of validation iterations, if None validate + on the whole validation set + eval_score_higher_is_better (bool): if True higher eval scores are considered better + best_eval_score (float): best validation score so far (higher better) + num_iterations (int): useful when loading the model from the checkpoint + num_epoch (int): useful when loading the model from the checkpoint + tensorboard_formatter (callable): converts a given batch of input/output/target image to a series of images + that can be displayed in tensorboard + sample_plotter (callable): saves sample inputs, network outputs and targets to a given directory + during validation phase + skip_train_validation (bool): if True eval_criterion is not evaluated on the training set (used mostly when + evaluation is expensive) + """ + + def __init__(self, model, optimizer, lr_scheduler, loss_criterion, + eval_criterion, device, loaders, checkpoint_dir, + max_num_epochs=100, max_num_iterations=int(1e5), + validate_after_iters=100, log_after_iters=100, + validate_iters=None, num_iterations=1, num_epoch=0, + eval_score_higher_is_better=True, best_eval_score=None, + tensorboard_formatter=None, sample_plotter=None, + skip_train_validation=False, **kwargs): + """ + 返回的UNet3DTrainer(model=model===》get_model(config['model']), + optimizer=optimizer, + lr_scheduler=lr_scheduler, + loss_criterion=loss_criterion, + eval_criterion=eval_criterion, + device=config['device'], + loaders=loaders, + tensorboard_formatter=tensorboard_formatter, + sample_plotter=sample_plotter, + **trainer_config)进行初始化的过程 + :param max_num_epochs:50 + :param max_num_iterations:100000 + :param validate_after_iters:20 + :param log_after_iters:20 + :param validate_iters:none + :param num_iterations:1 + :param num_epoch:0 + :param eval_score_higher_is_better:True + :param best_eval_score:none + :param tensorboard_formatter:none + :param sample_plotter:none + :param skip_train_validation:false + :param kwargs:*** + 根据eval_score_higher_is_better,设置best_eval_score为负无穷 + """ + self.model = model + self.optimizer = optimizer + self.scheduler = lr_scheduler + self.loss_criterion = loss_criterion + self.eval_criterion = eval_criterion + self.device = device + self.loaders = loaders + self.checkpoint_dir = checkpoint_dir + self.max_num_epochs = max_num_epochs + self.max_num_iterations = max_num_iterations + self.validate_after_iters = validate_after_iters + self.log_after_iters = log_after_iters + self.validate_iters = validate_iters + self.eval_score_higher_is_better = eval_score_higher_is_better + + logger.info(model) + logger.info(f'eval_score_higher_is_better: {eval_score_higher_is_better}') + + if best_eval_score is not None: + self.best_eval_score = best_eval_score + else: + # initialize the best_eval_score + if eval_score_higher_is_better: + self.best_eval_score = float('-inf') + else: + self.best_eval_score = float('+inf') + + self.writer = SummaryWriter(log_dir=os.path.join(checkpoint_dir, 'logs')) + + assert tensorboard_formatter is not None, 'TensorboardFormatter must be provided' + self.tensorboard_formatter = tensorboard_formatter + self.sample_plotter = sample_plotter + + self.num_iterations = num_iterations + self.num_epoch = num_epoch + self.skip_train_validation = skip_train_validation + + @classmethod + def from_checkpoint(cls, resume, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, + tensorboard_formatter=None, sample_plotter=None, **kwargs): + logger.info(f"Loading checkpoint '{resume}'...") + state = utils.load_checkpoint(resume, model, optimizer) + logger.info( + f"Checkpoint loaded. Epoch: {state['epoch']}. Best val score: {state['best_eval_score']}. Num_iterations: {state['num_iterations']}") + checkpoint_dir = os.path.split(resume)[0] + return cls(model, optimizer, lr_scheduler, + loss_criterion, eval_criterion, + torch.device(state['device']), + loaders, checkpoint_dir, + eval_score_higher_is_better=state['eval_score_higher_is_better'], + best_eval_score=state['best_eval_score'], + num_iterations=state['num_iterations'], + num_epoch=state['epoch'], + max_num_epochs=state['max_num_epochs'], + max_num_iterations=state['max_num_iterations'], + validate_after_iters=state['validate_after_iters'], + log_after_iters=state['log_after_iters'], + validate_iters=state['validate_iters'], + skip_train_validation=state.get('skip_train_validation', False), + tensorboard_formatter=tensorboard_formatter, + sample_plotter=sample_plotter) + + @classmethod + def from_pretrained(cls, pre_trained, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, + device, loaders, + max_num_epochs=100, max_num_iterations=int(1e5), + validate_after_iters=100, log_after_iters=100, + validate_iters=None, num_iterations=1, num_epoch=0, + eval_score_higher_is_better=True, best_eval_score=None, + tensorboard_formatter=None, sample_plotter=None, + skip_train_validation=False, **kwargs): + logger.info(f"Logging pre-trained model from '{pre_trained}'...") + utils.load_checkpoint(pre_trained, model, None) + if 'checkpoint_dir' not in kwargs: + checkpoint_dir = os.path.split(pre_trained)[0] + else: + checkpoint_dir = kwargs.pop('checkpoint_dir') + return cls(model, optimizer, lr_scheduler, + loss_criterion, eval_criterion, + device, loaders, checkpoint_dir, + eval_score_higher_is_better=eval_score_higher_is_better, + best_eval_score=best_eval_score, + num_iterations=num_iterations, + num_epoch=num_epoch, + max_num_epochs=max_num_epochs, + max_num_iterations=max_num_iterations, + validate_after_iters=validate_after_iters, + log_after_iters=log_after_iters, + validate_iters=validate_iters, + tensorboard_formatter=tensorboard_formatter, + sample_plotter=sample_plotter, + skip_train_validation=skip_train_validation) + + def fit(self): + """ + 代码的main函数通过调用UNet3DTrainerBuilder中的builder函数来建立一个unet类(参数在函数中设置好了) + 在main()中调用了unet类中的fit() + + """ + for _ in range(self.num_epoch, self.max_num_epochs): + # train for one epoch + should_terminate = self.train() + + if should_terminate: + logger.info('Stopping criterion is satisfied. Finishing training') + return + + self.num_epoch += 1 + logger.info(f"Reached maximum number of epochs: {self.max_num_epochs}. Finishing training...") + + def train(self): + """Trains the model for 1 epoch. + 定义损失和评估分数,二者均为RunningAverage类生成的对象 + 利用t迭代loaders[train]:相当于对可迭代对象DataLoader(ConcatDataset(train_datasets), batch_size=batch_size, shuffle=True, + num_workers=num_workers)进行迭代 + 利用函数_split_training_batch() + 在这儿我有个疑问,我们生成的数据是通过transforms,compose(xx,xx,xx)的,但是我们明明是要数据增广,所以我一直把数据增广理解错了吗??? + 每次取出一个batch,提取input和target,基于函数生成output和损失 + + Returns: + True if the training should be terminated immediately, False otherwise + """ + train_losses = utils.RunningAverage() + train_eval_scores = utils.RunningAverage() + + # sets the model in training mode + self.model.train() + + for t in self.loaders['train']: + logger.info(f'Training iteration [{self.num_iterations}/{self.max_num_iterations}]. ' + f'Epoch [{self.num_epoch}/{self.max_num_epochs - 1}]') + + input, target, weight = self._split_training_batch(t) + + output, loss = self._forward_pass(input, target, weight) + + train_losses.update(loss.item(), self._batch_size(input)) + + # compute gradients and update parameters + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + if self.num_iterations % self.validate_after_iters == 0: + # set the model in eval mode + self.model.eval() + # evaluate on validation set + eval_score = self.validate() + # set the model back to training mode + self.model.train() + + # adjust learning rate if necessary + if isinstance(self.scheduler, ReduceLROnPlateau): + self.scheduler.step(eval_score) + else: + self.scheduler.step() + # log current learning rate in tensorboard + self._log_lr() + # remember best validation metric + is_best = self._is_best_eval_score(eval_score) + + # save checkpoint + self._save_checkpoint(is_best) + + if self.num_iterations % self.log_after_iters == 0: + # if model contains final_activation layer for normalizing logits apply it, otherwise both + # the evaluation metric as well as images in tensorboard will be incorrectly computed + if hasattr(self.model, 'final_activation') and self.model.final_activation is not None: + output = self.model.final_activation(output) + + # compute eval criterion + if not self.skip_train_validation: + eval_score = self.eval_criterion(output, target) + train_eval_scores.update(eval_score.item(), self._batch_size(input)) + + # log stats, params and images + logger.info( + f'Training stats. Loss: {train_losses.avg}. Evaluation score: {train_eval_scores.avg}') + self._log_stats('train', train_losses.avg, train_eval_scores.avg) + self._log_params() + self._log_images(input, target, output, 'train_') + + if self.should_stop(): + return True + + self.num_iterations += 1 + + return False + + def should_stop(self): + """ + Training will terminate if maximum number of iterations is exceeded or the learning rate drops below + some predefined threshold (1e-6 in our case) + """ + if self.max_num_iterations < self.num_iterations: + logger.info(f'Maximum number of iterations {self.max_num_iterations} exceeded.') + return True + + min_lr = 1e-6 + lr = self.optimizer.param_groups[0]['lr'] + if lr < min_lr: + logger.info(f'Learning rate below the minimum {min_lr}.') + return True + + return False + + def validate(self): + logger.info('Validating...') + + val_losses = utils.RunningAverage() + val_scores = utils.RunningAverage() + + if self.sample_plotter is not None: + self.sample_plotter.update_current_dir() + + with torch.no_grad(): + for i, t in enumerate(self.loaders['val']): + logger.info(f'Validation iteration {i}') + + input, target, weight = self._split_training_batch(t) + + output, loss = self._forward_pass(input, target, weight) + val_losses.update(loss.item(), self._batch_size(input)) + + # if model contains final_activation layer for normalizing logits apply it, otherwise + # the evaluation metric will be incorrectly computed + if hasattr(self.model, 'final_activation') and self.model.final_activation is not None: + output = self.model.final_activation(output) + + if i % 100 == 0: + self._log_images(input, target, output, 'val_') + + eval_score = self.eval_criterion(output, target) + val_scores.update(eval_score.item(), self._batch_size(input)) + + if self.sample_plotter is not None: + self.sample_plotter(i, input, output, target, 'val') + + if self.validate_iters is not None and self.validate_iters <= i: + # stop validation + break + + self._log_stats('val', val_losses.avg, val_scores.avg) + logger.info(f'Validation finished. Loss: {val_losses.avg}. Evaluation score: {val_scores.avg}') + return val_scores.avg + + def _split_training_batch(self, t): + def _move_to_device(input): + if isinstance(input, tuple) or isinstance(input, list): + return tuple([_move_to_device(x) for x in input]) + else: + return input.to(self.device) + + t = _move_to_device(t) + weight = None + if len(t) == 2: + input, target = t + else: + input, target, weight = t + return input, target, weight + + def _forward_pass(self, input, target, weight=None): + # forward pass + output = self.model(input) + + # compute the loss + if weight is None: + loss = self.loss_criterion(output, target) + else: + loss = self.loss_criterion(output, target, weight) + + return output, loss + + def _is_best_eval_score(self, eval_score): + if self.eval_score_higher_is_better: + is_best = eval_score > self.best_eval_score + else: + is_best = eval_score < self.best_eval_score + + if is_best: + logger.info(f'Saving new best evaluation metric: {eval_score}') + self.best_eval_score = eval_score + + return is_best + + def _save_checkpoint(self, is_best): + # remove `module` prefix from layer names when using `nn.DataParallel` + # see: https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/20 + if isinstance(self.model, nn.DataParallel): + state_dict = self.model.module.state_dict() + else: + state_dict = self.model.state_dict() + + utils.save_checkpoint({ + 'epoch': self.num_epoch + 1, + 'num_iterations': self.num_iterations, + 'model_state_dict': state_dict, + 'best_eval_score': self.best_eval_score, + 'eval_score_higher_is_better': self.eval_score_higher_is_better, + 'optimizer_state_dict': self.optimizer.state_dict(), + 'device': str(self.device), + 'max_num_epochs': self.max_num_epochs, + 'max_num_iterations': self.max_num_iterations, + 'validate_after_iters': self.validate_after_iters, + 'log_after_iters': self.log_after_iters, + 'validate_iters': self.validate_iters, + 'skip_train_validation': self.skip_train_validation + }, is_best, checkpoint_dir=self.checkpoint_dir, + logger=logger) + + def _log_lr(self): + lr = self.optimizer.param_groups[0]['lr'] + self.writer.add_scalar('learning_rate', lr, self.num_iterations) + + def _log_stats(self, phase, loss_avg, eval_score_avg): + tag_value = { + f'{phase}_loss_avg': loss_avg, + f'{phase}_eval_score_avg': eval_score_avg + } + + for tag, value in tag_value.items(): + self.writer.add_scalar(tag, value, self.num_iterations) + + def _log_params(self): + logger.info('Logging model parameters and gradients') + for name, value in self.model.named_parameters(): + self.writer.add_histogram(name, value.data.cpu().numpy(), self.num_iterations) + self.writer.add_histogram(name + '/grad', value.grad.data.cpu().numpy(), self.num_iterations) + + def _log_images(self, input, target, prediction, prefix=''): + inputs_map = { + 'inputs': input, + 'targets': target, + 'predictions': prediction + } + img_sources = {} + for name, batch in inputs_map.items(): + if isinstance(batch, list) or isinstance(batch, tuple): + for i, b in enumerate(batch): + img_sources[f'{name}{i}'] = b.data.cpu().numpy() + else: + img_sources[name] = batch.data.cpu().numpy() + + for name, batch in img_sources.items(): + for tag, image in self.tensorboard_formatter(name, batch): + self.writer.add_image(prefix + tag, image, self.num_iterations, dataformats='CHW') + + @staticmethod + def _batch_size(input): + if isinstance(input, list) or isinstance(input, tuple): + return input[0].size(0) + else: + return input.size(0) diff --git "a/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/utils.py" "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/utils.py" new file mode 100644 index 0000000000000000000000000000000000000000..d5f714b5e2fe1e036e6fc69bb82a94e978486050 --- /dev/null +++ "b/code/2021_autumn/\346\235\250\345\277\227\346\257\205-\345\214\273\345\255\246\345\233\276\345\203\217\345\244\204\347\220\206-3DUNET/utils.py" @@ -0,0 +1,387 @@ +import importlib +import logging +import os +import shutil +import sys + +import h5py +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch import optim + +plt.ioff() +plt.switch_backend('agg') + + +def save_checkpoint(state, is_best, checkpoint_dir, logger=None): + """Saves model and training parameters at '{checkpoint_dir}/last_checkpoint.pytorch'. + If is_best==True saves '{checkpoint_dir}/best_checkpoint.pytorch' as well. + + Args: + state (dict): contains model's state_dict, optimizer's state_dict, epoch + and best evaluation metric value so far + is_best (bool): if True state contains the best model seen so far + checkpoint_dir (string): directory where the checkpoint are to be saved + """ + + + def log_info(message): + if logger is not None: + logger.info(message) + + if not os.path.exists(checkpoint_dir): + log_info( + f"Checkpoint directory does not exists. Creating {checkpoint_dir}") + os.mkdir(checkpoint_dir) + + last_file_path = os.path.join(checkpoint_dir, 'last_checkpoint.pytorch') + log_info(f"Saving last checkpoint to '{last_file_path}'") + torch.save(state, last_file_path) + if is_best: + best_file_path = os.path.join(checkpoint_dir, 'best_checkpoint.pytorch') + log_info(f"Saving best checkpoint to '{best_file_path}'") + shutil.copyfile(last_file_path, best_file_path) + + +def load_checkpoint(checkpoint_path, model, optimizer=None, + model_key='model_state_dict', optimizer_key='optimizer_state_dict'): + """Loads model and training parameters from a given checkpoint_path + If optimizer is provided, loads optimizer's state_dict of as well. + + Args: + checkpoint_path (string): path to the checkpoint to be loaded + model (torch.nn.Module): model into which the parameters are to be copied + optimizer (torch.optim.Optimizer) optional: optimizer instance into + which the parameters are to be copied + + Returns: + state + """ + if not os.path.exists(checkpoint_path): + raise IOError(f"Checkpoint '{checkpoint_path}' does not exist") + + state = torch.load(checkpoint_path, map_location='cpu') + model.load_state_dict(state[model_key]) + + if optimizer is not None: + optimizer.load_state_dict(state[optimizer_key]) + + return state + + +def save_network_output(output_path, output, logger=None): + if logger is not None: + logger.info(f'Saving network output to: {output_path}...') + output = output.detach().cpu()[0] + with h5py.File(output_path, 'w') as f: + f.create_dataset('predictions', data=output, compression='gzip') + + +loggers = {} + + +def get_logger(name, level=logging.INFO): + """本代码被多个代码引用,通过 + """ + + + global loggers + if loggers.get(name) is not None: + return loggers[name] + else: + logger = logging.getLogger(name) + logger.setLevel(level) + # Logging to console + stream_handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter( + '%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s') + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + + loggers[name] = logger + + return logger + + +def get_number_of_learnable_parameters(model): + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + return sum([np.prod(p.size()) for p in model_parameters]) + + +class RunningAverage: + """Computes and stores the average + """ + + def __init__(self): + self.count = 0 + self.sum = 0 + self.avg = 0 + + def update(self, value, n=1): + self.count += n + self.sum += value * n + self.avg = self.sum / self.count + + +def find_maximum_patch_size(model, device): + """Tries to find the biggest patch size that can be send to GPU for inference + without throwing CUDA out of memory""" + logger = get_logger('PatchFinder') + in_channels = model.in_channels + + patch_shapes = [(64, 128, 128), (96, 128, 128), + (64, 160, 160), (96, 160, 160), + (64, 192, 192), (96, 192, 192)] + + for shape in patch_shapes: + # generate random patch of a given size + patch = np.random.randn(*shape).astype('float32') + + patch = torch \ + .from_numpy(patch) \ + .view((1, in_channels) + patch.shape) \ + .to(device) + + logger.info(f"Current patch size: {shape}") + model(patch) + + +def remove_halo(patch, index, shape, patch_halo): + """ + Remove `pad_width` voxels around the edges of a given patch. + """ + assert len(patch_halo) == 3 + + def _new_slices(slicing, max_size, pad): + if slicing.start == 0: + p_start = 0 + i_start = 0 + else: + p_start = pad + i_start = slicing.start + pad + + if slicing.stop == max_size: + p_stop = None + i_stop = max_size + else: + p_stop = -pad if pad != 0 else 1 + i_stop = slicing.stop - pad + + return slice(p_start, p_stop), slice(i_start, i_stop) + + D, H, W = shape + + i_c, i_z, i_y, i_x = index + p_c = slice(0, patch.shape[0]) + + p_z, i_z = _new_slices(i_z, D, patch_halo[0]) + p_y, i_y = _new_slices(i_y, H, patch_halo[1]) + p_x, i_x = _new_slices(i_x, W, patch_halo[2]) + + patch_index = (p_c, p_z, p_y, p_x) + index = (i_c, i_z, i_y, i_x) + return patch[patch_index], index + + + +class _TensorboardFormatter: + """ + Tensorboard formatters converts a given batch of images (be it input/output to the network or the target segmentation + image) to a series of images that can be displayed in tensorboard. This is the parent class for all tensorboard + formatters which ensures that returned images are in the 'CHW' format. + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, name, batch): + """ + Transform a batch to a series of tuples of the form (tag, img), where `tag` corresponds to the image tag + and `img` is the image itself. + + Args: + name (str): one of 'inputs'/'targets'/'predictions' + batch (torch.tensor): 4D or 5D torch tensor + """ + + def _check_img(tag_img): + tag, img = tag_img + + assert img.ndim == 2 or img.ndim == 3, 'Only 2D (HW) and 3D (CHW) images are accepted for display' + + if img.ndim == 2: + img = np.expand_dims(img, axis=0) + else: + C = img.shape[0] + assert C == 1 or C == 3, 'Only (1, H, W) or (3, H, W) images are supported' + + return tag, img + + tagged_images = self.process_batch(name, batch) + + return list(map(_check_img, tagged_images)) + + def process_batch(self, name, batch): + raise NotImplementedError + + +class DefaultTensorboardFormatter(_TensorboardFormatter): + def __init__(self, skip_last_target=False, **kwargs): + super().__init__(**kwargs) + self.skip_last_target = skip_last_target + + def process_batch(self, name, batch): + if name == 'targets' and self.skip_last_target: + batch = batch[:, :-1, ...] + + tag_template = '{}/batch_{}/channel_{}/slice_{}' + + tagged_images = [] + + if batch.ndim == 5: + # NCDHW + slice_idx = batch.shape[2] // 2 # get the middle slice + for batch_idx in range(batch.shape[0]): + for channel_idx in range(batch.shape[1]): + tag = tag_template.format(name, batch_idx, channel_idx, slice_idx) + img = batch[batch_idx, channel_idx, slice_idx, ...] + tagged_images.append((tag, self._normalize_img(img))) + else: + # batch has no channel dim: NDHW + slice_idx = batch.shape[1] // 2 # get the middle slice + for batch_idx in range(batch.shape[0]): + tag = tag_template.format(name, batch_idx, 0, slice_idx) + img = batch[batch_idx, slice_idx, ...] + tagged_images.append((tag, self._normalize_img(img))) + + return tagged_images + + @staticmethod + def _normalize_img(img): + return np.nan_to_num((img - np.min(img)) / np.ptp(img)) + + +def _find_masks(batch, min_size=10): + """Center the z-slice in the 'middle' of a given instance, given a batch of instances + + Args: + batch (ndarray): 5d numpy tensor (NCDHW) + """ + result = [] + for b in batch: + assert b.shape[0] == 1 + patch = b[0] + z_sum = patch.sum(axis=(1, 2)) + coords = np.where(z_sum > min_size)[0] + if len(coords) > 0: + ind = coords[len(coords) // 2] + result.append(b[:, ind:ind + 1, ...]) + else: + ind = b.shape[1] // 2 + result.append(b[:, ind:ind + 1, ...]) + + return np.stack(result, axis=0) + + +def get_tensorboard_formatter(config): + if config is None: + return DefaultTensorboardFormatter() + + class_name = config['name'] + m = importlib.import_module('utils') + clazz = getattr(m, class_name) + return clazz(**config) + + +def expand_as_one_hot(input, C, ignore_index=None): + """ + Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. + It is assumed that the batch dimension is present. + Args: + input (torch.Tensor): 3D/4D input image + C (int): number of channels/labels + ignore_index (int): ignore index to be kept during the expansion + Returns: + 4D/5D output torch.Tensor (NxCxSPATIAL) + """ + assert input.dim() == 4 + + # expand the input tensor to Nx1xSPATIAL before scattering + input = input.unsqueeze(1) + # create output tensor shape (NxCxSPATIAL) + shape = list(input.size()) + shape[1] = C + + if ignore_index is not None: + # create ignore_index mask for the result + mask = input.expand(shape) == ignore_index + # clone the src tensor and zero out ignore_index in the input + input = input.clone() + input[input == ignore_index] = 0 + # scatter to get the one-hot tensor + result = torch.zeros(shape).to(input.device).scatter_(1, input, 1) + # bring back the ignore_index in the result + result[mask] = ignore_index + return result + else: + # scatter to get the one-hot tensor + return torch.zeros(shape).to(input.device).scatter_(1, input, 1) + + +def convert_to_numpy(*inputs): + """ + Coverts input tensors to numpy ndarrays + + Args: + inputs (iteable of torch.Tensor): torch tensor + + Returns: + tuple of ndarrays + """ + + def _to_numpy(i): + assert isinstance(i, torch.Tensor), "Expected input to be torch.Tensor" + return i.detach().cpu().numpy() + + return (_to_numpy(i) for i in inputs) + + +def create_optimizer(optimizer_config, model): + """ + learning_rate: 0.0002 + + weight_decay: 0.0001 + """ + learning_rate = optimizer_config['learning_rate'] + weight_decay = optimizer_config.get('weight_decay', 0) + betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) + optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=betas, weight_decay=weight_decay) + return optimizer + + +def create_lr_scheduler(lr_config, optimizer): + """ + lr_scheduler: + name: MultiStepLR + milestones: [10, 30, 60] + gamma: 0.2 + + """ + if lr_config is None: + return None + class_name = lr_config.pop('name') + m = importlib.import_module('torch.optim.lr_scheduler') + clazz = getattr(m, class_name) + # add optimizer to the config + lr_config['optimizer'] = optimizer + return clazz(**lr_config) + + +def create_sample_plotter(sample_plotter_config): + if sample_plotter_config is None: + return None + class_name = sample_plotter_config['name'] + m = importlib.import_module('utils') + clazz = getattr(m, class_name) + return clazz(**sample_plotter_config)