为什么你的 Flask 服务慢? 因为它是 串行 的。 请求 A 进来 -> CPU 预处理 -> GPU 推理 -> CPU 后处理 -> 返回。 在这个过程中,当 CPU 在做 cv2.resize 时,GPU 是闲着的;当 GPU 在算矩阵乘法时,CPU 是闲着的。

Triton 的核心价值在于:它是一个调度器。 它支持 Dynamic Batching(动态批处理)Concurrent Model Execution(模型并发执行)。它能把 0.1 秒内涌进来的 16 个请求,打包成一个 Batch 扔给 GPU,瞬间算完。

1. 架构改造:从 Paddle 到 ONNX

虽然 Triton 支持 Python Backend(直接跑 Paddle 代码),但这依然受限于 Python 性能。 为了极致性能,我们通常走这条路: Paddle 模型 -> ONNX 格式 -> Triton (ONNX Runtime Backend)

第一步:模型转换 使用 paddle2onnx 把你的 inference.pdmodel 转成 .onnx

Bash

# 安装转换工具
pip install paddle2onnx

# 转换检测模型 (Det)
paddle2onnx --model_dir ./ch_PP-OCRv4_det_infer \
            --model_filename inference.pdmodel \
            --params_filename inference.pdiparams \
            --save_file ./models/det/1/model.onnx \
            --opset_version 11 \
            --enable_onnx_checker True

# 识别模型 (Rec) 同理转换

2. Triton 目录结构:严谨的工业风

Triton 对目录结构有严格要求,这很多新手容易踩坑的地方。你需要建立一个 model_repository

Plaintext

model_repository/
  └── text_detection/           <-- 模型名称
      ├── config.pbtxt          <-- 配置文件 (灵魂所在)
      └── 1/                    <-- 版本号
          └── model.onnx        <-- 模型文件

3. 核心配置:config.pbtxt (榨干性能的关键)

这是 Triton 最强大的地方。你需要告诉它:输入是什么形状?最大 Batch 是多少?要不要开动态批处理?

以下是一个典型的 PaddleOCR 检测模型 (Det) 的配置文件:

Protocol Buffers

name: "text_detection"
platform: "onnxruntime_onnx"
max_batch_size: 16  # 关键:允许最大 16 个请求合并

# 输入定义 (根据你的模型导出时的 shape 填写)
input [
  {
    name: "x"
    data_type: TYPE_FP32
    dims: [ 3, -1, -1 ]  # 3通道,高宽动态
  }
]

# 输出定义
output [
  {
    name: "sigmoid_0.tmp_0"
    data_type: TYPE_FP32
    dims: [ 1, -1, -1 ]
  }
]

# 核心魔法:动态批处理
dynamic_batching {
  preferred_batch_size: [ 4, 8, 16 ]
  max_queue_delay_microseconds: 1000  # 等待 1ms 凑 Batch,凑不够也发车
}

# 实例组:多实例并行
instance_group [
  {
    count: 2  # 在一张卡上起 2 个实例,掩盖 CPU/GPU 传输延迟
    kind: KIND_GPU
  }
]

解读:

  • dynamic_batching:如果有 10 个请求同时到,Triton 会自动凑成一个 Batch=10 的矩阵给 GPU。这比处理 10 次 Batch=1 的矩阵快得多。
  • instance_group:在 GPU 显存允许的情况下,多开几个模型实例。当一个实例在做计算时,另一个实例可以做数据拷贝。

4. 客户端调用:gRPC 才是正道

别用 HTTP (requests) 了,那太慢。Triton 支持 gRPC,基于 Protocol Buffers,序列化效率极高。

你需要安装客户端库: pip install tritonclient[all]

Python 客户端代码示例:

Python

import tritonclient.grpc as grpcclient
import numpy as np
import cv2

# 1. 连接 Triton Server
triton_client = grpcclient.InferenceServerClient(url="localhost:8001")

def infer_det(image_path):
    # 2. 预处理 (Client Side Preprocessing)
    # 注意:Triton 也可以把预处理做进模型里,但通常我们在 Client 做
    img = cv2.imread(image_path)
    img = cv2.resize(img, (960, 960)) # 示例尺寸
    img = img.transpose((2, 0, 1))    # HWC -> CHW
    img = img.astype(np.float32) / 255.0
    img = np.expand_dims(img, axis=0) # Add Batch Dim: (1, 3, 960, 960)

    # 3. 构造输入
    inputs = []
    inputs.append(grpcclient.InferInput('x', [1, 3, 960, 960], "FP32"))
    inputs[0].set_data_from_numpy(img)

    # 4. 构造输出占位符
    outputs = []
    outputs.append(grpcclient.InferRequestedOutput('sigmoid_0.tmp_0'))

    # 5. 发送请求
    results = triton_client.infer(
        model_name="text_detection",
        inputs=inputs,
        outputs=outputs
    )

    # 6. 获取结果
    output_data = results.as_numpy('sigmoid_0.tmp_0')
    print(f"Inference Shape: {output_data.shape}")

if __name__ == "__main__":
    # 模拟并发,你可以开多线程跑这个函数
    infer_det("test.jpg")

5. 进阶架构:Ensemble (流水线编排)

你可能会问:“OCR 是先检测(Det)再识别(Rec),这中间还有个抠图(Crop)的操作,怎么在 Triton 里搞?”

Triton 提供了一种叫 Ensemble 的模式。你可以写一个 Python Backend 模型作为“胶水”,把 Det 和 Rec 串起来。

架构如下:

  1. Request -> Ensemble Model
  2. Ensemble 调用 -> Det Model (GPU)
  3. Det Output -> Python Processing (CPU, 负责 Crop 图片)
  4. Cropped Images -> Rec Model (GPU)
  5. Rec Output -> Response

这样,整个 OCR 流程都在 Server 内部完成,客户端只需要发一张图,拿回来的就是文字结果。这大大减少了网络传输的开销。

6. 性能对比:Flask vs Triton

在 T4 显卡上,处理 100 张发票的基准测试:

  • Flask (单进程)
    • QPS: ~12
    • GPU 利用率: ~25% (大量时间在等待 CPU resize)
    • 延迟: 80ms/张
  • Triton (Dynamic Batching + ONNX Runtime)
    • QPS: ~65 (提升 5 倍)
    • GPU 利用率: ~85% (吃满了!)
    • 延迟: 20ms/张 (在 batch 较大的情况下)

总结

如果你的 OCR 服务只是给公司内部财务提个报销,Flask 够用了。 但如果你的服务是面向 C 端用户的(比如扫描全能王这种),或者需要处理海量历史文档,Triton 是必经之路

它把“如何高效使用 GPU”这个问题,从代码层面剥离到了配置层面。你只需要写好 config.pbtxt,剩下的并发、调度、队列,NVIDIA 帮你搞定。

别让 Python 的龟速,限制了你显卡的野兽性能。