跳到主要内容

ONNX Runtime 会话模块

会话对象负责加载 ONNX 模型、查看输入输出信息以及执行推理。

运行时配置

onnxruntime.configure(opts)

assert(onnxruntime.configure({
log_severity_level = 2,
log_id = "my-runtime",
use_global_thread_pools = false,
global_intra_op_num_threads = 0,
global_inter_op_num_threads = 0,
}))

说明:

  • 必须在创建任何 session 之前调用
  • 一旦已有活动 session,继续调用会报错

支持字段:

  • log_severity_level
  • log_id
  • use_global_thread_pools
  • global_intra_op_num_threads
  • global_inter_op_num_threads

创建会话

onnxruntime.session(model_path[, opts])

会话对象, 错误信息 = onnxruntime.session(模型路径, 选项)

从文件路径加载 ONNX 模型。

onnxruntime.session_from_bytes(model_bytes[, opts])

会话对象, 错误信息 = onnxruntime.session_from_bytes(模型字节串, 选项)

从内存字节创建会话。

Session 选项

通用字段

  • providersprovider 可传单个字符串或字符串数组;当前原生处理并支持 "cpu""coreml",也接受 CPUExecutionProviderCoreMLExecutionProvider 这类别名
  • fallback_to_cpu 布尔型,默认 true
  • intra_op_num_threads
  • inter_op_num_threads
  • log_id
  • session_log_severity_level
  • session_log_verbosity_level
  • optimized_model_path
  • profile_file_prefix
  • free_dimension_overrides
  • config_entries
  • graph_optimization_level 可选 "disable""basic""extended""all"
  • execution_mode 可选 "sequential""parallel"
  • deterministic_compute
  • disable_per_session_threads
  • enable_cpu_mem_arena
  • enable_mem_pattern
  • custom_op_libraries

补充说明:

  • free_dimension_overrides 需要传数组表,每项结构为 { by = "name"|"denotation", key = "...", value = 整数 }
  • config_entries 必须是“字符串键 -> 字符串值”的 table
  • custom_op_libraries 可以是单个路径字符串、路径数组,或 load_custom_op_library() 返回的句柄;数组里也可以混用路径和句柄
  • 如果没有显式指定 providers,或者 provider 列表为空,当前实现会默认补上 CPU provider
  • 当 provider 列表里包含 "coreml"fallback_to_cpu = true 时,CoreML provider 初始化失败后可自动回退到 CPU
  • 如果你显式写成 providers = {"coreml", "cpu"},顺序就表示先 CoreML、后 CPU

CoreML provider 相关字段

providers 中包含 "coreml" 时,还可使用:

  • coreml_compute_units 推荐使用 "all""cpu_only""cpu_and_gpu""cpu_and_neural_engine";解析器也兼容 CPUOnlyCPUAndGPUCPUAndNeuralEngineMLComputeUnits... 这些别名
  • coreml_create_mlprogram
  • coreml_require_static_input_shapes
  • coreml_enable_on_subgraph
  • coreml_flags
  • coreml_use_cpu_only
  • coreml_use_cpu_and_gpu
  • coreml_only_enable_device_with_ane

补充说明:

  • coreml_flagscoreml_use_cpu_onlycoreml_use_cpu_and_gpucoreml_only_enable_device_with_ane 都是兼容旧写法的字段
  • 新旧字段可以混用,但如果表达的含义互相冲突,session 创建会直接报错
  • coreml_only_enable_device_with_ane 不能和 coreml_compute_units = "cpu_only" / "cpu_and_gpu" 这类互斥配置同时使用

会话对象方法

基础信息

  • session:input_names()
  • session:output_names()
  • session:overridable_initializer_names()
  • session:input_count()
  • session:output_count()
  • session:overridable_initializer_count()

类型信息

  • session:input_info(name_or_index)
  • session:output_info(name_or_index)
  • session:overridable_initializer_info(name_or_index)

返回值是类型信息表,常见字段包括:

  • name
  • onnx_type
  • is_sparse
  • data_type
  • type
  • has_shape
  • shape
  • symbolic_shape
  • element
  • key_type
  • value

说明:

  • tensor / sparse tensor 会带 data_typeshapesymbolic_shape
  • sequence / optional 会带嵌套的 element
  • map 会带 key_type 和嵌套的 value

内存信息

  • session:memory_info_for_inputs()
  • session:memory_info_for_outputs()

返回值既可以按顺序访问,也可以按名字访问。单项通常包含:

  • name
  • id
  • mem_type
  • allocator_type
  • device_type
  • device_mem_type
  • vendor_id

元信息与生命周期

  • session:metadata()
  • session:close()
  • session:end_profiling()
  • session:profiling_start_time_ns()
  • session:set_ep_dynamic_options(opts)
  • session:register_custom_op_library(path_or_handle)

说明:

  • end_profiling() 返回 profiling 输出文件路径
  • set_ep_dynamic_options() 会把传入 table 的 key/value 都转成字符串再传给 ORT
  • register_custom_op_library() 会在当前 session 选项基础上重建内部 session
  • path_or_handle 既可以是路径,也可以是 load_custom_op_library() 返回的句柄

执行推理

session:run(inputs[, output_names[, run_options]])

输出表, 错误信息 = session:run({
input_ids = 输入张量,
attention_mask = 掩码张量,
}, {
"logits",
}, run_options)

session:run_into(inputs, outputs[, run_options])

输出表, 错误信息 = session:run_into({
x = 输入张量,
}, {
y = 复用输出张量,
}, run_options)

session:run_with_iobinding(binding[, run_options])

输出表, 错误信息 = session:run_with_iobinding(binding, run_options)

输入规则:

  • inputs 可以是顺序数组,也可以是按输入名组织的字典
  • 顺序数组按模型输入顺序匹配,后面也可以继续覆盖 overridable initializer
  • 字典形式下,键必须与输入名或 overridable initializer 名一致
  • optional 输入可以省略,也可以传 onnxruntime.optional(nil, type_info)

输出规则:

  • 返回值是一个 table
  • 同一个输出既可以用数字索引访问,也可以用输出名访问
  • run_into() 如果某个输出复用了已有 tensor,返回表里对应项就是原对象本身

Run Options

onnxruntime.run_options([opts])

local run_options = assert(onnxruntime.run_options({
tag = "session-run",
log_severity_level = 2,
log_verbosity_level = 1,
}))

支持字段:

  • tag
  • log_severity_level
  • log_verbosity_level

对象方法:

  • run_options:tag([value])
  • run_options:log_severity_level([value])
  • run_options:log_verbosity_level([value])
  • run_options:terminate()
  • run_options:reset_terminate()

IOBinding

session:create_io_binding()

binding, 错误信息 = session:create_io_binding()

binding:bind_input(name, value)

绑定输入值。这里不接受空 optional。

binding:bind_output(name[, spec_or_tensor])

支持三种形式:

  • binding:bind_output("y") 绑定到 CPU 内存,稍后通过 get_outputs() 取回
  • binding:bind_output("y", existing_tensor) 直接写入已有 tensor
  • binding:bind_output("y", {type = "float32", shape = {1, 2}}) 由接口创建一个输出 tensor 并返回

也支持:

  • binding:bind_output("y", {mode = "device"})

其它方法

  • binding:clear_inputs()
  • binding:clear_outputs()
  • binding:synchronize_inputs()
  • binding:synchronize_outputs()
  • binding:get_outputs()

示例

local ort = require("onnxruntime")

local session = assert(ort.session(XXT_HOME_PATH.."/models/demo/model.onnx", {
providers = {"coreml", "cpu"},
fallback_to_cpu = true,
coreml_compute_units = "all",
}))

local x = assert(ort.tensor("float32", {1, 2}, {1.0, 2.0}))
local bias = assert(ort.tensor("float32", {1, 2}, {0.5, -0.5}))
local run_options = assert(ort.run_options({tag = "demo"}))

local outputs = assert(session:run({
x = x,
bias = bias,
}, {"y"}, run_options))

print(outputs.y:to_table()[1])