(即插即用模块

(即插即用模块

文章目录

1、Cross-Attention Fusion Module2、代码实现

paper:Attention Multihop Graph and Multiscale Convolutional Fusion Network for Hyperspectral Image Classification

Code:https://github.com/EdwardHaoz/IEEE_TGRS_AMGCFN

1、Cross-Attention Fusion Module

在现有的CNN 和 GCN中,存在一些局限性: 即虽然CNN和GCN都能有效提取特征,但它们分别侧重于像素级和超像素级信息。直接融合两者得到的特征往往不够充分,难以有效提升分类性能。此外,现有融合方法也存在着一些不足: 现有的融合方法大多采用简单的加权融合,缺乏对特征重要性的考虑,无法有效地突出重要特征,导致融合效果不佳。所以这篇论文提出一种 交叉注意力融合模块(Cross-Attention Fusion Module)。

CAFM 的基本原理是通过交叉注意力机制,将 PMCsN 和 MGCsN 提取的特征进行交互和融合,以获得更具判别力的特征。

CAFM 包含两个部分:通道注意力交叉模块和空间注意力融合模块。其具体实现过程如下:

通道注意力交叉模块:首先对两个子网络的特征分别进行全局最大池化和平均池化,得到两个通道描述。其中,全局最大池化操作会提取每个通道的最大值,而全局平均池化操作会提取每个通道的平均值,从而分别得到两个不同的通道描述。

然后将两个通道描述输入到一个共享的两层神经网络,该神经网络包含一个 ReLU 激活函数。通过两层神经网络,得到两个通道权重系数。再将两个通道权重系数相乘,得到一个交叉矩阵。最后将交叉矩阵分别与两个子网络的特征相乘,得到融合后的通道特征。

空间注意力融合模块:在空间层面,首先对两个子网络的特征分别进行最大池化和平均池化,得到两个空间描述。最大池化操作会提取每个像素的最大值,而平均池化操作会提取每个像素的平均值,从而分别得到两个不同的空间描述。

然后将两个空间描述在通道维度进行拼接,得到一个新的特征图。再将拼接后的特征图输入到一个共享的卷积层。再用一个卷积层学习空间特征,并得到空间权重系数。最后将空间权重系数分别与两个子网络的特征相乘,得到融合后的空间特征。

残差连接:最后将融合后的特征与输入特征进行残差连接,得到最终的融合特征。残差连接可以增强网络的鲁棒性,并有助于网络学习更深层次的特征。

Cross-Attention Fusion Module 结构图:

2、代码实现

import torch

import torch.nn as nn

import torch.nn.functional as F

from einops.einops import rearrange

class CAFM(nn.Module): # Cross Attention Fusion Module

def __init__(self, channels):

super(CAFM, self).__init__()

self.conv1_spatial = nn.Conv2d(2, 1, 3, stride=1, padding=1, groups=1)

self.conv2_spatial = nn.Conv2d(1, 1, 3, stride=1, padding=1, groups=1)

self.avg1 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)

self.avg2 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)

self.max1 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)

self.max2 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)

self.avg11 = nn.Conv2d(64, channels, 1, stride=1, padding=0)

self.avg22 = nn.Conv2d(64, channels, 1, stride=1, padding=0)

self.max11 = nn.Conv2d(64, channels, 1, stride=1, padding=0)

self.max22 = nn.Conv2d(64, channels, 1, stride=1, padding=0)

def forward(self, f1, f2):

b, c, h, w = f1.size()

f1 = f1.reshape([b, c, -1])

f2 = f2.reshape([b, c, -1])

avg_1 = torch.mean(f1, dim=-1, keepdim=True).unsqueeze(-1)

max_1, _ = torch.max(f1, dim=-1, keepdim=True)

max_1 = max_1.unsqueeze(-1)

avg_1 = F.relu(self.avg1(avg_1))

max_1 = F.relu(self.max1(max_1))

avg_1 = self.avg11(avg_1).squeeze(-1)

max_1 = self.max11(max_1).squeeze(-1)

a1 = avg_1 + max_1

avg_2 = torch.mean(f2, dim=-1, keepdim=True).unsqueeze(-1)

max_2, _ = torch.max(f2, dim=-1, keepdim=True)

max_2 = max_2.unsqueeze(-1)

avg_2 = F.relu(self.avg2(avg_2))

max_2 = F.relu(self.max2(max_2))

avg_2 = self.avg22(avg_2).squeeze(-1)

max_2 = self.max22(max_2).squeeze(-1)

a2 = avg_2 + max_2

cross = torch.matmul(a1, a2.transpose(1, 2))

a1 = torch.matmul(F.softmax(cross, dim=-1), f1)

a2 = torch.matmul(F.softmax(cross.transpose(1, 2), dim=-1), f2)

a1 = a1.reshape([b, c, h, w])

avg_out = torch.mean(a1, dim=1, keepdim=True)

max_out, _ = torch.max(a1, dim=1, keepdim=True)

a1 = torch.cat([avg_out, max_out], dim=1)

a1 = F.relu(self.conv1_spatial(a1))

a1 = self.conv2_spatial(a1)

a1 = a1.reshape([b, 1, -1])

a1 = F.softmax(a1, dim=-1)

a2 = a2.reshape([b, c, h, w])

avg_out = torch.mean(a2, dim=1, keepdim=True)

max_out, _ = torch.max(a2, dim=1, keepdim=True)

a2 = torch.cat([avg_out, max_out], dim=1)

a2 = F.relu(self.conv1_spatial(a2))

a2 = self.conv2_spatial(a2)

a2 = a2.reshape([b, 1, -1])

a2 = F.softmax(a2, dim=-1)

f1 = f1 * a1 + f1

f2 = f2 * a2 + f2

f1 = f1.squeeze(0)

f2 = f2.squeeze(0)

return f1.transpose(0, 1), f2.transpose(0, 1)

if __name__ == '__main__':

"""

本来CAFM的输入通道是固定的128,我在这里加了个参数

CAFM 的结果有两个,并且维度顺序是乱的,可以先相加,再调维度顺序

"""

H, W = 7, 7

x = torch.randn(4, 512, 7, 7).cuda()

y = torch.randn(4, 512, 7, 7).cuda()

model = CAFM(512).cuda()

out_1,out_2 = model(x,y)

out = out_1 + out_2

out = out.permute(1, 2, 0)

out = rearrange(out, 'b (h w) c -> b c h w', h=H, w=W)

print(out.shape)

相关内容

世界杯爆笑图片合集
365提款需要多久

世界杯爆笑图片合集

⌛ 07-03 👁️ 9406
70首适合在婚礼上唱的歌曲推荐
365提款需要多久

70首适合在婚礼上唱的歌曲推荐

⌛ 07-26 👁️ 6303
每千克水相当于多少毫升?
365娱乐场体育投注

每千克水相当于多少毫升?

⌛ 07-19 👁️ 2324