当前位置:首页 » 《随便一记》 » 正文

【人工智能前沿弄潮】—— SAM系列:SAM自动生成物体mask

12 人参与  2024年02月05日 19:16  分类 : 《随便一记》  评论

点击全文阅读


SAM自动生成物体mask

由于SAM可以高效处理提示,可以通过在图像上抽样大量的提示来生成整个图像的mask。这种方法被用来生成数据集SA-1B。

SamAutomaticMaskGenerator实现了这个功能。它通过在图像上的网格中对单点输入提示进行抽样,从每个提示中SAM可以预测多个mask。然后,使用非极大值抑制对mask进行质量过滤和去重。其他选项允许进一步提高mask的质量和数量,例如在图像的多个裁剪上运行预测,或者对mask进行后处理以去除小的不连通区域和孔洞。

设置

import numpy as npimport torchimport matplotlib.pyplot as pltimport cv2def show_anns(anns):    if len(anns) == 0:        return    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)    ax = plt.gca()    ax.set_autoscale_on(False)    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))    img[:,:,3] = 0    for ann in sorted_anns:        m = ann['segmentation']        color_mask = np.concatenate([np.random.random(3), [0.35]])        img[m] = color_mask    ax.imshow(img)

示例图像

image = cv2.imread('images/dog.jpg')image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(20,20))plt.imshow(image)plt.axis('off')plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-z24a4zns-1691578969711)(output_13_0.png)]

自动mask生成

要运行自动mask生成,将一个SAM模型提供给SamAutomaticMaskGenerator类。在下面设置SAM检查点的路径。建议在CUDA上运行并使用默认模型。

import syssys.path.append("..")from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictorsam_checkpoint = "sam_vit_h_4b8939.pth"model_type = "vit_h"device = "cuda"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)sam.to(device=device)mask_generator = SamAutomaticMaskGenerator(sam)

要生成mask,只需在图像上运行generate

masks = mask_generator.generate(image)

mask生成返回一个mask列表,其中每个mask是一个包含有关mask的各种数据的字典。这些键包括:

segmentation:maskarea:mask的面积(以像素为单位)bbox:mask的边界框(XYWH格式)predicted_iou:模型对mask质量的预测point_coords:生成此mask的抽样输入点stability_score:mask质量的附加衡量指标crop_box:用于生成此mask的图像裁剪(XYWH格式)
print(len(masks))print(masks[0].keys())
dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])

显示所有mask叠加在图像上。

plt.figure(figsize=(20,20))plt.imshow(image)show_anns(masks)plt.axis('off')plt.show() 


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pqdrIXEW-1691578969713)(output_22_0.png)]

自动mask生成选项

自动mask生成中有几个可调参数,用于控制点是如何抽样的,以及移除低质量或重复mask的阈值是什么。此外,可以自动在图像的裁剪上运行生成,以在较小的对象上获得更好的性能,并且后处理可以去除杂散像素和孔洞。以下是一个示例配置,用于抽样更多的mask:

mask_generator_2 = SamAutomaticMaskGenerator(    model=sam,    points_per_side=32,    pred_iou_thresh=0.86,    stability_score_thresh=0.92,    crop_n_layers=1,    crop_n_points_downscale_factor=2,    min_mask_region_area=100,  # Requires open-cv to run post-processing)
masks2 = mask_generator_2.generate(image)
len(masks2)
90
plt.figure(figsize=(20,20))plt.imshow(image)show_anns(masks2)plt.axis('off')plt.show() 

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jJXJpvVR-1691578969725)(output_28_0.png)]


点击全文阅读


本文链接:http://zhangshiyu.com/post/67516.html

<< 上一篇 下一篇 >>

  • 评论(0)
  • 赞助本站

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。

关于我们 | 我要投稿 | 免责申明

Copyright © 2020-2022 ZhangShiYu.com Rights Reserved.豫ICP备2022013469号-1