🌀Jarson Cai's Blog
头脑是日用品,不是装饰品
TensorRT动态Batch和动态宽高的实现
学习一下TensorRT中动态宽高和动态Batch的实现方式。

TensorRT之动态Batch和动态宽高

动态Batch

该特性的需求主要源于TensorRT编译时对batch的处理,若静态batch则意味着无论你有多少图,都按照固定大小batch推理。耗时是固定的。

实现动态Batch的注意点:

1.onnx导出模型是,注意view操作不能固定batch维度数值,通常写-1。 2.onnx导出模型是,通常可以指定dynamic_axes(通常用于指定动态维度),实际上不指定也没关系。

动态宽高

该特性需求来自onnx导出时指定的宽高是固定的,TensorRT编译时也需要固定大小引擎,若你想得到另外一个不同大小的TensorRT引擎(一个eng模型只能支持一个输入分辨率)时,就需要动态宽高的存在。而直接使用TensorRT的动态宽高(一个eng模型能支持不同输入分辨率的推理)会带来不必要的复杂度,所以使用中间方案:在编译时修改onnx输入实现相对动态(一个onnx模型,修改参数可以得到不同输入分辨率大小的eng模型),避免重回Pytorch再做导出。

实现动态宽高的注意点:

1.不建议使用dynamic_axes指定Batch以外的维度为动态,这样带来的复杂度太高,并且存在有的layer不支持。 2.如果onnx文件已经导出,但是输入的shape固定了,此时希望修改onnx的输入shape:    步骤一:使用TRT::compile函数的inputsDimsSetup参数重新定义输入的shape。    步骤二:使用TRT::set_layer_hook_reshape钩子动态修改reshape的参数实现适配。

动态Batch demo:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
int max_batch_size = 5;
/** 模型编译,onnx到trtmodel **/
TRT::compile(
    TRT::Model::FP32,
    max_batch_size,               //最大batch size
    "model_name.onnx",
    "model_name.fp32.trtmodel"
);

/** 加载编译好的引擎 **/
auto infer = TRT::load_infer("model_name.fp32.trtmodel");

/** 设置输入的值 **/
/** 修改input的0维度为1,最大可以是5 **/
infer->input(0)->resize_single_dim(0, 2);
infer->input(0)->set_to(1.0f);

/** 引擎进行推理 **/
infer->forward();

/** 取出引擎的输出并打印 **/
auto out = infer->output(0);
INFO("out.shape = %s", out->shape_string());

动态宽高 demo:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
/** 这里的动态宽高是相对的,仅仅调整onnx输入大小为目的 **/
static void model_name() {
    //钩子函数
    TRT::set_layer_hook_reshape([](const string& name, const vector<int64_t>& shape)->vector<int64_t>{
        INFO("name: %s, shape: %s", name.c_str(), iLogger::join_dims(shape).c_str());
        return {-1, 25}; //25代表5*5的宽高,-1代表的是Batch的维度
    });

    /** 模型编译 **/
    TRT::compile(
        TRT::Model::FP32,
        1,
        "model_name.onnx",
        "model_name.fp32.trtmodel",
        {{1, 1, 5, 5}}             //对输入的重定义
    );

    auto infer = TRT::load_infer("model_name.fp32.trtmodel");
    auto model = infer->output(0);
    INFO("out.shape = %s", out->shape_string());
} 

最后修改于 2023-05-16

知识共享许可协议
本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。