← 返回首页
🤖

机器翻译入门

📂 ai ⏱ 3 min 419 words

机器翻译入门

什么是机器翻译

机器翻译是将文本从一种语言自动翻译到另一种语言的任务,是NLP的经典问题。

Seq2Seq模型

编码器-解码器架构是机器翻译的基础:

import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_dim, embed_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
    
    def forward(self, src):
        embedded = self.embedding(src)
        outputs, hidden = self.rnn(embedded)
        return outputs, hidden

class Decoder(nn.Module):
    def __init__(self, output_dim, embed_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.rnn = nn.GRU(embed_dim + hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, input, hidden, context):
        embedded = self.embedding(input)
        rnn_input = torch.cat([embedded, context], dim=-1)
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        prediction = self.fc(output.squeeze(1))
        return prediction, hidden, output

encoder = Encoder(input_dim=1000, embed_dim=256, hidden_dim=512)
decoder = Decoder(output_dim=1000, embed_dim=256, hidden_dim=512)

注意力机制

注意力机制帮助解码器关注源序列的关键位置:

class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim):
        super(Attention, self).__init__()
        self.W_encoder = nn.Linear(encoder_dim, decoder_dim)
        self.W_decoder = nn.Linear(decoder_dim, decoder_dim)
        self.V = nn.Linear(decoder_dim, 1)
    
    def forward(self, encoder_outputs, decoder_hidden):
        score = self.V(torch.tanh(
            self.W_encoder(encoder_outputs) + 
            self.W_decoder(decoder_hidden.unsqueeze(1))
        ))
        attention_weights = torch.softmax(score, dim=1)
        context = torch.bmm(attention_weights.transpose(1, 2), 
                           encoder_outputs)
        return context, attention_weights

attention = Attention(encoder_dim=512, decoder_dim=512)
encoder_outputs = torch.randn(32, 50, 512)
decoder_hidden = torch.randn(32, 512)
context, weights = attention(encoder_outputs, decoder_hidden)
print("上下文向量形状:", context.shape)

Transformer架构

Transformer使用自注意力机制:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(output)
        
        return output

mha = MultiHeadAttention(d_model=512, num_heads=8)
query = torch.randn(32, 50, 512)
output = mha(query, query, query)
print("输出形状:", output.shape)

BLEU评估

BLEU是机器翻译的主要评估指标:

import math
from collections import Counter

def compute_bleu(reference, hypothesis, max_order=4):
    matches_by_order = [0] * max_order
    possible_matches_by_order = [0] * max_order
    
    reference_tokens = reference.split()
    hypothesis_tokens = hypothesis.split()
    
    clipped_counts = 0
    for n in range(1, max_order + 1):
        ref_ngrams = Counter([tuple(reference_tokens[i:i+n]) 
                             for i in range(len(reference_tokens) - n + 1)])
        hyp_ngrams = Counter([tuple(hypothesis_tokens[i:i+n]) 
                             for i in range(len(hypothesis_tokens) - n + 1)])
        
        for ngram, count in hyp_ngrams.items():
            matches_by_order[n-1] += min(count, ref_ngrams.get(ngram, 0))
            possible_matches_by_order[n-1] += count
    
    precisions = [0] * max_order
    for i in range(max_order):
        if possible_matches_by_order[i] > 0:
            precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
    
    if min(precisions) > 0:
        p_log_sum = sum(1/max_order * math.log(p) for p in precisions)
        bleu = math.exp(p_log_sum)
    else:
        bleu = 0
    
    return bleu

reference = "the cat is on the mat"
hypothesis = "the cat sat on the mat"
bleu = compute_bleu(reference, hypothesis)
print(f"BLEU分数: {bleu:.4f}")

总结

机器翻译是NLP的重要应用。从Seq2Seq到Transformer,注意力机制的引入显著提升了翻译质量,BLEU分数是评估翻译效果的关键指标。