diff --git a/ACL_PyTorch/built-in/cv/SAM/README.md b/ACL_PyTorch/built-in/cv/SAM/README.md index 4ca3afa1f55cc9977408741eddf3a8e66c4a0e2a..b35a7a354dbf3b6beabc3a7c4efe4e947222738d 100644 --- a/ACL_PyTorch/built-in/cv/SAM/README.md +++ b/ACL_PyTorch/built-in/cv/SAM/README.md @@ -1,5 +1,6 @@ -# SAM 推理指导 +# SAM(ONNX)-推理指导 +## 概述 Segment Anything Model (SAM) 是由 Meta 开源的图像分割大模型,在计算机视觉领域(CV)取得了新的突破。SAM 可在不需要任何标注的情况下,对任何图像中的任何物体进行分割,SAM 的开源引起了业界的广泛反响,被称为计算机视觉领域的 GPT。 - 论文: @@ -17,7 +18,48 @@ Segment Anything Model (SAM) 是由 Meta 开源的图像分割大模型,在计 model_name=sam_vit_b_01ec64 ``` -## 1. 输入输出数据 +## 推理环境准备 + +- 该模型需要以下插件与驱动 + +**表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ---- | ---- | ---- | + | 固件与驱动 | 25.2.RC1 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | + | CANN | 8.2.RC1 | - | + | MindIE | 2.1.RC1 | - | + | Python | 3.11.10 | - | + | PyTorch | 2.1.0 | - | + | 说明:Atlas 300I Duo 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \ | + +## 快速上手 + +### 1. 获取源码 + +```bash +git clone https://gitee.com/ascend/ModelZoo-PyTorch.git +cd ModelZoo-PyTorch/ACL_PyTorch/built-in/cv/SAM +git clone https://github.com/facebookresearch/segment-anything.git +cd segment-anything +git reset --hard 6fdee8f2727f4506cfbbe553e23b895e27956588 +git apply ../segment_anything_diff.patch +pip3 install -e . +cd .. +``` + +### 2. 安装依赖。 + +- 安装基础环境。 + +```bash +pip3 install -r requirements.txt +``` +说明:如果某些库通过此方式安装失败,可使用 pip3 install 单独进行安装。 + +- 安装 [msit](https://gitee.com/ascend/msit/tree/master/msit/) 的 surgeon 组件和 benchmark 组件。 + +### 3. 输入输出数据描述 SAM 首先会自动分割图像中的所有内容,但是如果你需要分割某一个目标物体,则需要你输入一个目标物体上的坐标,比如一张图片你想让SAM分割Cat或Dog这个目标的提示坐标,SAM会自动在照片中猫或狗进行分割,在离线推理时,会转成encoder模型和decoder模型,其输入输出详情如下: @@ -53,51 +95,9 @@ SAM 首先会自动分割图像中的所有内容,但是如果你需要分割 | low_res_masks | FLOAT32 | -1 x 1 x -1 x -1 | ND | -## 2. 推理环境准备 - -- 该模型需要以下插件与驱动 +### 4. 准备数据集 - **表 1** 版本配套表 - -| 配套 | 版本 | 环境准备指导 | -| ---- | ---- | ---- | -| 固件与驱动 | 25.2.RC1 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | -| CANN | 8.2.RC1 | - | -| MindIE | 2.1.RC1 | - | -| Python | 3.11.10 | - | -| PyTorch | 2.1.0 | - | -| 说明:Atlas 300I Duo 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \ | - -## 3. 快速上手 - -### 3.1 获取源码 - -``` -git clone https://gitee.com/ascend/ModelZoo-PyTorch.git -cd ModelZoo-PyTorch/ACL_PyTorch/built-in/cv/SAM -git clone https://github.com/facebookresearch/segment-anything.git -cd segment-anything -git reset --hard 6fdee8f2727f4506cfbbe553e23b895e27956588 -patch -p2 < ../segment_anything_diff.patch -pip3 install -e . -cd .. -``` - -### 3.2 安装依赖。 - -1. 安装基础环境。 - - ```bash - pip3 install -r requirements.txt - ``` - - 说明:如果某些库通过此方式安装失败,可使用 pip3 install 单独进行安装。 - -2. 安装 [msit](https://gitee.com/ascend/msit/tree/master/msit/) 的 surgeon 组件和 benchmark 组件。 - -### 3.3 准备数据集 - -GitHub 仓库没有提供精度和性能的测试手段,这里取仓库里的 demo 图片进行测试。 +- 取仓库里的 demo 图片进行端到端测试。 ```bash mkdir data @@ -106,9 +106,23 @@ wget -O demo.jpg https://raw.githubusercontent.com/facebookresearch/segment-anyt cd .. ``` -### 3.4 模型转换 +- 下载coco2017数据集进行精度测试。 -#### 3.4.1 获取权重文件 +下载COCO-2017数据集的[图片](https://gitee.com/link?target=http%3A%2F%2Fimages.cocodataset.org%2Fzips%2Fval2017.zip)与[标注](https://gitee.com/link?target=http%3A%2F%2Fimages.cocodataset.org%2Fannotations%2Fannotations_trainval2017.zip),放置coco2017目录下 + + ``` + coco2017 + ├── annotations/ + │ └── instances_val2017.json + └── val2017/ + ├── 000000000139.jpg + ├── 000000000139.jpg + └── ... + ``` + +### 5. 模型转换 + +#### 5.1 获取权重文件 GitHub 仓库提供了三种大小的权重文件:vit_h、vit_l、vit_b。这里以 vit_b 为例。 @@ -119,7 +133,7 @@ wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth cd .. ``` -#### 3.4.2 导出 ONNX 模型 +#### 5.2 导出 ONNX 模型 ```bash python3 segment-anything/scripts/export_onnx_model.py \ @@ -140,7 +154,7 @@ python3 segment-anything/scripts/export_onnx_model.py \ - decoder-output:保存decoder模型的输出ONNX模型的文件路径。 - return-single-mask:设置最优mask模式。 -#### 3.4.3 使用 onnxsim 简化 ONNX 模型 +#### 5.3 使用 onnxsim 简化 ONNX 模型 这里以 batchsize=1 为例。 @@ -156,7 +170,7 @@ onnxsim models/decoder.onnx models/decoder_sim.onnx - 第二个参数:简化后的 ONNX 保存路径。 - overwrite-input-shape:指定输入的维度。 -#### 3.4.4 运行改图脚本,修改 ONNX 模型以适配昇腾芯片 +#### 5.4 运行改图脚本,修改 ONNX 模型以适配昇腾芯片 ```bash python3 encoder_onnx_modify.py \ @@ -169,9 +183,9 @@ python3 encoder_onnx_modify.py \ - 第一个参数:原 ONNX 路径。 - 第二个参数:适配后的 ONNX 保存路径。 -#### 3.4.5 使用 ATC 工具将 ONNX 模型转为 OM 模型 +#### 5.5 使用 ATC 工具将 ONNX 模型转为 OM 模型 -1. 配置环境变量。 +- 配置环境变量。 ```bash source /usr/local/Ascend/ascend-toolkit/set_env.sh @@ -180,7 +194,7 @@ python3 encoder_onnx_modify.py \ > **说明:** 该脚本中环境变量仅供参考,请以实际安装环境配置环境变量。详细介绍请参见《[CANN 开发辅助工具指南 \(推理\)](https://support.huawei.com/enterprise/zh/ascend-computing/cann-pid-251168373?category=developer-documents&subcategory=auxiliary-development-tools)》。 -2. 执行命令查看芯片名称($\{chip\_name\})。 +- 执行命令查看芯片名称($\{chip\_name\})。 ```bash npu-smi info @@ -198,7 +212,7 @@ python3 encoder_onnx_modify.py \ +===================+=================+======================================================+ ``` -3. 执行 atc 命令。 +- 执行 atc 命令。 ```bash atc \ @@ -234,9 +248,9 @@ python3 encoder_onnx_modify.py \ 更多参数说明请参考 [ATC 参数概览](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC1alpha002/devaids/auxiliarydevtool/atlasatc_16_0039.html)(如果链接失效,请从 [CANN 社区版文档](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition) 查找目录「应用开发 > ATC 模型转换 > 参数说明 > 参数概览」) -### 3.5 推理验证 +### 6 推理验证 -1. 端到端推理。成功执行下述命令后会在save-path参数指定的目录生成离线推理的结果。 +6.1 端到端推理。成功执行下述命令后会在save-path参数指定的目录生成离线推理的结果。 ```bash python3 sam_end2end_infer.py \ @@ -271,7 +285,7 @@ python3 encoder_onnx_modify.py \ ![](./assets/om_truck_result.JPG) -2. 性能验证。 +6.2 性能验证。 1. encoder 纯推理性能验证。 @@ -305,9 +319,38 @@ python3 encoder_onnx_modify.py \ - loop: 循环次数 - batchsize: 模型batch size -## 4. 模型推理性能 & 精度 +6.3 精度验证。 + +SAM 官方未提供精度评测手段,这里提供对应脚本,基于 COCO 验证集标注框作为输入提示,使用 SAM 预测分割掩码,并与 COCO 标注掩码逐实例进行 IoU 计算,最后对所有实例的 IoU 结果取平均,得到整体的平均交并比(mIoU)。 + + ```bash + python sam_coco_metric.py \ + --dataset-path coco2017 \ + --save-path outputs \ + --encoder-model-path models/encoder_sim.om \ + --decoder-model-path models/decoder_sim.om \ + --device-id 0 \ + --max-instances 0 + ``` +参数说明: + +- dataset-path: coco数据集目录 +- save-path: SAM预测掩码结果存储路径 +- encoder-model-path:encoder的OM模型路径 +- decoder-model-path:decoder的OM模型路径 +- device-id: 指定推理的NPU设备ID +- max-instances: 评测的最大实例数量,默认为0表示测评完整验证集 + +## 4. 模型推理性能 & 精度 +性能结果: | 芯片型号 | 模型 | Batch Size | 性能 | | ---- | ---- | ---- | ---- | | 300I Pro | encoder | 1 | 4.43 fps | | 300I Pro | decoder | 1 | 679.77 fps | + +精度结果: +| 芯片型号 | 模型 | Batch Size | 精度(mIoU) | +| ---- | ---- | ---- | ---- | +| 300I Pro | SAM | 1 | 0.7654 | + diff --git a/ACL_PyTorch/built-in/cv/SAM/requirements.txt b/ACL_PyTorch/built-in/cv/SAM/requirements.txt index 6722adb464f8d1cbc54c229ff0e6e974124339b4..b4969b22635980cd34e286a64d4bbf4968e918c1 100644 --- a/ACL_PyTorch/built-in/cv/SAM/requirements.txt +++ b/ACL_PyTorch/built-in/cv/SAM/requirements.txt @@ -1,5 +1,5 @@ torch==2.1.0 -torch_npu==2.1.0.post17.dev20250905 +torch_npu==2.1.0.post10 torchvision==0.16.0 torchaudio==2.1.0 decorator diff --git a/ACL_PyTorch/built-in/cv/SAM/sam_coco_metric.py b/ACL_PyTorch/built-in/cv/SAM/sam_coco_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..fe44855be3e335d333eb40f47055f54fccb3c827 --- /dev/null +++ b/ACL_PyTorch/built-in/cv/SAM/sam_coco_metric.py @@ -0,0 +1,191 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import cv2 +import argparse +import numpy as np +from tqdm import tqdm +from pycocotools.coco import COCO +from pycocotools import mask as maskUtils + +from ais_bench.infer.interface import InferSession\ + +from sam_preprocessing_pytorch import encoder_preprocessing, decoder_preprocessing +from sam_postprocessing_pytorch import sam_postprocessing + + +def rle_to_mask(rle, h, w): + """COCO segmentation → binary mask (h,w) uint8.""" + if isinstance(rle, list): + rles = maskUtils.frPyObjects(rle, h, w) + rle = maskUtils.merge(rles) + elif isinstance(rle, dict) and isinstance(rle.get("counts"), list): + rle = maskUtils.frPyObjects(rle, h, w) + return maskUtils.decode(rle).astype(np.uint8) + + +def compute_iou(pred_mask, gt_mask): + pred = (pred_mask > 0).astype(np.uint8) + gt = (gt_mask > 0).astype(np.uint8) + inter = (pred & gt).sum() + union = (pred | gt).sum() + return float(inter) / float(union) if union > 0 else 0.0 + + +def coco_bbox_to_xyxy(bbox_xywh): + x, y, w, h = bbox_xywh + return [x, y, x + w, y + h] + + +def encoder_infer(session_encoder, x): + encoder_outputs = session_encoder.infer([x]) + image_embedding = encoder_outputs[0] + return image_embedding + + +def decoder_infer(session_decoder, decoder_inputs): + decoder_outputs = session_decoder.infer(decoder_inputs, mode="dymdims", custom_sizes=[1000, 1000000]) + low_res_masks = decoder_outputs[1] + return low_res_masks + + +def save_mask_overlay(masks, image, save_dir, image_name): + overlay = image.copy() + alpha = 0.5 + + for mask in masks: + if mask.sum() == 0: + continue + color = np.random.randint(0, 255, (3,), dtype=np.uint8) # 每个实例随机颜色 + overlay[mask > 0] = (overlay[mask > 0] * (1 - alpha) + color * alpha).astype(np.uint8) + + base, ext = os.path.splitext(image_name) + save_path = os.path.join(save_dir, f"{base}_sam_pre{ext}") + cv2.imwrite(save_path, overlay) + + +def evaluate_sam_on_coco(coco_root, save_path, encoder, decoder, max_instances=0): + ann_file = os.path.join(coco_root, "annotations", "instances_val2017.json") + img_root = os.path.join(coco_root, "val2017") + if not os.path.isfile(ann_file): + raise FileNotFoundError(f"COCO annotations not found: {ann_file}") + if not os.path.isdir(img_root): + raise FileNotFoundError(f"COCO val2017 images not found: {img_root}") + + coco = COCO(ann_file) + img_ids = coco.getImgIds() + + session_encoder = encoder + session_decoder = decoder + + ious = [] + counted = 0 + + for img_id in tqdm(img_ids, desc="Evaluating"): + img_info = coco.loadImgs(img_id)[0] + img_path = os.path.join(img_root, img_info["file_name"]) + image = cv2.imread(img_path) + + H, W = image.shape[:2] + + x = encoder_preprocessing(image) + image_embedding = encoder_infer(session_encoder, x) + + ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False) + anns = coco.loadAnns(ann_ids) + + mask_list = [] + for ann in anns: + + if max_instances > 0 and counted >= max_instances: + break + + box_xyxy = coco_bbox_to_xyxy(ann["bbox"]) + + decoder_inputs = decoder_preprocessing(image_embedding, box=box_xyxy, image=image) + low_res_masks = decoder_infer(session_decoder, decoder_inputs) + masks = sam_postprocessing(low_res_masks, image) + + pred2d = masks[0][0].astype(np.uint8) + mask_list.append(pred2d) + pred_bin = pred2d.astype(np.uint8) + + gt_mask = rle_to_mask(ann["segmentation"], H, W) + iou = compute_iou(pred_bin, gt_mask) + ious.append(iou) + counted += 1 + + if save_path is not None and len(mask_list) > 0: + save_mask_overlay(mask_list, image, save_path, img_info["file_name"]) + + if max_instances > 0 and counted >= max_instances: + break + + miou = float(np.mean(ious)) if counted > 0 else 0.0 + print("\n=========== COCO Evaluation (Box Prompt) ===========") + print(f"Instances Evaluated : {counted}") + print(f"Mean IoU (mIoU) : {miou:.4f}") + print("====================================================\n") + return miou + + +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [ int(v) for v in value.split(',') ] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError("{} of device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError("device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataset-path', type=str, default='./datasets/', help='input path to coco dataset') + parser.add_argument('--save-path', type=str, default=None, help='output path to image') + parser.add_argument('--encoder-model-path', type=str, default='./models/encoder_sim.om', help='path to encoder model') + parser.add_argument('--decoder-model-path', type=str, default='./models/decoder_sim.om', help='path to decoder model') + parser.add_argument('--device-id', type=check_device_range_valid, default=0, help='NPU device id.') + parser.add_argument('--max-instances', type=int, default=0, help='Maximum number of instances to evaluate (0 = all).') + args = parser.parse_args() + + if args.save_path and not os.path.exists(args.save_path): + os.makedirs(os.path.realpath(args.save_path), mode=0o744) + + session_encoder = InferSession(args.device_id, args.encoder_model_path) + session_decoder = InferSession(args.device_id, args.decoder_model_path) + + evaluate_sam_on_coco( + args.dataset_path, + args.save_path, + session_encoder, + session_decoder, + max_instances=args.max_instances + ) + +if __name__ == "__main__": + main() + diff --git a/ACL_PyTorch/built-in/cv/SAM/sam_end2end_infer.py b/ACL_PyTorch/built-in/cv/SAM/sam_end2end_infer.py index 25db4ffd006b5603580ff1206c7326dfdeb7f797..952c95520dc5c0ec4829c6ce8dfce3064569db87 100644 --- a/ACL_PyTorch/built-in/cv/SAM/sam_end2end_infer.py +++ b/ACL_PyTorch/built-in/cv/SAM/sam_end2end_infer.py @@ -69,11 +69,11 @@ def decoder_infer(session_decoder, decoder_inputs): return low_res_masks -def sam_infer(src_path, session_encoder, session_decoder, input_point, save_path): +def sam_infer(src_path, session_encoder, session_decoder, input_point=None, box=None, save_path="./"): image = cv2.imread(src_path) x = encoder_preprocessing(image) image_embedding = encoder_infer(session_encoder, x) - decoder_inputs = decoder_preprocessing(image_embedding, input_point, image) + decoder_inputs = decoder_preprocessing(image_embedding, input_point=input_point, box=box, image=image) low_res_masks = decoder_infer(session_decoder, decoder_inputs) masks = sam_postprocessing(low_res_masks, image) save_mask(masks, image, src_path, save_path, random_color=True) @@ -95,8 +95,7 @@ def main(): session_encoder = InferSession(args.device_id, args.encoder_model_path) session_decoder = InferSession(args.device_id, args.decoder_model_path) - sam_infer(args.src_path, session_encoder, session_decoder, args.input_point, args.save_path) - + sam_infer(args.src_path, session_encoder, session_decoder, input_point=args.input_point, save_path=args.save_path) if __name__ == '__main__': main() diff --git a/ACL_PyTorch/built-in/cv/SAM/sam_preprocessing_pytorch.py b/ACL_PyTorch/built-in/cv/SAM/sam_preprocessing_pytorch.py index d8b5a1474207a1fc4f554beac8d5f5bbbcbd2979..f79da035e85e4098dd95f80f9b0eecb19c646dd3 100644 --- a/ACL_PyTorch/built-in/cv/SAM/sam_preprocessing_pytorch.py +++ b/ACL_PyTorch/built-in/cv/SAM/sam_preprocessing_pytorch.py @@ -35,12 +35,27 @@ def encoder_preprocessing(image): return image -def decoder_preprocessing(image_embedding, input_point, image): - input_point = np.array(input_point) - input_label = [1] * len(input_point) - input_label = np.array(input_label) - onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] - onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) +def decoder_preprocessing(image_embedding, input_point=None, box=None, image=None): ## box:[x0,y0,x1,y1] + coords_list = [] + labels_list = [] + + if input_point is not None and len(input_point) > 0: + input_point = np.array(input_point, dtype=np.float32) + input_label = np.ones(len(input_point), dtype=np.float32) + coords_list.append(input_point) + labels_list.append(input_label) + + coords_list.append(np.array([[0.0, 0.0]], dtype=np.float32)) + labels_list.append(np.array([-1], dtype=np.float32)) + + if box is not None: + box = np.array(box, dtype=np.float32).reshape(2, 2) + coords_list.append(box) + labels_list.append(np.array([2, 3], dtype=np.float32)) + + onnx_coord = np.concatenate(coords_list, axis=0)[None, :, :] # (1,N,2) + onnx_label = np.concatenate(labels_list, axis=0)[None, :].astype(np.float32) # (1,N) + transform = ResizeLongestSide(IMAGE_SIZE) onnx_coord = transform.apply_coords(onnx_coord, image.shape[: 2]).astype(np.float32) onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) diff --git a/ACL_PyTorch/built-in/cv/SAM/segment_anything_diff.patch b/ACL_PyTorch/built-in/cv/SAM/segment_anything_diff.patch index aec413383ecb17bc2ed7eb15def82ad337b44ff7..96284944bc9f131e8c6f4819e25304bcc553f8a2 100644 --- a/ACL_PyTorch/built-in/cv/SAM/segment_anything_diff.patch +++ b/ACL_PyTorch/built-in/cv/SAM/segment_anything_diff.patch @@ -1,6 +1,7 @@ -diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/scripts/export_onnx_model.py ---- a/segment-anything/scripts/export_onnx_model.py 2023-11-13 16:25:26.000000000 +0800 -+++ b/segment-anything/scripts/export_onnx_model.py 2023-11-18 16:15:20.088025762 +0800 +diff --git a/scripts/export_onnx_model.py b/scripts/export_onnx_model.py +index 5c6f838..0bfaff2 100644 +--- a/scripts/export_onnx_model.py ++++ b/scripts/export_onnx_model.py @@ -6,8 +6,12 @@ import torch @@ -14,7 +15,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc import argparse import warnings -@@ -24,11 +28,30 @@ +@@ -24,11 +28,30 @@ parser = argparse.ArgumentParser( ) parser.add_argument( @@ -47,7 +48,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc ) parser.add_argument( -@@ -56,11 +79,21 @@ +@@ -56,11 +79,21 @@ parser.add_argument( ) parser.add_argument( @@ -71,7 +72,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." ), ) -@@ -97,7 +130,9 @@ +@@ -97,7 +130,9 @@ parser.add_argument( def run_export( model_type: str, checkpoint: str, @@ -82,7 +83,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc opset: int, return_single_mask: bool, gelu_approximate: bool = False, -@@ -107,6 +142,74 @@ +@@ -107,6 +142,74 @@ def run_export( print("Loading model...") sam = sam_model_registry[model_type](checkpoint=checkpoint) @@ -157,7 +158,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc onnx_model = SamOnnxModel( model=sam, return_single_mask=return_single_mask, -@@ -129,16 +232,17 @@ +@@ -129,16 +232,17 @@ def run_export( mask_input_size = [4 * x for x in embed_size] dummy_inputs = { "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), @@ -178,7 +179,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) -@@ -164,7 +268,7 @@ +@@ -164,7 +268,7 @@ def run_export( providers = ["CPUExecutionProvider"] ort_session = onnxruntime.InferenceSession(output, providers=providers) _ = ort_session.run(None, ort_inputs) @@ -187,7 +188,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc def to_numpy(tensor): -@@ -176,7 +280,9 @@ +@@ -176,7 +280,9 @@ if __name__ == "__main__": run_export( model_type=args.model_type, checkpoint=args.checkpoint, @@ -198,7 +199,7 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc opset=args.opset, return_single_mask=args.return_single_mask, gelu_approximate=args.gelu_approximate, -@@ -184,18 +290,34 @@ +@@ -184,18 +290,34 @@ if __name__ == "__main__": return_extra_metrics=args.return_extra_metrics, ) @@ -238,10 +239,11 @@ diff -Naru a/segment-anything/scripts/export_onnx_model.py b/segment-anything/sc + ) + print("Done!") \ No newline at end of file -diff -Naru a/segment-anything/segment_anything/modeling/image_encoder.py b/segment-anything/segment_anything/modeling/image_encoder.py ---- a/segment-anything/segment_anything/modeling/image_encoder.py 2023-11-13 16:25:26.000000000 +0800 -+++ b/segment-anything/segment_anything/modeling/image_encoder.py 2023-11-13 19:26:32.000000000 +0800 -@@ -253,8 +253,8 @@ +diff --git a/segment_anything/modeling/image_encoder.py b/segment_anything/modeling/image_encoder.py +index 66351d9..31d622c 100644 +--- a/segment_anything/modeling/image_encoder.py ++++ b/segment_anything/modeling/image_encoder.py +@@ -253,8 +253,8 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T """ B, H, W, C = x.shape @@ -252,7 +254,7 @@ diff -Naru a/segment-anything/segment_anything/modeling/image_encoder.py b/segme if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w -@@ -322,6 +322,15 @@ +@@ -322,6 +322,15 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor return rel_pos_resized[relative_coords.long()] @@ -268,7 +270,7 @@ diff -Naru a/segment-anything/segment_anything/modeling/image_encoder.py b/segme def add_decomposed_rel_pos( attn: torch.Tensor, q: torch.Tensor, -@@ -351,8 +360,8 @@ +@@ -351,8 +360,8 @@ def add_decomposed_rel_pos( B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) @@ -279,10 +281,33 @@ diff -Naru a/segment-anything/segment_anything/modeling/image_encoder.py b/segme attn = ( attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] -diff -Naru a/segment-anything/segment_anything/utils/onnx.py b/segment-anything/segment_anything/utils/onnx.py ---- a/segment-anything/segment_anything/utils/onnx.py 2023-11-13 16:25:26.000000000 +0800 -+++ b/segment-anything/segment_anything/utils/onnx.py 2023-11-18 16:14:01.512027850 +0800 -@@ -112,7 +112,6 @@ +diff --git a/segment_anything/modeling/mask_decoder.py b/segment_anything/modeling/mask_decoder.py +index 5d2fdb0..ee8da94 100644 +--- a/segment_anything/modeling/mask_decoder.py ++++ b/segment_anything/modeling/mask_decoder.py +@@ -123,9 +123,15 @@ class MaskDecoder(nn.Module): + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask +- src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) ++ N = tokens.shape[0] ++ B, C, H, W = image_embeddings.shape ++ src = image_embeddings.unsqueeze(1).expand(B, N, C, H, W).reshape(B * N, C, H, W) ++ + src = src + dense_prompt_embeddings +- pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) ++ ++ B, C, H, W = image_pe.shape ++ pos_src = image_pe.unsqueeze(1).expand(B, N, C, H, W).reshape(B * N, C, H, W) ++ + b, c, h, w = src.shape + + # Run the transformer +diff --git a/segment_anything/utils/onnx.py b/segment_anything/utils/onnx.py +index 3196bdf..e718afc 100644 +--- a/segment_anything/utils/onnx.py ++++ b/segment_anything/utils/onnx.py +@@ -112,7 +112,6 @@ class SamOnnxModel(nn.Module): point_labels: torch.Tensor, mask_input: torch.Tensor, has_mask_input: torch.Tensor, @@ -290,7 +315,7 @@ diff -Naru a/segment-anything/segment_anything/utils/onnx.py b/segment-anything/ ): sparse_embedding = self._embed_points(point_coords, point_labels) dense_embedding = self._embed_masks(mask_input, has_mask_input) -@@ -131,14 +130,4 @@ +@@ -131,14 +130,4 @@ class SamOnnxModel(nn.Module): if self.return_single_mask: masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) @@ -306,3 +331,4 @@ diff -Naru a/segment-anything/segment_anything/utils/onnx.py b/segment-anything/ - - return upscaled_masks, scores, masks + return scores, masks +