我最近在使用MinkowskiEngine在Resnet中添加注意力模块,其中NonLocal注意力模块用Pytorch实现的代码如下:
class NonLocalModule(nn.Module):
def __init__(self, C, latent= 8):
super(NonLocalModule, self).__init__()
self.inputChannel = C
self.latentChannel = C // latent
self.bn1 = nn.BatchNorm1d(C//latent)
self.bn2 = nn.BatchNorm1d(C//latent)
self.bn3 = nn.BatchNorm1d(C//latent)
self.bn4 = nn.BatchNorm1d(C)
self.cov1 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),
self.bn1,
nn.ReLU())
self.cov2 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),
self.bn2,
nn.ReLU())
self.cov3 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),
self.bn3,
nn.ReLU())
self.out_conv = nn.Sequential(nn.Conv1d(in_channels=C//latent, out_channels=C, kernel_size=1, bias=False),
self.bn4,
nn.ReLU())
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
b, c, n = x.shape
out1 = self.cov1(x).view(b, -1, n).permute(0, 2, 1) #b,n,c/latent
out2 = self.cov2(x).view(b, -1, n) #b, c/latent, n
attention_matrix = self.softmax(torch.bmm(out1, out2)) # b,n,n
out3 = self.cov3(x).view(b, -1, n) # b,c/latent,n
attention = torch.bmm(out3, attention_matrix.permute(0, 2, 1)) # b,c/latent,n
out = self.out_conv(attention) #b,c,n
return self.gamma*out + x
nn.BatchNorm1d, nn.Conv1d和nn.ReLU都在MinkowskiEngine中有对应的实现版本。但是我搜索MinkowskiEngine官方文档却没有找到PyTorch: view( ), torch.bmm和torch.nn.Parameter()该如何用MinkowskiEngine实现。有人知道该怎么将上述代码改成MinkowskiEngine版本吗?
大概思路: