对于许多数据科学家和开发者而言,
einsum
通常被视为NumPy文档中那个晦涩难懂的高级函数——功能强大但难以理解。不过一旦掌握其基本原理,
einsum
将成为Python科学计算生态系统中处理多维数组运算最为优雅高效的工具之一。它不仅语法简洁,表达力强,而且在众多应用场景中通常比常规方法更具计算效率。
本文将全面介绍
einsum()
函数——其数学基础、实现原理以及实际应用场景。我们将深入剖析其符号系统,通过实用示例展示其功能,探讨性能优化策略,并提供一个完整的参考速查表辅助实际应用。
无论您是深度学习研究人员、数值计算工程师,还是追求代码简洁高效的Python开发者,本文都将帮助您充分理解并有效应用
einsum()
函数。
1、爱因斯坦求和约定基础
我们从一个经典案例入手:矩阵乘法。在线性代数中,两个矩阵A和B相乘的标准定义是计算A中每一行与B中每一列的点积。图形化表示如下:
图1:标准矩阵乘法示意图。
注意观察,对于结果矩阵C中的每个元素,我们取A的第i行和B的第k列对应位置的元素相乘,然后对索引j求和。
这种元素级乘法和求和模式在张量运算中极为常见,尤其在阿尔伯特·爱因斯坦的广义相对论研究中。爱因斯坦为简化复杂张量表达式,提出了一种简洁表示法,省略显式的求和符号。其核心原则是:当一个索引在表达式中出现两次,则默认对该索引进行求和。
因此,对于矩阵乘法,我们可以表示为:
图2:使用爱因斯坦表示法的矩阵乘法。
在NumPy中(后文将讨论其他框架如PyTorch、TensorFlow等),
np.einsum()
函数允许直接应用这种表示法进行数组操作。下面是一个基本示例,展示了如何使用此函数进行矩阵乘法:
importnumpyasnp
A=np.random.randn(4,3)
B=np.random.randn(3,3)
C=np.einsum("ij,jk->ik", A, B)
该操作等效于
np.dot(A, B)
或
A @ B
,但
einsum()
的真正价值在于它能够将这一原理推广到更复杂的多维张量运算,这正是本文将要深入探讨的内容。
关于索引的技术说明:j是求和索引(也称为哑索引)。它可以是任何字母,不会改变表达式的实际含义,故称为"哑"索引。它在多个张量中出现,但不出现在结果中。而索引i和k称为自由索引,它们不参与求和运算,在每个张量中仅出现一次,并保留在输出结果中。
接下来,我们将详细分析
einsum()
的语法结构。
2、理解基本的 einsum() 语法
einsum()函数的语法初见时可能显得复杂,但一旦理解其逻辑,您会发现它是一个极为强大的数组操作工具。以下分析主要基于NumPy实现,但这些概念同样适用于PyTorch和TensorFlow等其他科学计算库,只有少量实现细节的差异。
本文将重点关注显式语法,因为从工程实践角度看,显式语法更不易出错且更具可读性。所谓显式语法,是指在einsum字符串中明确指定输出张量的索引结构。
虽然隐式语法更为简洁,但它缺乏灵活性,且容易因索引字母的顺序问题导致错误。这主要是个人偏好问题,但建议初学者先掌握显式语法。如需讨论隐式语法的细节,可在评论区交流。
以下是我们在矩阵乘法示例中使用的einsum字符串结构:
图3:einsum()基本语法结构。
这张图清晰展示了einsum字符串的基本结构。现在让我们专注于理解这一字符串的核心含义。
在einsum中,输入和输出张量的每个维度都由一个索引字母表示。例如,3维张量需要3个字母,4维张量需要4个字母,依此类推。这些索引实质上代表着循环变量。在NumPy实现中,可以使用任何小写或大写字母作为索引,且在显式语法中,字母的顺序并不影响计算结果。
einsum()
执行的具体操作取决于索引在字符串中的出现位置和方式。主要有四种情况:
图4:显式
einsum()
语法规则。
在张量计算中,"收缩"(contraction)这一术语经常与"求和"一起使用,但两者有细微差别。求和专指对特定索引进行规约操作,例如对索引j求和。而收缩是一个更广泛的概念,它包含求和操作,但同时也表示张量维度的减少。例如,矩阵乘法可视为在一个维度上的收缩。因此,每个求和操作都构成一种收缩,但并非所有收缩都是简单的求和。
为了建立直观理解,我们来分析几个基本示例:
外积
外积是线性代数中的基本运算,它通过将第一个向量的每个元素与第二个向量的每个元素相乘,生成一个二维矩阵。
图5:
einsum()
表示的外积运算。
观察einsum字符串
"i,k->ik"
,我们发现输入中的所有索引都保留在输出中,因此没有执行求和操作。这表明第一个向量中的每个元素都与第二个向量中的每个元素相乘,形成结果矩阵。
批处理外积
批处理外积是外积的扩展应用,适用于同时计算多对向量的外积情况。
图6:
einsum()
表示的批处理外积。
在einsum字符串
"bi,bk->bik"
中,所有输入索引都保留在输出中,表明没有执行求和。这里的b索引出现在所有张量中,表示在批处理维度上执行外积运算。
提取对角线
从方阵提取对角线元素是一个典型的单输入操作示例。矩阵的对角线是指从左上角到右下角的元素集合。
图7:
einsum()
表示的对角线提取操作。
einsum字符串
"ii->i"
表明我们有一个二维输入和一维输出。由于使用了相同的索引i,我们并非对其求和,而是提取对角线元素。理解einsum字符串的最佳方式是分析其隐含的循环结构。
计算迹
如果要计算对角线元素的和(即矩阵的迹),可以使用相同的输入,但将einsum字符串改为
"ii->"
。
图8:
einsum()
表示的矩阵迹计算。
输出是一个标量(0维张量)。由于索引
i
出现在输入中但不在输出中,因此对其执行求和操作。
加权和
加权和是机器学习中的常见操作,特别是在线性回归和神经网络中。给定一个值矩阵和一个权重向量,可以执行如下计算:
图9:
einsum()
表示的加权和计算。
分析einsum字符串
"ij,j->i"
,索引
j
出现在两个输入中但不在输出中,因此对其求和。而索引
i
同时出现在输入和输出中,表明不对其求和。由于
i
在第一个输入的第一个轴上,可以推断输出是一个列向量,与图示一致。
简单转置
矩阵转置是线性代数中最基本的操作之一,它将矩阵的行和列互换。用
einsum()
实现如下:
图10:
einsum()
表示的矩阵转置操作。
这是一个没有收缩的单输入操作,因为所有索引都同时出现在输入和输出中,只是顺序发生了变化。
一步完成矩阵乘法和转置
einsum()
允许在一个操作中同时完成矩阵乘法和转置。以下示例展示了转置第二个输入的矩阵乘法:
图11:
einsum()
表示的带转置的矩阵乘法。
这里的收缩发生在索引
j
上,它出现在两个输入中但不在输出中。
图像中的轴交换(通道优先 vs. 通道置后)
广义上,转置只是轴交换的一种特例。在图像处理中,深度学习框架常用两种不同的数据格式:PyTorch使用通道优先格式(B,C,H,W),而TensorFlow使用通道置后格式(B,H,W,C)。在框架间转换时,可能需要调整图像的轴顺序:
图12:
einsum()
表示的图像数据格式转换。
这种操作的优势在于其明确性——einsum字符串直观地展示了维度的重排方式。由于所有索引都出现在输入和输出中,不执行任何收缩,仅改变轴的顺序。
值得注意的是,尽管看起来这些操作可能涉及大量数据复制,但
einsum()
通常会创建原始数组的视图,避免额外的内存分配,这是NumPy和
einsum()
的重要优化特性。
einsum的高级语法
基本语法已经展示了
einsum()
的强大功能,但这仅仅是开始。
einsum()
还提供了一些高级特性,使其在复杂计算场景中更加灵活高效。
本节将探讨
einsum()
的高级应用,特别是在深度学习中的应用,尤其是Transformer架构中的应用实例。
具体而言,我们将学习省略号(...)运算符的使用、如何处理多输入操作,以及如何通过
einsum_path()
优化
einsum()
的计算效率。
多头注意力机制概述
我们先简要回顾Transformer架构中的多头注意力机制。Transformer是一种序列到序列模型,由多个Transformer层堆叠而成,每层包含多头注意力和前馈神经网络(为简洁起见,此处省略了层归一化和残差连接)。其结构如下图所示:
图13:Transformer层的简化结构图。
注意多头注意力机制中,多个注意力头是并行计算的,每个头为每个批次计算独立的注意力分数。也就是说,注意力分数的计算在多个头和批次维度上进行广播。
用于批处理多头注意力的省略号 (…) 运算符
省略号运算符(...)是
einsum()
中一个强大功能,允许在不明确指定每个维度的情况下表示任意数量的维度。这在处理具有多个维度(如多头注意力中的头维度或批次维度)的操作时特别有用。
以一个简单例子说明:在多头注意力机制中,注意力分数通过计算查询矩阵Q和键矩阵K的矩阵乘积,并应用SoftMax函数得到。然后,这些注意力分数用于计算值矩阵V的加权和(为简化起见,此处省略了除以sqrt(d)的缩放步骤):
图14:注意力块内部结构(简化版)。
让我们关注图中绿色高亮的查询矩阵和键矩阵的乘法。常见表达式为
S=QK^T
,但这种表示方式有些简化,因为实际上我们需要为每个头和每个批次执行此计算,并且还需要转置键矩阵。使用
einsum()
,这些操作可以在一行代码中完成:
Q=np.random.rand(batch_size, num_heads, head_dim, num_query_tokens)
K=np.random.rand(batch_size, num_heads, head_dim, num_key_tokens)
# 带转置的批处理多头矩阵乘法 (Batched Multi-Head Matrix Multiplication with transpose)
S=np.einsum("bhdi,bhdj->bhij", Q, K) # (batch_size, num_heads, num_query_tokens, num_key_tokens)
这一操作不仅为每个头和批次执行矩阵乘法,还同时完成了键矩阵的转置。
在此示例中,我区分了
num_query_tokens
和
num_key_tokens
,因为虽然在自注意力机制中查询和键矩阵的token数量相同,但在交叉注意力场景中可能不同(如多模态系统中,文本提示可能只有少量token,而图像可能有成千上万个token)。
如果我们想将相同的函数应用于单头注意力或单个批次,可以使用省略号运算符:
S=np.einsum("...di,...dj->...ij", Q, K)
这种表示方式不关心头和批次的具体数量,只关注最后两个维度的运算,非常灵活。我们可以用这个函数计算任意数量的头和批次的注意力分数:
图15:多头自注意力中任意批次和头数的矩阵乘法。
多输入与优化——推导线性 Transformer
到目前为止,我们主要使用了一个或两个输入的
einsum()
示例。但
einsum()
能够处理更多输入。为了展示多输入处理能力,我们将分析一个实际案例:推导Linear Transformer,这是一种计算复杂度与序列长度呈线性关系而非二次关系的Transformer变体。
我们需要理解为什么标准Transformer的复杂度与序列长度呈O(n²)关系。这主要源于注意力机制。回顾之前的例子,当我们计算查询矩阵和键矩阵的乘积时,得到一个形状为
(batch_size, num_heads, num_tokens, num_tokens)
的矩阵,即一个
num_tokens × num_tokens
的方阵。随着token数量的增加,计算量和存储需求呈二次增长。
这带来两个问题:计算复杂度高且内存需求大。线性Transformer采用核技巧来近似SoftMax函数:
图16:使用线性Transformer中的核技巧近似SoftMax函数。
对于本文而言,具体的
Φ(.)
函数定义不是重点,关键是理解它同时应用于查询和键矩阵。结构上,我们可以执行如下操作:
图17:为线性注意力重构的注意力块图。
这还不是最终结构,但它展示了如何使用
einsum()
在一步中计算包含三个输入的两个连续矩阵乘法:
图18:三矩阵乘法的Einsum表示法。
通过省略号运算符,该操作支持任意数量的头和批次,并且前两个输入还处理了必要的转置。但我们仍面临二次复杂度的问题。
解决方案核心在于利用矩阵乘法的结合律重新排列计算顺序:
图19:线性Transformer中应用矩阵乘法结合律。
仅通过改变矩阵乘法的顺序,计算复杂度从O(n²)降低到O(n)。这是线性Transformer的核心优化。但对于复杂的
einsum()
表达式,如何确定最优执行顺序?
这里可以使用
einsum_path()
函数,它能够为给定的einsum表达式确定最优收缩顺序。该函数支持"greedy"(贪婪)和"optimal"(最优)两种优化策略。贪婪策略计算速度快但可能不是最优,而最优策略计算较慢但保证最佳结果。
首先,定义输入并测量计算时间:
importnumpyasnp
B=8 # 批处理大小 (batch size)
H=16 # 头的数量 (number of heads)
D=64 # 特征维度 (feature dimension)
I=J=2048# 标记数量 (number of tokens)
Q_phi=np.random.rand(B, H, D, I) # 查询 (Query)
K_phi=np.random.rand(B, H, D, J) # 键 (Key)
V=np.random.rand(B, H, D, J) # 值 (Value)
%timeitnp.einsum("bhdi, bhdj, bhdj -> bhdi", Q_phi, K_phi, V)
# 每个循环 15.6 秒 ± 41.2 毫秒 (平均值 ± 标准差,7 次运行,每次 1 个循环) (15.6 s ± 41.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each))
使用
einsum_path()
寻找最优收缩顺序:
path_standard=np.einsum_path("bhdi, bhdj, bhdj -> bhdi", Q_phi, K_phi, V, optimize=False)
path_optimal=np.einsum_path("bhdi, bhdj, bhdj -> bhdi", Q_phi, K_phi, V, optimize="optimal")
print(path_standard[0])
# ['einsum_path', (0, 1, 2)]
print(path_optimal[0])
#['einsum_path', (1, 2), (0, 1)]
print(path_optimal[1])
# Complete contraction: bhdi,bhdj,bhdj->bhdi
# Naive scaling: 5
# Optimized scaling: 4
# Naive FLOP count: 1.031e+11
# Optimized FLOP count: 5.033e+07
# Theoretical speedup: 2048.000
# Largest intermediate: 1.678e+07 elements
# --------------------------------------------------------------------------
# scaling current remaining
# --------------------------------------------------------------------------
# 4 bhdj,bhdj->bhd bhdi,bhd->bhdi
# 4 bhd,bhdi->bhdi bhdi->bhdi
输出包含丰富信息,主要关注点如下:
- 收缩顺序:标准收缩顺序(0,1,2)表示先将前两个矩阵相乘,再与第三个矩阵相乘;最优收缩顺序(1,2),(0,1)表示先将第二个和第三个矩阵相乘,再与第一个矩阵相乘。
- 缩放因子:朴素缩放是einsum字符串中索引数量,优化缩放是最优收缩顺序中的索引数量。缩放值越低,计算越快。
- FLOP计数:浮点运算次数,优化后的FLOP计数大幅减少。
- 理论加速比:朴素FLOP计数与优化FLOP计数的比率,本例中为2048倍,恰好等于示例中的token数量,验证了线性Transformer的线性复杂度特性。
最后测量优化后的计算时间:
%timeitnp.einsum("bhdi, bhdj, bhdj -> bhdi", Q_phi, K_phi, V, optimize=path_optimal[0])
# 每个循环 36.6 毫秒 ± 1.23 毫秒 (平均值 ± 标准差,7 次运行,每次 1 个循环) (36.6 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each))
计算时间从15.6秒降至36.6毫秒,加速比达426倍!实际应用中,矩阵规模更大时加速效果会更明显。
线性Transformer的更新结构如下:
图20:通过改变矩阵乘法收缩顺序实现线性Transformer。
需要注意的是,理论加速比在实际应用中并不总能完全实现,因为实际性能还受
einsum
具体实现和硬件配置影响。内存访问模式、缓存使用效率等因素也会影响性能。此外,某些计算模式可能允许使用优化的BLAS子程序,而其他模式则不行。因此,在应用优化前后务必进行性能测试,确认优化效果。
使用 einsum时的注意事项
隐式语法 (Implicit Syntax)
隐式语法是
einsum()
的简写形式,允许在不明确指定输出形状的情况下编写einsum字符串。虽然更简洁,但灵活性较低,且有一些需要注意的特殊情况,特别是索引字母顺序的重要性。
以下示例说明这一点:
a=np.empty((10,20))
# 因为索引顺序为ba,形状从(10,20)变为(20,10)
np.einsum("ba",a).shape
# (20, 10)
虽然在简单情况下这种表示方式较为直观,但在处理多输入操作时可能导致混淆:
b=np.empty((20,30))
# 矩阵乘法:
np.einsum("bc,cd",a,b).shape
# (10, 30)
# 将索引d改为a,会转置结果!
np.einsum("bc,ca",a,b).shape
# (30, 10)
除混淆风险外,隐式语法在功能上也存在局限。例如,无法提取矩阵对角线,因为重复索引会自动进行求和:
c=np.ones((20,20))
np.einsum("ii",c)
# 输出: 20 (即对角线元素之和)
建议:坚持使用显式语法,它更具可读性且不易出错。
数据类型不会自动提升 (Data Types Not Being Promoted)
einsum()
不会自动提升数据类型,这在处理整数类型时尤其需要注意。如果计算结果超出数据类型范围,将导致溢出:
a=np.ones(200, dtype=np.int8)
print(np.sum(a))
# 输出: 200
print(np.einsum("i->",a))
# 输出: -56 (由于int8类型溢出)
由于
a
中元素和超过127(int8的最大值),结果发生溢出,返回负值。相比之下,
np.sum()
会自动将结果提升为更大的整数类型(此例中为np.int64)。
BLAS 子程序的使用 (Usage of BLAS Subroutines)
BLAS(基础线性代数子程序)是一系列高度优化的线性代数运算库。NumPy、PyTorch、TensorFlow等科学计算库通常使用这些库加速矩阵运算。
einsum()
在某些情况下可利用BLAS子程序优化计算,特别是矩阵乘法等标准线性代数运算。然而,并非所有
einsum
操作都能通过BLAS优化。例如,涉及多维广播的操作通常无法利用BLAS优化。
以下示例展示了BLAS优化的效果:
B=1# 批处理 (batch)
D=512# 特征维度 (feature dimension)
I=J=4096# 标记数量 (number of tokens)
# 非批处理矩阵 (Non-batched matrices)
Q= np.random.rand(D, I)
K= np.random.rand(D, J)
# 批处理矩阵 (Batched matrices)
Qb= np.random.rand(B, D, I)
Kb= np.random.rand(B, D, J)
%timeitnp.einsum("di,dj->ij", Q, K, optimize=False) # -> 无BLAS优化
# 每个循环 5.22 秒 ± 56.7 毫秒 (平均值 ± 标准差,7 次运行,每次 1 个循环)
%timeitnp.einsum("di,dj->ij", Q, K, optimize=True) # -> 使用BLAS优化
# 每个循环 125 毫秒 ± 5.3 毫秒 (平均值 ± 标准差,7 次运行,每次 10 个循环)
%timeitnp.einsum("bdi,bdj->bij", Qb, Kb, optimize=True) # -> 批处理模式下无BLAS优化
# 每个循环 5.28 秒 ± 72.3 毫秒 (平均值 ± 标准差,7 次运行,每次 1 个循环)
对于大型批处理矩阵,可以考虑使用GPU或支持批处理BLAS的库,如PyTorch:
importtorch
# PyTorch使用张量
Qbt=torch.tensor(Qb)
Kbt=torch.tensor(Kb)
%timeittorch.einsum("bdi,bdj->bij", Qbt, Kbt)
# 每个循环 76.3 毫秒 ± 2.15 毫秒 (平均值 ± 标准差,7 次运行,每次 10 个循环)
PyTorch的实现比NumPy中优化的
einsum()
快约两倍,这得益于其不同的实现方式。
不同框架中的不同行为 (Different Behavior in Different Frameworks)
如上一节所示,
einsum()
在不同框架中的行为可能存在差异。PyTorch和TensorFlow都有各自针对框架优化的
einsum()
实现,导致性能和某些行为细节的差异。
使用不同框架时,务必查阅相应文档了解具体实现细节。
另一个值得考虑的选择是einops库,它为张量操作提供了更一致、更用户友好的接口。该库在NumPy、PyTorch和TensorFlow之上构建,为三个框架提供统一接口。它支持与
einsum()
类似的语法,并提供额外功能,如模式匹配、打包和解包、多字母索引名等,还为PyTorch和TensorFlow实现了相应的层。
总结
恭喜您完成了这篇关于
einsum()
的全面指南!🎉 我们从基本语法入手,逐步探讨了高级应用,如线性Transformer的实现。我们还讨论了使用
einsum()
时的注意事项及其解决方案。
希望本文能帮助您更好地理解
einsum()
,并在实际工作中自信地应用这一强大工具。如有任何问题或建议,欢迎在评论区交流。您也可以分享自己的使用案例,促进社区共同学习!
https://avoid.overfit.cn/post/2c4838dc17614a43b57ac27b21747c37
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。