采用Pytorch或者其他的深度学习框架导出ONNX模型后,通过Netron可视化该模型,能够看到模型的输入和输出尺寸。但是在导出一些自己手动搭建的神经网络结构或者导出较为复杂的网络结构时,往往需要知道每一层输入和输出的尺寸,并将每一层的结果可视化,这就需要采用ONNX官方提供的infershape()接口。接下来就用一个例子来深度剖析一下怎么进行形状的推理。

ONNX形状推理方法

from onnx.shape_inference import infer_shapes
from onnx import load_model, save_model
import torch
import torch.nn as nn
import numpy as np
from torch.nn.modules.upsampling import UpsamplingNearest2d

class TestNet(nn.Module):
    def __init__(self):
        super(TestNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.bn2 = nn.BatchNorm2d(32)  
        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)
        self.upsampling1 = nn.UpsamplingNearest2d(scale_factor=2)
        self.upsampling2 = nn.UpsamplingNearest2d(scale_factor=4)
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.upsampling1(x)
        x = self.upsampling2(x)
        return x

x=torch.randn((1,3,12,12))

model=TestNet()
model.eval() 

output_onnx_name = 'test_net.onnx'

torch.onnx.export(model, 
    x,
    output_onnx_name, 
    input_names=["input"], 
    output_names=["output"],
    opset_version=11,
    # dynamic_axes={'input':{0:'batch', 2:'h', 3:'w'}, 'output':{0:'batch', 2:'h2', 3:'w2'}} 
)

onnx_model = load_model(output_onnx_name)
onnx_model = infer_shapes(onnx_model)

save_model(onnx_model, "infered_test_net.onnx")

形状推理最核心的方法就是onnx模块中的infer_shapes,先采用Pytorch框架搭建一个卷积网络,并在网络结构最后增加两个上采样的OP,使用torch.onnx.export()将该模型导出,该例导出一个定长输入模型。直接调用onnx中的infer_shapes方法,将重新加载的模型进行形状推理,最后保存成一个新的模型,用Netron打开两个模型进行比较,如图所示。

定长尺寸形状推理前和推理后的可视化

从图中不难看出,没有进行形状推理时,仅有输入尺寸和输出尺寸,在调用了infer_shapes()方法后,模型中记录了每层推理的形状尺寸,通过每层的输入输出尺寸,可以了解每层OP进行的操作变化,有针对性的进行优化,同时对于一些自定义的OP,可以通过形状推理判断出实现是否正确。

如果将定长尺寸改成可变长尺寸,导出的图如下所示。可以看到,相对于左图,右图将所有变长的部分在进行每一层的推理时都用一个新的占位符表示。这里值得注意的是,如果是一个变长的尺寸和一个定长的尺寸进行维度运算,得到的是一个新的尺寸。比如一个形状为[1, n]的数组和一个形状为[1, 3]的数组进行相加,得到的不是一个维度为[1, n+3]的数组,而是维度为[1, m]的数组。

变长尺寸形状推理前和推理后的可视化

源码走读

在github上搜索onnx项目可下载源码。

// implementation.cc::516
void InferShapes(...) {
    std::unordered_map<std::string, int> opset_imports;
    for (const auto& opset_import : m.opset_import()) {
        opset_imports[opset_import.domain()] = static_cast<int>(opset_import.version());
    }
    auto* g = m.mutable_graph();
    SymbolTableImpl symbolTable;
    traverseGraphsToAddExistingSymbols(*g, symbolTable);
    InferShapesImpl(...);
}

调用python的infer_shapes()接口后,最终调用的是http://implementation.cc中的InferShapes(...)函数,InferShapes(...)中申明了一个SymbolTableImpl类的对象symbolTable,这个对象存放着模型中动态维度的名称,在该例中是“batch”,“w”,“h”,“w2”,“h2”。在获取了模型中的symbol后才会调用InferShapes(...)的实现InfershapesImpl(...)。

// implementation.cc:241
static void InferShapesImpl(ISchemaRegistry schema_registry, ...) {
    std::unordered_map<std::string, TypeProto*> valueTypesByName{outer_scope_value_types_by_name};
    std::unordered_map<std::string, TypeProto*> undefinedValueTypesByName{outer_scope_value_types_by_name};
    std::unordered_map<std::string, TensorShapeProto> generatedShapeDataByName;

    GraphInferenceContext graphInferenceContext{valueTypesByName, opset_imports, symbolTable, schema_registry, ir_version};
    for (auto& vi : *g->mutable_value_info()) {
        if (vi.has_type()) {
        valueTypesByName[vi.name()] = vi.mutable_type();
        }
    }
    for (auto& vi : *g->mutable_input()) {
        if (vi.has_type()) {
        valueTypesByName[vi.name()] = vi.mutable_type();
        }
    }
    for (auto& vi : *g->mutable_output()) {
        if (vi.has_type()) {
        valueTypesByName[vi.name()] = vi.mutable_type();
        } else {
        // Some output type might be undefined in subgraph. e.g., Loop Op
        // Saving names of outputs with undefined types to allow assigning inferred types to them
        undefinedValueTypesByName[vi.name()] = vi.mutable_type();
        }
    }
...
}

在InferShapesImpl(...)中,定义了几个map去存放模型不同的参数。整个模型的输入输出类型以及模型参数的类型存放在valueTypeByName中,模型参数存放在inputDataByName或者inputSparseDataByName中。这里的目的是将用Proto表示的模型中推理部分需要用到的参数用另一种形式表示出来。

// implementation.cc:365
for (auto& n : *g->mutable_node()) {
    // Resolve domain for node
    auto dit = opset_imports.find(n.domain());
    if (dit == opset_imports.end()) {
      continue;
    }
    auto domain_version = dit->second;
    const auto schema = schema_registry->GetSchema(n.op_type(), domain_version, n.domain());
    InferenceContextImpl ctx(n, valueTypesByName, inputDataByName,inputSparseDataByName,
        &generatedShapeDataByName, &graphInferenceContext);
    if (!schema) {
        std::cerr << "Warning: Unsupported operator " << n.op_type() << ". No schema registered for this operator."
                    << std::endl;
        has_unsupported_op = true;
        continue;
    } else if (schema->has_type_and_shape_inference_function()) {
        ONNX_TRY {
            schema->GetTypeAndShapeInferenceFunction()(ctx);
        }
    ...
    }
}

观察InferShapesImpl(...)的接口,传入了一个单例类ISchemaRegistry对象schema_registry,顾名思义这个对象和注册器有关联。在使用schema_registry对象时,直接调用了该对象的GetSchema(...)方法,并返回了一个OpSchema类型的对象。

// schema.h:150
class OpSchema final {
    ...
    class FormalParameter final {
    public:
        FormalParameter() = default;

        explicit FormalParameter(
            std::string name,
            DataTypeSet allowed_type_set,
            std::string type_str,
            const std::string& description,
            FormalParameterOption param_option = Single,
            bool is_homogeneous = true,
            int min_arity = 1,
            DifferentiationCategory differentiation_category = Unknown)
            : name_(std::move(name)),
            type_set_(std::move(allowed_type_set)),
            type_str_(std::move(type_str)),
#ifndef __ONNX_NO_DOC_STRINGS
            description_(description),
#endif
            param_option_(param_option),
            is_homogeneous_(is_homogeneous),
            min_arity_(min_arity),
            differentiation_category_(differentiation_category) {
#ifdef __ONNX_NO_DOC_STRINGS
        ONNX_UNUSED_PARAMETER(description);
#endif
        }
    ...
    }
    ...
    struct Attribute final {
    Attribute(
        std::string name_,
        std::string description_,
        AttributeProto::AttributeType type_,
        bool required_)
        : name(std::move(name_)),
        description(std::move(description_)),
        type(type_),
        required(required_),
        default_value() {}

    Attribute(
        std::string name_,
        std::string description_,
        AttributeProto default_value_)
        : name(std::move(name_)),
        description(std::move(description_)),
        type(default_value_.type()),
        required(false),
        default_value(std::move(default_value_)) {}
};

打开schema.h,看看OpSchema类到底是如何实现的。OpSchema类是用来记录指定名称OP的公共接口(记录每个OP有几个输入输出,OP的说明,OP进行形状推理的规则等),OpSchema类中包含了一个内部类FormalParameter用来记录输入和输出的名字和类型,一个内部结构体Attribute用来记录一些属性参数,最重要的是提供了TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction)接口用来进行类型和形状的推导。不难理解,通过schema_registry调用GetSchema(...)方法,传入指定的OP类型,可以获取对应的OpSchema对象schema,通过调用schema的GetTypeAndShapeInferenceFunction()方法,可以获取对应的推导规则函数,传入上下文对象ctx,实现每个OP形状的推理。

// schema.h:1086
static const OpSchema* Schema(
        const std::string& key,
        const int maxInclusiveVersion,
        const std::string& domain = ONNX_DOMAIN) {
    auto& m = map();
    if (m.count(key) && m[key].count(domain)) {
        auto pos = m[key][domain].lower_bound(maxInclusiveVersion);
        if (m[key][domain].begin() == pos && pos->first > maxInclusiveVersion) {
            // All versions are greater than specified version.
            return nullptr;
        }
        if (m[key][domain].end() == pos || pos->first > maxInclusiveVersion) {
            // All versions are less than specified version, or,
            // The <pos> version is greater than specified version.
            pos--;
        }

        // Schema with exact version as specified one exists.
        return &(pos->second);
        } else {
            return nullptr;
        }
    }
    ...
}

// schema.cc:877
OpName_Domain_Version_Schema_Map& OpSchemaRegistry::map() {
    ...
#ifndef __ONNX_DISABLE_STATIC_REGISTRATION
    static SchemasRegisterer schemasRegisterer;
#endif

// operator_sets.h:1018
inline void RegisterOnnxOperatorSetSchema() {
    RegisterOpSetSchema<OpSet_Onnx_ver1>();
    RegisterOpSetSchema<OpSet_Onnx_ver2>();
    ...
    RegisterOpSetSchema<OpSet_Onnx_ver16>();
    OpSchemaRegistry::Instance()->SetLoadedSchemaVersion(0);
}

// schema.h:1172
template <class T>
void RegisterOpSetSchema(int opset_version_to_load=0) {
    T::ForEachSchema([opset_version_to_load](OpSchema&& schema) {
        RegisterSchema(schema, opset_version_to_load);
    });
};

// schema.h:1185
#define ONNX_OPERATOR_SET_SCHEMA(name, ver, impl) \
  ONNX_OPERATOR_SET_SCHEMA_EX(name, Onnx, ONNX_DOMAIN, ver, true, impl)

#define ONNX_ML_OPERATOR_SET_SCHEMA(name, ver, impl) \
  ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxML, AI_ONNX_ML_DOMAIN, ver, true, impl)

#define ONNX_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
  ONNX_OPERATOR_SET_SCHEMA_EX(                             \
      name, OnnxTraining, AI_ONNX_TRAINING_DOMAIN, ver, true, impl)

#define ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
  ONNX_OPERATOR_SET_SCHEMA_EX(                                     \
      name, OnnxPreview, AI_ONNX_PREVIEW_TRAINING_DOMAIN, ver, true, impl)

#define ONNX_OPERATOR_SET_SCHEMA_EX(                                        \
    name, domain, domain_str, ver, dbg_included_in_static_opset, impl)      \
  class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name);             \
  template <>                                                               \
  OpSchema                                                                  \
  GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name)>() {   \
    return impl.SetName(#name)                                              \
        .SetDomain(domain_str)                                              \
        .SinceVersion(ver)                                                  \
        .SetLocation(__FILE__, __LINE__);                                   \
  }                                                                         \
  size_t dbg_count_check_##name##_##domain##_ver##ver =                     \
      (dbg_included_in_static_opset) ? ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() \
                                     : 0;
#ifdef NDEBUG
#define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() 0
#else
#define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() \
  DbgOperatorSetTracker::Instance().IncrementCount()
#define ONNX_DBG_GET_COUNT_IN_OPSETS() \
  DbgOperatorSetTracker::Instance().GetCount()

知道了整个形状推理的大致流程,但是OP对应的OpSchema是如何注册和加载的还没有弄清楚。调试GetSchema()函数,最终调用的是schema.h文件中的Schema(...)方法,在该方法中调用了OpSchemaRegistry类的map()方法,map()方法中申明了一个SchemasRegisterer类,同时申明了一个静态实例,该实例只有一个,在SchemasRegisterer类的构造函数中,调用了几种不同类型的OP集注册方式,以RegsiterOnnxOperatorSetSchema()方法为例,在RegsiterOnnxOperatorSetSchema()中,注册了不同版本的OpSet集合,到目前为止,已经注册了16个版本的Op集合,如果后续有新的大版本更新,可继续增加新的Op集合。注意到这里的注册方法RegisterOpSetSchema(...)是一个模板函数,其内部调用的是T::ForEachSechema()函数,且这个函数传入的形参是一个调用RegisterShcema()的匿名函数,取OpSet_Onnx_ver1这个类的ForEachSchema()进行跟踪,里面包含了多个算子的注册。ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME是一个宏,将域名、版本号、和OP类型拼接成一个类名,调用GetOpSchema()方法,就是调用拼接类的GetOpSchema()方法,这个方法又是在哪里定义的呢?以Conv算子为例,进一步跟踪可以发现,在http://old.cc中,定义了Conv的GetOpSchema()方法,但是采用了ONNX_OPERATOR_SET_SCHEMA宏的方式构建了类,这个类的名称和拼接类的名称是一致的。宏的第三个参数传入的是impl,也就是对应的OpSchema().FillUsing(ConvOpSchemaGenerator_10("a filter")),ConvOpSchemaGenerator_10(const char* filter_desc) 函数中,才真正的实现了Conv算子参数的注册。

对每个算子进行了形状推理后可能会遇到部分的维度是动态的情况,在materializeSymbolicshape(...)中对这种情况进行了处理,如果遇到了是动态尺寸导致某个维度无法推理出具体的尺寸,则在符号前面添加一个unk__作为前缀。最后再将结果写入模型中。

自定义OP schema的注册

ONNX_OPERATOR_SCHEMA(name)和ONNX_OPERATOR_SET_SCHEMA(name, version, OpSchema())都可以注册一个新的OP Schema,前者定义了一个静态类,采用静态的方法,后者是非静态的方法。通过模仿项目中其它算子schema是如何注册的,可以给自定义算子添加新的形状推理方法,这里不再过多赘述。

专栏持续更新,码字不易,动动小手点个赞吧!