NLLB 翻译模型实战


Hugging FaceGitHub 相似,专注于 AI 模型领域的开源社区,拥有丰富的模型可实现快速集成。社区吸引了众多的优秀的开发者及公司机构,源源不断的产出优秀成果。

今天,就让我们通过 Python 集成 facebook 开源的翻译模型,快速打造专属翻译引擎,以 API 接口为媒介为应用提供基础翻译服务。

一、基本准备

1. 模型下载

nllb-200-distilled-600M 模型是由 Facebook 开源的 AI 翻译模型,模型支持 200 种语言翻译。

开始之前需要下载相应的模型文件,访问 nllb 模型相应的仓库下载图中标注内容:nllb-200-distilled-600M

2. 依赖下载

新建 Python 项目并安装下述依赖,服务将通过 torch 加载模型,并基于 fastapi 框架以 API 接口形式对外开放服务。同时,利用 PyYAML 库实现 yml 文件内容配置化,并通过 langdetect 语言库实现语言检测。

新建 requirements.txt 依赖文件,类库依赖内容如下:

fastapi==0.110.0
uvicorn[standard]==0.29.0

transformers==4.39.3
torch>=1.13.1

sentencepiece==0.1.99

protobuf==3.20.3

pydantic==1.10.13

PyYAML~=6.0.2
langdetect~=1.0.9

二、系统配置

1. 配置文件

为了更好的方便系统维护,利用 PyYAML 库可实现 yml 文件的解析,实现配置化管理。

新建 application.yml 配置文件,定义对应的接口服务信息,以及默认上述下载后的模型存储路径。

app:
  name: Daily Word
  description: Word Translation
  version: 1.0
  host: 127.0.0.1
  port: 8080

model:
  path: "D:/Workspace/Model"

2. 实体定义

@dataclass3.7+ 版本标准库中的轻量级数据容器,提供了基础的对象定义。

这里通过类对象更便捷的交互配置文件,分别新建 AppConfig.pyModelConfig.py 文件,文件内容如下:

from dataclasses import dataclass

@dataclass
class AppConfig:
    name: str
    description: str
    version: str
    host: str
    port: int


@dataclass
class ModelConfig:
    path: str

2. 文件加载

定义类文件后便可通过 yaml.safe_load(file) 方法加载 yml 配置文件。

完整实现代码如下,后续便可访问 get_app_config()get_model_config() 方法或者配置信息。

import yaml
import os
from config.AppConfig import AppConfig
from config.ModelConfig import ModelConfig

# 配置文件
def load_config():
    # 读取相对路径
    base_dir = os.path.dirname(os.path.abspath(__file__))
    config_path = os.path.join(base_dir, '../../config/application.yml')

    # 文件存在判断
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Not found config file of: {config_path}")

    with open(config_path, 'r', encoding='utf-8') as file:
        return yaml.safe_load(file)

# 加载配置文件
_config_data = load_config()

# 获取系统配置
def get_app_config() -> AppConfig:
    app_data = _config_data.get('app', {})
    return AppConfig(**app_data)

# 获取模型配置
def get_model_config() -> ModelConfig:
    model_data = _config_data.get('model', {})
    return ModelConfig(**model_data)

三、引擎实现

1. 语言检测

通过 langdetect 库可实现文本语言类型检测,服务调用时则无需手动声明类型,可由系统自行匹配。

langdetect 语言表示与 nllb 模型语言表达并不一致,因此系统内仍需要维护一份语言映射表,此处定义的方法作用描述参考下表:

方法 作用
detect_lang() 检测文本语言。
convert_type() 转化映射语言类型。
reverse_type() 反转映射语言类型。

完整实现代码如下:

from langdetect import detect

LANG_CODE_MAP = {
    'zh': 'zho_Hans',
    'en': 'eng_Latn',
    'fr': 'fra_Latn',
    'de': 'deu_Latn',
    'ja': 'jpn_Jpan',
    'ko': 'kor_Hang',
    'es': 'spa_Latn'
}

REVERSE_LANG_CODE_MAP = {v: k for k, v in LANG_CODE_MAP.items()}


# 语言检测
def detect_lang(text: str):
    try:
        detected = detect(text)
        if detected == "zh-cn" or detected =="zh-tw":
            detected = "zh"

        # 类型映射
        return LANG_CODE_MAP.get(detected, 'eng_Latn')
    except Exception:
        return 'eng_Latn'

# 类型转化
def convert_type(lang_type: str) -> str:
    try:
        return LANG_CODE_MAP[lang_type.lower()]
    except KeyError:
        raise ValueError(f"Unsupported language type: {lang_type}")

# 类型反转
def reverse_type(code: str) -> str:
    try:
        return REVERSE_LANG_CODE_MAP[code]
    except KeyError:
        raise ValueError(f"Unsupported language code: {code}")

2. 模型加载

下载 torchtransformers 类库后,便可开箱即用加载上述下载完成的模型。

代码逻辑十分简洁,其中 local_files_only=True 表示模型读取本地文件,当然也支持读取线上资源,但本地加载模型相对速度更快。

需注意一点,本地加载 NLLB 模型需要机器内存至少不低于 4G

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from pathlib import Path
# 导入工具
from tool.ConfigTool import get_model_config

# 配置加载
model_config = get_model_config()

# 模型加载
model_path = Path(model_config.path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)

四、接口服务

1. 实体定义

在上述的基本工作完成后,便可编写对应的接口服务。同样的,为接口的出入参定义相应的实体类 RequestDTOResponseDTO

可以看到此处定义方式并非通过 @dataclass 实现,而是由第三方类库 PydanticBaseModel 声明。

from pydantic import BaseModel 

class RequestDTO(BaseModel):
    text: str
    targetType: str

class ResponseDTO(BaseModel):
    sourceType: str
    sourceText: str
    targetType: str
    targetText: str

Python 标准库中自带的 @dataclass 虽然能够实现类定义,但功能简单适合简单的数据结构。而 BaseModel 额外包含了类型自动转化、Json 格式解析与导出等功能,提供了强大的数据模型。

例如下述示例中,由 BaseModel 定义的 User2 类成员变量 age 类型为 int,在输入字符串时将会自动转化格式,而 @dataclass 则无法实现。

from dataclasses import dataclass

@dataclass
class User1:
    name: str
    age: int

user1 = User1(name="Alice", age="30")
# name='Alice', age='30'
print(user1)


from pydantic import BaseModel

class User2(BaseModel):
    name: str
    age: int

user2 = User2(name="Alice", age="30")
# name='Alice' age=30
print(user2)

同时若 age 属性赋值非数字或非数据字符串将提示下述错误,因为 BaseModel 自带类型校验。

因此,在系统设计中 BaseModel 更多的应用于强业务逻辑中,如常见的方法与接口的出入参。

pydantic.error_wrappers.ValidationError: 1 validation error for User2
age
  value is not a valid integer (type=type_error.integer)

2. 接口定义

接下来便可基于 fastapi 框架编写对应的接口服务。

在上述内容中我们已经各个功能点拆分,接下要做的组装即可,步骤逻辑如下:

  • 检测输入文本语言;
  • 校验语言是否支持;
  • 调用翻译模型服务;
  • 封装翻译结果返回;

相对应的完整实现代码如下:

import torch
from fastapi import APIRouter, HTTPException
from schemas.RequestDTO import RequestDTO
from schemas.ResponseDTO import ResponseDTO
from tool.LanguageTool import detect_lang, convert_type, reverse_type
from tool.ModelTool import model, tokenizer

# 导出接口
router = APIRouter()

# 接口服务
@router.post("/nllb/translate")
def translate(req: RequestDTO) -> ResponseDTO:
    source_text = req.text
    # 检测输入语言
    source_type = detect_lang(source_text)

    # 语言类型校验
    target_type = convert_type(req.targetType)
    if not target_type or target_type not in tokenizer.lang_code_to_id:
        raise HTTPException(status_code = 400, detail = f"Unsupported target language: {target_type}")

    # 语言类型一致直接返回
    if source_type == target_type:
         return ResponseDTO(
            sourceType = req.targetType,
            sourceText = source_text,
            targetType = req.targetType,
            targetText = source_text
        )

    # 翻译内容
    tokenizer.src_lang = source_type
    inputs = tokenizer(source_text, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id = tokenizer.convert_tokens_to_ids(target_type),
            max_length=256
        )
    # 解码输出
    result = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

    # 返回结果
    return ResponseDTO(
        # 反转类型
        sourceType = reverse_type(source_type),
        sourceText = source_text,
        targetType = req.targetType,
        targetText = result
    )

4. 启动入口

完成这一切工作之后,新建 main.py 文件编写程序入口。代码内容相对简单,定义 FastAPI 相应的信息后挂载上述定义的接口服务,完成后运行程序便可启动服务。

import uvicorn
from fastapi import FastAPI
from tool.ConfigTool import get_app_config
from api.WordResource import router as translate_router

# 系统信息
app_config = get_app_config()
app = FastAPI(
    title = app_config.name,
    description = app_config.description,
    version = app_config.version
)

# 挂在接口
app.include_router(translate_router)

# 启动服务
if __name__ == "__main__":
    uvicorn.run(
        app,
        host = app_config.host,
        port = app_config.port
    )

文中涉及代码已上传 GitHub,项目地址:translation-engine


文章作者: 烽火戏诸诸诸侯
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 烽火戏诸诸诸侯 !
  目录