PyTorch 2.0 的正式发布,相信很多小伙伴已经使用过 PyTorch 2.0 的 compile 功能,也尝试写过自己的编译后端,对模型做一些定制化的优化。得益于 Dynamo 强大的字节码解析能力,我们能够在不关心代码解析过程的情况下,随心所欲地写编译优化后端。然而,由于字节码解析部分实现的复杂性,目前并没有比较完整的资料介绍其工作原理。今天我们就来由浅入深地好好聊一聊,PyTorch 2.0 中的 Dynamo,是如何完成 Graph trace 的。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
compile=True
目前 PyTorch Dynamo 的 dynamic_shape 功能还不完善,因此部分动态尺寸输入的算法,例如检测模型的编译可能会有一些问题。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
上一篇文章我们提到,Dynamo 是如何通过 PEP 523 改变 Python 默认的函数(帧评估)执行流程,将它从下图的 Default Python Behavior 转变为 TorchDynamo Behavior:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
在了解 Dynamo 设计的基石后,我们就可以一步一步地理解上图右侧栏各个流程框图的含义:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
- 在第一次执行被 torch.compile 编译的函数时,会走上图右侧的分支,从 PythonFrameObject(帧的定义可以见上篇文章)中解析出 PyCodeObject
- 基于 PyCodeObject 中的字节码解析出 fx graph,同时生成守卫(Guard),并在解析过程中使用指定后端对代码进行编译
- 将编译后的代码替换原有的代码,获得 Transformed PyCodeObject,函数实际运行时会调用编译后的代码
- 第二次执行时,守卫会判断是否需要重新编译,如果不需要则会从缓存中直接读取上次编译的代码,否则会触发重新编译
好的好的,一口气抛出这么多概念,相信不少小伙伴会有一种说了等于没说的感觉。没关系,今天我们由浅入深,详细介绍每一个步骤的内容。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
Dynamo 的帧执行流程
上篇文章我们提到,Dynamo 基于 PEP 523,设计了一个自定义的帧执行函数,而今天我们就来看看,这个函数具体做了哪些事(只保留了代码的主体逻辑,且不考虑 subgraph 等更复杂的情况):文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
- 调用 torch.compile 编译函数时,编译返回的函数实际为 _TorchDynamoContext 里定义的 _fn 函数
- _fn 会把 Python 默认的帧执行函数替换为 Dynamo 自定义的帧执行函数 _custom_eval_frame
- 执行目标函数时,会进入 _custom_eval_frame,并调用 callback 函数(关于 callback 函数的功能可以见上一篇文章)对帧进行解析,并返回编译结果 result
- callback: 即此处定义的 _compile 函数(去掉了多层 wrapper),用于解析字节码,进行 Graph trace,最后返回编译结果
- result:即 GuardedCode 实例,其中 code 属性为编译优化后的代码,check_fn 为检查代码,用于检查当前是否需要重新编译函数。
调用 callback 函数时还需要传入 cache_size 参数,表示当前是第几次编译该函数,第一次调用时其值为 0。当 cache_size 大于阈值时,不再编译该函数,按照原有逻辑执行。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
4. 将 result 缓存到链表 ,每执行一次编译链表都会新增一个元素。往后每次执行函数时都会根据当前帧的状态和链表中的往期编译结果来判断是否需要进行重新编译文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
5. 执行编译后的代码,返回结果文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
6. 第 2 次执行时,加载上次生成的 extra,进行查表操作(lookup)。遍历 extra 中的每个元素,执行 GuardedCode.check_fn文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
- 如果 extra 中某个元素的 check_fn 返回 True,则把该元素放到链表的最前端,方便下一次检查时优先遍历。同时终止遍历,运行之前编译好的代码。
- 如果所有的 check_fn 均返回 Fasle,则重复执行 2~4 步骤。需要注意的是,每执行一轮 2-4 步骤。
如果你觉得上述流程说得通,继续按照文章顺序阅读即可,如果你觉得上述流程存在逻辑缺陷,可以直接移步编译子图一节
如果你对 C 代码不是很熟,也可以跳过这部分的理解,只需要记住:
字节码解析最终会返回 GuardedCode 实例,该实例含有两属性,其中 check_fn 用来判断代码是否需要重新编译的,code 部分则存放编译好的代码。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
字节码解析与图生成
上一章提到的编译好的代码(GuardedCode.code)其实已经是 Dynamo 编译器前端解析+后端编译的最终产物了,而现在我们要介绍的字节码解析,正是前端解析的具体流程。本章我们会深入 callback 函数,理解如何从帧中解析字节码,获取模型图结构,最终生成 GuardedCode。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
在 CPython 中,Python 代码是在 CPython 的虚拟机中执行的,而执行的过程,正是上篇文章我们提到的 _PyEval_EvalFrameDefault 函数,它会将帧中函数的代码,解析成一系列的字节码,并在一大串的 switch-case 中逐条执行字节码,CPython 支持的所有字节码见 opcode.h。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
我们可以在 Python 代码中,通过使用 dis.dis 函数,来查看任意一个函数在 CPython 虚拟机中执行时的字节码:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
输出:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
这些字节码到底做了什么事呢,CPython 用非常复杂的 C 代码来解析每个字节码,而 Dynamo 则在 Python 层面对字节码进行解析,并 trace 模型的图结构的。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
如上图所示,字节码的解析可以大体分成以下 7 个步骤:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
- 解析输入参数 x,y,将其存储到局部变量(local_var)
- LOAD_FAST x:将变量 x push 入栈
- LOAD_FAST y:将变量 y push入栈
- BINARY_ADD:从 stack 中 pop 出 x,y ,计算出结果后将其 push 入栈
- STORE_FAST res:从 stack 中 pop 出栈顶元素(即上一步的结果),并将其存储到局部变量 res
- LOAD_FAST res:将变量 res push入栈
- RETURN_VALUE:pop 出栈中的 res 并返回
Dynamo 实现了 InstructionTranslator 来解析字节码,为了方便理解其核心内容,这边实现了简易版的 SimpleInstructionTranslator:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
在看样例代码之前,我们先介绍几个概念:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
torch.fx.Graphtorch.fx.GraphGraph.create_nodeGraph.python_code
runLOAD_FASTLOAD_FASTBINARY_ADDRETURN_VALUE
显然,SimpleInstructionTranslator 依旧很好地完成了 add_three 字节码解析和图 trace 的工作。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
由于实际解析的代码会更加的复杂,官方的 InstructionTranslator 实现了更多的字节码解析函数,处理各种各样的 corner case。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
事实上,PyTorch 并没有往 stack 里 push GraphNode 而选择往里面 push 一个新的抽象 VariableTracker,并在此基础上引入了 Guard 的概念。后续我们将会从原理和源码层面分析,为什么需要 VariableTracker 和 Guard,以及它们又是如何实现的。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
为什么需要 VariableTracker
字节码信息的不完整性文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
self.layer1(x)
LOAD_FASTLOAD_ATTRLOAD_FASTCALL_FUNCTION
self.layer1
Graph 的动态特性
Dynamo trace 出来图的动态特性,是由守卫(guard) 所赋予的,而守卫的载体就是 VariableTracker,这部分我们后续会进行详细介绍。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
字节码信息对于模型图结构是“冗余”的
Dynamo 基于字节码的 graph trace,其目的不是 trace 出一个完整 Python 的图表示,否则这和基于字节码重构抽象语法树也没有太大区别。这里给出一个简单的例子:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
LOAD_ATTR
VariableTracker
既然 Node 不适合直接作为字节码解析过程中,push pop 操作的载体,Dynamo 就设计了一个新的数据类,VariableTracker。其功能顾名思义,就是用来追踪字节码解析过程中产生的变量。VariableTracker 能够接受函数运行时的信息,并控制 Graph 的生成。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
设想一下,如果我们把样例代码中的所有 Node,都替换成 VariableTracker,直接面临的问题就有两个:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
- Node 是有 op type 的,不同类型 op type 的 Node 相互组合才可以生成 PythonCode 的 Graph,那么 VariableTracker 应该如何体现节点类型的不同呢?
- VariableTracker 又应该如何和 Node 关联,以生成最终的 Graph 呢?
不同类型的 VariableTrackers
正如问题里提到的,解析不同类型的字节码需要生成不同类型的 VariableTracker,例如在执行 CALL_FUNCTION 之前,我们需往先 stack 里 push 一个 UserFunctionVariable,再往 stack 里 push 一个 TensorVariable (假设函数的输入是 Tensor 类型)。最后在 CALL_FUNCTION 里将二者 pop 出来,调用 UserFunctionVariable 的方法模拟函数执行。
Dynamo 在 variables 文件夹中定义了所有的 VariableTracker 类型,感兴趣的话可以看看每个 VariableTracker 的功能。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
LOAD_ATTR
foo 中的 BINARAY_ADD 字节码作为内置函数(BuiltinVariable),解析时会生成新的节点, 而 foo1 作为是一个空函数 ,则不会生成新的节点。因此是否生成新的节点,是和 VariableTracker 实例本身相关,而如果要在 InstructionTranslator 这一层处理这些逻辑,这部分代码的可读性将是一个灾难。因此
Dynamo 新增了一层抽象 VariableBuilder 来负责 VariableTracker 的构建,并控制过程中是否生成新的 Node 等操作(包括生成 Guard,下一节会介绍)。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
这边贴一段 VariableBuilder 生成 TensorVariable 代码片段,大家自行感受一下(冰山一角):文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
这边给大家简单翻译一下(可以简单把 source 理解成数据源,用于帮助 Guard 生成检查代码):文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
- 如果这个 Tensor 是来自于一个 nn.Module 的(类似 register_buffer),那么他就会往 graph 里注册一个节点,并返回一个 TensorVariable
- 如果这个 Tensor 的数据来源是一个常量(torch.Tensor(1)),操作同上,只不过名字会有所不同
- ...
register_attr_or_module
VariableTracker
守卫(Guard)
前面介绍的种种只是在描述 Dynamo 是如何通过字节码生 trace graph,而为了让 trace 出来的 graph 保持动态特性,就离不开核心组件:Guard。在构建 VariableTracker 时,可能会绑定一个或多个 guard,用于生成监视变量的检查代码,也就是我们最初提到的 check_fn。需要注意的是,Graph trace 阶段可能会生成非常多的 guard,但是最后只有部分 guard 会被用于生成 check_fn,这其实也很好理解,因为只有部分变量都会造成模型的动态结构。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
Guard 功能的实现主要依赖两个模块: Guard 和 GuardBuilder文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
- Guard:Graph trace 过程中生成,记录最后生成检查代码阶段所需的额外信息,并最后存储生成后的代码。这边最主要介绍初始化阶段的两个核心参数:
- source:记录守护的变量名 name,例如 "self.layer1.state",变量名用于生成检查代码
- create_fn:用于生成检查代码的函数,其值通常为 GuardBuilder 的 method,在 GuardBuilder 部分展开介绍
2. GuardBuilder:Graph trace 完成后,基于 trace 过程中生成的 Guards ,生成最终的检查代码。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
我们通过一些代码示例来理解 Guard 和 GuardBuilder 是如何起作用的。首先修改 Dynamo 的配置,以输出 Guard 相关的日志:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
Guard 相关的输出日志:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
self.linear1self.linear2selfx
CONSTANT_MATCHTENSOR_MATCH
ID_MATCH
check_obj_id
检查 self 参数时,check_obj_id 会根据其 id 是否匹配,来决定是否需要进行重复编译文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
TENSOR_MATCH
对于 Tensor 类型数据的检查,出于效率方面的考虑,检查代码同样在 C++ 代码里实现:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
简单来说会检查以下几个内容:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
- 数据类型是否发生变化,例如原来数据类型为 float32,第二次输入时类型变成 float16,返回 False
- 数据所在设备是否发生变化,例如原来是在 GPU 0 上的,第二次输入变成在 GPU 1 上了,返回 False
- 数据的梯度属性是否发生变化,例如原来是需要计算梯度的,第二次却不再要求计算梯度,返回 False
- (Dynamic shape=Flase 时)数据的形状以及内存排布是否发生变化
___check_tensors
对于上例来说,其最终生成的检查代码 check_fn 的过程等价于:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
回到第一节编译与执行流程的第四步,其中提到的 check_fn 等价于上例中返回的 check_fn,如果 self 的 id 发生变化,亦或是 x 无法通过 TensorGuards.check,均会触发重新编译。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
编译子图
Guard 一节提到,check_fn 只会检查模型的输入,而不是实际运行一遍代码后,再判断是否应该重新编译一遍函数。这也是合情合理的,因为执行一遍代码才能完成代码检查,这样的开销是不可接受的。然而这样也会引入其他问题,真的能够仅仅根据输入去判断是否需要重新编译模型么?文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
对于比较简单的函数:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
Dynamo 并不需要为 a 和 b 生成 Guard 和 check_fn,因为只要 x 的形状不变,a.shape 就不会发生变化(假设 len 是 builtin func,且 linear1 保持不变),因此只需要对 x 构建 guard 并生成 check_fn 就足够了。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
那如果换一种写法:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
这里的 x.sum() 会返回一个 Tensor,此时无论如何都没有办法仅凭输入去判断会走哪个分支。对于这种情况,Dynamo 的做法是:编译子图。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
细心的同学可能会发现,编译与执行流程一节提到的执行顺序,是有漏洞的。因为在执行完第一步,将默认的执行函数替换成 _custom_eval_frame 后,这意味着 callback 执行过程中产生的函数栈,也会触发 _custom_eval_frame,这是不符合期望的。我们只希望执行被编译的函数时,能够触发 _custom_eval_frame,因此完整的执行流程如下:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
- 在 eval_frame.py 中,将帧执行函数替换 Dynamo 自定义的执行函数 _custom_eval_frame
- 进入 _custom_eval_frame 后,将帧执行函数替换回默认的执行函数
- 第一次执行待编译的函数时调用 callback 函数,对帧进行解析,返回一个 result。callback 可以理解成(去掉了多层 wrapper)这里定义的 _compile 函数, 而 result 则是 Python 层面生成的 GuardedCode 实例。调用 callback 函数时还会传入 cache_size 参数,第一次编译时其值为 0。
- 将 result 缓存到链表,每执行一次编译链表都会新增一个元素。往后每次执行函数时都会根据当前帧的状态和 extra 中的往期编译结果来判断是否需要进行重新编译
- 将默认的帧执行函数重新替换成 _custom_eval_frame
- 用默认的帧执行函数执行编译后的字节码,并返回结果
- 第 2 次执行时,加载上次生成的 extra,进行查表操作(lookup)。遍历 extra 中的每个元素,执行 GuardedCode.check_fn
torch._dynamo.config.cache_size_limit
8. 编译完整个函数后,将帧执行函数替换回默认的执行函数文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
第六步,划重点!在第五步我们将帧执行函数替换成 _custom_eval_frame 后,如果我们直接执行编译后的字节码,这就意味着会触发无限递归,因此需要调用默认的帧执行函数执行字节码。那既然如此,为什么还需要在第五步把帧执行函数替换成 _custom_eval_rame 呢?答案是,编译子图。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
编译后的字节码中还会存在 CALL_FUNCTION 字节码,在执行时会进入 _custom_eval_frame,进而触发对子图的编译。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
Model.forward
等价 Python 代码如下:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
x.sum() >= 1
__compiled_fn_0__resume_at_32_1 __resume_at_40_2__resume_at_32_1 __resume_at_40_2 _custom_eval_frame__resume_at_32_1__resume_at_40_2
细心的你可能会发现,这样 __compiled_fn_0 不是也会触发二次编译么。Dynamo 自然也考虑到了这一点,编译后的函数会经过 disable 处理,保证后续的调用不会再走 _custom_eval_frame 的逻辑。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
第一次,解析 forward 时生成的 guard:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
第二次,执行 forward 编译后的函数 compiled_fn,生成的字节码:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
动手试一试,相信你会理解的更加深刻,对于更加复杂的情况,子图中还会递归地执行 2-7 步,生成更细粒度的子图。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
InliningInstructionTranslator
如果编译的函数涉及比较复杂的函数调用,例如:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
InstructionTranslator 会在解析 CALL_FUNCTIONS 时,构建一个 InliningInstructionTranslator,获取函数的字节码,在解析字节码的过程中继续完成 graph trace。与编译子图不同的是,InliningInstructionTranslator 会进入函数,“连续”的解析字节码。函数中的字节码可以和之前解析的字节码一起进行编译优化,而编译子图意则是函数内外分开编译。此外, InliningInstructionTranslator 解析的函数也可以触发编译子图的逻辑。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
至此我们梳理完了 Dynamo trace graph 的主体逻辑,Dynamo 从字节码入手,首先实现了 Python 版的虚拟机,用于解析函数的字节码,以实现 Graph trace 的功能;在此基础上,为了能够根据输入信息实现动态的 Graph trace,Dynamo 引入了 VariableTracker 以及 Guard 的概念,能够根据模型输入信息去判断是否需要触发重新编译;最后,Dynamo 通过动态地调整帧评估函数,递归地去编译在上一次编译中,重新划分的子图,实现更加灵活地 Graph trace。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
MMEngine 目前已接入 PyTorch 2.0 的编译功能,各个算法库的训练速度都有了明显的提升:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/suanfa/37485.html
算法库 | 模型 | 训练速度 |
MMPreTrain | ResNet | 10.00% ↑ |
ViT | 4.60% ↑ | |
MMDetection | RTMDet | 3.16% ↑ |
MMSegmentation | PSPNet | 34.0% ↑ |
SegFormer | 7.12% ↑ | |
MMEditing | BasicVSR | 19.03% ↑ |
NAFNet | 15.18% ↑ | |
MMPose | RTMPose | 13.06% ↑ |
HRNet | 37.07% ↑ |