请问MinkowskiEngine中怎么实现PyTorch: view( ), torch.bmm和torch.nn.Parameter()?

新手上路,请多包涵

我最近在使用MinkowskiEngine在Resnet中添加注意力模块,其中NonLocal注意力模块用Pytorch实现的代码如下:

class NonLocalModule(nn.Module):
    def __init__(self, C, latent= 8):
        super(NonLocalModule, self).__init__()
        self.inputChannel = C
        self.latentChannel = C // latent

        self.bn1 = nn.BatchNorm1d(C//latent)
        self.bn2 = nn.BatchNorm1d(C//latent)
        self.bn3 = nn.BatchNorm1d(C//latent)
        self.bn4 = nn.BatchNorm1d(C)

        self.cov1 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),
                                self.bn1,
                                nn.ReLU())
        self.cov2 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),
                                self.bn2,
                                nn.ReLU())
        self.cov3 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),
                                self.bn3,
                                nn.ReLU())
        self.out_conv = nn.Sequential(nn.Conv1d(in_channels=C//latent, out_channels=C, kernel_size=1, bias=False),
                                self.bn4,
                                nn.ReLU())

        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        b, c, n = x.shape

        out1 = self.cov1(x).view(b, -1, n).permute(0, 2, 1) #b,n,c/latent
        out2 = self.cov2(x).view(b, -1, n) #b, c/latent, n

        attention_matrix = self.softmax(torch.bmm(out1, out2)) # b,n,n

        out3 = self.cov3(x).view(b, -1, n) # b,c/latent,n

        attention = torch.bmm(out3, attention_matrix.permute(0, 2, 1)) # b,c/latent,n

        out = self.out_conv(attention) #b,c,n

        return self.gamma*out + x

nn.BatchNorm1d, nn.Conv1d和nn.ReLU都在MinkowskiEngine中有对应的实现版本。但是我搜索MinkowskiEngine官方文档却没有找到PyTorch: view( ), torch.bmm和torch.nn.Parameter()该如何用MinkowskiEngine实现。有人知道该怎么将上述代码改成MinkowskiEngine版本吗?

阅读 3k
1 个回答

大概思路:

import MinkowskiEngine as ME

class NonLocalModule(ME.MinkowskiNetwork):
    def __init__(self, C, latent=8):
        super(NonLocalModule, self).__init__()
        # ... 定义层和参数 ...

    def forward(self, x):
        # x 是一个 ME.SparseTensor 对象
        b, c, n = x.feats.size()

        # 用 view-like 操作
        out1 = self.cov1(x.feats).reshape(b, -1, n).permute(0, 2, 1)

        # 手动实现 bmm
        out2 = self.cov2(x.feats).reshape(b, -1, n)
        attention_matrix = []
        for i in range(b):
            attention_matrix.append(torch.mm(out1[i], out2[i]))
        attention_matrix = torch.stack(attention_matrix)

        # ...

        # 返回 ME.SparseTensor
        return ME.SparseTensor(feats=self.gamma * out + x.feats, coords_key=x.coords_key, coords_manager=x.coords_man)
撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题