ONNX Runtime 会话模块
会话对象负责加载 ONNX 模型、查看输入输出信息以及执行推理。
运行时配置
onnxruntime.configure(opts)
assert(onnxruntime.configure({
log_severity_level = 2,
log_id = "my-runtime",
use_global_thread_pools = false,
global_intra_op_num_threads = 0,
global_inter_op_num_threads = 0,
}))
说明:
- 必须在创建任何 session 之前调用
- 一旦已有活动 session,继续调用会报错
支持字段:
log_severity_levellog_iduse_global_thread_poolsglobal_intra_op_num_threadsglobal_inter_op_num_threads
创建会话
onnxruntime.session(model_path[, opts])
会话对象, 错误信息 = onnxruntime.session(模型路径, 选项)
从文件路径加载 ONNX 模型。
onnxruntime.session_from_bytes(model_bytes[, opts])
会话对象, 错误信息 = onnxruntime.session_from_bytes(模型字节串, 选项)
从内存字节创建会话。
Session 选项
通用字段
providers或provider可传单个字符串或字符串数组;当前原生处理并支持"cpu"、"coreml",也接受CPUExecutionProvider、CoreMLExecutionProvider这类别名fallback_to_cpu布尔型,默认trueintra_op_num_threadsinter_op_num_threadslog_idsession_log_severity_levelsession_log_verbosity_leveloptimized_model_pathprofile_file_prefixfree_dimension_overridesconfig_entriesgraph_optimization_level可选"disable"、"basic"、"extended"、"all"execution_mode可选"sequential"、"parallel"deterministic_computedisable_per_session_threadsenable_cpu_mem_arenaenable_mem_patterncustom_op_libraries
补充说明:
free_dimension_overrides需要传数组表,每项结构为{ by = "name"|"denotation", key = "...", value = 整数 }config_entries必须是“字符串键 -> 字符串值”的 tablecustom_op_libraries可以是单个路径字符串、路径数组,或load_custom_op_library()返回的句柄;数组里也可以混用路径和句柄- 如果没有显式指定
providers,或者 provider 列表为空,当前实现会默认补上 CPU provider - 当 provider 列表里包含
"coreml"且fallback_to_cpu = true时,CoreML provider 初始化失败后可自动回退到 CPU - 如果你显式写成
providers = {"coreml", "cpu"},顺序就表示先 CoreML、后 CPU
CoreML provider 相关字段
当 providers 中包含 "coreml" 时,还可使用:
coreml_compute_units推荐使用"all"、"cpu_only"、"cpu_and_gpu"、"cpu_and_neural_engine";解析器也兼容CPUOnly、CPUAndGPU、CPUAndNeuralEngine、MLComputeUnits...这些别名coreml_create_mlprogramcoreml_require_static_input_shapescoreml_enable_on_subgraphcoreml_flagscoreml_use_cpu_onlycoreml_use_cpu_and_gpucoreml_only_enable_device_with_ane
补充说明:
coreml_flags、coreml_use_cpu_only、coreml_use_cpu_and_gpu、coreml_only_enable_device_with_ane都是兼容旧写法的字段- 新旧字段可以混用,但如果表达的含义互相冲突,session 创建会直接报错
coreml_only_enable_device_with_ane不能和coreml_compute_units = "cpu_only"/"cpu_and_gpu"这类互斥配置同时使用
会话对象方法
基础信息
session:input_names()session:output_names()session:overridable_initializer_names()session:input_count()session:output_count()session:overridable_initializer_count()
类型信息
session:input_info(name_or_index)session:output_info(name_or_index)session:overridable_initializer_info(name_or_index)
返回值是类型信息表,常见字段包括:
nameonnx_typeis_sparsedata_typetypehas_shapeshapesymbolic_shapeelementkey_typevalue
说明:
- tensor / sparse tensor 会带
data_type、shape、symbolic_shape - sequence / optional 会带嵌套的
element - map 会带
key_type和嵌套的value
内存信息
session:memory_info_for_inputs()session:memory_info_for_outputs()
返回值既可以按顺序访问,也可以按名字访问。单项通常包含:
nameidmem_typeallocator_typedevice_typedevice_mem_typevendor_id
元信息与生命周期
session:metadata()session:close()session:end_profiling()session:profiling_start_time_ns()session:set_ep_dynamic_options(opts)session:register_custom_op_library(path_or_handle)
说明:
end_profiling()返回 profiling 输出文件路径set_ep_dynamic_options()会把传入 table 的 key/value 都转成字符串再传给 ORTregister_custom_op_library()会在当前 session 选项基础上重建内部 sessionpath_or_handle既可以是路径,也可以是load_custom_op_library()返回的句柄
执行推理
session:run(inputs[, output_names[, run_options]])
输出表, 错误信息 = session:run({
input_ids = 输入张量,
attention_mask = 掩码张量,
}, {
"logits",
}, run_options)
session:run_into(inputs, outputs[, run_options])
输出表, 错误信息 = session:run_into({
x = 输入张量,
}, {
y = 复用输出张量,
}, run_options)
session:run_with_iobinding(binding[, run_options])
输出表, 错误信息 = session:run_with_iobinding(binding, run_options)
输入规则:
inputs可以是顺序数组,也可以是按输入名组织的字典- 顺序数组按模型输入顺序匹配,后面也可以继续覆盖 overridable initializer
- 字典形式下,键必须与输入名或 overridable initializer 名一致
- optional 输入可以省略,也可以传
onnxruntime.optional(nil, type_info)
输出规则:
- 返回值是一个 table
- 同一个输出既可以用数字索引访问,也可以用输出名访问
run_into()如果某个输出复用了已有 tensor,返回表里对应项就是原对象本身
Run Options
onnxruntime.run_options([opts])
local run_options = assert(onnxruntime.run_options({
tag = "session-run",
log_severity_level = 2,
log_verbosity_level = 1,
}))
支持字段:
taglog_severity_levellog_verbosity_level
对象方法:
run_options:tag([value])run_options:log_severity_level([value])run_options:log_verbosity_level([value])run_options:terminate()run_options:reset_terminate()
IOBinding
session:create_io_binding()
binding, 错误信息 = session:create_io_binding()
binding:bind_input(name, value)
绑定输入值。这里不接受空 optional。
binding:bind_output(name[, spec_or_tensor])
支持三种形式:
binding:bind_output("y")绑定到 CPU 内存,稍后通过get_outputs()取回binding:bind_output("y", existing_tensor)直接写入已有 tensorbinding:bind_output("y", {type = "float32", shape = {1, 2}})由接口创建一个输出 tensor 并返回
也支持:
binding:bind_output("y", {mode = "device"})
其它方法
binding:clear_inputs()binding:clear_outputs()binding:synchronize_inputs()binding:synchronize_outputs()binding:get_outputs()
示例
local ort = require("onnxruntime")
local session = assert(ort.session(XXT_HOME_PATH.."/models/demo/model.onnx", {
providers = {"coreml", "cpu"},
fallback_to_cpu = true,
coreml_compute_units = "all",
}))
local x = assert(ort.tensor("float32", {1, 2}, {1.0, 2.0}))
local bias = assert(ort.tensor("float32", {1, 2}, {0.5, -0.5}))
local run_options = assert(ort.run_options({tag = "demo"}))
local outputs = assert(session:run({
x = x,
bias = bias,
}, {"y"}, run_options))
print(outputs.y:to_table()[1])