机器翻译入门
机器翻译入门
什么是机器翻译
机器翻译是将文本从一种语言自动翻译到另一种语言的任务,是NLP的经典问题。
Seq2Seq模型
编码器-解码器架构是机器翻译的基础:
import torch
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, input_dim, embed_dim, hidden_dim):
super(Encoder, self).__init__()
self.embedding = nn.Embedding(input_dim, embed_dim)
self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
def forward(self, src):
embedded = self.embedding(src)
outputs, hidden = self.rnn(embedded)
return outputs, hidden
class Decoder(nn.Module):
def __init__(self, output_dim, embed_dim, hidden_dim):
super(Decoder, self).__init__()
self.embedding = nn.Embedding(output_dim, embed_dim)
self.rnn = nn.GRU(embed_dim + hidden_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, input, hidden, context):
embedded = self.embedding(input)
rnn_input = torch.cat([embedded, context], dim=-1)
output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
prediction = self.fc(output.squeeze(1))
return prediction, hidden, output
encoder = Encoder(input_dim=1000, embed_dim=256, hidden_dim=512)
decoder = Decoder(output_dim=1000, embed_dim=256, hidden_dim=512)
注意力机制
注意力机制帮助解码器关注源序列的关键位置:
class Attention(nn.Module):
def __init__(self, encoder_dim, decoder_dim):
super(Attention, self).__init__()
self.W_encoder = nn.Linear(encoder_dim, decoder_dim)
self.W_decoder = nn.Linear(decoder_dim, decoder_dim)
self.V = nn.Linear(decoder_dim, 1)
def forward(self, encoder_outputs, decoder_hidden):
score = self.V(torch.tanh(
self.W_encoder(encoder_outputs) +
self.W_decoder(decoder_hidden.unsqueeze(1))
))
attention_weights = torch.softmax(score, dim=1)
context = torch.bmm(attention_weights.transpose(1, 2),
encoder_outputs)
return context, attention_weights
attention = Attention(encoder_dim=512, decoder_dim=512)
encoder_outputs = torch.randn(32, 50, 512)
decoder_hidden = torch.randn(32, 512)
context, weights = attention(encoder_outputs, decoder_hidden)
print("上下文向量形状:", context.shape)
Transformer架构
Transformer使用自注意力机制:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_o(output)
return output
mha = MultiHeadAttention(d_model=512, num_heads=8)
query = torch.randn(32, 50, 512)
output = mha(query, query, query)
print("输出形状:", output.shape)
BLEU评估
BLEU是机器翻译的主要评估指标:
import math
from collections import Counter
def compute_bleu(reference, hypothesis, max_order=4):
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
reference_tokens = reference.split()
hypothesis_tokens = hypothesis.split()
clipped_counts = 0
for n in range(1, max_order + 1):
ref_ngrams = Counter([tuple(reference_tokens[i:i+n])
for i in range(len(reference_tokens) - n + 1)])
hyp_ngrams = Counter([tuple(hypothesis_tokens[i:i+n])
for i in range(len(hypothesis_tokens) - n + 1)])
for ngram, count in hyp_ngrams.items():
matches_by_order[n-1] += min(count, ref_ngrams.get(ngram, 0))
possible_matches_by_order[n-1] += count
precisions = [0] * max_order
for i in range(max_order):
if possible_matches_by_order[i] > 0:
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
if min(precisions) > 0:
p_log_sum = sum(1/max_order * math.log(p) for p in precisions)
bleu = math.exp(p_log_sum)
else:
bleu = 0
return bleu
reference = "the cat is on the mat"
hypothesis = "the cat sat on the mat"
bleu = compute_bleu(reference, hypothesis)
print(f"BLEU分数: {bleu:.4f}")
总结
机器翻译是NLP的重要应用。从Seq2Seq到Transformer,注意力机制的引入显著提升了翻译质量,BLEU分数是评估翻译效果的关键指标。