做过 NLP 的同学看到 TrOCR 的架构会觉得极其眼熟。因为它就是一个标准的 Encoder-Decoder 架构。

  • Encoder (ViT):把图片切成 16×16 的 Patch,变成一串向量。
  • Decoder (BERT):接收这串向量,利用 Cross-Attention,一个字一个字地把文本“吐”出来。

这种架构最大的好处是:它懂语言模型(LM)。当图片模糊时,它会根据上下文猜出是 “apple” 而不是 “apqle”。

1. 快速推理:Hello World

TrOCR 已经被集成在 transformers 中,开箱即用。微软开源了几个在 IAM(英文手写体数据集)上微调好的模型,效果惊人。

环境准备:

Bash

pip install transformers torch pillow

推理代码:

Python

from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch

# 1. 加载模型
# 微软提供了三个尺寸: small, base, large
# - microsoft/trocr-base-handwritten: 手写体 (强推)
# - microsoft/trocr-base-printed: 印刷体
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading model...")
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device)

# 2. 加载图片
# 注意:TrOCR 对图片尺寸不敏感,ViT 会自动 resize 和 normalize
url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" # 一个经典的 IAM 手写样本
# 如果是本地图片: image = Image.open("path/to/image.jpg").convert("RGB")
# 这里为了演示方便,假设你已经下载了图片
image = Image.open("handwritten_sample.jpg").convert("RGB")

# 3. 预处理
# pixel_values 是 ViT 需要的输入张量
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)

# 4. 生成 (Generation)
# 本质上是在跑 Beam Search
print("Generating text...")
generated_ids = model.generate(pixel_values)

# 5. 解码
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(f"Result: {generated_text}")

坑点预警:

  • 速度慢:因为是 Autoregressive 生成,输出 10 个字符就要跑 10 次 Decoder。在 CPU 上可能要几百毫秒。
  • 幻觉:如果图片极度模糊,它可能会生成一句通顺但错误的英文句子。

2. 微调实战:让它认得你的字

预训练模型只认识英文手写体。如果你要识别 中文手写体古籍、或者 特定字体的验证码,必须进行 Fine-tuning。

好消息是,你不需要重头训练 ViT,只需要微调。

步骤一:准备数据

你需要构建一个 Dataset 类。假设你有 images/ 目录和 labels.txt(格式:文件名 \t 文本)。

Python

import os
from torch.utils.data import Dataset
from PIL import Image

class OCRDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # 1. 读图片
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        image = Image.open(os.path.join(self.root_dir, file_name)).convert("RGB")

        # 2. 处理图片 (Encoder Input)
        pixel_values = self.processor(image, return_tensors="pt").pixel_values

        # 3. 处理文本 (Decoder Label)
        # 关键点:padding_side 设为 right
        labels = self.processor.tokenizer(
            text, 
            padding="max_length", 
            max_length=self.max_target_length
        ).input_ids
        
        # 4. 处理 Padding
        # HuggingFace 的 Loss 计算会自动忽略 -100 的 label
        # 我们把 pad_token_id 替换为 -100
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

步骤二:配置 Trainer

使用 Seq2SeqTrainer,这比自己写 Training Loop 要稳得多。

Python

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
from datasets import load_metric

# 加载 CER (Character Error Rate) 评估指标
cer_metric = load_metric("cer")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # 解码预测结果
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    
    # 解码 Label,要把 -100 还原回去
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    return {"cer": cer}

# 设置参数
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,     # 评估时生成文本
    evaluation_strategy="steps",
    per_device_train_batch_size=8,  # 显存不够就调小
    per_device_eval_batch_size=8,
    fp16=True,                      # 开启混合精度,显存省一半,速度快一倍
    output_dir="./trocr-finetuned",
    logging_steps=100,
    save_steps=500,
    eval_steps=500,
    save_total_limit=2,
    learning_rate=5e-5,             # 学习率不要太大
    num_train_epochs=3,
)

# 初始化 Trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor, # 注意这里传的是 feature_extractor
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)

# 开始炼丹
trainer.train()

3. 模型初始化:关于中文微调的特别说明

如果你要微调中文模型,直接加载 microsoft/trocr-base-handwritten 是不行的,因为它的 Decoder (RoBERTa) 词表里只有英文。

你需要做一个 “嫁接手术”

  1. Encoder:保留 TrOCR 的 ViT Encoder(它是通用的)。
  2. Decoder:替换为一个中文预训练模型(比如 bert-base-chinese)。

Python

from transformers import VisionEncoderDecoderModel, AutoTokenizer, TrOCRProcessor

# 1. 初始化一个混合模型
# Encoder 用 google/vit-base-patch16-224-in21k
# Decoder 用 bert-base-chinese
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    "google/vit-base-patch16-224-in21k", 
    "bert-base-chinese"
)

# 2. 设置 Decoder 的特殊 Token
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

# 3. 设置生成参数
model.config.eos_token_id = tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

# 接下来就和上面的微调流程一样了...

4. 总结:TrOCR 适合你吗?

TrOCR 是一个典型的 High Accuracy, Low Throughput(高精度,低吞吐)模型。

  • 不要用它做:实时视频流字幕提取、通用小票识别(PaddleOCR 更快更好)。
  • 一定要用它做
    • 手写体识别:精度目前是 SOTA。
    • 古籍/历史档案:那些 CRNN 根本切分不对的连笔字,TrOCR 能搞定。
    • 复杂验证码:能抗住扭曲和干扰线。

在 Hugging Face 生态下,TrOCR 把“看图说话”的门槛降到了最低。只要你有数据,你就能训练出一个专属的 OCR 专家。