头图

本篇博客是【深度学习个人笔记】Pytorch基本知识博客的扩展,采用里面介绍的Pytorch中的常用模块来搭建经典的VGG网络,使用的是Pytorch的官方代码,通过官方代码来体会如何用Pytorch搭建神经网络,并且学习如何保存神经网络的参数,并在GPU上进行训练
部分参考资料如下:
【Pytorch学习】-- 使用pytorch自带模型 -- VGG16
PyTorch中模型的parameters()方法浅析
使用pytorch搭建VGG网络 学习笔记
VGG网络的Pytorch官方实现过程解读
Pytorch(六)(模型参数的遍历) —— model.parameters() & model.named_parameters() & model.state_dict()
pytorch_神经网络模型搭建系列(1):自定义神经网络模型
PyTorch中的model.modules(), model.children(), model.named_children(), model.parameters(), model.nam...
深入浅出Pytorch函数——torch.nn.init.normal_
typing -- 给你的 Python 加上类型注解
【python笔记-2】cls含义及使用方法

VGG基本介绍

VGG是非常经典的神经网络模型,其中比较常用的结构是VGG-16和VGG-19,16和19代表神经网络的层数,其具体网络结构如下表所示:

网络结构VGG-16VGG-19
模块1conv3-64
conv3-64
conv3-64
conv3-64
下采样maxpoolmaxpool
模块2conv3-128
conv3-128
conv3-128
conv3-128
下采样maxpoolmaxpool
模块3conv3-256
conv3-256
conv3-256
conv3-256
conv3-256
conv3-256
conv3-256
下采样maxpoolmaxpool
模块4conv3-512
conv3-512
conv3-512
conv3-512
conv3-512
conv3-512
conv3-512
下采样maxpoolmaxpool
模块5conv3-512
conv3-512
conv3-512
conv3-512
conv3-512
conv3-512
conv3-512
下采样maxpoolmaxpool
全连接层FC-4096FC-4096
全连接层FC-4096FC-4096
全连接层FC-1000FC-1000
softmax分类输出层softmaxsoftmax

上表中卷积层的参数表示为conv<卷积核大小>-<卷积核个数>,采用ReLU激活函数
可见VGG-16中包含了16个隐藏层(13个卷积层和3个全连接层),VGG-19中包含了19个隐藏层(16个卷积层和3个全连接层)
 
 
可以用可视化的形式来理解VGG的网络结构,以VGG-16为例:
VGG-16的可视化形式
上图中黑色边框代表卷积层和ReLU激活函数。红色边框代表最大池化,黄色边框代表softmax,蓝色边框代表全连接层和ReLU激活函数

Pytorch搭建VGG-16和VGG-19

在Pytorch中调用实现好的VGG-16和VGG-19

Pytorch已经实现了很多经典模型,VGG就是其中之一,同时现在我们使用VGG网络通常是采用预训练模型。在Pytorch中直接调用VGG可以用torchvision.models,并可以自己选择是否使用预训练的模型:

import torchvision
vgg16 = torchvision.models.vgg16(pretrained = False)
vgg16_pretrained = torchvision.models.vgg16(pretrained = True)
vgg19 = torchvision.models.vgg19(pretrained = False)
vgg19_pretrained = torchvision.models.vgg19(pretrained = True)

可以直接通过print查看网络结构和网络参数(将下面代码中的net替换为上面的模型变量即可):

print(net)
print(net.parameters())
print(net.named_parameters())

对于直接打印模型变量,输出结果如下(以上面的vgg16为例):

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

可见与之前表中的VGG-16的网络结构一致
 
而对于直接打印模型变量,输出结果如下(以上面的vgg16为例):

<generator object Module.named_parameters at 0x00000183F6A9EF90>
<generator object Module.parameters at 0x00000183F6A9EF90>

可见.parameters()方法和.named_parameters()方法会返回一个生成器(迭代器),可以直接将其转换为列表并打印出来(这样子做的结果就是一堆tensor),也可以用enumerate,只不过对于.named_parameters()方法而言返回的是len为2的tuple,第一个元素是name,第二个元素是name对应的值,如下所示:

for _,param in enumerate(net.named_parameters()):
    print(param[0])
    print(param[1])
    print('----------------')

for _,param in enumerate(net.parameters()):
    print(param)
    print('----------------')

Pytorch官方实现VGG的过程

如何参考Pytorch官方实现VGG的过程

在上一小节中,调用VGG模型使用了torchvision.models,实际上说的更具体些,是models里的vgg.py,具体在PC上的路径如下(其中的torch是自己在anaconda创建的环境名,需要根据实际情况修改):

Anaconda3\envs\torch\lib\site-packages\torchvision\models\vgg.py

下面就看一下Pytorch官方实现VGG-16和VGG-19的过程,学习一下Pytorch的用法,顺便学习一些里面用到的Python知识,值得一提的是,这里使用的是Python版本是3.9,Pytorch版本是1.12,CUDA版本是11.6,可以通过如下的代码获取:

print("Python version:", sys.version)
print("Pytorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)

在我的电脑上的运行结果如下:

Python version: 3.9.13 (main, Oct 13 2022, 21:23:06) [MSC v.1916 64 bit (AMD64)]
Pytorch version: 1.12.0
CUDA version: 11.6

 
 
 
下面就从头到尾来阅读一下vgg.py,为了阅读的方便,我将vgg.py分为几个部分进行说明,对应下面的几个标题

vgg.py的第一部分

在vgg.py中首先定义了一个列表,很显然列表里的内容对应着Pytorch实现的与VGG相关的不同模型种类,具体这些模型有什么区别可以等后面看到具体的类或函数再来深究(这里只关注和VGG-16和VGG-19相关的模型):

__all__ = [
    "VGG",
    "VGG11_Weights",
    "VGG11_BN_Weights",
    "VGG13_Weights",
    "VGG13_BN_Weights",
    "VGG16_Weights",
    "VGG16_BN_Weights",
    "VGG19_Weights",
    "VGG19_BN_Weights",
    "vgg11",
    "vgg11_bn",
    "vgg13",
    "vgg13_bn",
    "vgg16",
    "vgg16_bn",
    "vgg19",
    "vgg19_bn",
]

 

vgg.py的第二部分

接下来在vgg.py中定义了VGG通用网络结构,这是非常标准的定义自己的神经网络模型的方法,通过继承nn.Module类并写好forward方法:

class VGG(nn.Module):
    def __init__(
        self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
    ) -> None:
        super().__init__()
        _log_api_usage_once(self)
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, 0, 0.01)
                    nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

可见Pytorch官方在forward方法中将VGG通用网络结构分为了三个部分:featuresavgpoolclassifier,其实这对应的是上面的VGG网络结构表中的模块5的后面的部分,即通过前面的卷积神经网络(不同的VGG类型,比如VGG-16和VGG-19对应的卷积神经网络结构不同)提取特征,通过自适应平均池化(nn.AdaptiveAvgPool2d)将特征图池化到固定尺寸7×7大小,然后将特征图展平为1维向量(torch.flatten(x, 1)),最后通过分类器,即三层全连接层输出对应类别数量维度的向量(这里默认是1000维,即在类的初始化中输入的num_classes: int = 1000)

上面的代码除了定义VGG通用网络结构外,还有关于权重初始化的部分,即在类的初始化方法中if init_weights:后面的内容,这里的判断条件init_weights默认为True(init_weights: bool = True),而self.modules()会返回一个迭代器,会迭代遍历模型的所有子层(即nn.Module的子类,这里VGG、features、avgpool、classifier、池化、Linear等都是nn.Module的子类);isinstance()用于判断一个对象是否是一个已知的类型,同时会认为子类是一种父类类型,考虑继承关系。

根据上面的分析,这里的if判断语句分为了三个部分:nn.Conv2d、nn.BatchNorm2d和nn.Linear,分别对这三部分进行初始化,初始化函数有三种,其功能分别如下:

(1)nn.init.kaiming_normal_:何恺明大神提出的参数初始化(论文:Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification)
(2)nn.init.constant_(tensor, val):用val的值填充整个张量
(3)nn.init.normal_(tensor, mean=0.0, std=1.0):从给定均值和标准差的正态分布中生成值,填充输入的张量

介绍了Pytorch搭建神经网络的相关知识,下面就介绍一下代码中有关Python的技巧:

(1)在初始化方法和forward方法传入的参数中,每一个都有一个冒号,比如num_classes: int = 1000,这些参数的冒号其实是参数的类型建议符,目的是告诉使用程序的人希望传入的实参的类型,类型建议符并非强制规定和检查,即使传入的实际参数与建议参数不符,也不会报错
(2)同时初始化方法和forward方法后面都跟着一个箭头,比如-> torch.Tensor,这是说明该方法返回的值是什么类型
(3)_log_api_usage_once(self)

 

vgg.py的第三部分

接下来在vgg.py中定义了函数make_layers和_vgg以及字典cfgs,这些都是构成后面具体网络结构比如VGG-16和VGG-19的基础:

def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
    layers: List[nn.Module] = []
    in_channels = 3
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            v = cast(int, v)
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


cfgs: Dict[str, List[Union[str, int]]] = {
    "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
    "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}


def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
    if weights is not None:
        kwargs["init_weights"] = False
        if weights.meta["categories"] is not None:
            _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
    return model

可见make_layers函数是根据配置表返回模型层列表;而_vgg函数实现了整个网络模型的生成以及预训练权重的导入,有关这两个函数的讲解先放在后面(原因在下一段)。

由于上面的代码涉及到一些新的数据类型,这里先对这些进行讲解,即List、Union、Dict,这些数据类型来自typing模块。在前面vgg.py的第二部分的说明中提到了一些有关类型注解的说明,但是对于Python中的列表、元组和字典而言,其中包含元素的复合类型,仅仅使用简单的list、dict和tuple不能准确说明内部元素的具体类型,因此需要用到typing模块提供的复合注解的功能,具体的看一下下面给出的代码就可以理解了:

  • List

    import typing
    a: typing.List = []        # 表示参数 a 是一个 list 类型的参数
    a: typing.List[str] = ["string1", "string2"]    # 表示参数 a 是一个 list 类型的参数, list中的元素为 str 类型
  • Union:联合类型, Union[X, Y] 等价于 X | Y ,意味着满足 X 或 Y 之一

    a: typing.Union[int, str] = 100        # 表示 a 是 int 或者 str 类型,二选一
  • Dict

    import typing
    a: typing.Dict = {}        # 表示参数 a 是一个 dict 类型的参数
    a: typing.Dict[str, int] = {"string1": 10}        # 表示参数 a 是一个 dict 类型的参数, dict 中 key 为 str 类型, value 为 int 类型

    综上所述,cfgs就是一个字典,只不过其里面的类型需要符合定义,"A"、"B"、"D"、"E"就是不同的VGG网络配置对应的代号,这个在VGG论文中可以看到,我们需要的VGG-16和VGG-19分别是D和E:
    不同VGG模型对应的字母
    那么字典中每一个键对应的值的意义其实也就确定了,就是网络结构,这里用列表列出是为了用函数生成VGG网络中的卷积部分,可以参考上面的具体网络结构表
     
    接下来就具体说明一下两个函数(make_layers和_vgg)的作用:

  • make_layers函数:layers就是一个列表,列表里的内容就是配置表对应的神经网络的各个层;cfg可以根据它的类型发现就是具体的配置表,表明有几层卷积层和池化层,如果cfg中的某一个元素是"M",则需要在layers中加入最大池化层;如果cfg中的某一个元素是int类型的,则需要在layers中加入卷积层,cfg中的元素就是输出的通道数;同时还需要对batch_norm进行判断,以决定是否在layers中加入Batch Normalization
  • _vgg函数:根据配置表调用make_layers来创建卷积神经网络的部分,然后作为参数(即features)输入到VGG通用网络结构,同时如果输入的weights参数不为None,则使用.load_state_dict方法将权重加载到模型中(progress表示是否显示下载进度条)

在_vgg函数中也使用了一些typing模块中的一些东西,即Any、Optional和cast,下面对这些进行说明:

  • Any:静态类型检查器视Any与任何类型兼容
  • Optional:可选类型,作用几乎和带默认值的参数等价,不同的是使用Optional可以表示参数除了给定的默认值外还可以是None
  • cast:在编写代码时强制指定变量的类型,这个函数将返回一个被转换成指定类型的变量

最后也介绍一下这一段代码使用的Python技巧:

  • 在make_layers函数的返回值中,其返回中传入的参数带有*,即nn.Sequential(*layers),后面的_vgg函数的传入参数中带有**,即**kwargs,这其实与Python代码中常见的*args和**kwargs有关

    • *args:arguments,表示位置参数
    • **kwargs:keyword arguments,表示关键字参数
    • 这两种参数是Python中可变参数的两种形式,且*args必须放在**kwargs的前面,*args传递了一个可变参数元组给函数实参,这个参数列表的数目未知,长度可以为0;**kwargs将一个可变的关键字参数的字典传给函数实参,同样参数列表长度可以为0或为其他值
    • 单星号还可以解压参数列表,如下面的程序所示:
    def foo(runoob_1, runoob_2):
        print(runoob_1, runoob_2)
      l = [1, 2]
      foo(*l)

vgg.py的第四部分

vgg.py的第四部分定义了一个字典_COMMON_META以及一些以_Weights为结尾的类,这些类是原作者将模型预训练权重和其他配置等打包好,放在该类中提供给我们使用,因此如果需要自己实现模型一般不用这些类(这里仅对VGG-16和VGG-19相关的进行说明):

_COMMON_META = {
    "min_size": (32, 32),
    "categories": _IMAGENET_CATEGORIES,
    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
    "_docs": """These weights were trained from scratch by using a simplified training recipe.""",
}
class VGG16_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vgg16-397923af.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 138357544,
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 71.592,
                    "acc@5": 90.382,
                }
            },
        },
    )
    IMAGENET1K_FEATURES = Weights(
        # Weights ported from https://github.com/amdegroot/ssd.pytorch/
        url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
        transforms=partial(
            ImageClassification,
            crop_size=224,
            mean=(0.48235, 0.45882, 0.40784),
            std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0),
        ),
        meta={
            **_COMMON_META,
            "num_params": 138357544,
            "categories": None,
            "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": float("nan"),
                    "acc@5": float("nan"),
                }
            },
            "_docs": """
                These weights can't be used for classification because they are missing values in the `classifier`
                module. Only the `features` module has valid values and can be used for feature extraction. The weights
                were trained using the original input standardization method as described in the paper.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V1


class VGG16_BN_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 138365992,
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 73.360,
                    "acc@5": 91.516,
                }
            },
        },
    )
    DEFAULT = IMAGENET1K_V1


class VGG19_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 143667240,
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 72.376,
                    "acc@5": 90.876,
                }
            },
        },
    )
    DEFAULT = IMAGENET1K_V1


class VGG19_BN_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 143678248,
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 74.218,
                    "acc@5": 91.842,
                }
            },
        },
    )
    DEFAULT = IMAGENET1K_V1

首先_COMMON_META定义了一些基本信息,而后面几种类继承自WeightsEnum,Pytorch中的每个模型的权重类都继承于该类,代码中的Weights也是一个类,该类存储模型的信息,有三个参数:

  • url:预训练权重下载地址(str)
  • transfomer:模型的预处理方法(callable)
  • meta:Stores meta-data related to the weights of the model and its configuration. These can be informative attributes (for example the number of parameters/flops, recipe link/methods used in training etc), configuration parameters (for example the num_classes) needed to construct the model or important meta-data (for example the classes of a classification model) needed to use the model.

因此,这一部分的vgg.py主要就是定义一些预训练权重类和它们的基本信息,方便搭建VGG网络使用

vgg.py的第五部分

vgg.py的第五部分定义了构建模型的函数,在实际使用时,我们直接调用第五部分的函数来完成VGG网络搭建就可以了,同样这里只关注VGG-16和VGG-19相关的函数:

@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1))
def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
    """VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.

    Args:
        weights (:class:`~torchvision.models.VGG16_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.VGG16_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.VGG16_Weights
        :members:
    """
    weights = VGG16_Weights.verify(weights)

    return _vgg("D", False, weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1))
def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
    """VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.

    Args:
        weights (:class:`~torchvision.models.VGG16_BN_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.VGG16_BN_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.VGG16_BN_Weights
        :members:
    """
    weights = VGG16_BN_Weights.verify(weights)

    return _vgg("D", True, weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1))
def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
    """VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.

    Args:
        weights (:class:`~torchvision.models.VGG19_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.VGG19_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.VGG19_Weights
        :members:
    """
    weights = VGG19_Weights.verify(weights)

    return _vgg("E", False, weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1))
def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
    """VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.

    Args:
        weights (:class:`~torchvision.models.VGG19_BN_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.VGG19_BN_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.VGG19_BN_Weights
        :members:
    """
    weights = VGG19_BN_Weights.verify(weights)

    return _vgg("E", True, weights, progress, **kwargs)

可见这些函数都是类似的,首先每个函数的开头都有一个@handle_legacy_interface处理旧接口,这里指定如果使用pretrained参数,将其映射到相应的权重,这是为了兼容之前的接口;其次调用了两个东西,先调用第四部分中的以_Weights为结尾的类来生成权重,然后调用第三部分中的_vgg函数生成整个网络模型,而对于带_bn的模型的生成,调用_vgg函数的第二个参数(batch_norm)传入为True,否则为False

另外,这里调用以_Weights为结尾的类时使用了.verify方法,可以看一下源码:

    @classmethod
    def verify(cls, obj: Any) -> Any:
        if obj is not None:
            if type(obj) is str:
                obj = cls.from_str(obj.replace(cls.__name__ + ".", ""))
            elif not isinstance(obj, cls):
                raise TypeError(
                    f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
                )
        return obj

其中有一个cls参数,在Python中用于表示一个类本身,在类方法中被使用,用于引用调用该方法的类。类方法是一种特殊的方法,它与类相关联,而不是与实例相关联,这种方法在定义时需要使用@classmethod装饰器来标识:

  • 在类方法中,第一个参数通常被命名为cls,这是一个惯例,但不是强制要求,使用cls参数可以在类方法中访问和修改类的属性,或者创建类的新实例,可以通过以下的程序进行理解:

    class Person(object):
      def __init__(self, name, age):
          self.name = name
          self.age = age
          print('self:', self)
    
      # 定义一个build方法,返回一个person实例对象,这个方法等价于Person()
      @classmethod
      def build(cls):
          # cls()等于Person()
          p = cls("Tom", 18)
          print('cls:', cls)
          return p
    
    
    if __name__ == '__main__':
      person = Person.build()
      print(person, person.name, person.age)

最后也来说一下这一部分使用的Python技巧:

  • 在这一部分调用的函数中第一个参数都为*,比如def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:中的第一个参数,其实在Python中在定义函数的时候可以有/和*这两种参数,下面对这两种参数的作用进行说明:

    • /:/符号之前的参数都只能使用位置参数而不能使用关键字参数传参(positional-only),可以参考如下的程序:
    def f1(a, b, /):
      return a + b

    调用f1时参数a、b只能使用特定的值,而不能以关键字传参,即f1(2, 3)执行正确,但是f1(a=2, 3)和f1(2, b=3)执行错误

    • *:*符号之后的参数只能用关键字参数的形式传参(keyword-only),可以参考如下的程序:
    def f1(a, *, b, c):
      return a + b + c

    调用f1时参数a可以任意值,但是b和c一定要以关键字参数的形式传参,比如f1(1, b=4, c=5)


Kazusa
5 声望7 粉丝

梦里不觉秋已深,余情岂是为他人


引用和评论

0 条评论