transformer语言翻译!🌞
transformer语言翻译
1.数据来源和处理
1 |
|
2.使用 Transformer 的 Seq2Seq 网络
2.1 位置编码
1 |
|
tips:
1.一旦使用 self.register_buffer('pos_embedding', pos_embedding)
将 pos_embedding
注册为模型的缓冲区,就可以使用 self.pos_embedding
来调用它。
2.为什么要注册到缓冲区?在 PyTorch 中,模型的参数通常是通过 nn.Parameter
对象进行管理的,这些参数会随着模型的训练而更新。然而,并非所有的模型参数都需要在训练过程中被更新。有些参数是在模型的初始化阶段就固定下来的,比如在 Transformer 模型中的位置编码。
对于这种不需要更新的固定参数,将它们注册为模型的缓冲区是一个比较好的做法,这样做有几个好处:
- 状态保存和加载:注册为缓冲区的参数会被包含在模型的状态字典(state_dict)中,因此在保存模型时,这些参数会自动保存。在加载模型时,这些参数也会被自动加载,而不需要额外的处理。
- GPU/CPU 转移:当模型移动到 GPU 或者 CPU 上时,注册为缓冲区的参数也会自动跟着移动,而不需要额外的处理。这样可以使得代码更具有通用性,不需要针对不同设备编写不同的逻辑。
- 代码可读性和可维护性:通过将不需要更新的参数注册为缓冲区,可以更加清晰地表达模型结构。这样可以使得代码更易于理解和维护。
综上所述,将不需要更新的固定参数注册为模型的缓冲区是一种良好的实践,能够提高代码的可读性、可维护性,并且能够自动处理状态保存、加载和设备转移等问题。
2.2 字符编码
1 |
|
tips:
1.为什么要用tokens.long()?在PyTorch中,nn.Embedding
层的输入需要是长整型(LongTensor
)类型的数据。tokens.long()
的作用就是将输入的 tokens
张量中的数据类型转换为长整型。这是因为在实际应用中,tokens 往往是表示词汇表中某个词的索引,索引一般是整数类型,因此需要将其转换为长整型,以便与 nn.Embedding
层兼容。
2.为什么要乘math.sqrt(self.emb_size)?math.sqrt(self.emb_size)
被用来对嵌入向量进行缩放操作,可能是为了控制嵌入向量的数值范围或方差。
2.3 Transformer
1 |
|
tips:
在 Seq2Seq 模型中,通常分为编码器 (Encoder) 和解码器 (Decoder) 两个部分。编码器负责将输入序列编码为一个语义空间中的表示,而解码器则根据这个表示生成输出序列。
虽然在 forward
方法中已经调用了 Transformer 模型进行编码器-解码器的处理,但是在某些情况下,我们可能需要分别对编码器和解码器进行操作,这就是为什么额外定义了 encode
和 decode
方法的原因。
- 编码器操作:
encode
方法允许我们单独对输入序列进行编码,而不需要进行解码器的操作。这在某些情况下是有用的,比如对于一些无需生成输出的任务,只需要输入序列的表示即可。 - 解码器操作:
decode
方法允许我们在给定编码后的记忆的情况下,单独对目标序列进行解码。这在一些场景下也是有用的,比如基于已有的语义表示生成一些补充信息,或者在解码器训练中进行推理。
通过将编码器和解码器的操作分别定义成方法,可以使模型更加灵活,可以根据需求进行更细粒度的操作。
2.4 单词掩码
1 |
|
假如:
1 |
|
则输出结果为:
1 |
|
2.5 实例化
1 |
|
tips:
权重初始化时,这段代码遍历了 Transformer 模型中的所有参数,如果参数的维度大于 1(即不是偏置项),则使用 Xavier 均匀分布进行初始化。Xavier 初始化旨在使得每一层的输出方差保持相等,从而避免梯度消失或爆炸问题,有助于加速模型的收敛。
3.整理
3.1 数据转换
1 |
|
输出:
1 |
|
tips:
如果padded_sequences
中没加 batch_first=True
,则填充过程中张量会发生转置。
举例:
1 |
|
3.2 训练
1 |
|
tips:
1.logits
的shape为torch.Size([23, 128, 10837]),也就是经过一个线性层之后,会输出10837的维度,也就是输出词表的长度。
2.为什么最后一层不用softmax?通常,在使用交叉熵损失函数时,softmax 操作会被包含在损失函数中。因此,在训练过程中,我们通常不需要显式地在模型中加入 softmax 操作。
举例:
1 |
|
在上述示例中,logits
是模型的输出,targets
是目标标签。在计算损失时,我们只需要将 logits 和 targets 传递给交叉熵损失函数 nn.CrossEntropyLoss()
,而不需要显式地在模型中使用 softmax 激活函数。交叉熵损失函数内部将 logits 应用 softmax 操作,并计算损失。
3.3 评估
1 |
|
4.开始
1 |
|