做过 NLP 的同学看到 TrOCR 的架构会觉得极其眼熟。因为它就是一个标准的 Encoder-Decoder 架构。
- Encoder (ViT):把图片切成 16×16 的 Patch,变成一串向量。
- Decoder (BERT):接收这串向量,利用 Cross-Attention,一个字一个字地把文本“吐”出来。
这种架构最大的好处是:它懂语言模型(LM)。当图片模糊时,它会根据上下文猜出是 “apple” 而不是 “apqle”。
1. 快速推理:Hello World
TrOCR 已经被集成在 transformers 中,开箱即用。微软开源了几个在 IAM(英文手写体数据集)上微调好的模型,效果惊人。
环境准备:
Bash
pip install transformers torch pillow
推理代码:
Python
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch
# 1. 加载模型
# 微软提供了三个尺寸: small, base, large
# - microsoft/trocr-base-handwritten: 手写体 (强推)
# - microsoft/trocr-base-printed: 印刷体
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading model...")
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device)
# 2. 加载图片
# 注意:TrOCR 对图片尺寸不敏感,ViT 会自动 resize 和 normalize
url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" # 一个经典的 IAM 手写样本
# 如果是本地图片: image = Image.open("path/to/image.jpg").convert("RGB")
# 这里为了演示方便,假设你已经下载了图片
image = Image.open("handwritten_sample.jpg").convert("RGB")
# 3. 预处理
# pixel_values 是 ViT 需要的输入张量
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
# 4. 生成 (Generation)
# 本质上是在跑 Beam Search
print("Generating text...")
generated_ids = model.generate(pixel_values)
# 5. 解码
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(f"Result: {generated_text}")
坑点预警:
- 速度慢:因为是 Autoregressive 生成,输出 10 个字符就要跑 10 次 Decoder。在 CPU 上可能要几百毫秒。
- 幻觉:如果图片极度模糊,它可能会生成一句通顺但错误的英文句子。
2. 微调实战:让它认得你的字
预训练模型只认识英文手写体。如果你要识别 中文手写体、古籍、或者 特定字体的验证码,必须进行 Fine-tuning。
好消息是,你不需要重头训练 ViT,只需要微调。
步骤一:准备数据
你需要构建一个 Dataset 类。假设你有 images/ 目录和 labels.txt(格式:文件名 \t 文本)。
Python
import os
from torch.utils.data import Dataset
from PIL import Image
class OCRDataset(Dataset):
def __init__(self, root_dir, df, processor, max_target_length=128):
self.root_dir = root_dir
self.df = df
self.processor = processor
self.max_target_length = max_target_length
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# 1. 读图片
file_name = self.df['file_name'][idx]
text = self.df['text'][idx]
image = Image.open(os.path.join(self.root_dir, file_name)).convert("RGB")
# 2. 处理图片 (Encoder Input)
pixel_values = self.processor(image, return_tensors="pt").pixel_values
# 3. 处理文本 (Decoder Label)
# 关键点:padding_side 设为 right
labels = self.processor.tokenizer(
text,
padding="max_length",
max_length=self.max_target_length
).input_ids
# 4. 处理 Padding
# HuggingFace 的 Loss 计算会自动忽略 -100 的 label
# 我们把 pad_token_id 替换为 -100
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
return encoding
步骤二:配置 Trainer
使用 Seq2SeqTrainer,这比自己写 Training Loop 要稳得多。
Python
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
from datasets import load_metric
# 加载 CER (Character Error Rate) 评估指标
cer_metric = load_metric("cer")
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
# 解码预测结果
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
# 解码 Label,要把 -100 还原回去
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
return {"cer": cer}
# 设置参数
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True, # 评估时生成文本
evaluation_strategy="steps",
per_device_train_batch_size=8, # 显存不够就调小
per_device_eval_batch_size=8,
fp16=True, # 开启混合精度,显存省一半,速度快一倍
output_dir="./trocr-finetuned",
logging_steps=100,
save_steps=500,
eval_steps=500,
save_total_limit=2,
learning_rate=5e-5, # 学习率不要太大
num_train_epochs=3,
)
# 初始化 Trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor, # 注意这里传的是 feature_extractor
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=default_data_collator,
)
# 开始炼丹
trainer.train()
3. 模型初始化:关于中文微调的特别说明
如果你要微调中文模型,直接加载 microsoft/trocr-base-handwritten 是不行的,因为它的 Decoder (RoBERTa) 词表里只有英文。
你需要做一个 “嫁接手术”:
- Encoder:保留 TrOCR 的 ViT Encoder(它是通用的)。
- Decoder:替换为一个中文预训练模型(比如
bert-base-chinese)。
Python
from transformers import VisionEncoderDecoderModel, AutoTokenizer, TrOCRProcessor
# 1. 初始化一个混合模型
# Encoder 用 google/vit-base-patch16-224-in21k
# Decoder 用 bert-base-chinese
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
"google/vit-base-patch16-224-in21k",
"bert-base-chinese"
)
# 2. 设置 Decoder 的特殊 Token
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
# 3. 设置生成参数
model.config.eos_token_id = tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
# 接下来就和上面的微调流程一样了...
4. 总结:TrOCR 适合你吗?
TrOCR 是一个典型的 High Accuracy, Low Throughput(高精度,低吞吐)模型。
- 不要用它做:实时视频流字幕提取、通用小票识别(PaddleOCR 更快更好)。
- 一定要用它做:
- 手写体识别:精度目前是 SOTA。
- 古籍/历史档案:那些 CRNN 根本切分不对的连笔字,TrOCR 能搞定。
- 复杂验证码:能抗住扭曲和干扰线。
在 Hugging Face 生态下,TrOCR 把“看图说话”的门槛降到了最低。只要你有数据,你就能训练出一个专属的 OCR 专家。