模型形状修复与改进 | AI生成和翻译

Home 2025.09

以下是你的脚本逐段执行的内容,以及你可能需要进行的几项重要修正。

模型结构

预期的张量形状(关键)

由于你没有设置 batch_first=True,PyTorch LSTM的默认输入形状是 (seq_len, batch_size, input_size)

你的代码可能出错的地方(形状)

你执行了:

inputs_one_hot = inputs_one_hot.permute(0, 2, 1)

如果 one_hot_encode_sequence 返回 (batch, seq_len, vocab),那么对于默认的LSTM,正确的置换是 permute(1, 0, 2)(以得到 (seq_len, batch, vocab))。你当前的 permute(0,2,1) 会给出 (batch, vocab, seq_len),这将导致不匹配,如果之前的形状与假设不同,要么会出错,要么会静默地学习无意义的内容。

两个简洁的选项(选择其一):

A) 保持默认的LSTM(批次第二):

# 需要 (seq_len, batch, input_size)
inputs_one_hot = inputs_one_hot.permute(1, 0, 2)  # (seq_len, 1, vocab_size)

B) 使其批次优先(通常更简单):

self.lstm = nn.LSTM(input_size=vocab_size, hidden_size=50, batch_first=True)
# 需要 (batch, seq_len, input_size)
# 然后 不要 置换;保持为 (1, seq_len, vocab_size)

训练循环逻辑

推理与打印输出

小型但高影响力的改进

  1. 修正输入形状(最重要)。使用上述A或B选项之一。如果你选择批次优先,将前向传播改为:

    x, (h, c) = self.lstm(x)         # x: (batch, seq, hidden)
    x = x.reshape(-1, self.lstm.hidden_size)  # (batch*seq, hidden)
    
  2. 在评估/推理时使用 with torch.no_grad() 以避免跟踪梯度,并提高速度/内存效率:

    net.eval()
    with torch.no_grad():
        outputs = net(inputs_one_hot)
        loss = criterion(outputs, targets_idx)
    

    以及用于numpy转换:

    outputs = net(inputs_one_hot).detach().numpy()
    
  3. 也为PyTorch设置随机种子 以确保可复现性:

    torch.manual_seed(42)
    
  4. 设备处理(如果你使用GPU):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net.to(device)
    inputs_one_hot = inputs_one_hot.to(device)
    targets_idx = targets_idx.to(device)
    
  5. 目标对齐:确保 targets 是与每个输入时间步对齐的下一词元。如果你的数据集以这种方式配对 (inputs, targets),那就没问题。否则,创建 targets = inputs[1:] 并丢弃最后一个输入或类似操作。
  6. 除非有理由,否则去掉 .bias=False;输出层上的偏置通常是无害/有帮助的。
  7. 填充/可变长度(可选):如果序列长度不同,考虑使用 pack_padded_sequence 以避免在填充词元上训练。
  8. 监控:在下一词元预测上的准确率比仅看损失更易解释——在验证集上计算 % correct argmax

最小修正代码片段(批次优先版本)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size=vocab_size, hidden_size=50, num_layers=1, batch_first=True)
        self.l_out = nn.Linear(50, vocab_size)

    def forward(self, x):
        # x: (batch, seq, vocab_size)
        x, _ = self.lstm(x)                    # (batch, seq, hidden)
        x = x.reshape(-1, self.lstm.hidden_size)  # (batch*seq, hidden)
        x = self.l_out(x)                      # (batch*seq, vocab)
        return x

# ... 在你的循环中,移除置换;确保 inputs_one_hot 是 (batch, seq, vocab)
# 验证/推理使用 `with torch.no_grad()` 和 `.detach()`

如果你分享 one_hot_encode_sequence 返回的确切形状(例如,(seq, vocab)(1, seq, vocab)),我可以为你的情况精确调整置换/重塑行。


Back Donate