Article 1 - What is torch.fx

Today, let’s talk about the more important torch.fx , and take this opportunity to organize the previous torch.fx notes. The notes are probably divided into three parts, corresponding to three articles:

  • what is torch.fx
  • Quantization based on torch.fx
  • Deploy to TensorRT based on torch.fx quantization

This article corresponds to the first article, which mainly introduces torch.fx and its basic usage. Without further ado, let's get started!

What is Torch.FX

torch.fx Pytorch 1.8 a set of tools or a library from ---d8b129841c9608f53ea625fbfeb41e50--- that does python-to-python code transformation , to the effect that you can convert the python forward code in pytorch to what you want The official introduction is as follows:

We apply this principle in torch.fx, a program capture and
transformation library for PyTorch written entirely in Python and optimized for high developer productivity by ML practitioners
The above papers from FX, if you are interested, you can read TORCH.FX: PRACTICAL PROGRAM CAPTURE AND TRANSFORMATION FOR DEEP LEARNING IN PYTHON . There is also a good interpretation on Zhihu, so I will not repeat it here. However, this article will also introduce the content of the paper, more from a practical point of view.

The core keywords are program capture and transformation library , these two concepts are very important.

So how to use FX ? Intuitively, we define a pytorch.nn.module :

 class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

It is very simple to inherit from the Module of torch.nn.Module (those who are familiar with pytorch should understand). The forward function also records the specific operation logic of this module.

If we want to replace part of the operation logic in forward in this Module self.linear(x + self.param).clamp(min=0.0, max=1.0) clamp with part of sigmoid , what should we do?

Of course, you can directly change the code, but if there are many operations, or you have written a lot of modules, or you want to do a lot of experiments (some modules are changed and some modules are not changed) , then it will be more cumbersome.

At this time, FX is needed, and we do not need to manually modify the code (that is, to change the forward implementation ourselves), just set the rules, use torch.fx , bring this model instance into it, and run the code. Then your forward part in this MyModule will become self.linear(x + self.param).sigmoid() :

 module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
# 打印查看FX的IR
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
# 通过FX生成的代码,可以视为module中的forward代码
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

In this way, FX will help you modify this Module, and the modified one--- dd78ae85ac5dcc81f04f0b287858f754 model can be used as usual. Note that here, FX captures the forward code you wrote, and then transforms it, modifying it operation.

Of course, this is just a very simple function of fx, we can also pass fx:

  • Fusion of two ops, such as conv and bn
  • remove some ops
  • replace some ops
  • Insert some ops or other operations after some ops

Etc., etc.

You may be wondering whether these operations are very similar to the PASS in the AI compiler , and the operation object is also a DAG (directed acyclic graph) of a neural network. In fact, FX can also be understood as a compiler, but the executable file finally generated by this compiler is python->python The final product is still python code based on Pytorch rules, which is why FX keeps saying that it is Python-to-Python (or Module-to-Module) transformation toolkit instead of compiler .

At present, most of the APIs of FX are stable (officially released in torch-1.10), and there is not much historical burden to use.

The official introduction of fx:

The relationship between torch.fx and quantization

The first advantage of FX is the quantization tool based on Pytorch, which is one of the reasons why I introduced FX. With the help of FX, it is very convenient to quantify the pytorch model. Before Shangtang, a quantification tool based on fx, MQBench , was released.

For quantization, whether it is PTQ (need to insert observation op to collect activation distribution and weight distribution of each layer) or QTA (need to insert fake quantization node to simulate quantization), the function of fx will be involved. So if you want to do quantification based on the Pytorch framework, it is recommended to start directly torch.fx .

fx is already in a stable state in pytorch-1.10 , and most of the APIs have been stabilized. I also quantified several models with torch.fx, and finally got it on TensorRT, involving convolution, BN, deconvolution , add, concat and other basic operations, the versions used are Pytorch-1.10 and TensorRT-8.2 .

The fx part has modified the source code and added some ops. Here I directly take out the fx part of the latest release of pytorch, and then pip install torch-1.10.0+cu113-cp38-cp38-linux_x86_64.whl , the two are eaten together.

Difference from TorchScript

In fact, at the beginning torch.fx I also thought about the difference between the two when they appeared. They both first analyze the model, then generate IR , then do some optimization based on IR, and finally generate a final version of the optimized model , Is one the python version and the other the C++ version? Certainly not so simple. When you use FX a lot, you will find that the positioning of FX and torchscript is different. FX focuses more on making some functional changes to the model (such as adding batches, modifying certain operations, such as adding statistical operations, such as quantization); The torchscript focuses more on optimizing the performance of the current model , and can be separated from python and only run in the C++ environment.

To borrow an official answer:

torch.fx is different from TorchScript in that it is a platform for Python-to-Python transformations of PyTorch code. TorchScript, on the other hand, is more targeted at moving PyTorch programs outside of Python for deployment purposes. In this sense, FX and TorchScript are orthogonal to each other, and can even be composed with each other (eg transform PyTorch programs with FX, then subsequently export to TorchScript for deployment).

The general idea is that FX is only doing Python2Python的 conversion, unlike Torchscript which is doing the conversion for deployment (out of the Python environment, running in C++). There is no relationship between the two, there is no conflict, the model converted with FX can also be converted with torchscript , and the two are orthogonal.

Python to Python?

However, it should be noted that the code generation of FX is from Python to Python. In other words, the code generated by FX is no different from the network we usually use nn.Module to build. You can directly use Pytorch's eager mode to run, unlike torchscript , is another set of runtime (when we run torchscript, we actually call a VM, which is a virtual machine, and run the model exported by torchscript in C++ through VM).

Therefore, the type of model converted by fx is the same as nn.Module 0b101ea54e18712d479d836fc90c9c2c---, so what nn.Module can do, can also be done for the converted model, we can do it continuously:

  • Module -> fx written by yourself is still Module -> continuous fx change -> get the final fx model

IR for FX and IR for Jit

These two IRs are different. Compared with Jit's IR, FX's IR has two advantages:

  • FX is tightly integrated into the Python runtime because FX captures prograim representations more accurately, unlike jit.trace which sometimes fails.
  • There is no difference between FX's Graph and torch.nn.module , and its IR is not so low-level, so it is easier to use and the efficiency will be improved.

Here is a brief list of the IRs of FX. It is very simple. There are only six types. The approximate functions are to adjust functions, extract attr, and obtain input and output :

  • placeholder represents a function input. The name attribute specifies the name this value will take on. target is similarly the name of the argument. args holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. kwargs is don't-care. Placeholders correspond to the function parameters (eg x ) in the graph printout.
  • get_attr retrieves a parameter from the module hierarchy. name is similarly the name the result of the fetch is assigned to. target is the fully-qualified name parameter's position in the module hierarchy. args and kwargs are don't-care
  • call_function applies a free function to some values. name is similarly the name of the value to assign to. target is the function to be applied. args and kwargs represent the arguments to the function, following the Python calling convention
  • call_module applies a module in the module hierarchy's forward() method to given arguments. name is as previous. target is the fully-qualified name of the module in the module hierarchy to call. args and kwargs represent the arguments to invoke the module on, including the self argument .
  • call_method calls a method on a value. name is as similar. target is the string name of the method to apply to the self argument. args and kwargs represent the arguments to invoke the module on, including the self argument
  • output contains the output of the traced function in its args[0] attribute. This corresponds to the "return" statement in the Graph printout.

Compared with torchscript's IR, FX's is much simpler, and we understand and use it very easily.

symbolic tracer

Going back to the code in the example at the beginning, one line is symbolic_traced : torch.fx.GraphModule = symbolic_trace(module) , the core here is the symbolic_trace function, which is the starting point of the FX parsing and conversion model. This function actually looks like this:

 @compatibility(is_backward_compatible=True)
def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None,
                   enable_cpatching: bool = False) -> GraphModule:
    """
    Symbolic tracing API

    Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
    constructed by recording operations seen while tracing through ``root``.

    ...
    """
    tracer = Tracer(enable_cpatching=enable_cpatching)
    graph = tracer.trace(root, concrete_args)
    name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
    return GraphModule(tracer.root, graph, name)

First it will create a Tracer class and then use the member function trace ours torch.nn.Module . After we trace the model, we can modify the model:

 def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`
    # 使用 Tracer 类对象去trace模型 m
    # 这边是拆开了,这个transform函数就是实现torch.fx.symbolic_trace的功能
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: 这里就可以任意修改模型了,这也是重点
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

The modified model can be used directly, or you can use graph_module.to_folder to extract this model and use it as a separate module (more on this later). The overall process is roughly like this:

symbolic tracing -> intermediate representation -> transforms -> Python code generation.

The respective functions are:

  • symbolic
The symbolic tracer performs “symbolic execution” of the Python code. It feeds fake values, called Proxies, through the code. Operations on theses Proxies are recorded. More information about symbolic tracing can be found in the symbolic_trace() and Tracer documentation.
  • intermediate representation
The intermediate representation is the container for the operations that were recorded during symbolic tracing. It consists of a list of Nodes that represent function inputs, callsites (to functions, methods, or torch.nn.Module instances), and return values. More information about the IR can be found in the documentation for Graph. The IR is the format on which transformations are applied.
  • Python code generation
Python code generation is what makes FX a Python-to-Python (or Module-to-Module) transformation toolkit. For each Graph IR, we can create valid Python code matching the Graph's semantics. This functionality is wrapped up in GraphModule, which is a torch.nn.Module instance that holds a Graph as well as a forward method generated from the Graph.

The above are the three core functions of FX.

Proxy/Retracing is the core of symbolic trace . Because my understanding of Proxy/Retracing is not very deep, I will not describe it here without authorization. Here is the official introduction:

Proxy objects are Node wrappers that flow through the program during symbolic tracing and record all the operations (torch function calls, method calls, operators) that they touch into the growing FX Graph.

If you're doing graph transforms, you can wrap your own Proxy method around a raw Node so that you can use the overloaded operators to add additional things to a Graph.

Correlation structure

The main structure of FX is Graph and GraphModule , of which A Graph is a data structure that represents a method on a GraphModule . It can be understood as Graph which stores the most critical in the network Node , these nodes are the nodes in the network (such as convolution, relu, add, concat, etc.), these The node records the corresponding method and input and output information, so that it can be strung together to form the logic of the network.

Through print_tabular() you can print out the node information in the graph:

 import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)
# 这里打印module中的node
gm.graph.print_tabular()

The print information is as follows:

graph中的node

It can be seen that for the input x, the corresponding IR type is placeholder ; for the weight information, the corresponding IR type is get_attr ; for the specific actual operations (add, linear, sum, relu, topk, etc.), corresponding to call_function , call_module these two IRs, and the final output corresponds to output this IR.

At the same time, it also prints the input information and additional parameter information of each node, through which the nodes can be connected.

However, it is not enough to have a graph, you also need a GraphModule. GraphModule inherits from torch.nn.Module and contains the parameters required by the forward function and the module in the network, which will be called by the node in the graph.

To sum up , that is, the nodes in the graph contain the logical information of the network, and then the front and back calling relationships of these nodes will be recombined by FX into the forward code generated by FX in the GraphModule (which can be printed out by traced.code), and these generated The code will need the parameter information in the GraphModule to ensure smooth execution.

Modify Graph

Now that you know that the graph contains the sequential execution information of the network, if you want to modify the network, you can directly modify the node:

 import torch
import torch.fx

# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # 对于graph中的node,FX会以顺序的形式来表示这个网络
    # 所以我们可以直接for循环来遍历:
    for node in graph.nodes:
        # 检测该node的IR类型是否是call_function
        if node.op == 'call_function':
            # 修改node.target为torch.mul,网络也因此变了
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # Does some checks to make sure the
                 # Graph is well-formed.

    return fx.GraphModule(m, graph)

To briefly mention, node.target represents which target is called in call_function 78041c1373903c90318fe06a4f2154d8---, and torch.add is the operation op that comes with pytorch. When calling this node, it will actually call torch.add .

Elegantly modify the graph network

The above direct modification is simple and rude, and FX also provides us with Graph rewrites tools, we can use these tools to easily add or delete a node:

 # Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
    # Insert a new `call_function` node calling `torch.relu`
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))
    # We want all places that used the value of `node` to
    # now use that value after the `relu` call we've added.
    # We use the `replace_all_uses_with` API to do this.
    node.replace_all_uses_with(new_node)

Modify the network with the help of replace_pattern

Graph rewrites tools are available (the related concept comes from the compiler), then the match pattern must also be available, we can modify the entire graph by replace_pattern() . For pattern, you can use the one that comes with fx or you can add your own rules:

 # Sample module
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, w1, w2):
        val1 = torch.neg(w1)
        m1 = torch.cat([val1, w2]).sum()
        val2 = torch.neg(w1)
        m2 = torch.cat([val2, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

# Symbolically trace an instance of `M`
traced = symbolic_trace(M())

# Define the pattern. 
def pattern(a1, a2):
    val1 = torch.neg(a1)
    return torch.cat([val1, a2]).sum()

# Define the replacement (same rules as the pattern)
def replacement(w1, w2):
    return torch.stack([w1, w2])

# Replace `pattern` with `replacement` in `traced`
replace_pattern(traced, pattern, replacement)

# After calling `replace_pattern`, the generated code is:
'''
def forward(self, x, w1, w2):
    stack = torch.stack([w1, w2])
    max_1 = torch.max(stack);  stack = None
    add = x + max_1;  x = max_1 = None
    stack_1 = torch.stack([w1, w2]);  w1 = w2 = None
    max_2 = torch.max(stack_1);  stack_1 = None
    add_1 = add + max_2;  add = max_2 = None
    return add_1
'''

Interpreter

Interpreter, the interpreter, is a good name to use. In fact, it is to loop a Graph node and execute them in a more elegant way, and complete some tasks at the same time. For example, we want to see the shape change of each layer during the run of the model:

 import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    """
    Shape propagation. This class takes a `GraphModule`.
    Then, its `propagate` method executes the `GraphModule`
    node-by-node with the given arguments. As each operation
    executes, the ShapeProp class stores away the shape and
    element type for the output values of each operation on
    the `shape` and `dtype` attributes of the operation's
    `Node`.
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # This is the only code specific to shape propagation.
            # you can delete this `if` branch and this becomes
            # a generic GraphModule interpreter.
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)

The above propagate function is very simple. It traverses the node once and records the information in node.shape and node.dtype . FX also provides the interpreter class, which stores some util functions, which we can use directly by inheritance (similar to the ShapeProp above).

Transformer

Transformer就是torch.nn.Module ,这些变换我们可以封装成一个函数或者写到类里头,其实Transformer PASS , in short, make some modifications to the network. For example:

 import torch
import torch.fx

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`

    # trace nn.Module
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: 这里对Graph进行修改
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)
Your transform will take in an torch.nn.Module, acquire a Graph from it, do some modifications, and return a new torch.nn.Module. You should think of the torch.nn.Module that your FX transform returns as identical to a regular torch.nn.Module – you can pass it to another FX transform, you can pass it to TorchScript, or you can run it. Ensuring that the inputs and outputs of your FX transform are a torch.nn.Module will allow for composability.

Of course, you can also directly modify GraphModule , there is no need to return a new one:

 import torch
import torch.fx

def transform(m : nn.Module) -> nn.Module:
    gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    # 这里修改 gm.graph
    # <...>

    # Recompile the forward() method of `gm` from its Graph
    gm.recompile()

    return gm

It should be noted that gm.recompile() must be added. After we modify the graph, we need recompile to regenerate the forward code.

Give a chestnut of FX

With so much foreshadowing, let’s simply give an actual example of FX. Here we use FX to quantify a target detection model based on the CenterNet framework. The backbone uses Resnet50 . Due to space limitations, this article only introduces the trace model and the fuse part. Let's talk about the article after quantifying and exporting trt.

First build the CenterNet model, and then trace:

 model = FXCenterNet()
tracer = Tracer()
graph_module = GraphModule(model, tracer.trace(model))

The function of trace is as follows, which is probably to traverse the operations in the model, convert it into a node according to the rules and store it in the graph, including attr and op, input and output and other information, and finally return the IR structure of the graph:

 @compatibility(is_backward_compatible=True)
def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
    # root FXCenterNet
    if isinstance(root, torch.nn.Module):
        self.root = root
        fn = type(root).forward
        self.submodule_paths = {mod: name for name, mod in root.named_modules()}
    else:
        self.root = torch.nn.Module()
        fn = root

    tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None)
    self.graph = Graph(tracer_cls=tracer_cls)
    # 这里大概就是遍历root中的操作,按照规则转换为node存放到graph中,
    # 包含attr和op、输入输出等信息,最终返回graph这个IR结构
    ... 
    return self.graph

The resulting self.graph type is torch.fx.graph.Graph .

 self.graph
<torch.fx.graph.Graph object at 0x7f57f59efdf0>

Call self.graph.print_tabular() to print the node information of the graph, you can see the familiar resnet-50-backbone structure, organized in the form of IR:

生成centernet-graph中的node信息

After generating the graph, start to assemble the GraphModule. The GraphModule is generated by the graph. The GraphModule will copy the parameters and module information in the node of the graph to itself:

 @compatibility(is_backward_compatible=True)
class GraphModule(torch.nn.Module):
    def __new__(cls: 'Type[GraphModule]', *args, **kwargs):
        for t in cls.__mro__:
            c = t.__qualname__.split('.')[-1]
            if c != 'GraphModuleImpl':
                cls = t
                break

        class GraphModuleImpl(cls):  # type: ignore[misc, valid-type]
            pass
        return super().__new__(GraphModuleImpl)

    @compatibility(is_backward_compatible=True)
    def __init__(self,
                 root: Union[torch.nn.Module, Dict[str, Any]],
                 graph: Graph,
                 class_name: str = 'GraphModule'):
        super().__init__()
        self.__class__.__name__ = class_name
        if isinstance(root, torch.nn.Module):
            if hasattr(root, 'training'):
                self.training = root.training
            # 这里拷贝graph中的参数信息和模块信息到self也就是GraphModule中
            for node in graph.nodes:
                if node.op in ['get_attr', 'call_module']:
                    assert isinstance(node.target, str)
                    _copy_attr(root, self, node.target)
        elif isinstance(root, dict):
            targets_to_copy = []
            for node in graph.nodes:
                if node.op in ['get_attr', 'call_module']:
                    assert isinstance(node.target, str)
                    if node.target not in root:
                        raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target +
                                           ' but that target was not provided in ``root``!')
                    targets_to_copy.append(node.target)
            targets_to_copy.sort(key=lambda t: t.count('.'))
            for target_to_copy in targets_to_copy:
                _assign_attr(root[target_to_copy], self, target_to_copy)
        else:
            raise RuntimeError('Unsupported type ' + str(root) + ' passed for root!')

        self.graph = graph
        self._tracer_cls = None
        if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
            self._tracer_cls = self.graph._tracer_cls
    __jit_unused_properties__ = ['graph']

The final graph_module contains the generated code, which is printed by print(graph_module.code) :

 def forward(self, input):
    input_1 = input
    upsampler_deconv_layers_0_bias = getattr(self.upsampler.deconv_layers, "0").bias
    ...
    head_angle_0 = getattr(self.head.angle, "0")(upsampler_deconv_layers_11);  upsampler_deconv_layers_11 = None
    head_angle_1 = getattr(self.head.angle, "1")(head_angle_0);  head_angle_0 = None
    head_angle_2 = getattr(self.head.angle, "2")(head_angle_1);  head_angle_1 = None
    return {'hm': head_hm_2, 'wh': head_wh_2, 'reg': head_reg_2, 'angle': head_angle_2}

At this time, we have the Module after the trace. This Module is no different from the original model. The forward function is also generated according to the forward of the original model. Because we just traced it briefly, the same input result is the same: graph_module(input) == original_model(input) , after all, nothing special.

OP fusion

Next is fuse, which directly calls the fuse function provided by FX. In fact, it calls Fuser :

 def _fuse_fx(
    graph_module: GraphModule,
    is_qat: bool,
    fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
    backend_config_dict: Optional[Dict[str, Any]] = None,
) -> GraphModule:
    r""" Internal helper function to fuse modules in preparation for quantization

    Args:
        graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
    """
    _check_is_graph_module(graph_module)
    fuser = Fuser()
    return fuser.fuse(
        graph_module, is_qat, fuse_custom_config_dict, backend_config_dict)

Let's see what Fuser have done. It's actually very simple. Just traverse the nodes in input_graph = model.graph and then fuse them according to the specified fuse rules. Fusion will involve modifying the graph structure. :

 class Fuser:
    def fuse(
        self,
        model: GraphModule,
        is_qat: bool,
        fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
        backend_config_dict: Optional[Dict[str, Any]] = None,
    ) -> GraphModule:
        if fuse_custom_config_dict is None:
            fuse_custom_config_dict = {}

        input_root = model
        input_graph = model.graph
        # 这里首先copy 原始模型中的named_modules中,之后会根据fuse情况进行修改
        self.modules = dict(input_root.named_modules())  
        ... 
        # 这里查找匹配的fuse pattern
        fusion_pairs = self._find_matches(
            input_root, input_graph, fusion_pattern_to_fuse_handler_cls)
        self.fused_graph = Graph()
        env: Dict[Any, Any] = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        def get_root_node(node_pattern):
            while not isinstance(node_pattern[-1], Node):
                node_pattern = node_pattern[-1]
            return node_pattern[-1]

        for node in input_graph.nodes:
            maybe_last_node, pattern, matched_node_pattern, obj = \
                fusion_pairs.get(node.name, (None, None, None, None))
            if maybe_last_node is node:
                assert obj is not None
                # TODO: currently we hard code the root node, which only works for
                # a sequence of ops and assume the root node is the last node,
                # we want to make this more general to support more complex patterns
                root_node = get_root_node(matched_node_pattern)  # 寻找fuse的根node
                env[node.name] = obj.fuse( # 这里将self传入,对self进行修改
                    self, load_arg, root_node, matched_node_pattern,  # type: ignore[arg-type]
                    fuse_custom_config_dict, fuser_method_mapping, is_qat)
            elif maybe_last_node is None:
                env[node.name] = self.fused_graph.node_copy(node, load_arg)
            # node matched in patterns and is not root is removed here

        preserved_attributes = set(fuse_custom_config_dict.get("preserved_attributes", []))
        model = FusedGraphModule(input_root, self.fused_graph, preserved_attributes)
        return model

    def _find_matches(
            self, root: GraphModule, graph: Graph,
            patterns: Dict[Pattern, Callable]
    ) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler]]:
        modules = dict(root.named_modules())
        match_map : Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler]] = {}  # node name -> (root_node, match_value)

        def apply_match(pattern, node, match, matched_node_pattern):
            if isinstance(pattern, tuple):
                s, *args = pattern
                current_node_pattern: List[Node] = []
                apply_match(s, node, match, current_node_pattern)
                for subpattern, arg in zip(args, node.args):
                    apply_match(subpattern, arg, match, current_node_pattern)
                matched_node_pattern.append(tuple(current_node_pattern))
            else:
                # the first pattern matches will take precedence
                if node.name not in match_map:
                    matched_node_pattern.append(node)
                    root_node, pattern, handler = match
                    match_map[node.name] = (root_node, pattern, matched_node_pattern, handler)
        # 这里就是match过程
        for node in reversed(graph.nodes):
            if node.name not in match_map:
                for pattern, value in patterns.items():
                    matched_node_pattern: List[Node] = []
                    if is_match(modules, node, pattern):
                        apply_match(pattern, node, (node, pattern, value(self, node)), matched_node_pattern)

        return match_map

As for the rules of which fuses are defined, you can find them in pytorch/torch/ao/quantization/fx/fusion_patterns.py here:

 # /ao/quantization/fx/fusion_patterns.py
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Linear))
class DefaultFuseHandler(FuseHandler):
    def __init__(
            self,
            quantizer: QuantizerCls,
            node: Node):
        super().__init__(quantizer, node)

    def fuse(...):
        # 这里执行实际的融合操作

DefaultFuseHandler类中的fuse方法内执行, fuser_method ,然后调用,返回融合后的fused_module setattr To modify the modules of the network, it will also modify the node in the graph by node_copy :

 matched_module_types = get_matched_types(matched_modules)
module_parent_name, module_name = _parent_name(root_node.target)
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
# TODO: change the signature for fuser_method to take matched module patterns
# as input
fused_module = fuser_method(is_qat, *matched_modules)
# TODO: maybe add a pass to cleanup bn modules?
setattr(quantizer.modules[module_parent_name], module_name, fused_module) # 往fuse控制的新模型中加入 新的modules
return quantizer.fused_graph.node_copy(root_node, load_arg)                # 往fuse控制的新graph中加入forward参数

Among them, the fusion details of Conv+bn+relu will call the pytorch/torch/ao/quantization/fuser_method_mappings.py fuse_conv_bn_relu function in ---c6b4a8d6f747711f9f1747bd1f3b0389---:

 def fuse_conv_bn_relu(is_qat, conv, bn, relu):
    assert(conv.training == bn.training == relu.training),\
        "Conv and BN both must be in the same mode (train or eval)."
    fused_module : Optional[Type[nn.Sequential]] = None
    map_to_fused_module_eval = {
        nn.Conv1d: nni.ConvReLU1d,
        nn.Conv2d: nni.ConvReLU2d,
        nn.Conv3d: nni.ConvReLU3d,
    }
    fused_module = map_to_fused_module_eval.get(type(conv), None)
    if fused_module is not None:
        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
        return fused_module(fused_conv, relu)
    else:
        raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))

The above fused_module is the torch.nn.intrinsic.modules.fused.ConvReLU2d class, which will call fuse_conv_bn_eval to actually suck bn to conv:

 def fuse_conv_bn_eval(conv, bn, transpose=False):
    assert(not (conv.training or bn.training)), "Fusion only for eval!"
    fused_conv = copy.deepcopy(conv)

    fused_conv.weight, fused_conv.bias = \
        fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
                             bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose)

    return fused_conv

def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False):
    if conv_b is None:
        conv_b = torch.zeros_like(bn_rm)
    if bn_w is None:
        bn_w = torch.ones_like(bn_rm)
    if bn_b is None:
        bn_b = torch.zeros_like(bn_rm)
    bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)

    if transpose:
        shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
    else:
        shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)

    conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(shape)
    conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b

    return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)

After sucking, get a new conv, and then bring in the ConvReLU2d class.

 class ConvReLU2d(_FusedModule):
    r"""This is a sequential container which calls the Conv2d and ReLU modules.
    During quantization this will be replaced with the corresponding fused module."""
    def __init__(self, conv, relu):
        assert type(conv) == Conv2d and type(relu) == ReLU, \
            'Incorrect types for input modules{}{}'.format(
                type(conv), type(relu))
        super().__init__(conv, relu)

The overall process is conv + bn->conv and then conv + relu -> ConvReLU2d .

The code after fuse is much cleaner, and both bn and relu are integrated (of course there are other integrations):

 def forward(self, input):
    input_1 = input
    backbone_conv1 = self.backbone.conv1(input_1)
    backbone_maxpool = self.backbone.maxpool(backbone_relu)
    backbone_layer1_0_conv1 = getattr(self.backbone.layer1, "0").conv1(backbone_maxpool)
    backbone_layer1_0_conv2 = getattr(self.backbone.layer1, "0").conv2(backbone_layer1_0_relu)
    backbone_layer1_0_conv3 = getattr(self.backbone.layer1, "0").conv3(backbone_layer1_0_relu_1)
    ...
    head_reg_0 = getattr(self.head.reg, "0")(upsampler_deconv_layers_11)
    head_reg_2 = getattr(self.head.reg, "2")(head_reg_1)
    head_angle_0 = getattr(self.head.angle, "0")(upsampler_deconv_layers_11)
    head_angle_2 = getattr(self.head.angle, "2")(head_angle_1)
    return {'hm': head_hm_2, 'wh': head_wh_2, 'reg': head_reg_2, 'angle': head_angle_2}

At this point, the models after trace and fuse are obtained, and you can see the fused ConvReLU2d module.

trace后以及fuse后的module

This GraphModule and torch.nn.module are used in exactly the same way, you can simply enter an image to verify.

In the next article we will quantify this GraphModule .

how to debug

Then we got the final GraphModule , how to debug it, that is, debug it step by step. There are also ways to debug fx to generate models in three ways:

Debug directly through pdb

We can enter FX Generated Code , or set breakpoints:

FX生成的代码是可以debug进去的

Print the generated code and combine with Module

Because the node in the graph contains the specified logic, and the GraphModule contains the model weight and other information, and these weight information are obtained through the original Module, then we can directly put the generated code into the forward of the original Module subclass to form A new Module to call.

 # Assume that `traced` is a GraphModule that has undergone some
# number of transforms

# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
    x = self.x
    add_1 = x + y;  x = y = None
    return add_1
"""

# 这里继承原始的Module
class SubclassM(M):
    def __init__(self):
        super().__init__()

    # 把生成的代码粘到这里
    def forward(self, y):
        x = self.x
        add_1 = x + y;  x = y = None
        return add_1

# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()

Isn't it common sense!

Use the to_folder function

As mentioned in the previous example, GraphModule.to_folder() is a magical function that can directly export the module generated by FX as a folder, which contains the parameters (.pt format) required by the model and Definition of the model.

FX代码导出fold

And the code of module.py is also generated for you:

 # 导出的module.py中
import torch
from torch.nn import *
class FusedModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # 这里加载权重参数信息
        self.backbone = torch.load(r'fx_debug/backbone.pt')
        self.load_state_dict(torch.load(r'fx_debug/state_dict.pt'))
        ...

    def forward(self, input):
        # 这里就是生成的code部分,也帮你写到forward中了
        input_1 = input
        backbone_conv1 = self.backbone.conv1(input_1)
        backbone_maxpool = self.backbone.maxpool(backbone_relu)
        backbone_layer1_0_conv1 = getattr(self.backbone.layer1, "0").conv1(backbone_maxpool)
        ...
        head_angle_0 = getattr(self.head.angle, "0")(upsampler_deconv_layers_11)
        head_angle_2 = getattr(self.head.angle, "2")(head_angle_1)
        return {'hm': head_hm_2, 'wh': head_wh_2, 'reg': head_reg_2, 'angle': head_angle_2}

Is it very powerful? !

We can also modify the generated code to do other experiments (but this export has some bugs, I don't know if I am using the wrong posture).

some restrictions

torch.fx also has some limitations (after all, it can't be perfect).

Because of the limitation of Symbolic execution .
Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a args or *kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors

The main limitation of symbolic tracing is it does not currently support dynamic control flow. That is, loops or if statements where the condition may depend on the input values of the program.

For more detailed restrictions, see the official introduction:

Let’s write it here first. The use of FX’s functions is more reflected in the quantization process. In the quantization practice in the next article, I will combine the quantization process to understand FX, and will also summarize the PTQ quantization process and attention. Point, I'm Lao Pan, see you in the next article~

tease me


老潘的博客
10 声望15 粉丝