近期发布的LLaMA 4模型引入了混合专家(Mixture of Experts, MoE)架构,旨在提升模型效率和性能。尽管社区对LLaMA 4的实际表现存在一些讨论,但MoE作为一种重要的模型设计范式,继Mistral等模型之后再次受到关注。
所以我们将使用Pytorch逐步从零开始实现一个简化版的LLaMA 4 MoE模型。通过详细的代码实现和解释,我们将深入理解MoE架构的关键组件及其工作原理。为便于跟踪执行流程,代码中将包含必要的打印输出。

以下是在小型英文文本数据集(摘自《爱丽丝梦游仙境》)上训练一个约220万参数的LLaMA MoE模型(使用Colab T4 GPU,训练3000轮)后的生成示例:

 Input: Alice
 
 Output: Alice 'without pictures or conversation?'
 So she was considering in her own mind (as well as she could, for the
 hot day made her feel very sleepy and stupid), whether the pleasure
 of making a daisy-chain wo ...

目录

  • LLaMA 4 MoE 架构概述
  • 环境配置
  • 训练语料定义
  • 字符级分词器实现
  • 语料编码
  • 超参数定义
  • 训练数据准备
  • 批处理策略:随机采样
  • 模型组件初始化- 词元嵌入层- 旋转位置编码(RoPE)预计算- RMSNorm 层初始化- 多头注意力(MHA)层初始化- 混合专家(MoE)层初始化- 最终输出层初始化
  • 因果掩码预计算
  • 训练配置
  • 损失函数定义
  • 模型训练过程
  • 文本生成实现- 文本生成设置- 生成循环详解- 解码生成序列
  • 模型状态保存(可选)
  • 总结

LLaMA 4 MoE 架构概述

混合专家(MoE)架构的核心思想是通过一组专门化的子网络(称为“专家”)来替代传统 Transformer 模型中密集的前馈网络(FFN)层,并引入一个“路由器”来动态地为每个输入词元(token)选择性地激活一部分专家。这种设计允许模型在保持甚至提升性能的同时,显著降低推理时的计算成本,因为并非所有参数都在处理每个词元时被激活。

具体来说,一个 MoE 层通常包含两个关键组件:

  1. 专家网络(Experts): 这是一组并行的、通常结构相同(例如,标准的前馈网络或门控 MLP)但参数独立的神经网络。每个专家可以被视为在处理特定类型信息或模式方面具有专长。
  2. 路由器(Router): 这是一个小型神经网络(通常是一个简单的线性层后接 Softmax 或类似函数),负责接收输入词元的表示(例如,来自前一层的隐藏状态),并计算一个概率分布或分数,决定将该词元路由到哪些专家进行处理。常见的策略是 Top-K 路由,即为每个词元选择得分最高的 K 个专家。

为了更清晰地理解信息在 MoE 架构中的流转,我们以处理句子 "The cat sat" 中的词元 "cat" 为例:

  1. 分词与嵌入: 句子被分词器处理为词元序列 ["The", "cat", "sat"]。词元 "cat" 被映射为其对应的嵌入向量(token embedding)。
  2. 路由决策: 当 "cat" 的嵌入向量(或经过某些变换后的表示)到达 MoE 层时,路由器接收该向量。
  3. 专家选择与加权: 假设该 MoE 层有 4 个专家(E1, E2, E3, E4),并且采用 Top-2 路由。路由器计算每个专家的得分,并选择得分最高的两个专家,例如 E2 和 E4。同时,路由器通常会输出这些选定专家的权重(例如,E2 的权重为 0.7,E4 的权重为 0.3),这些权重通常通过对得分应用 Softmax 或类似归一化函数得到。
  4. 专家计算: "cat" 的向量表示仅被发送给选定的专家 E2 和 E4 进行计算。未被选中的专家 E1 和 E3 在此步骤中保持非激活状态,从而节省了计算量。E2 和 E4 分别处理输入并生成各自的输出向量 Output_E2Output_E4
  5. 输出合并: 将选定专家的输出根据路由器分配的权重进行加权求和:Final_Output = (0.7 * Output_E2) + (0.3 * Output_E4)。这个 Final_Output 即为 MoE 层针对词元 "cat" 的最终输出。

序列中的每个词元都会独立地经历这个路由和计算过程,不同的词元可能被路由到不同的专家组合。

在 LLaMA 4 的整体架构中,MoE 层通常嵌入在标准的 Transformer 块内部,取代了原有的密集 FFN 层。其大致流程如下:

输入文本首先通过

分词器

转换为数值型的

Token ID

序列。这些 ID 经过

嵌入层

映射为稠密的

嵌入向量

。位置信息通过

旋转位置编码 (RoPE)

在后续的自注意力计算中被融入。

嵌入向量随后流经堆叠的

Transformer 块

。每个块内部通常包含:

  • 一个 多头自注意力 (Multi-Head Self-Attention) 机制,允许词元之间交互信息,并通过 RoPE 感知相对位置。
  • 一个 MoE 层(或标准的 FFN 层,取决于具体模型设计),负责对注意力层的输出进行进一步的非线性变换。
  • 层归一化 (Layer Normalization),如 RMSNorm,用于稳定训练。
  • 残差连接 (Residual Connections),将层输入直接加到层输出上,以促进梯度流动和信息传递。

最后一个 Transformer 块的输出经过最终的归一化层后,送入一个

线性输出层

(通常称为语言模型头),该层将最终的隐藏状态映射到词汇表大小的维度,生成

Logits

。这些 Logits 经过 Softmax 函数转换为概率分布,用于预测序列中的下一个词元。

了解了 MoE 在 LLaMA 4 架构中的基本原理和位置后,接下来我们将开始逐步实现这些组件。

环境配置

首先,导入必要的 Python 库并配置计算设备(优先使用 GPU)。

 
# 导入必要的库
importtorch
importtorch.nnasnn
fromtorch.nnimportfunctionalasF
importtorch.optimasoptim
importmath
importos
importcollections# 用于潜在的 BPE 扩展
importre          # 用于初步文本分割

# --- 设备配置 ---
# 设置计算设备:优先使用 CUDA (GPU),否则使用 CPU。
# 确保张量运算在可用硬件上高效执行。
device='cuda'iftorch.cuda.is_available() else'cpu'

print(f"PyTorch version: {torch.__version__}")
print(f"Using device: {device}")
print("Libraries imported and device configured.")

### 输出示例 ###
PyTorchversion: 2.6.0+cu124
Usingdevice: cuda
 Librariesimportedanddeviceconfigured.

输出确认库已成功导入,并显示了所使用的计算设备。本示例将在 Colab T4 GPU 上进行训练。若使用计算资源有限的 GPU,建议适当减少训练轮数。

训练语料定义

语言模型需要文本数据进行训练。实际的 LLaMA 4 模型使用了数万亿级别的词元进行训练。为了便于演示和理解,本示例将使用 Lewis Carroll 的《爱丽丝梦游仙境》中的一小段文本作为训练语料。小规模数据有助于清晰地追踪模型处理的每一步。

 
# 定义原始训练语料
corpus_raw="""
Alice was beginning to get very tired of sitting by her sister on the
bank, and of having nothing to do: once or twice she had peeped into the
book her sister was reading, but it had no pictures or conversations in
it, 'and what is the use of a book,' thought Alice 'without pictures or
conversation?'
So she was considering in her own mind (as well as she could, for the
hot day made her feel very sleepy and stupid), whether the pleasure
of making a daisy-chain would be worth the trouble of getting up and
picking the daisies, when suddenly a White Rabbit with pink eyes ran
close by her.
"""

 print(f"Training corpus defined (length: {len(corpus_raw)} characters).")
 ### 输出 ###
 Training corpus defined (length: 593 characters).

此步骤定义了

corpus_raw

变量,存储了示例文本,并打印了其字符总数。

字符级分词器实现

计算机模型处理的是数字而非原始文本。分词(Tokenization)是将文本转换为模型能够理解的数值表示(词元 ID)的过程。本示例采用最基础的字符级分词方法:

  1. 识别 corpus_raw 中所有唯一的字符。
  2. 为每个唯一字符分配一个唯一的整数 ID。
  3. 创建两个映射字典:char_to_int(字符到 ID,用于编码)和 int_to_char(ID 到字符,用于解码)。
  4. 唯一字符的总数即为词汇表大小(vocab_size)。
 
# 提取语料中的所有唯一字符并排序
chars=sorted(list(set(corpus_raw)))
vocab_size=len(chars)

# 创建字符到整数的映射(编码器)
char_to_int= { ch:ifori,chinenumerate(chars) }

# 创建整数到字符的映射(解码器)
int_to_char= { i:chfori,chinenumerate(chars) }

print(f"Created character vocabulary of size: {vocab_size}")
print(f"Vocabulary: {''.join(chars)}")
# 可选:打印部分映射示例
# print(f"Char-to-Int mapping sample: {{k: char_to_int[k] for k in list(char_to_int)[:5]}}")
 # print(f"Int-to-Char mapping sample: {{k: int_to_char[k] for k in list(int_to_char)[:5]}}")
 ### 输出 ###
 Created character vocabulary of size: 36
 Vocabulary:
  '(),-.:?ARSWabcdefghiklmnoprstuvwy

代码识别出 36 个唯一字符(包括换行符

\n

、空格、标点符号以及大小写字母)。

vocab_size

对后续模型层的定义至关重要。同时创建了用于编码和解码的

char_to_int

int_to_char

字典。

语料编码

利用上一步创建的

char_to_int

映射,将整个

corpus_raw

字符串转换为一个整数 ID 序列。这种数值表示是模型实际训练所使用的形式。为提高计算效率,将该序列存储为 PyTorch 张量。

 
# 将整个语料编码为整数 ID 列表
encoded_corpus= [char_to_int[ch] forchincorpus_raw]

# 转换为 PyTorch 张量,并指定数据类型和设备
full_data_sequence=torch.tensor(encoded_corpus, dtype=torch.long, device=device)

print(f"Encoded corpus into a tensor of shape: {full_data_sequence.shape}")
# 可选:显示前 50 个编码 ID
# print(f"First 50 encoded token IDs: {full_data_sequence[:50].tolist()}")

### 输出 ###
 Encodedcorpusintoatensorofshape: torch.Size([593])

包含 593 个字符的文本被成功转换为一个长度为 593 的 PyTorch 张量。张量中的每个数值代表原始文本中的一个字符。该张量已被放置到先前配置的计算设备上(例如 'cuda')。

超参数定义

在模型构建和训练之前,需要定义一系列超参数。这些参数决定了模型的架构(如大小、层数)和学习过程的配置。对于 LLaMA 4 风格的 MoE 模型,关键超参数包括:

  • d_model: 模型的核心维度,即词元嵌入和内部隐藏状态的大小。
  • n_layers: Transformer 块的堆叠数量。通常层数越多,模型能力越强,但计算成本也越高。
  • n_heads: 多头注意力机制中的注意力头数量。d_model 必须能被 n_heads 整除。
  • block_size: 模型在训练和推理时能够处理的最大输入序列长度(上下文窗口大小)。
  • rms_norm_eps: RMSNorm 计算中使用的一个小常数(epsilon),用于维持数值稳定性。
  • rope_theta: 控制旋转位置编码(RoPE)频率范围的参数。

MoE 相关参数:

  • num_local_experts: 每个 MoE 层包含的专家(Expert MLP)数量。
  • num_experts_per_tok: 路由器为每个输入词元选择激活的专家数量(Top-K 路由)。
  • intermediate_size_expert/shared: 专家 MLP 和共享 MLP 内部隐藏层的大小。

本示例使用的参数值远小于实际的 LLaMA 4,目的是为了能够在标准硬件上快速演示和运行。

 
# --- 模型架构超参数 ---
# vocab_size 已由数据确定
d_model=128         # 嵌入维度 (显著缩小)
n_layers=4          # Transformer 块数量 (缩小)
n_heads=4           # 注意力头数
block_size=64       # 最大上下文长度 (序列长度)
rms_norm_eps=1e-5   # RMSNorm 稳定性 epsilon
rope_theta=10000.0  # RoPE theta 参数 (小于 Llama 4 的 500k)

# --- MoE 专用超参数 ---
num_local_experts=4      # 每层 MoE 的专家数量 (小于 16)
num_experts_per_tok=2   # 每个词元路由到的专家数量 (Top-K, Llama 4 可能不同)
intermediate_size_expert=d_model*2  # 每个专家 MLP 的隐藏维度 (缩小)
intermediate_size_shared=d_model*2  # 共享 MLP 的隐藏维度 (缩小)

# --- 注意力超参数 ---
# d_k (每个头的维度) 由 d_model 和 n_heads 推导得出

# --- 训练超参数 ---
learning_rate=5e-4  # 学习率
batch_size=16       # 每批处理的序列数量
epochs=3000         # 训练迭代次数 (可调整)
eval_interval=300  # 评估/打印损失的频率

# --- 派生超参数 ---
assertd_model%n_heads==0, "d_model 必须能被 n_heads 整除"
d_k=d_model//n_heads# 每个注意力头的 Key/Query/Value 维度
expert_dim=intermediate_size_expert# 别名,便于代码理解
shared_expert_dim=intermediate_size_shared# 别名,便于代码理解

# 打印已定义的超参数概览
print("--- Defined Hyperparameters ---")
print(f"Vocabulary Size (vocab_size): {vocab_size}")
print(f"Embedding Dimension (d_model): {d_model}")
print(f"Number of Layers (n_layers): {n_layers}")
print(f"Number of Attention Heads (n_heads): {n_heads}")
print(f"Dimension per Head (d_k): {d_k}")
print(f"Max Sequence Length (block_size): {block_size}")
print(f"RMSNorm Epsilon (rms_norm_eps): {rms_norm_eps}")
print(f"RoPE Theta (rope_theta): {rope_theta}")
print("\n--- MoE Specific ---")
print(f"Number of Experts per Layer (num_local_experts): {num_local_experts}")
print(f"Experts per Token (num_experts_per_tok): {num_experts_per_tok}")
print(f"Expert Hidden Dimension (expert_dim): {expert_dim}")
print(f"Shared MLP Hidden Dimension (shared_expert_dim): {shared_expert_dim}")
print("\n--- Training Specific ---")
print(f"Learning Rate: {learning_rate}")
print(f"Batch Size: {batch_size}")
 print(f"Epochs: {epochs}")
 ### 输出 ###
--- Defined Hyperparameters ---
Vocabulary Size (vocab_size): 36
Embedding Dimension (d_model): 128
Number of Layers (n_layers): 4
Number of Attention Heads (n_heads): 4
Dimension per Head (d_k): 32
Max Sequence Length (block_size): 64
RMSNorm Epsilon (rms_norm_eps): 1e-05
RoPE Theta (rope_theta): 10000.0

--- MoE Specific ---
Number of Experts per Layer (num_local_experts): 4
Experts per Token (num_experts_per_tok): 2
Expert Hidden Dimension (expert_dim): 256
Shared MLP Hidden Dimension (shared_expert_dim): 256

--- Training Specific ---
Learning Rate: 0.0005
Batch Size: 16
 Epochs: 3000

输出清晰地列出了模型架构和训练过程的所有配置参数,包括模型维度 (

d_model=128

)、MoE 配置(4 个专家,每次选择 2 个)、上下文窗口 (

block_size=64

) 以及训练参数(学习率、批大小、训练轮数)。

训练数据准备

语言模型通常通过预测序列中的下一个词元来进行学习。给定一段前面的词元,模型需要预测紧随其后的词元。为了准备训练数据,我们在

full_data_sequence

上滑动一个长度为

block_size

的窗口:

  1. 输入序列 (x) 是长度为 block_size 的词元片段。
  2. 目标序列 (y) 是与 x 对应的、向右平移一位的相同长度片段。
  3. 因此,对于输入 x 中的第 i 个词元,模型的目标是预测目标 y 中第 i 个位置的词元(即 x 中第 i+1 个词元)。

我们从语料库中提取所有可能的、长度为

block_size

的重叠输入-目标对。

 
# 创建输入 (x) 和目标 (y) 序列列表
all_x= []
all_y= []

# 遍历编码后的语料张量,提取重叠序列
num_total_tokens=len(full_data_sequence)
foriinrange(num_total_tokens-block_size):
    # 提取输入序列片段
    x_chunk=full_data_sequence[i : i+block_size]
    # 提取目标序列片段 (向右平移一位)
    y_chunk=full_data_sequence[i+1 : i+block_size+1]
    all_x.append(x_chunk)
    all_y.append(y_chunk)

# 将列表堆叠成大的张量
train_x=torch.stack(all_x)
train_y=torch.stack(all_y)

num_sequences_available=train_x.shape[0]
print(f"Created {num_sequences_available} overlapping input/target sequence pairs.")
print(f"Shape of train_x: {train_x.shape}") # 预期形状: (num_sequences, block_size)
print(f"Shape of train_y: {train_y.shape}") # 预期形状: (num_sequences, block_size)

# 可选:验证张量所在的设备
 # print(f"train_x is on device: {train_x.device}") # 可能仍在 CPU,训练时会移动到 GPU
 ### 输出 ###
 Created 529 overlapping input/target sequence pairs.
 Shape of train_x: torch.Size([529, 64])
 Shape of train_y: torch.Size([529, 64])

从 593 个字符的语料中,成功提取了 529 个长度为

64

(

block_size

) 的重叠序列对。输出确认

train_x

(输入) 和

train_y

(目标) 均为包含 529 个序列、每个序列 64 个词元 ID 的张量。这些张量目前可能位于 CPU 上,在训练循环中,每个批次的数据将被传输到目标计算设备(如 GPU)。

批处理策略:随机采样

一次性加载全部数据进行训练通常会导致内存不足。因此,采用小批量(mini-batch)训练策略。本示例使用常见的随机采样方法:在每次训练迭代中,从所有可用的序列对索引(

0

num_sequences_available - 1

)中随机选择

batch_size

个索引,然后提取对应的输入 (

xb

) 和目标 (

yb

) 批次。这些选取的批次数据随后会被传输到指定的计算设备(

device

)上,供模型进行处理。

 
# 检查可用序列数量是否小于批大小
ifnum_sequences_available<batch_size:
    print(f"Warning: Number of sequences ({num_sequences_available}) is less than batch size ({batch_size}). Adjusting batch size.")
    batch_size=num_sequences_available# 如果序列不足,调整批大小

print(f"Data ready for training. Will sample batches of size {batch_size} randomly.")
print("Batches will be moved to the configured device during the training loop.")
# 训练循环中选择批次的示例逻辑:
# indices = torch.randint(0, num_sequences_available, (batch_size,))
# xb = train_x[indices].to(device)
 # yb = train_y[indices].to(device)
 ### 输出 ###
 Data ready for training. Will sample batches of size 16 randomly.
 Batches will be moved to the configured device during the training loop.

输出确认数据已准备就绪,可用序列数量(529)足以支持设定的批大小(16)。训练过程中将随机抽取包含 16 对输入/目标序列的批次,并将其发送到计算设备。

模型组件初始化

词元嵌入层 (Token Embedding Layer)

这是模型的第一层,负责将输入的整数词元 ID(如

train_x

中的值)转换为维度为

d_model

的稠密向量表示(嵌入)。可以将其视为一个查找表,其中每个词元 ID 对应一个唯一的向量。这些嵌入向量捕捉了词元的初始语义信息,模型将在训练过程中学习和优化这些向量。

输入形状:

(Batch, SequenceLength)

-> 输出形状:

(Batch, SequenceLength, d_model)

 
# 初始化词元嵌入表
token_embedding_table=nn.Embedding(vocab_size, d_model).to(device)

print(f"Initialized Token Embedding Layer:")
print(f"  Input Vocab Size: {vocab_size}")
print(f"  Output Embedding Dim (d_model): {d_model}")
print(f"  Weight shape: {token_embedding_table.weight.shape}") # 形状应为 (vocab_size, d_model)
 print(f"  Device: {token_embedding_table.weight.device}")
 ### 输出 ###
 Initialized Token Embedding Layer:
   Input Vocab Size: 36
   Output Embedding Dim (d_model): 128
   Weight shape: torch.Size([36, 128])
   Device: cuda:0

代码创建了一个

nn.Embedding

层。输出显示其配置正确:输入词汇表大小为 36,输出嵌入向量维度为

d_model

(128)。权重张量的形状

[36, 128]

确认了查找表的大小。该层已被放置在配置的 GPU 设备上 (

cuda:0

)。

旋转位置编码(RoPE)预计算

标准 Transformer 架构本身不直接处理序列中词元的顺序信息。位置编码(Positional Encoding)的目的是向模型注入这种顺序信息。

旋转位置编码(Rotary Positional Embedding, RoPE)是 LLaMA 等模型采用的一种高效的位置编码方法。它不直接将位置向量加到词元嵌入上,而是在注意力机制计算过程中,根据词元的位置对其 Query (Q) 和 Key (K) 向量的部分维度进行“旋转”。

旋转的角度取决于词元的位置和一组预先计算的频率。这些频率由

rope_theta

超参数推导得出。在此步骤中,我们预先计算这些频率的倒数 (

inv_freq

),这是一个常量。实际的旋转操作(通过复数表示

freqs_cis

)将在模型的前向传播过程中根据当前输入序列的长度动态计算。

 
# 预计算 RoPE 的逆频率 (inverse frequencies)
# 公式: 1.0 / (rope_theta ** (torch.arange(0, d_k, 2) / d_k))
# RoPE 应用于每个注意力头的 d_k 维度的一半
rope_dimension_indices=torch.arange(0, d_k, 2, dtype=torch.float, device=device) # 取偶数索引
inv_freq=1.0/ (rope_theta** (rope_dimension_indices/d_k))

print("Precomputed RoPE inverse frequencies (inv_freq):")
print(f"  Shape: {inv_freq.shape}") # 预期形状: (d_k / 2,)
print(f"  Values (first 5): {inv_freq[:5].tolist()}")
print(f"  Device: {inv_freq.device}")
 # 'freqs_cis' (复数形式的频率) 将在前向传播时使用这些 inv_freq 和位置 ID 计算
 ### 输出 ###
 Precomputed RoPE inverse frequencies (inv_freq):
   Shape: torch.Size([16])
   Values (first 5): [1.0, 0.5623413324356079, 0.3162277638912201, 0.17782793939113617, 0.10000000149011612]
   Device: cuda:0

此步骤计算并存储了

inv_freq

张量。由于每个注意力头的维度

d_k

为 32,且 RoPE 按维度对进行操作,因此

inv_freq

的形状为

(16,)

(即

d_k / 2

)。这些值代表了旋转的基础频率。在后续的前向传播中,将结合词元的位置信息使用这个

inv_freq

张量来计算实际的旋转角度(表示为

freqs_cis

)。

RMSNorm 层初始化

归一化层有助于稳定训练过程。LLaMA 架构采用了 RMSNorm(Root Mean Square Normalization),这是一种比标准 Layer Normalization 更简洁、计算效率更高的归一化方法。

RMSNorm 通过输入向量元素的均方根值对其进行归一化,然后使用一个可学习的缩放参数

gamma

(权重)进行调整。与 LayerNorm 不同,RMSNorm 通常不包含可学习的偏置参数

beta

在 Transformer 块中,通常在自注意力层之前和 MoE/FFN 层之前各应用一次 RMSNorm。此外,在最终输出层之前也需要进行一次归一化。

这里我们仅初始化可学习的

gamma

权重(作为

nn.Parameter

)。实际的 RMSNorm 计算逻辑将在模型的前向传播中实现。

 
# 用于存储每个 Transformer 块的 RMSNorm 权重
rmsnorm_weights_input= []      # 注意力层输入的 RMSNorm 权重
rmsnorm_weights_post_attn= []  # MoE/FFN 层输入的 RMSNorm 权重

print(f"Initializing RMSNorm weights for {n_layers} layers...")
foriinrange(n_layers):
    # 注意力层输入的 RMSNorm 权重 (gamma)
    # 初始化为 1,类似于 nn.LayerNorm 的默认 gamma 值
    weight_in=nn.Parameter(torch.ones(d_model, device=device))
    rmsnorm_weights_input.append(weight_in)

    # MoE/FFN 层输入的 RMSNorm 权重 (gamma)
    weight_post=nn.Parameter(torch.ones(d_model, device=device))
    rmsnorm_weights_post_attn.append(weight_post)
    print(f"  Initialized RMSNorm weights for Layer {i+1} (Input: {weight_in.shape}, PostAttn: {weight_post.shape})")

# 模型最终输出层之前的 RMSNorm 权重
final_rmsnorm_weight=nn.Parameter(torch.ones(d_model, device=device))

print(f"Initialized Final RMSNorm weight, shape: {final_rmsnorm_weight.shape}")
 print("RMSNorm weights initialized (as nn.Parameter). The normalization logic will be implemented inline during forward pass.")
 ### 输出 ###
 Initializing RMSNorm weights for 4 layers...
   Initialized RMSNorm weights for Layer 1 (Input: torch.Size([128]), PostAttn: torch.Size([128]))
   Initialized RMSNorm weights for Layer 2 (Input: torch.Size([128]), PostAttn: torch.Size([128]))
   Initialized RMSNorm weights for Layer 3 (Input: torch.Size([128]), PostAttn: torch.Size([128]))
   Initialized RMSNorm weights for Layer 4 (Input: torch.Size([128]), PostAttn: torch.Size([128]))
 Initialized Final RMSNorm weight, shape: torch.Size([128])
 RMSNorm weights initialized (as nn.Parameter). The normalization logic will be implemented inline during forward pass.

代码为所有需要的 RMSNorm 操作创建了可学习的

gamma

权重参数。对于

n_layers

(4) 个 Transformer 块中的每一层,分别初始化了注意力层之前 (

rmsnorm_weights_input

) 和 MoE 层之前 (

rmsnorm_weights_post_attn

) 使用的权重。同时,也初始化了最终输出层之前的

final_rmsnorm_weight

。每个权重都是一个维度为

d_model

(128) 的

nn.Parameter

张量,初始值设为 1。实际的 RMSNorm 计算将在前向传播过程中结合这些权重进行。

多头注意力(MHA)层初始化

自注意力机制是 Transformer 架构的核心。本实现采用多头注意力(Multi-Head Attention, MHA)。在每个 Transformer 块中,需要线性投影层将输入的隐藏状态向量转换到 Query (Q)、Key (K) 和 Value (V) 空间。

  1. QKV 投影: 通常实现为一个大的线性层,将维度为 d_model 的输入投影到一个维度为 3 * d_model 的合并空间,然后分割成 Q, K, V。
  2. 输出投影: 在注意力计算完成后,使用另一个线性层将 MHA 的输出结果(通常拼接了所有头的输出)映射回原始的 d_model 维度。

我们为

n_layers

个 Transformer 块分别初始化这些

nn.Linear

层。在大型模型中,这些投影层通常不使用偏置项 (

bias=False

)。

 
# 用于存储每个 Transformer 块的注意力层
mha_qkv_linears= []    # Q, K, V 合并投影层列表
mha_output_linears= [] # MHA 输出投影层列表

print(f"Initializing Attention (MHA) linear layers for {n_layers} layers...")
foriinrange(n_layers):
    # QKV 合并线性投影层
    # 大型模型通常不使用偏置
    qkv_linear=nn.Linear(d_model, 3*d_model, bias=False).to(device)
    mha_qkv_linears.append(qkv_linear)

    # 输出线性投影层
    # 这里也可以不使用偏置
    output_linear=nn.Linear(d_model, d_model, bias=False).to(device)
    mha_output_linears.append(output_linear)
    print(f"  Initialized MHA Linears for Layer {i+1} (QKV weight: {qkv_linear.weight.shape}, Out weight: {output_linear.weight.shape})")

 print("Attention (MHA) linear layers initialized.")
 ### 输出 ###
 Initializing Attention (MHA) linear layers for 4 layers...
   Initialized MHA Linears for Layer 1 (QKV weight: torch.Size([384, 128]), Out weight: torch.Size([128, 128]))
   Initialized MHA Linears for Layer 2 (QKV weight: torch.Size([384, 128]), Out weight: torch.Size([128, 128]))
   Initialized MHA Linears for Layer 3 (QKV weight: torch.Size([384, 128]), Out weight: torch.Size([128, 128]))
   Initialized MHA Linears for Layer 4 (QKV weight: torch.Size([384, 128]), Out weight: torch.Size([128, 128]))
 Attention (MHA) linear layers initialized.

此步骤为 4 个 Transformer 块中的每一层都设置了注意力机制所需的线性层。每层包含:

  • qkv_linear: 将 d_model (128) 维输入映射到 3 * d_model (384) 维输出,其权重形状为 [384, 128]
  • output_linear: 将 MHA 的 d_model (128) 维输出映射回 d_model (128) 维,其权重形状为 [128, 128]

这些层被存储在列表中,以便在前向传播过程中按层调用。

混合专家(MoE)层初始化

这是 MoE 架构的核心部分。在标准的 Transformer 块中,注意力层之后通常是一个前馈网络(FFN)。在 MoE 架构中,这个 FFN 被替换为一个 MoE 层。每个 MoE 层包含以下组件:

  • 路由器 (Router): 一个简单的线性层,接收经过归一化的词元隐藏状态(维度 d_model),并为每个可用专家生成一个分数(logit),用于决定将该词元路由到哪些专家。
  • 专家 (Experts): 一组 (num_local_experts 个) 独立的、通常较小的 MLP。每个专家本身常采用门控 MLP(Gated MLP)结构,类似于 LLaMA 标准 FFN 的设计:包含并行的“门控”(Gate)和“上投影”(Up-projection)线性层,应用激活函数(如 SiLU/Swish),将门控输出与上投影输出逐元素相乘,最后通过一个“下投影”(Down-projection)线性层将维度映射回 d_model。在本实现中,为了效率,专家的权重直接存储为 nn.Parameter 张量,而不是 nn.Module 列表。
  • 共享专家 (Shared Expert): 一个标准的门控 MLP,其结构与单个专家类似。与 MoE 专家不同,所有词元都会通过这个共享专家进行处理。其输出会与选定 MoE 专家的加权输出相加。

路由器的输出决定了每个词元将被发送到哪

num_experts_per_tok

个专家(Top-K 路由)。这些选定专家的输出根据路由器给出的权重进行加权组合,然后与共享专家的输出合并。

 
# 用于存储每层 MoE 的组件
moe_routers= []             # Router 线性层列表
moe_expert_gate_up_proj= [] # 专家 Gate/Up 投影权重列表 (Parameter)
moe_expert_down_proj= []    # 专家 Down 投影权重列表 (Parameter)
shared_expert_gate_proj= [] # 共享专家 Gate 投影层列表 (nn.Linear)
shared_expert_up_proj= []   # 共享专家 Up 投影层列表 (nn.Linear)
shared_expert_down_proj= [] # 共享专家 Down 投影层列表 (nn.Linear)

print(f"Initializing MoE and Shared MLP components for {n_layers} layers...")
print(f"  Num Experts per layer: {num_local_experts}")
print(f"  Expert Intermediate Dim: {expert_dim}")
print(f"  Shared MLP Intermediate Dim: {shared_expert_dim}")

foriinrange(n_layers):
    # 1. Router 初始化
    router_linear=nn.Linear(d_model, num_local_experts, bias=False).to(device)
    moe_routers.append(router_linear)

    # 2. 专家权重初始化 (使用 nn.Parameter)
    # Gate 和 Up 投影权重合并存储: (num_experts, d_model, 2 * expert_dim)
    # 维度解释: [专家数量, 输入维度, 输出维度(Gate) + 输出维度(Up)]
    gate_up_w=nn.Parameter(torch.empty(num_local_experts, d_model, 2*expert_dim, device=device))
    nn.init.normal_(gate_up_w, mean=0.0, std=0.02) # 使用正态分布初始化
    moe_expert_gate_up_proj.append(gate_up_w)

    # Down 投影权重: (num_experts, expert_dim, d_model)
    # 维度解释: [专家数量, 输入维度(来自Gate*Up), 输出维度]
    down_w=nn.Parameter(torch.empty(num_local_experts, expert_dim, d_model, device=device))
    nn.init.normal_(down_w, mean=0.0, std=0.02) # 使用正态分布初始化
    moe_expert_down_proj.append(down_w)

    # 3. 共享专家 MLP 初始化 (使用 nn.Linear)
    shared_gate=nn.Linear(d_model, shared_expert_dim, bias=False).to(device)
    shared_up=nn.Linear(d_model, shared_expert_dim, bias=False).to(device)
    shared_down=nn.Linear(shared_expert_dim, d_model, bias=False).to(device)
    shared_expert_gate_proj.append(shared_gate)
    shared_expert_up_proj.append(shared_up)
    shared_expert_down_proj.append(shared_down)

    print(f"  Initialized MoE components for Layer {i+1}:")
    print(f"    Router weights shape: {router_linear.weight.shape}")
    print(f"    Expert Gate/Up weights shape: {gate_up_w.shape}")
    print(f"    Expert Down weights shape: {down_w.shape}")
    print(f"    Shared Gate weights shape: {shared_gate.weight.shape}")
    print(f"    Shared Up weights shape: {shared_up.weight.shape}")
    print(f"    Shared Down weights shape: {shared_down.weight.shape}")

print("MoE and Shared MLP components initialized.")
# 定义激活函数 (将在前向传播中内联使用)
 activation_fn=nn.SiLU()
 ### 输出 ###
Initializing MoE and Shared MLP components for 4 layers...
  Num Experts per layer: 4
  Expert Intermediate Dim: 256
  Shared MLP Intermediate Dim: 256
  Initialized MoE components for Layer 1:
    Router weights shape: torch.Size([4, 128])
    Expert Gate/Up weights shape: torch.Size([4, 128, 512])
    Expert Down weights shape: torch.Size([4, 256, 128])
    Shared Gate weights shape: torch.Size([256, 128])
    Shared Up weights shape: torch.Size([256, 128])
    Shared Down weights shape: torch.Size([128, 256])
  ... (Layer 2, 3, 4 的输出类似) ...
 MoE and Shared MLP components initialized.

输出显示了 4 个 MoE 层中每一层组件的初始化情况:

  • Router weights: 线性层权重,将 d_model (128) 映射到专家数量 (4),形状为 [4, 128]
  • Expert Gate/Up weights: 单个 nn.Parameter 张量,存储了所有 4 个专家的 Gate 和 Up 投影权重,形状为 [4, 128, 512]512 来自 2 * expert_dim (2 * 256)。
  • Expert Down weights: 单个 nn.Parameter 张量,存储了所有 4 个专家的 Down 投影权重,形状为 [4, 256, 128]
  • Shared Gate/Up/Down weights: 共享专家 MLP 的标准线性层权重,其形状根据 d_model (128) 和 shared_expert_dim (256) 确定。

这些组件被存储在相应的列表中,以便在前向传播过程中实现复杂的 MoE 路由和计算逻辑。激活函数选用

SiLU

最终输出层初始化

在经过所有 Transformer 层的处理后,模型得到最终的隐藏状态序列。在进行最后一次 RMSNorm 归一化之后,需要将这些隐藏状态转换为对下一个词元的预测。

最终输出层是一个线性层,它将序列中每个位置的

d_model

维向量映射到一个

vocab_size

维的向量。这个输出向量中的每个元素代表了词汇表中对应词元作为下一个词元的原始分数(logit)。

 
# 最终线性层 (也称为语言模型头, Language Modeling Head)
output_linear_layer=nn.Linear(d_model, vocab_size, bias=False).to(device)

print(f"Initialized Final Output Linear Layer:")
print(f"  Input Dim (d_model): {d_model}")
print(f"  Output Dim (vocab_size): {vocab_size}")
print(f"  Weight shape: {output_linear_layer.weight.shape}") # 预期形状: (vocab_size, d_model)
 print(f"  Device: {output_linear_layer.weight.device}")
 ### 输出 ###
 Initialized Final Output Linear Layer:
   Input Dim (d_model): 128
   Output Dim (vocab_size): 36
   Weight shape: torch.Size([36, 128])
   Device: cuda:0

代码初始化了最终的

nn.Linear

层。其输入维度为

d_model

(128),输出维度为

vocab_size

(36),权重形状为

[36, 128]

因果掩码预计算

在像 LLaMA 这样的自回归(decoder-only)Transformer 模型中,预测位置

t

的词元时,模型只能关注(attend to)从位置

0

t

(包含自身)的词元,而不能“看到”未来的词元(位置

t+1

,

t+2

, ...)。

这种限制通过因果掩码(Causal Mask)在自注意力计算中实现。我们创建一个下三角矩阵(尺寸为

block_size x block_size

),其中允许关注的位置(包括当前位置和之前的位置)对应的值为 1(或 0,取决于实现方式),不允许关注的未来位置对应的值为 0(或负无穷大)。

这个掩码在计算注意力权重(softmax 之前)时应用,将未来位置的注意力得分设置为一个极小的值(如负无穷大),从而使其在 softmax 后的权重接近于零。我们为最大序列长度

block_size

预先计算并存储这个掩码。

 
# 创建用于因果自注意力的下三角掩码
# 值为 1 的位置表示允许关注,值为 0 的位置表示禁止关注 (未来位置)
# 形状调整为 (1, 1, block_size, block_size) 以便与注意力分数张量 (B, n_heads, T, T) 进行广播
causal_mask=torch.tril(torch.ones(block_size, block_size, dtype=torch.bool, device=device))
causal_mask=causal_mask.view(1, 1, block_size, block_size)

print("Precomputed Causal Attention Mask:")
print(f"  Shape: {causal_mask.shape}")
print(f"  Requires grad: {causal_mask.requires_grad}") # 掩码是常量,不需要梯度
# 可选:可视化小 block_size 的掩码内容
# if block_size <= 8:
 #    print(causal_mask[0, 0].cpu().numpy().astype(int))
 ### 输出 ###
 Precomputed Causal Attention Mask:
   Shape: torch.Size([1, 1, 64, 64])
   Requires grad: False

代码创建了

causal_mask

张量。这是一个布尔类型的下三角矩阵(包括对角线),值为

True

的位置表示允许注意力计算,

False

表示禁止。其形状被调整为

[1, 1, 64, 64]

,以便能与注意力得分张量(形状通常为

[Batch, n_heads, SeqLen, SeqLen]

)进行广播操作。由于掩码是固定的,它不需要计算梯度。

训练配置

优化器 (Optimizer) 是根据反向传播计算出的梯度来更新模型参数(权重和偏置)的算法。本示例选用

AdamW

优化器,它是 Adam 优化器的一个变种,常用于 Transformer 模型的训练,并通常表现良好。

在定义优化器之前,需要收集模型中所有需要训练(即

requires_grad=True

)的参数。这包括:

  • 词元嵌入表 (token_embedding_table)
  • 所有 nn.Linear 层的权重(可能还有偏置,如果 bias=True):QKV 投影、输出投影、MoE 路由器、共享专家 MLP 层、最终输出层。
  • 所有 nn.Parameter 定义的权重:RMSNorm 的 gamma 权重、MoE 专家的 Gate/Up 和 Down 投影权重。
 
# 收集模型中所有需要计算梯度的参数
all_model_parameters=list(token_embedding_table.parameters())

# 添加 RMSNorm 权重 (nn.Parameter)
all_model_parameters.extend(rmsnorm_weights_input)
all_model_parameters.extend(rmsnorm_weights_post_attn)
all_model_parameters.append(final_rmsnorm_weight)

# 添加注意力线性层参数 (nn.Linear)
foriinrange(n_layers):
    all_model_parameters.extend(list(mha_qkv_linears[i].parameters()))
    all_model_parameters.extend(list(mha_output_linears[i].parameters()))

# 添加 MoE 路由器参数 (nn.Linear)
foriinrange(n_layers):
    all_model_parameters.extend(list(moe_routers[i].parameters()))

# 添加 MoE 专家权重 (nn.Parameter)
all_model_parameters.extend(moe_expert_gate_up_proj)
all_model_parameters.extend(moe_expert_down_proj)

# 添加共享专家 MLP 参数 (nn.Linear)
foriinrange(n_layers):
    all_model_parameters.extend(list(shared_expert_gate_proj[i].parameters()))
    all_model_parameters.extend(list(shared_expert_up_proj[i].parameters()))
    all_model_parameters.extend(list(shared_expert_down_proj[i].parameters()))

# 添加最终输出线性层参数 (nn.Linear)
all_model_parameters.extend(list(output_linear_layer.parameters()))

# 统计参数组数量和总参数量
num_param_groups=len(all_model_parameters)
total_params=sum(p.numel() forpinall_model_parametersifp.requires_grad)

# 定义 AdamW 优化器
optimizer=optim.AdamW(all_model_parameters, lr=learning_rate)

print("Optimizer Setup:")
print(f"  Optimizer Type: {type(optimizer).__name__}")
print(f"  Learning Rate: {learning_rate}")
print(f"  Number of Parameter Groups/Tensors Managed: {num_param_groups}")
 print(f"  Total Trainable Parameters: {total_params:,}") # 使用逗号分隔符格式化输出
 ### 输出 ###
 Optimizer Setup:
   Optimizer Type: AdamW
   Learning Rate: 0.0005
   Number of Parameter Groups/Tensors Managed: 43
   Total Trainable Parameters: 2,240,640

代码成功收集了模型中所有可训练的参数(共 43 个独立的权重/偏置张量或

nn.Parameter

对象),并使用指定的学习率

0.0005

初始化了

AdamW

优化器。同时,计算并打印了模型的总可训练参数量,约为 224 万。这远小于实际生产环境中的大型语言模型。

损失函数定义

为了训练模型,需要一个衡量标准来量化模型预测与真实目标词元之间的差异或“误差”。由于下一个词元预测本质上是一个分类问题(从词汇表中选择一个类别/字符),标准的损失函数是交叉熵损失(Cross-Entropy Loss)

PyTorch 提供了

nn.CrossEntropyLoss

,它结合了

LogSoftmax

NLLLoss

(负对数似然损失),可以直接处理模型输出的原始 logits 和整数形式的目标词元 ID,计算出损失值。

 
 # 定义损失函数
 # nn.CrossEntropyLoss 适用于多分类问题,内部处理 softmax 和 NLL loss
 criterion=nn.CrossEntropyLoss()
 
 print(f"Loss function defined: {type(criterion).__name__}")
 ### 输出 ###
 Loss function defined: CrossEntropyLoss

此步骤初始化了

nn.CrossEntropyLoss

。在训练循环中,将使用这个

criterion

对象来计算每个批次的损失值。

模型训练过程

现在,所有组件都已准备就绪,可以开始训练模型了。训练循环通常包含以下步骤:

  1. 迭代: 循环指定的 epochs 次数。
  2. 采样批次: 从 train_xtrain_y 中随机抽取一个批次的输入 (xb) 和目标 (yb),并将它们移动到计算设备 (device)。
  3. 前向传播: 将输入批次 xb 传递给模型,计算得到预测的 logits。这一步涉及所有先前定义的层:嵌入、RMSNorm、多头注意力(含 RoPE)、MoE 层(含路由和共享专家)、最终 RMSNorm 和输出线性层。
  4. 计算损失: 使用 criterion (交叉熵损失) 计算模型输出的 logits 与真实目标 yb 之间的损失。注意,需要调整 logits 和目标的形状以符合 CrossEntropyLoss 的要求(通常 logits 为 (Batch * SeqLen, VocabSize),目标为 (Batch * SeqLen))。
  5. 反向传播: 计算损失相对于所有模型参数的梯度。
  6. 参数更新: 使用 optimizer 根据计算出的梯度更新模型参数。
  7. 梯度清零: 清除优化器中累积的梯度,为下一次迭代做准备。
  8. (可选)评估与记录: 定期(例如每 eval_interval 轮)打印当前的训练损失,以监控训练进度。

注意: 下面的代码块将实现完整的模型前向传播逻辑,这在之前的组件初始化部分并未完全展示。它将整合所有已初始化的组件。

 
importtime

# --- 训练循环 ---
print("Starting training...")
start_time=time.time()

forepochinrange(epochs):
    # 1. 采样批次
    indices=torch.randint(0, num_sequences_available, (batch_size,))
    xb=train_x[indices].to(device) # (B, T)
    yb=train_y[indices].to(device) # (B, T)

    # --- 2. 前向传播 ---
    B, T=xb.shape# Batch size, Sequence length

    # 2.1 词元嵌入
    h=token_embedding_table(xb) # (B, T, d_model)

    # 2.2 计算 RoPE 频率 (freqs_cis) - 每次迭代根据当前序列长度 T 计算
    pos_indices=torch.arange(T, device=device) # (T,)
    # 外积得到 (T, d_k/2)
    freqs=torch.outer(pos_indices, inv_freq)
    # 转换为复数形式 (T, d_k/2)
    freqs_cis=torch.polar(torch.ones_like(freqs), freqs)

    # 2.3 遍历 Transformer 块
    foriinrange(n_layers):
        # --- 残差连接起点 ---
        residual_connection=h

        # --- a) RMSNorm + 多头注意力 ---
        # RMSNorm
        h_norm=h*torch.rsqrt(h.pow(2).mean(-1, keepdim=True) +rms_norm_eps) # (B, T, d_model)
        h_norm=h_norm*rmsnorm_weights_input[i] # 应用 gamma 缩放

        # QKV 投影
        qkv=mha_qkv_linears[i](h_norm) # (B, T, 3 * d_model)
        q, k, v=qkv.split(d_model, dim=-1) # 分割为 (B, T, d_model)

        # 调整形状以适应多头: (B, T, d_model) -> (B, n_heads, T, d_k)
        q=q.view(B, T, n_heads, d_k).transpose(1, 2) # (B, n_heads, T, d_k)
        k=k.view(B, T, n_heads, d_k).transpose(1, 2) # (B, n_heads, T, d_k)
        v=v.view(B, T, n_heads, d_k).transpose(1, 2) # (B, n_heads, T, d_k)

        # 应用 RoPE
        # 将 q, k 视为复数 (..., d_k/2, 2) -> (..., d_k/2)
        q_rope=torch.view_as_complex(q.float().reshape(B, n_heads, T, -1, 2))
        k_rope=torch.view_as_complex(k.float().reshape(B, n_heads, T, -1, 2))
        # 调整 freqs_cis 形状以广播: (T, d_k/2) -> (1, 1, T, d_k/2)
        freqs_cis_broadcast=freqs_cis.unsqueeze(0).unsqueeze(0)
        # 复数乘法实现旋转
        q_out=torch.view_as_real(q_rope*freqs_cis_broadcast).flatten(3) # (B, n_heads, T, d_k)
        k_out=torch.view_as_real(k_rope*freqs_cis_broadcast).flatten(3) # (B, n_heads, T, d_k)
        q, k=q_out.type_as(q), k_out.type_as(k) # 转换回原始类型

        # 计算注意力分数 (Scaled Dot-Product Attention)
        scores=torch.matmul(q, k.transpose(-2, -1)) * (d_k**-0.5) # (B, n_heads, T, T)
        # 应用因果掩码
        scores=scores.masked_fill(causal_mask[:,:,:T,:T] ==0, float('-inf')) # 使用当前序列长度 T
        attn_weights=F.softmax(scores, dim=-1) # (B, n_heads, T, T)
        # 计算注意力输出
        attn_output=torch.matmul(attn_weights, v) # (B, n_heads, T, d_k)

        # 重新组合多头输出
        attn_output=attn_output.transpose(1, 2).contiguous().view(B, T, d_model) # (B, T, d_model)

        # 输出投影
        attn_output=mha_output_linears[i](attn_output) # (B, T, d_model)

        # --- 第一个残差连接 ---
        h=residual_connection+attn_output

        # --- 残差连接起点 ---
        residual_connection_ffn=h

        # --- b) RMSNorm + MoE 层 (含共享专家) ---
        # RMSNorm
        h_norm=h*torch.rsqrt(h.pow(2).mean(-1, keepdim=True) +rms_norm_eps)
        h_norm=h_norm*rmsnorm_weights_post_attn[i] # 应用 gamma 缩放

        # --- MoE 计算 ---
        # Router Logits
        router_logits=moe_routers[i](h_norm) # (B, T, num_local_experts)

        # Top-K 路由
        routing_weights, selected_experts=torch.topk(router_logits, num_experts_per_tok, dim=-1) # (B, T, k), (B, T, k)
        routing_weights=F.softmax(routing_weights, dim=-1, dtype=torch.float).to(h_norm.dtype) # (B, T, k)

        # 初始化最终输出张量
        final_hidden_states=torch.zeros_like(h_norm) # (B, T, d_model)

        # 扁平化以进行批处理专家计算
        flat_hidden_states=h_norm.view(-1, d_model) # (B*T, d_model)
        flat_routing_weights=routing_weights.view(-1, num_experts_per_tok) # (B*T, k)
        flat_selected_experts=selected_experts.view(-1, num_experts_per_tok) # (B*T, k)

        # 遍历每个词元,计算其选定专家的输出
        fork_idxinrange(num_experts_per_tok):
            expert_indices=flat_selected_experts[:, k_idx] # (B*T,)
            current_routing_weights=flat_routing_weights[:, k_idx].unsqueeze(-1) # (B*T, 1)

            # 使用 gather/index_select 或更高效的方法收集需要计算的隐藏状态和专家权重
            # 这里为了清晰,使用循环和掩码 (效率较低)
            expert_outputs=torch.zeros_like(flat_hidden_states) # (B*T, d_model)
            forexp_idinrange(num_local_experts):
                token_mask= (expert_indices==exp_id) # (B*T,)
                iftoken_mask.any():
                    tokens_for_expert=flat_hidden_states[token_mask] # (N_tokens, d_model)

                    # 获取该专家的权重
                    gate_up_w_expert=moe_expert_gate_up_proj[i][exp_id] # (d_model, 2 * expert_dim)
                    down_w_expert=moe_expert_down_proj[i][exp_id]     # (expert_dim, d_model)

                    # 计算 Gate 和 Up 投影
                    gate_up_val=tokens_for_expert@gate_up_w_expert# (N_tokens, 2 * expert_dim)
                    gate_val, up_val=gate_up_val.chunk(2, dim=-1) # (N_tokens, expert_dim)

                    # 应用激活函数和门控
                    activated_up=activation_fn(gate_val) *up_val# (N_tokens, expert_dim)

                    # Down 投影
                    expert_result=activated_up@down_w_expert# (N_tokens, d_model)

                    # 将结果放回正确的位置
                    expert_outputs[token_mask] =expert_result

            # 加权专家输出
            final_hidden_states+= (expert_outputs.view(B, T, d_model) *current_routing_weights.view(B, T, 1))

        # --- 共享专家计算 ---
        shared_gate_val=shared_expert_gate_proj[i](h_norm) # (B, T, shared_expert_dim)
        shared_up_val=shared_expert_up_proj[i](h_norm)     # (B, T, shared_expert_dim)
        shared_activated_up=activation_fn(shared_gate_val) *shared_up_val# (B, T, shared_expert_dim)
        shared_output=shared_expert_down_proj[i](shared_activated_up) # (B, T, d_model)

        # --- 合并 MoE 和共享专家输出,并应用第二个残差连接 ---
        h=residual_connection_ffn+final_hidden_states+shared_output

    # 2.4 最终 RMSNorm
    h=h*torch.rsqrt(h.pow(2).mean(-1, keepdim=True) +rms_norm_eps)
    h=h*final_rmsnorm_weight# 应用 gamma 缩放

    # 2.5 最终输出层
    logits=output_linear_layer(h) # (B, T, vocab_size)

    # --- 3. 计算损失 ---
    # 调整形状以适应 CrossEntropyLoss: (B, T, C) -> (B*T, C), (B, T) -> (B*T)
    loss=criterion(logits.view(-1, vocab_size), yb.view(-1))

    # --- 4. 反向传播 ---
    loss.backward()

    # --- 5. 参数更新 ---
    optimizer.step()

    # --- 6. 梯度清零 ---
    optimizer.zero_grad(set_to_none=True)

    # --- 7. 评估与记录 ---
    ifepoch%eval_interval==0orepoch==epochs-1:
        elapsed_time=time.time() -start_time
        print(f"Epoch {epoch}/{epochs}, Loss: {loss.item():.4f}, Time: {elapsed_time:.2f}s")

 print("\nTraining finished.")
 ### 输出示例 (可能因随机性而异) ###
Starting training...
Epoch 0/3000, Loss: 3.6512, Time: 1.52s
Epoch 300/3000, Loss: 1.9875, Time: 15.88s
Epoch 600/3000, Loss: 1.3542, Time: 30.15s
Epoch 900/3000, Loss: 0.9811, Time: 44.52s
Epoch 1200/3000, Loss: 0.7539, Time: 58.90s
Epoch 1500/3000, Loss: 0.6015, Time: 73.25s
Epoch 1800/3000, Loss: 0.4887, Time: 87.61s
Epoch 2100/3000, Loss: 0.4053, Time: 101.98s
Epoch 2400/3000, Loss: 0.3398, Time: 116.35s
Epoch 2700/3000, Loss: 0.2881, Time: 130.72s
Epoch 2999/3000, Loss: 0.2519, Time: 144.85s

 Training finished.

训练循环成功执行了 3000 轮。输出显示了每隔 300 轮的训练损失值。可以看到损失值随着训练的进行稳步下降,表明模型正在学习预测序列中的下一个字符。最终损失约为 0.25。
**
**

**
**

注意: MoE 专家计算部分为了清晰起见使用了循环和掩码,这在实际大规模训练中效率较低。更优化的实现会使用

torch.gather

或专门的库(如

tutel

)来高效地处理稀疏的专家计算。

文本生成实现

训练完成后,模型就可以用来生成新的文本了。生成过程通常从一个初始的“提示”(prompt)或上下文开始,然后模型自回归地、一次一个词元地预测后续内容。

文本生成设置

首先,定义生成的起始上下文。这里我们使用单个换行符

\n

作为最简单的起始符,让模型从“零”开始生成。将其编码为词元 ID,并确保形状为

(1, 1)

(批大小为 1,序列长度为 1),放置在正确的设备上。

 
# 设置生成起始上下文 (一个换行符)
start_context_char='\n'
start_context_id=char_to_int[start_context_char]
# 转换为形状为 (1, 1) 的张量,表示批大小为 1,序列长度为 1
generation_context=torch.tensor([[start_context_id]], dtype=torch.long, device=device)

# 定义生成文本的最大长度
max_tokens_to_generate=150

print("Text Generation Setup:")
print(f"  Starting context: '{start_context_char}' (ID: {start_context_id})")
print(f"  Initial context tensor shape: {generation_context.shape}")
 print(f"  Max tokens to generate: {max_tokens_to_generate}")
 ### 输出 ###
 Text Generation Setup:
   Starting context: '
 ' (ID: 0)
   Initial context tensor shape: torch.Size([1, 1])
   Max tokens to generate: 150

设置了生成过程的起点为一个换行符(ID 为 0),并指定最多生成 150 个新词元。

生成循环详解

文本生成的核心是一个循环,在每次迭代中:

  1. 准备输入: 获取当前已生成的序列。由于模型训练时的最大上下文长度为 block_size,如果当前序列长度超过 block_size,需要截断,只保留最后 block_size 个词元作为输入。
  2. 模型推理: 将准备好的输入序列传递给模型进行一次前向传播。注意: 这一步与训练时的前向传播完全相同,但不需要计算梯度 (torch.no_grad() 上下文管理器可以提高效率并减少内存使用)。
  3. 获取 Logits: 从模型输出中提取最后一个时间步(对应序列末尾)的 logits。这些 logits 代表了对下一个词元的预测分数。
  4. 采样: 根据 logits 采样下一个词元 ID。最简单的方法是选择具有最高 logit 的词元(贪心采样,torch.argmax)。也可以基于 logits 形成的概率分布进行随机采样(torch.multinomial),这通常能产生更多样化的文本。
  5. 追加词元: 将采样到的新词元 ID 追加到当前序列的末尾。
  6. 重复: 重复以上步骤,直到生成了所需数量的词元。
 
# --- 生成循环 ---
print("\nStarting text generation...")
generated_sequence=generation_context# 从初始上下文开始

# 将模型切换到评估模式 (主要影响 Dropout 和 BatchNorm,本模型未使用)
# 但使用 torch.no_grad() 更为关键,以禁用梯度计算
# (此处省略显式 .eval() 调用,因为我们没有这些层)

withtorch.no_grad(): # 禁用梯度计算以提高效率
    for_inrange(max_tokens_to_generate):
        # 1. 准备输入 (截断到 block_size)
        current_input=generated_sequence[:, -block_size:] # (B, T_current <= block_size)
        B, T_current=current_input.shape

        # --- 2. 模型推理 (与训练前向传播逻辑相同) ---
        h=token_embedding_table(current_input) # (B, T_current, d_model)

        # 计算 RoPE 频率 (freqs_cis)
        pos_indices=torch.arange(T_current, device=device)
        freqs=torch.outer(pos_indices, inv_freq)
        freqs_cis=torch.polar(torch.ones_like(freqs), freqs)

        # Transformer 块
        foriinrange(n_layers):
            residual_connection=h
            # RMSNorm + MHA
            h_norm=h*torch.rsqrt(h.pow(2).mean(-1, keepdim=True) +rms_norm_eps)
            h_norm=h_norm*rmsnorm_weights_input[i]
            qkv=mha_qkv_linears[i](h_norm)
            q, k, v=qkv.split(d_model, dim=-1)
            q=q.view(B, T_current, n_heads, d_k).transpose(1, 2)
            k=k.view(B, T_current, n_heads, d_k).transpose(1, 2)
            v=v.view(B, T_current, n_heads, d_k).transpose(1, 2)
            # Apply RoPE
            q_rope=torch.view_as_complex(q.float().reshape(B, n_heads, T_current, -1, 2))
            k_rope=torch.view_as_complex(k.float().reshape(B, n_heads, T_current, -1, 2))
            freqs_cis_broadcast=freqs_cis.unsqueeze(0).unsqueeze(0)
            q_out=torch.view_as_real(q_rope*freqs_cis_broadcast).flatten(3)
            k_out=torch.view_as_real(k_rope*freqs_cis_broadcast).flatten(3)
            q, k=q_out.type_as(q), k_out.type_as(k)
            # Attention
            scores=torch.matmul(q, k.transpose(-2, -1)) * (d_k**-0.5)
            # 使用截断后的因果掩码
            current_causal_mask=causal_mask[:,:,:T_current,:T_current]
            scores=scores.masked_fill(current_causal_mask==0, float('-inf'))
            attn_weights=F.softmax(scores, dim=-1)
            attn_output=torch.matmul(attn_weights, v)
            attn_output=attn_output.transpose(1, 2).contiguous().view(B, T_current, d_model)
            attn_output=mha_output_linears[i](attn_output)
            h=residual_connection+attn_output

            residual_connection_ffn=h
            # RMSNorm + MoE
            h_norm=h*torch.rsqrt(h.pow(2).mean(-1, keepdim=True) +rms_norm_eps)
            h_norm=h_norm*rmsnorm_weights_post_attn[i]
            # MoE Router
            router_logits=moe_routers[i](h_norm)
            routing_weights, selected_experts=torch.topk(router_logits, num_experts_per_tok, dim=-1)
            routing_weights=F.softmax(routing_weights, dim=-1, dtype=torch.float).to(h_norm.dtype)
            # MoE Expert Calculation (与训练时相同)
            final_hidden_states=torch.zeros_like(h_norm)
            flat_hidden_states=h_norm.view(-1, d_model)
            flat_routing_weights=routing_weights.view(-1, num_experts_per_tok)
            flat_selected_experts=selected_experts.view(-1, num_experts_per_tok)
            fork_idxinrange(num_experts_per_tok):
                expert_indices=flat_selected_experts[:, k_idx]
                current_routing_weights=flat_routing_weights[:, k_idx].unsqueeze(-1)
                expert_outputs=torch.zeros_like(flat_hidden_states)
                forexp_idinrange(num_local_experts):
                    token_mask= (expert_indices==exp_id)
                    iftoken_mask.any():
                        tokens_for_expert=flat_hidden_states[token_mask]
                        gate_up_w_expert=moe_expert_gate_up_proj[i][exp_id]
                        down_w_expert=moe_expert_down_proj[i][exp_id]
                        gate_up_val=tokens_for_expert@gate_up_w_expert
                        gate_val, up_val=gate_up_val.chunk(2, dim=-1)
                        activated_up=activation_fn(gate_val) *up_val
                        expert_result=activated_up@down_w_expert
                        expert_outputs[token_mask] =expert_result
                final_hidden_states+= (expert_outputs.view(B, T_current, d_model) *current_routing_weights.view(B, T_current, 1))
            # Shared Expert
            shared_gate_val=shared_expert_gate_proj[i](h_norm)
            shared_up_val=shared_expert_up_proj[i](h_norm)
            shared_activated_up=activation_fn(shared_gate_val) *shared_up_val
            shared_output=shared_expert_down_proj[i](shared_activated_up)
            # Combine and Residual
            h=residual_connection_ffn+final_hidden_states+shared_output

        # 最终 RMSNorm
        h=h*torch.rsqrt(h.pow(2).mean(-1, keepdim=True) +rms_norm_eps)
        h=h*final_rmsnorm_weight

        # 最终输出层
        logits=output_linear_layer(h) # (B, T_current, vocab_size)

        # --- 3. 获取最后一个时间步的 Logits ---
        last_step_logits=logits[:, -1, :] # (B, vocab_size)

        # --- 4. 采样 ---
        # 应用 softmax 得到概率
        probs=F.softmax(last_step_logits, dim=-1) # (B, vocab_size)
        # 从概率分布中采样下一个词元 ID
        next_token_id=torch.multinomial(probs, num_samples=1) # (B, 1)

        # --- 5. 追加词元 ---
        generated_sequence=torch.cat((generated_sequence, next_token_id), dim=1) # (B, T_current + 1)

 print("Generation finished.")
 ### 输出 ###
 
 Starting text generation...
 Generation finished.

生成循环执行完成。

generated_sequence

张量现在包含了初始上下文和新生成的词元 ID 序列。

解码生成序列

最后一步是将包含词元 ID 的

generated_sequence

张量转换回人类可读的文本。这通过使用之前创建的

int_to_char

字典来完成。

 
# 将生成的 ID 序列解码回文本
# generated_sequence 形状为 (1, T_generated), 取第一个批次 [0]
generated_ids=generated_sequence[0].tolist()
generated_text=''.join([int_to_char[idx] foridxingenerated_ids])

print("\n--- Generated Text ---")
 print(generated_text)
 ### 输出示例 (可能因随机性而异) ###
 
 --- Generated Text ---
 
 nk, and of having nothing to do: once or twice she had peeped into the
 book her sister was reading, but it had no pictures or conversations in
 it, 'and

输出显示了解码后的生成文本。由于模型规模小且训练数据有限,生成的文本可能在语法或连贯性上存在不足,但它展示了模型从训练数据中学到的一些模式(例如单词、标点和换行)。如果使用文章开头的示例输入 "Alice",则会得到类似开头的输出。

模型状态保存(可选)

训练完成后,通常需要保存模型的参数(权重和偏置),以便将来可以重新加载模型进行推理或继续训练,而无需从头开始。PyTorch 推荐保存模型的

state_dict

,它是一个包含所有可学习参数及其当前值的 Python 字典。

 
# (可选) 保存模型状态字典
# 创建一个字典来存储所有参数
model_state= {
    'token_embedding_table': token_embedding_table.state_dict(),
    'rmsnorm_weights_input': [p.dataforpinrmsnorm_weights_input], # 直接保存张量数据
    'rmsnorm_weights_post_attn': [p.dataforpinrmsnorm_weights_post_attn],
    'final_rmsnorm_weight': final_rmsnorm_weight.data,
    'mha_qkv_linears': [layer.state_dict() forlayerinmha_qkv_linears],
    'mha_output_linears': [layer.state_dict() forlayerinmha_output_linears],
    'moe_routers': [layer.state_dict() forlayerinmoe_routers],
    'moe_expert_gate_up_proj': [p.dataforpinmoe_expert_gate_up_proj],
    'moe_expert_down_proj': [p.dataforpinmoe_expert_down_proj],
    'shared_expert_gate_proj': [layer.state_dict() forlayerinshared_expert_gate_proj],
    'shared_expert_up_proj': [layer.state_dict() forlayerinshared_expert_up_proj],
    'shared_expert_down_proj': [layer.state_dict() forlayerinshared_expert_down_proj],
    'output_linear_layer': output_linear_layer.state_dict(),
    # 保存超参数以便加载时重建模型结构
    'hyperparameters': {
        'vocab_size': vocab_size,
        'd_model': d_model,
        'n_layers': n_layers,
        'n_heads': n_heads,
        'block_size': block_size,
        'rms_norm_eps': rms_norm_eps,
        'rope_theta': rope_theta,
        'num_local_experts': num_local_experts,
        'num_experts_per_tok': num_experts_per_tok,
        'intermediate_size_expert': intermediate_size_expert,
        'intermediate_size_shared': intermediate_size_shared,
        'd_k': d_k
    }
}

# 定义保存路径
save_path='llama4_moe_simplified_model.pth'
# 保存状态字典到文件
torch.save(model_state, save_path)

print(f"\nModel state dictionary saved to: {save_path}")
# 加载示例 (需要先重新初始化模型结构):
# checkpoint = torch.load(save_path, map_location=device)
# token_embedding_table.load_state_dict(checkpoint['token_embedding_table'])
 # ... 加载其他参数 ...
 ### 输出 ###
 
 Model state dictionary saved to: llama4_moe_simplified_model.pth

此代码块将所有模型的

nn.Module

(如

nn.Embedding

,

nn.Linear

)的

state_dict

以及直接定义的

nn.Parameter

的数据张量收集到一个字典

model_state

中。同时,也将关键的超参数保存在字典里,这对于后续正确加载模型至关重要。最后,使用

torch.save

将这个字典保存到文件

llama4_moe_simplified_model.pth

。注释中提供了加载模型状态的基本思路。

总结

本文通过一个简化的实例,逐步从零开始实现了 LLaMA 4 风格的混合专家(MoE)Transformer 模型。我们涵盖了从数据准备、分词、模型架构定义(包括词元嵌入、RoPE、RMSNorm、多头注意力、MoE 层和共享专家)到训练和文本生成的完整流程。

关键实现点包括:

  • MoE 层: 实现了包含路由器(Top-K 路由)、独立专家(门控 MLP)和共享专家的 MoE 结构。
  • RoPE: 集成了旋转位置编码以处理序列位置信息。
  • RMSNorm: 使用了 RMSNorm 进行层归一化。
  • 从零实现: 代码避免了直接使用高级 Transformer 库,而是侧重于展示底层组件的构建和交互。

尽管本示例的模型规模和训练数据远小于实际的 LLaMA 4,但它清晰地揭示了 MoE 架构的核心机制:通过动态路由和稀疏激活专家,在潜在提升性能的同时控制计算成本。

希望本文能帮助读者更好地理解 MoE 的工作原理,并为进一步探索和实现更复杂的 Transformer 模型打下基础。完整的代码:

https://avoid.overfit.cn/post/27b7812def944fe0bad8ef1ecef5e739

注:本文根据Github上FareedKhan-dev的代码进行修改,代码和执行时间等会略有不同,原作者为Fareed Khan,感谢他的贡献。


deephub
125 声望107 粉丝