跳到主要内容

ONNX Runtime 模块 (onnxruntime)

onnxruntime 模块用于在设备上直接加载和运行 ONNX 模型,适合文本、Embedding、分类、检测以及各种通用张量推理场景。

仅支持 iOS 13 及以上版本系统

加载模块

local ort = require("onnxruntime")

这是一个按需加载模块,不像 coreml 那样是内置全局模块。

require("onnxruntime") 成功执行后,还会向内置 coreml 模块注入两组桥接接口:

  • coreml.multi_array_from_ort_tensor(tensor[, data_type])
  • multi_array:to_ort_tensor([data_type])

这两组转换都在 native 层直接拷贝,不经过 Lua table。

模块级函数

运行时与基础信息

  • onnxruntime.version()
  • onnxruntime.providers()
  • onnxruntime.configure(opts)

说明:

  • providers() 返回当前 ORT 运行时实际可用的 Execution Provider 列表
  • configure() 用于设置全局运行时默认值,必须在创建任何 session 之前调用

张量、图像与数值辅助

  • onnxruntime.tensor(type, shape[, data])
  • onnxruntime.tensor_from_bytes(type, shape, bytes)
  • onnxruntime.tensor_from_cv_mat(mat[, opts])
  • onnxruntime.tensor_from_quad(mat, quad[, opts])
  • onnxruntime.tensor_from_quads(mat, quads[, opts])
  • onnxruntime.tensor_from_image(image[, opts])
  • onnxruntime.tensor_from_images(images[, opts])
  • onnxruntime.image_from_tensor(tensor[, opts])
  • onnxruntime.clamp(tensor, min, max)
  • onnxruntime.sigmoid(tensor)
  • onnxruntime.exp(tensor)
  • onnxruntime.where(condition, x, y)
  • onnxruntime.matmul(lhs, rhs)
  • onnxruntime.concat(tensors[, axis])
  • onnxruntime.stack(tensors[, axis])

说明:

  • clamp()sigmoid()exp()matmul() 与同名 tensor: 方法等价,只是把 tensor 作为第一个参数传入
  • where() 支持标量 / 布尔值 / tensor 混用,并按广播规则生成结果
  • 图像预处理、OpenCV 桥接和 image_from_tensor() 的细节,详见 张量模块

检测、解码与后处理辅助

  • onnxruntime.nms(boxes, scores[, opts])
  • onnxruntime.box_points(rotated_boxes)
  • onnxruntime.xywh_to_xyxy(boxes)
  • onnxruntime.xyxy_to_xywh(boxes)
  • onnxruntime.rotated_iou(lhs_box, rhs_box)
  • onnxruntime.rotated_nms(boxes, scores[, opts])
  • onnxruntime.create_decoder(schema)
  • onnxruntime.decode_yolo(output[, opts])
  • onnxruntime.decode_yolo_obb(output[, opts])
  • onnxruntime.decode_matrix_candidates(output, schema[, opts])
  • onnxruntime.decode_dense_detection(output, opts)
  • onnxruntime.records_from_boxes(boxes, scores, class_ids[, keep_indices])
  • onnxruntime.obb_records_from_rows(rows, scores, class_ids[, angles[, keep_indices[, opts]]])
  • onnxruntime.points_to_records(points[, opts])
  • onnxruntime.threshold_masks(masks, threshold)
  • onnxruntime.crop_masks_by_boxes(masks, boxes)
  • onnxruntime.resize_masks(masks, width, height[, opts])
  • onnxruntime.mask_iou(lhs_mask, rhs_mask)
  • onnxruntime.mask_to_polygon(mask[, opts])
  • onnxruntime.proto_masks(proto, coeffs, boxes, image_width, image_height[, opts])
  • onnxruntime.project_masks(proto, coeffs, boxes, image_width, image_height[, opts])
  • onnxruntime.db_postprocess(score_map[, opts])
  • onnxruntime.tracker([opts])
  • onnxruntime.reshape_keypoints(points[, keypoint_count[, keypoint_dim|opts]])
  • onnxruntime.scale_boxes(boxes, transform)
  • onnxruntime.clip_boxes(boxes, clip_width, clip_height)
  • onnxruntime.scale_points(points, transform[, opts])
  • onnxruntime.scale_keypoints(points, transform[, opts])
  • onnxruntime.clip_keypoints(points, clip_width, clip_height[, opts])
  • onnxruntime.ctc_greedy_decode(logits[, opts])
  • onnxruntime.sample_logits(logits[, opts])

说明:

  • tensor_from_quad() / tensor_from_quads() 需要先 require("image.cv"),适合 OCR 四边形裁剪后直接生成张量
  • box_points() 接收形状为 [5][1, 5][N, 5] 的旋转框 tensor,不是五个分离标量参数
  • create_decoder() 返回 decoder 对象,支持 :decode():task():schema()
  • tracker() 返回 tracker 对象,支持 :update():reset():state():close()
  • records_from_boxes()obb_records_from_rows()points_to_records() 会把 tensor 结果整理成更适合 Lua 侧消费的 record table
  • proto_masks()project_masks() 当前是同一套实现,后者只是别名
  • mask_iou() 用于直接计算两张 mask 的交并比,也支持第三个参数 opts,可传 compare_size = true,或显式传 width / height
  • db_postprocess() 适合 DB / DBNet 一类文本检测后处理;返回的每个检测项都带 scorepointsbox
  • decode_dense_detection() 要求 opts.strides 为非空正整数数组,并且还需要 decode_widthdecode_height;当前只支持 box_encoding = "grid_center_log_wh"
  • ctc_greedy_decode() 支持 blank_indexmerge_repeatedapply_softmaxreturn_probabilitiescharset
  • ctc_greedy_decode() 一定返回 indicestext 仅在传入 charset 时返回;confidence 仅在启用 apply_softmaxreturn_probabilities 时返回;probabilitiesprobability_confidence 仅在启用 return_probabilities 时返回
  • nms() / rotated_nms() 返回的是 int64 tensor,索引语义为 1-based
  • sample_logits() 支持 argmaxtemperaturetop_ktop_pmin_pseed
  • sample_logits() 对 1D logits 返回单个索引;对 batched logits 返回 int64 tensor

结构化值

  • onnxruntime.value(value)
  • onnxruntime.optional(value, type_info)
  • onnxruntime.sequence(items)
  • onnxruntime.map(key_type, value_type, pairs)
  • onnxruntime.sparse_tensor(type, dense_shape, indices, values)
  • onnxruntime.sparse_tensor_from_dense(tensor)

适合处理非纯 tensor 的输入输出,例如 optional、sequence、map 和 sparse tensor。

当前行为可以概括为:

  • onnxruntime.value(x) 如果 x 已经是 ORT tensor / value / sequence / map / sparse tensor,则原样返回;如果 x 是 Lua table,则按 sequence 处理;否则会把标量包装成 tensor
  • onnxruntime.optional(value, type_info) 第二个参数必填;type_info 可以是字符串,也可以直接传 session:input_info(...) / output_info(...) 返回的类型信息表;空 optional 用 onnxruntime.optional(nil, type_info) 表示
  • onnxruntime.map(key_type, value_type, pairs) 当前 key_type 仅支持 "string""int64"
  • onnxruntime.sparse_tensor(type, dense_shape, indices, values) 当前只支持数值 / bool 稀疏张量,按 COO 方式构造;indices 可以是扁平数组,也可以是坐标数组
  • onnxruntime.sparse_tensor_from_dense(tensor) 当前不支持 string tensor

常用对象方法:

  • value:type() / value:has_value() / value:get()
  • sequence:length() / sequence:get(i) / sequence:items()
  • map:get(key) / map:set(key, value) / map:keys() / map:pairs()
  • sparse_tensor:dense_shape() / sparse_tensor:values() / sparse_tensor:indices() / sparse_tensor:format() / sparse_tensor:to_dense()

会话与推理

  • onnxruntime.session(model_path[, opts])
  • onnxruntime.session_from_bytes(model_bytes[, opts])
  • onnxruntime.run_options([opts])
  • onnxruntime.load_custom_op_library(path)

会话对象负责模型加载、输入输出信息查询、执行推理以及 IOBinding。详见 会话模块

支持的数据类型

当前张量接口支持以下元素类型名称:

  • "float32" / "float"
  • "float16"
  • "bfloat16"
  • "uint8"
  • "uint16"
  • "uint32"
  • "uint64"
  • "int8"
  • "int16"
  • "int32"
  • "int64"
  • "double" / "float64"
  • "bool"
  • "string"

说明:

  • tensor_from_bytes()copy_from_bytes() 只支持数值和 bool 类型
  • bytes() 不支持 string tensor
  • tensor:to("string") 目前只支持 string -> string

Provider 说明

onnxruntime.providers() 会返回运行时可见的 provider 列表,但当前 session 选项里原生处理并支持的 provider 字符串是:

  • "cpu"
  • "coreml"

说明:

  • provider / providers 也接受 CPUExecutionProviderCoreMLExecutionProvider 这类别名,内部会归一化到 "cpu""coreml"
  • 如果没有显式指定 provider,或 provider 列表为空,session 创建阶段会自动补上 CPU provider
  • 如果 provider 列表里包含 "coreml"fallback_to_cpu = true,实现也可能补上 CPU 作为回退路径
  • 如果你传入 providers = {"coreml", "cpu"},表示优先尝试 CoreML,再尝试 CPU

与 CoreML 联动

如果你要复用 coreml 分词器或 MLMultiArray 预处理流程,推荐这样组合:

local ort = require("onnxruntime")

local tokenizer = assert(coreml.new_text_tokenizer({
type = "wordpiece",
vocab_path = XXT_HOME_PATH.."/models/demo/vocab.txt",
context_length = 52,
}))

local input_ids = assert(tokenizer:encode("hello", {
output = "ort_tensor",
}))

或者把已有的 MLMultiArray 直接转成 ORT tensor:

local ort = require("onnxruntime")
local tensor = assert(multi_array:to_ort_tensor("int64"))