ONNX Runtime 模块 (onnxruntime)
onnxruntime 模块用于在设备上直接加载和运行 ONNX 模型,适合文本、Embedding、分类、检测以及各种通用张量推理场景。
仅支持 iOS 13 及以上版本系统
加载模块
local ort = require("onnxruntime")
这是一个按需加载模块,不像 coreml 那样是内置全局模块。
当 require("onnxruntime") 成功执行后,还会向内置 coreml 模块注入两组桥接接口:
coreml.multi_array_from_ort_tensor(tensor[, data_type])multi_array:to_ort_tensor([data_type])
这两组转换都在 native 层直接拷贝,不经过 Lua table。
模块级函数
运行时与基础信息
onnxruntime.version()onnxruntime.providers()onnxruntime.configure(opts)
说明:
providers()返回当前 ORT 运行时实际可用的 Execution Provider 列表configure()用于设置全局运行时默认值,必须在创建任何 session 之前调用
张量、图像与数值辅助
onnxruntime.tensor(type, shape[, data])onnxruntime.tensor_from_bytes(type, shape, bytes)onnxruntime.tensor_from_cv_mat(mat[, opts])onnxruntime.tensor_from_quad(mat, quad[, opts])onnxruntime.tensor_from_quads(mat, quads[, opts])onnxruntime.tensor_from_image(image[, opts])onnxruntime.tensor_from_images(images[, opts])onnxruntime.image_from_tensor(tensor[, opts])onnxruntime.clamp(tensor, min, max)onnxruntime.sigmoid(tensor)onnxruntime.exp(tensor)onnxruntime.where(condition, x, y)onnxruntime.matmul(lhs, rhs)onnxruntime.concat(tensors[, axis])onnxruntime.stack(tensors[, axis])
说明:
clamp()、sigmoid()、exp()、matmul()与同名tensor:方法等价,只是把 tensor 作为第一个参数传入where()支持标量 / 布尔值 / tensor 混用,并按广播规则生成结果- 图像预处理、OpenCV 桥接和
image_from_tensor()的细节,详见 张量模块
检测、解码与后处理辅助
onnxruntime.nms(boxes, scores[, opts])onnxruntime.box_points(rotated_boxes)onnxruntime.xywh_to_xyxy(boxes)onnxruntime.xyxy_to_xywh(boxes)onnxruntime.rotated_iou(lhs_box, rhs_box)onnxruntime.rotated_nms(boxes, scores[, opts])onnxruntime.create_decoder(schema)onnxruntime.decode_yolo(output[, opts])onnxruntime.decode_yolo_obb(output[, opts])onnxruntime.decode_matrix_candidates(output, schema[, opts])onnxruntime.decode_dense_detection(output, opts)onnxruntime.records_from_boxes(boxes, scores, class_ids[, keep_indices])onnxruntime.obb_records_from_rows(rows, scores, class_ids[, angles[, keep_indices[, opts]]])onnxruntime.points_to_records(points[, opts])onnxruntime.threshold_masks(masks, threshold)onnxruntime.crop_masks_by_boxes(masks, boxes)onnxruntime.resize_masks(masks, width, height[, opts])onnxruntime.mask_iou(lhs_mask, rhs_mask)onnxruntime.mask_to_polygon(mask[, opts])onnxruntime.proto_masks(proto, coeffs, boxes, image_width, image_height[, opts])onnxruntime.project_masks(proto, coeffs, boxes, image_width, image_height[, opts])onnxruntime.db_postprocess(score_map[, opts])onnxruntime.tracker([opts])onnxruntime.reshape_keypoints(points[, keypoint_count[, keypoint_dim|opts]])onnxruntime.scale_boxes(boxes, transform)onnxruntime.clip_boxes(boxes, clip_width, clip_height)onnxruntime.scale_points(points, transform[, opts])onnxruntime.scale_keypoints(points, transform[, opts])onnxruntime.clip_keypoints(points, clip_width, clip_height[, opts])onnxruntime.ctc_greedy_decode(logits[, opts])onnxruntime.sample_logits(logits[, opts])
说明:
tensor_from_quad()/tensor_from_quads()需要先require("image.cv"),适合 OCR 四边形裁剪后直接生成张量box_points()接收形状为[5]、[1, 5]或[N, 5]的旋转框 tensor,不是五个分离标量参数create_decoder()返回 decoder 对象,支持:decode()、:task()、:schema()tracker()返回 tracker 对象,支持:update()、:reset()、:state()、:close()records_from_boxes()、obb_records_from_rows()、points_to_records()会把 tensor 结果整理成更适合 Lua 侧消费的 record tableproto_masks()与project_masks()当前是同一套实现,后者只是别名mask_iou()用于直接计算两张 mask 的交并比,也支持第三个参数opts,可传compare_size = true,或显式传width/heightdb_postprocess()适合 DB / DBNet 一类文本检测后处理;返回的每个检测项都带score、points和boxdecode_dense_detection()要求opts.strides为非空正整数数组,并且还需要decode_width、decode_height;当前只支持box_encoding = "grid_center_log_wh"ctc_greedy_decode()支持blank_index、merge_repeated、apply_softmax、return_probabilities、charsetctc_greedy_decode()一定返回indices;text仅在传入charset时返回;confidence仅在启用apply_softmax或return_probabilities时返回;probabilities与probability_confidence仅在启用return_probabilities时返回nms()/rotated_nms()返回的是int64tensor,索引语义为 1-basedsample_logits()支持argmax、temperature、top_k、top_p、min_p、seedsample_logits()对 1D logits 返回单个索引;对 batched logits 返回int64tensor
结构化值
onnxruntime.value(value)onnxruntime.optional(value, type_info)onnxruntime.sequence(items)onnxruntime.map(key_type, value_type, pairs)onnxruntime.sparse_tensor(type, dense_shape, indices, values)onnxruntime.sparse_tensor_from_dense(tensor)
适合处理非纯 tensor 的输入输出,例如 optional、sequence、map 和 sparse tensor。
当前行为可以概括为:
onnxruntime.value(x)如果x已经是 ORT tensor / value / sequence / map / sparse tensor,则原样返回;如果x是 Lua table,则按sequence处理;否则会把标量包装成 tensoronnxruntime.optional(value, type_info)第二个参数必填;type_info可以是字符串,也可以直接传session:input_info(...)/output_info(...)返回的类型信息表;空 optional 用onnxruntime.optional(nil, type_info)表示onnxruntime.map(key_type, value_type, pairs)当前key_type仅支持"string"或"int64"onnxruntime.sparse_tensor(type, dense_shape, indices, values)当前只支持数值 /bool稀疏张量,按 COO 方式构造;indices可以是扁平数组,也可以是坐标数组onnxruntime.sparse_tensor_from_dense(tensor)当前不支持stringtensor
常用对象方法:
value:type()/value:has_value()/value:get()sequence:length()/sequence:get(i)/sequence:items()map:get(key)/map:set(key, value)/map:keys()/map:pairs()sparse_tensor:dense_shape()/sparse_tensor:values()/sparse_tensor:indices()/sparse_tensor:format()/sparse_tensor:to_dense()
会话与推理
onnxruntime.session(model_path[, opts])onnxruntime.session_from_bytes(model_bytes[, opts])onnxruntime.run_options([opts])onnxruntime.load_custom_op_library(path)
会话对象负责模型加载、输入输出信息查询、执行推理以及 IOBinding。详见 会话模块。
支持的数据类型
当前张量接口支持以下元素类型名称:
"float32"/"float""float16""bfloat16""uint8""uint16""uint32""uint64""int8""int16""int32""int64""double"/"float64""bool""string"
说明:
tensor_from_bytes()和copy_from_bytes()只支持数值和bool类型bytes()不支持stringtensortensor:to("string")目前只支持string -> string
Provider 说明
onnxruntime.providers() 会返回运行时可见的 provider 列表,但当前 session 选项里原生处理并支持的 provider 字符串是:
"cpu""coreml"
说明:
provider/providers也接受CPUExecutionProvider、CoreMLExecutionProvider这类别名,内部会归一化到"cpu"、"coreml"- 如果没有显式指定 provider,或 provider 列表为空,session 创建阶段会自动补上 CPU provider
- 如果 provider 列表里包含
"coreml"且fallback_to_cpu = true,实现也可能补上 CPU 作为回退路径 - 如果你传入
providers = {"coreml", "cpu"},表示优先尝试 CoreML,再尝试 CPU
与 CoreML 联动
如果你要复用 coreml 分词器或 MLMultiArray 预处理流程,推荐这样组合:
local ort = require("onnxruntime")
local tokenizer = assert(coreml.new_text_tokenizer({
type = "wordpiece",
vocab_path = XXT_HOME_PATH.."/models/demo/vocab.txt",
context_length = 52,
}))
local input_ids = assert(tokenizer:encode("hello", {
output = "ort_tensor",
}))
或者把已有的 MLMultiArray 直接转成 ORT tensor:
local ort = require("onnxruntime")
local tensor = assert(multi_array:to_ort_tensor("int64"))