StarNet
论文链接:[2403.19967] Rewrite the Stars
github仓库:GitHub - ma-xu/Rewrite-the-Stars: [CVPR 2024] Rewrite the Stars
CVPR2024 Rewrite the Stars论文揭示了star operation
(元素乘法)在无需加宽网络下,将输入映射到高维非线性特征空间的能力。基于此提出了StarNet
,在紧凑的网络结构和较低的能耗下展示了令人印象深刻的性能和低延迟。
优势 (Advantages)
高维和非线性特征变换 (High-Dimensional and Non-Linear Feature Transformation)
StarNet通过星操作(star operation)实现高维和非线性特征空间的映射,而无需增加计算复杂度。与传统的内核技巧(kernel tricks)类似,星操作能够在低维输入中隐式获得高维特征 (ar5iv)。对于YOLO系列网络,这意味着在保持计算效率的同时,能够获得更丰富和表达力更强的特征表示,这对于目标检测任务中的精细特征捕获尤为重要。高效网络设计 (Efficient Network Design)
StarNet通过星操作实现了高效的特征表示,无需复杂的网络设计和额外的计算开销。其独特的能力在于能够在低维空间中执行计算,但隐式地考虑极高维的特征 (ar5iv)。这使得StarNet可以作为YOLO系列网络的主干,提供高效的计算和更好的特征表示,有助于在资源受限的环境中实现更高的检测性能。多层次隐式特征扩展 (Multi-Layer Implicit Feature Expansion)
通过多层星操作,StarNet能够递归地增加隐式特征维度,接近无限维度。对于具有较大宽度和深度的网络,这种特性可以显著增强特征的表达能力 (ar5iv)。对于YOLO系列网络,这意味着可以通过适当的深度和宽度设计,显著提高特征提取的质量,从而提升目标检测的准确性。解决的问题 (Problems Addressed)
计算复杂度与性能的平衡 (Balance Between Computational Complexity and Performance)
StarNet通过星操作在保持计算复杂度较低的同时,实现了高维特征空间的映射。这解决了传统高效网络设计中计算复杂度与性能之间的权衡问题 (ar5iv)。YOLO系列网络需要在实时性和检测精度之间找到平衡,StarNet的高效特性正好契合这一需求。特征表示的丰富性 (Richness of Feature Representation)
传统卷积网络在特征表示的高维非线性变换上存在一定局限性,而StarNet通过星操作实现了更丰富的特征表示 (ar5iv)。在目标检测任务中,特别是对于小目标和复杂场景,丰富的特征表示能够显著提升检测效果,使得YOLO系列网络在这些场景中表现更佳。简化网络设计 (Simplified Network Design)
StarNet通过星操作提供了一种简化网络设计的方法,无需复杂的特征融合和多分支设计就能实现高效的特征表示 (ar5iv)。对于YOLO系列网络,这意味着可以更容易地设计和实现高效的主干网络,降低设计和调试的复杂度。在MMYOLO中将StarNet替换成yolov5的主干网络
1. 在上文提到的仓库中下载imagenet/starnet.py
2. 修改starnet.py中的forward函数,并且添加out_dices参数使其能够输出不同stage的特征向量
3. 将class StarNet注册并且在__init__()函数中进行修改
4. 修改配置文件,主要是调整YOLOv5 neck和head的输入输出通道数
修改后的starnet.py
"""Implementation of Prof-of-Concept Network: StarNet.We make StarNet as simple as possible [to show the key contribution of element-wise multiplication]: - like NO layer-scale in network design, - and NO EMA during training, - which would improve the performance further.Created by: Xu Ma (Email: ma.xu1@northeastern.edu)Modified Date: Mar/29/2024"""import torchimport torch.nn as nnfrom timm.models.layers import DropPath, trunc_normal_from typing import List, Sequence, Union# from timm.models.registry import register_modelfrom mmyolo.registry import MODELSmodel_urls = { "starnet_s1": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar", "starnet_s2": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar", "starnet_s3": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar", "starnet_s4": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar",}class ConvBN(torch.nn.Sequential): def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, with_bn=True): super().__init__() self.add_module('conv', torch.nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation, groups)) if with_bn: self.add_module('bn', torch.nn.BatchNorm2d(out_planes)) torch.nn.init.constant_(self.bn.weight, 1) torch.nn.init.constant_(self.bn.bias, 0)class Block(nn.Module): def __init__(self, dim, mlp_ratio=3, drop_path=0.): super().__init__() self.dwconv = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=True) self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False) self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False) self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True) self.dwconv2 = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=False) self.act = nn.ReLU6() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): input = x x = self.dwconv(x) x1, x2 = self.f1(x), self.f2(x) x = self.act(x1) * x2 x = self.dwconv2(self.g(x)) x = input + self.drop_path(x) return x@MODELS.register_module()class StarNet(nn.Module): def __init__(self, base_dim=32, out_indices: Sequence[int] = (0, 1, 2), depths=[3, 3, 12, 5], mlp_ratio=4, drop_path_rate=0.0, num_classes=1000, **kwargs): super().__init__() self.num_classes = num_classes self.in_channel = 32 self.out_indices = out_indices self.depths = depths # stem layer self.stem = nn.Sequential(ConvBN(3, self.in_channel, kernel_size=3, stride=2, padding=1), nn.ReLU6()) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth # build stages self.stages = nn.ModuleList() cur = 0 for i_layer in range(len(depths)): embed_dim = base_dim * 2 ** i_layer down_sampler = ConvBN(self.in_channel, embed_dim, 3, 2, 1) self.in_channel = embed_dim blocks = [Block(self.in_channel, mlp_ratio, dpr[cur + i]) for i in range(depths[i_layer])] cur += depths[i_layer] self.stages.append(nn.Sequential(down_sampler, *blocks)) # head # self.norm = nn.BatchNorm2d(self.in_channel) # self.avgpool = nn.AdaptiveAvgPool2d(1) # self.head = nn.Linear(self.in_channel, num_classes) # self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear or nn.Conv2d): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm or nn.BatchNorm2d): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x): x = self.stem(x) ##记录stage的输出 outs = [] for i in range(len(self.depths)): x = self.stages[i](x) if i in self.out_indices: outs.append(x) return tuple(outs)@MODELS.register_module()def starnet_s1(pretrained=False, **kwargs): model = StarNet(24, (0, 1, 2), [2, 2, 8, 3], **kwargs) if pretrained: url = model_urls['starnet_s1'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") model.load_state_dict(checkpoint["state_dict"]) return model@MODELS.register_module()def starnet_s2(pretrained=False, **kwargs): model = StarNet(32, (0, 1, 2), [1, 2, 6, 2], **kwargs) if pretrained: url = model_urls['starnet_s2'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") model.load_state_dict(checkpoint["state_dict"]) return model@MODELS.register_module()def starnet_s3(pretrained=False, **kwargs): model = StarNet(32, (0, 1, 2), [2, 2, 8, 4], **kwargs) if pretrained: url = model_urls['starnet_s3'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") model.load_state_dict(checkpoint["state_dict"]) return model@MODELS.register_module()def starnet_s4(pretrained=False, **kwargs): model = StarNet(32, (0, 1, 2), [3, 3, 12, 5], **kwargs) if pretrained: url = model_urls['starnet_s4'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") model.load_state_dict(checkpoint["state_dict"]) return model# very small networks #@MODELS.register_module()def starnet_s050(pretrained=False, **kwargs): return StarNet(16, (0, 1, 2), [1, 1, 3, 1], 3, **kwargs)@MODELS.register_module()def starnet_s100(pretrained=False, **kwargs): return StarNet(20, (0, 1, 2), [1, 2, 4, 1], 4, **kwargs)@MODELS.register_module()def starnet_s150(pretrained=False, **kwargs): return StarNet(24, (0, 1, 2), [1, 2, 4, 2], 3, **kwargs)if __name__ == '__main__': model = StarNet() input_tensor = torch.randn(1, 3, 224, 224) outputs = model(input_tensor)
修改后的__init__.py
# Copyright (c) OpenMMLab. All rights reserved.from .base_backbone import BaseBackbonefrom .csp_darknet import YOLOv5CSPDarknet, YOLOv8CSPDarknet, YOLOXCSPDarknetfrom .csp_resnet import PPYOLOECSPResNetfrom .cspnext import CSPNeXtfrom .efficient_rep import YOLOv6CSPBep, YOLOv6EfficientRepfrom .yolov7_backbone import YOLOv7Backbonefrom .starnet import StarNet__all__ = [ 'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep', 'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet', 'YOLOv8CSPDarknet','StarNet']
修改后的配置文件(以yolov5_s-v61_syncbn_8xb16-300e_coco.py为例子)
_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']# ========================Frequently modified parameters======================# -----data related-----data_root = 'data/coco/' # Root path of data# Path of train annotation filetrain_ann_file = 'annotations/instances_train2017.json'train_data_prefix = 'train2017/' # Prefix of train image path# Path of val annotation fileval_ann_file = 'annotations/instances_val2017.json'val_data_prefix = 'val2017/' # Prefix of val image pathnum_classes = 80 # Number of classes for classification# Batch size of a single GPU during trainingtrain_batch_size_per_gpu = 16# Worker to pre-fetch data for each single GPU during trainingtrain_num_workers = 8# persistent_workers must be False if num_workers is 0persistent_workers = True# -----model related-----# Basic size of multi-scale prior boxanchors = [ [(10, 13), (16, 30), (33, 23)], # P3/8 [(30, 61), (62, 45), (59, 119)], # P4/16 [(116, 90), (156, 198), (373, 326)] # P5/32]# -----train val related-----# Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bsbase_lr = 0.01max_epochs = 300 # Maximum training epochsmodel_test_cfg = dict( # The config of multi-label for multi-class prediction. multi_label=True, # The number of boxes before NMS nms_pre=30000, score_thr=0.001, # Threshold to filter out boxes. nms=dict(type='nms', iou_threshold=0.65), # NMS type and threshold max_per_img=300) # Max number of detections of each image# ========================Possible modified parameters========================# -----data related-----img_scale = (640, 640) # width, height# Dataset type, this will be used to define the datasetdataset_type = 'YOLOv5CocoDataset'# Batch size of a single GPU during validationval_batch_size_per_gpu = 1# Worker to pre-fetch data for each single GPU during validationval_num_workers = 2# Config of batch shapes. Only on val.# It means not used if batch_shapes_cfg is None.batch_shapes_cfg = dict( type='BatchShapePolicy', batch_size=val_batch_size_per_gpu, img_size=img_scale[0], # The image scale of padding should be divided by pad_size_divisor size_divisor=32, # Additional paddings for pixel scale extra_pad_ratio=0.5)# -----model related-----# The scaling factor that controls the depth of the network structuredeepen_factor = 0.33# The scaling factor that controls the width of the network structurewiden_factor = 0.5# Strides of multi-scale prior boxstrides = [8, 16, 32]num_det_layers = 3 # The number of model output scalesnorm_cfg = dict(type='BN', momentum=0.03, eps=0.001) # Normalization config# -----train val related-----affine_scale = 0.5 # YOLOv5RandomAffine scaling ratioloss_cls_weight = 0.5loss_bbox_weight = 0.05loss_obj_weight = 1.0prior_match_thr = 4. # Priori box matching threshold# The obj loss weights of the three output layersobj_level_weights = [4., 1., 0.4]lr_factor = 0.01 # Learning rate scaling factorweight_decay = 0.0005# Save model checkpoint and validation intervalssave_checkpoint_intervals = 10# The maximum checkpoints to keep.max_keep_ckpts = 3# Single-scale training is recommended to# be turned on, which can speed up training.env_cfg = dict(cudnn_benchmark=True)'''starnet_channel,base_dim,depths,mlp_ratios1:24,[48, 96, 192],[2, 2, 8, 3],4s2:32,[64, 128, 256],[1, 2, 6, 2],4s3:32,[64, 128, 256],[2, 2, 8, 4],4s4:32,[64, 128, 256],[3, 3, 12, 5],4starnet_s050:16,[32,64,128],[1, 1, 3, 1],3starnet_s0100:20,[40, 80, 120],[1, 2, 4, 1],4starnet_s150:24,[48, 96, 192],[1, 2, 4, 2],3'''starnet_channel=[48, 96, 192]depths=[1, 2, 6, 2]# ===============================Unmodified in most cases====================model = dict( type='YOLODetector', data_preprocessor=dict( type='mmdet.DetDataPreprocessor', mean=[0., 0., 0.], std=[255., 255., 255.], bgr_to_rgb=True), backbone=dict( ##s1 type='StarNet', base_dim=24, out_indices=(0,1,2), depths=depths, mlp_ratio=4, num_classes=num_classes, # deepen_factor=deepen_factor, # widen_factor=widen_factor, # norm_cfg=norm_cfg, # act_cfg=dict(type='SiLU', inplace=True) ), neck=dict( type='YOLOv5PAFPN', deepen_factor=deepen_factor, widen_factor=widen_factor, in_channels=starnet_channel, out_channels=starnet_channel, num_csp_blocks=3, norm_cfg=norm_cfg, act_cfg=dict(type='SiLU', inplace=True)), bbox_head=dict( type='YOLOv5Head', head_module=dict( type='YOLOv5HeadModule', num_classes=num_classes, in_channels=starnet_channel, widen_factor=widen_factor, featmap_strides=strides, num_base_priors=3), prior_generator=dict( type='mmdet.YOLOAnchorGenerator', base_sizes=anchors, strides=strides), # scaled based on number of detection layers loss_cls=dict( type='mmdet.CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=loss_cls_weight * (num_classes / 80 * 3 / num_det_layers)), # 修改此处实现IoU损失函数的替换 loss_bbox=dict( type='IoULoss', focal=True, iou_mode='ciou', bbox_format='xywh', eps=1e-7, reduction='mean', loss_weight=loss_bbox_weight * (3 / num_det_layers), return_iou=True), loss_obj=dict( type='mmdet.CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=loss_obj_weight * ((img_scale[0] / 640) ** 2 * 3 / num_det_layers)), prior_match_thr=prior_match_thr, obj_level_weights=obj_level_weights), test_cfg=model_test_cfg)albu_train_transforms = [ dict(type='Blur', p=0.01), dict(type='MedianBlur', p=0.01), dict(type='ToGray', p=0.01), dict(type='CLAHE', p=0.01)]pre_transform = [ dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), dict(type='LoadAnnotations', with_bbox=True)]train_pipeline = [ *pre_transform, dict( type='Mosaic', img_scale=img_scale, pad_val=114.0, pre_transform=pre_transform), dict( type='YOLOv5RandomAffine', max_rotate_degree=0.0, max_shear_degree=0.0, scaling_ratio_range=(1 - affine_scale, 1 + affine_scale), # img_scale is (width, height) border=(-img_scale[0] // 2, -img_scale[1] // 2), border_val=(114, 114, 114)), dict( type='mmdet.Albu', transforms=albu_train_transforms, bbox_params=dict( type='BboxParams', format='pascal_voc', label_fields=['gt_bboxes_labels', 'gt_ignore_flags']), keymap={ 'img': 'image', 'gt_bboxes': 'bboxes' }), dict(type='YOLOv5HSVRandomAug'), dict(type='mmdet.RandomFlip', prob=0.5), dict( type='mmdet.PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip', 'flip_direction'))]train_dataloader = dict( batch_size=train_batch_size_per_gpu, num_workers=train_num_workers, persistent_workers=persistent_workers, pin_memory=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=dict( type=dataset_type, data_root=data_root, ann_file=train_ann_file, data_prefix=dict(img=train_data_prefix), filter_cfg=dict(filter_empty_gt=False, min_size=32), pipeline=train_pipeline))test_pipeline = [ dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), dict(type='YOLOv5KeepRatioResize', scale=img_scale), dict( type='LetterResize', scale=img_scale, allow_scale_up=False, pad_val=dict(img=114)), dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'), dict( type='mmdet.PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', 'pad_param'))]val_dataloader = dict( batch_size=val_batch_size_per_gpu, num_workers=val_num_workers, persistent_workers=persistent_workers, pin_memory=True, drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, test_mode=True, data_prefix=dict(img=val_data_prefix), ann_file=val_ann_file, pipeline=test_pipeline, batch_shapes_cfg=batch_shapes_cfg))test_dataloader = val_dataloaderparam_scheduler = Noneoptim_wrapper = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=base_lr, momentum=0.937, weight_decay=weight_decay, nesterov=True, batch_size_per_gpu=train_batch_size_per_gpu), constructor='YOLOv5OptimizerConstructor')default_hooks = dict( param_scheduler=dict( type='YOLOv5ParamSchedulerHook', scheduler_type='linear', lr_factor=lr_factor, max_epochs=max_epochs), checkpoint=dict( type='CheckpointHook', interval=save_checkpoint_intervals, save_best='auto', max_keep_ckpts=max_keep_ckpts))custom_hooks = [ dict( type='EMAHook', ema_type='ExpMomentumEMA', momentum=0.0001, update_buffers=True, strict_load=False, priority=49)]val_evaluator = dict( type='mmdet.CocoMetric', proposal_nums=(100, 1, 10), ann_file=data_root + val_ann_file, metric='bbox')test_evaluator = val_evaluatortrain_cfg = dict( type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=save_checkpoint_intervals)val_cfg = dict(type='ValLoop')test_cfg = dict(type='TestLoop')