生活随笔
收集整理的這篇文章主要介紹了
Deformable-DETR(two-stage version)中Encoder Proposal
小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
Deformable-DETR variants:Two-stage Deformable DETR
前言
two stage Deformable DETR
上圖為論文中關(guān)于two-stage的部分,介紹較少,DETR及其變體分為:one-stage\two-stage,其中one-stage的decoder部分queries的初始化是由隨機(jī)初始化的content queries(initially set zero and unlearnable) + position embeding(set randomly and learnable)。two-stage類似于RCNN,把encoder輸出的memory作為shared feature map用于ROI Proposal,并將proposals用于后面decoder的queries初始化。這樣可以加快decoder部分的收斂和穩(wěn)定性。
DINO中對(duì)于目前的初始化方法分為三類:
第一類以DETR為首的static anchors 第二類以deformable detr為首的dynamic anchors and contents 第三類作者提出的dynamic anchors and static contents
源碼部分
gen_encoder_output_proposals(mmdetection\mmdet\models\utils\transformer.py)
def gen_encoder_output_proposals ( self
, memory
, memory_padding_mask
, spatial_shapes
) : """Generate proposals from encoded memory.Args:memory (Tensor) : The output of encoder,has shape (bs, num_key, embed_dim). num_key isequal the number of points on feature map fromall level.memory_padding_mask (Tensor): Padding mask for memory.has shape (bs, num_key).spatial_shapes (Tensor): The shape of all feature maps.has shape (num_level, 2).Returns:tuple: A tuple of feature map and bbox prediction.- output_memory (Tensor): The input of decoder, \has shape (bs, num_key, embed_dim). num_key is \equal the number of points on feature map from \all levels.- output_proposals (Tensor): The normalized proposal \after a inverse sigmoid, has shape \(bs, num_keys, 4).""" N
, S
, C
= memory
. shapeproposals
= [ ] _cur
= 0 for lvl
, ( H
, W
) in enumerate ( spatial_shapes
) : mask_flatten_
= memory_padding_mask
[ : , _cur
: ( _cur
+ H
* W
) ] . view
( N
, H
, W
, 1 ) valid_H
= torch
. sum ( ~ mask_flatten_
[ : , : , 0 , 0 ] , 1 ) valid_W
= torch
. sum ( ~ mask_flatten_
[ : , 0 , : , 0 ] , 1 ) grid_y
, grid_x
= torch
. meshgrid
( torch
. linspace
( 0 , H
- 1 , H
, dtype
= torch
. float32
, device
= memory
. device
) , torch
. linspace
( 0 , W
- 1 , W
, dtype
= torch
. float32
, device
= memory
. device
) ) grid
= torch
. cat
( [ grid_x
. unsqueeze
( - 1 ) , grid_y
. unsqueeze
( - 1 ) ] , - 1 ) scale
= torch
. cat
( [ valid_W
. unsqueeze
( - 1 ) , valid_H
. unsqueeze
( - 1 ) ] , 1 ) . view
( N
, 1 , 1 , 2 ) grid
= ( grid
. unsqueeze
( 0 ) . expand
( N
, - 1 , - 1 , - 1 ) + 0.5 ) / scalewh
= torch
. ones_like
( grid
) * 0.05 * ( 2.0 ** lvl
) proposal
= torch
. cat
( ( grid
, wh
) , - 1 ) . view
( N
, - 1 , 4 ) proposals
. append
( proposal
) _cur
+= ( H
* W
) output_proposals
= torch
. cat
( proposals
, 1 ) output_proposals_valid
= ( ( output_proposals
> 0.01 ) & ( output_proposals
< 0.99 ) ) . all ( - 1 , keepdim
= True ) output_proposals
= torch
. log
( output_proposals
/ ( 1 - output_proposals
) ) output_proposals
= output_proposals
. masked_fill
( memory_padding_mask
. unsqueeze
( - 1 ) , float ( 'inf' ) ) output_proposals
= output_proposals
. masked_fill
( ~ output_proposals_valid
, float ( 'inf' ) ) output_memory
= memoryoutput_memory
= output_memory
. masked_fill
( memory_padding_mask
. unsqueeze
( - 1 ) , float ( 0 ) ) output_memory
= output_memory
. masked_fill
( ~ output_proposals_valid
, float ( 0 ) ) output_memory
= self
. enc_output_norm
( self
. enc_output
( output_memory
) ) return output_memory
, output_proposals
class DeformableDetrTransformer(Transformer):
def forward ( self
, mlvl_feats
, mlvl_masks
, query_embed
, mlvl_pos_embeds
, reg_branches
= None , cls_branches
= None , ** kwargs
) : assert self
. as_two_stage
or query_embed
is not None feat_flatten
= [ ] mask_flatten
= [ ] lvl_pos_embed_flatten
= [ ] spatial_shapes
= [ ] for lvl
, ( feat
, mask
, pos_embed
) in enumerate ( zip ( mlvl_feats
, mlvl_masks
, mlvl_pos_embeds
) ) : bs
, c
, h
, w
= feat
. shapespatial_shape
= ( h
, w
) spatial_shapes
. append
( spatial_shape
) feat
= feat
. flatten
( 2 ) . transpose
( 1 , 2 ) mask
= mask
. flatten
( 1 ) pos_embed
= pos_embed
. flatten
( 2 ) . transpose
( 1 , 2 ) lvl_pos_embed
= pos_embed
+ self
. level_embeds
[ lvl
] . view
( 1 , 1 , - 1 ) lvl_pos_embed_flatten
. append
( lvl_pos_embed
) feat_flatten
. append
( feat
) mask_flatten
. append
( mask
) feat_flatten
= torch
. cat
( feat_flatten
, 1 ) mask_flatten
= torch
. cat
( mask_flatten
, 1 ) lvl_pos_embed_flatten
= torch
. cat
( lvl_pos_embed_flatten
, 1 ) spatial_shapes
= torch
. as_tensor
( spatial_shapes
, dtype
= torch
. long , device
= feat_flatten
. device
) level_start_index
= torch
. cat
( ( spatial_shapes
. new_zeros
( ( 1 , ) ) , spatial_shapes
. prod
( 1 ) . cumsum
( 0 ) [ : - 1 ] ) ) valid_ratios
= torch
. stack
( [ self
. get_valid_ratio
( m
) for m
in mlvl_masks
] , 1 ) reference_points
= \self
. get_reference_points
( spatial_shapes
, valid_ratios
, device
= feat
. device
) feat_flatten
= feat_flatten
. permute
( 1 , 0 , 2 ) lvl_pos_embed_flatten
= lvl_pos_embed_flatten
. permute
( 1 , 0 , 2 ) memory
= self
. encoder
( query
= feat_flatten
, key
= None , value
= None , query_pos
= lvl_pos_embed_flatten
, query_key_padding_mask
= mask_flatten
, spatial_shapes
= spatial_shapes
, reference_points
= reference_points
, level_start_index
= level_start_index
, valid_ratios
= valid_ratios
, ** kwargs
) memory
= memory
. permute
( 1 , 0 , 2 ) bs
, _
, c
= memory
. shape
if self
. as_two_stage
: output_memory
, output_proposals
= \self
. gen_encoder_output_proposals
( memory
, mask_flatten
, spatial_shapes
) enc_outputs_class
= cls_branches
[ self
. decoder
. num_layers
] ( output_memory
) enc_outputs_coord_unact
= \reg_branches
[ self
. decoder
. num_layers
] ( output_memory
) + output_proposalstopk
= self
. two_stage_num_proposalstopk_proposals
= torch
. topk
( enc_outputs_class
[ . . . , 0 ] , topk
, dim
= 1 ) [ 1 ] topk_coords_unact
= torch
. gather
( enc_outputs_coord_unact
, 1 , topk_proposals
. unsqueeze
( - 1 ) . repeat
( 1 , 1 , 4 ) ) topk_coords_unact
= topk_coords_unact
. detach
( ) reference_points
= topk_coords_unact
. sigmoid
( ) init_reference_out
= reference_pointspos_trans_out
= self
. pos_trans_norm
( self
. pos_trans
( self
. get_proposal_pos_embed
( topk_coords_unact
) ) ) query_pos
, query
= torch
. split
( pos_trans_out
, c
, dim
= 2 ) else : query_pos
, query
= torch
. split
( query_embed
, c
, dim
= 1 ) query_pos
= query_pos
. unsqueeze
( 0 ) . expand
( bs
, - 1 , - 1 ) query
= query
. unsqueeze
( 0 ) . expand
( bs
, - 1 , - 1 ) reference_points
= self
. reference_points
( query_pos
) . sigmoid
( ) init_reference_out
= reference_pointsquery
= query
. permute
( 1 , 0 , 2 ) memory
= memory
. permute
( 1 , 0 , 2 ) query_pos
= query_pos
. permute
( 1 , 0 , 2 ) inter_states
, inter_references
= self
. decoder
( query
= query
, key
= None , value
= memory
, query_pos
= query_pos
, key_padding_mask
= mask_flatten
, reference_points
= reference_points
, spatial_shapes
= spatial_shapes
, level_start_index
= level_start_index
, valid_ratios
= valid_ratios
, reg_branches
= reg_branches
, ** kwargs
) inter_references_out
= inter_references
if self
. as_two_stage
: return inter_states
, init_reference_out
, \inter_references_out
, enc_outputs_class
, \enc_outputs_coord_unact
return inter_states
, init_reference_out
, \inter_references_out
, None , None
總結(jié)
以上是生活随笔 為你收集整理的Deformable-DETR(two-stage version)中Encoder Proposal 的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
如果覺(jué)得生活随笔 網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔 推薦給好友。