对于许多数据科学家和开发者而言,

einsum

通常被视为NumPy文档中那个晦涩难懂的高级函数——功能强大但难以理解。不过一旦掌握其基本原理,

einsum

将成为Python科学计算生态系统中处理多维数组运算最为优雅高效的工具之一。它不仅语法简洁,表达力强,而且在众多应用场景中通常比常规方法更具计算效率。

本文将全面介绍

einsum()

函数——其数学基础、实现原理以及实际应用场景。我们将深入剖析其符号系统,通过实用示例展示其功能,探讨性能优化策略,并提供一个完整的参考速查表辅助实际应用。

无论您是深度学习研究人员、数值计算工程师,还是追求代码简洁高效的Python开发者,本文都将帮助您充分理解并有效应用

einsum()

函数。

1、爱因斯坦求和约定基础

我们从一个经典案例入手:矩阵乘法。在线性代数中,两个矩阵AB相乘的标准定义是计算A中每一行与B中每一列的点积。图形化表示如下:

图1:标准矩阵乘法示意图。

注意观察,对于结果矩阵C中的每个元素,我们取A的第i行和B的第k列对应位置的元素相乘,然后对索引j求和。

这种元素级乘法和求和模式在张量运算中极为常见,尤其在阿尔伯特·爱因斯坦的广义相对论研究中。爱因斯坦为简化复杂张量表达式,提出了一种简洁表示法,省略显式的求和符号。其核心原则是:当一个索引在表达式中出现两次,则默认对该索引进行求和。

因此,对于矩阵乘法,我们可以表示为:

图2:使用爱因斯坦表示法的矩阵乘法。

在NumPy中(后文将讨论其他框架如PyTorch、TensorFlow等),

np.einsum()

函数允许直接应用这种表示法进行数组操作。下面是一个基本示例,展示了如何使用此函数进行矩阵乘法:

 importnumpyasnp  
   
 A=np.random.randn(4,3)  
 B=np.random.randn(3,3)  
   
 C=np.einsum("ij,jk->ik", A, B)

该操作等效于

np.dot(A, B)

A @ B

,但

einsum()

的真正价值在于它能够将这一原理推广到更复杂的多维张量运算,这正是本文将要深入探讨的内容。

关于索引的技术说明:j是求和索引(也称为哑索引)。它可以是任何字母,不会改变表达式的实际含义,故称为"哑"索引。它在多个张量中出现,但不出现在结果中。而索引i和k称为自由索引,它们不参与求和运算,在每个张量中仅出现一次,并保留在输出结果中。

接下来,我们将详细分析

einsum()

的语法结构。

2、理解基本的 einsum() 语法

einsum()函数的语法初见时可能显得复杂,但一旦理解其逻辑,您会发现它是一个极为强大的数组操作工具。以下分析主要基于NumPy实现,但这些概念同样适用于PyTorch和TensorFlow等其他科学计算库,只有少量实现细节的差异。

本文将重点关注显式语法,因为从工程实践角度看,显式语法更不易出错且更具可读性。所谓显式语法,是指在einsum字符串中明确指定输出张量的索引结构。

虽然隐式语法更为简洁,但它缺乏灵活性,且容易因索引字母的顺序问题导致错误。这主要是个人偏好问题,但建议初学者先掌握显式语法。如需讨论隐式语法的细节,可在评论区交流。

以下是我们在矩阵乘法示例中使用的einsum字符串结构:

图3:einsum()基本语法结构。

这张图清晰展示了einsum字符串的基本结构。现在让我们专注于理解这一字符串的核心含义。

在einsum中,输入和输出张量的每个维度都由一个索引字母表示。例如,3维张量需要3个字母,4维张量需要4个字母,依此类推。这些索引实质上代表着循环变量。在NumPy实现中,可以使用任何小写或大写字母作为索引,且在显式语法中,字母的顺序并不影响计算结果。

einsum()

执行的具体操作取决于索引在字符串中的出现位置和方式。主要有四种情况:

图4:显式

einsum()

语法规则。

在张量计算中,"收缩"(contraction)这一术语经常与"求和"一起使用,但两者有细微差别。求和专指对特定索引进行规约操作,例如对索引j求和。而收缩是一个更广泛的概念,它包含求和操作,但同时也表示张量维度的减少。例如,矩阵乘法可视为在一个维度上的收缩。因此,每个求和操作都构成一种收缩,但并非所有收缩都是简单的求和。

为了建立直观理解,我们来分析几个基本示例:

外积

外积是线性代数中的基本运算,它通过将第一个向量的每个元素与第二个向量的每个元素相乘,生成一个二维矩阵。

图5:

einsum()

表示的外积运算。

观察einsum字符串

"i,k->ik"

,我们发现输入中的所有索引都保留在输出中,因此没有执行求和操作。这表明第一个向量中的每个元素都与第二个向量中的每个元素相乘,形成结果矩阵。

批处理外积

批处理外积是外积的扩展应用,适用于同时计算多对向量的外积情况。

图6:

einsum()

表示的批处理外积。

在einsum字符串

"bi,bk->bik"

中,所有输入索引都保留在输出中,表明没有执行求和。这里的b索引出现在所有张量中,表示在批处理维度上执行外积运算。

提取对角线

从方阵提取对角线元素是一个典型的单输入操作示例。矩阵的对角线是指从左上角到右下角的元素集合。

图7:

einsum()

表示的对角线提取操作。

einsum字符串

"ii->i"

表明我们有一个二维输入和一维输出。由于使用了相同的索引i,我们并非对其求和,而是提取对角线元素。理解einsum字符串的最佳方式是分析其隐含的循环结构。

计算迹

如果要计算对角线元素的和(即矩阵的迹),可以使用相同的输入,但将einsum字符串改为

"ii->"


图8:

einsum()

表示的矩阵迹计算。

输出是一个标量(0维张量)。由于索引

i

出现在输入中但不在输出中,因此对其执行求和操作。

加权和

加权和是机器学习中的常见操作,特别是在线性回归和神经网络中。给定一个值矩阵和一个权重向量,可以执行如下计算:

图9:

einsum()

表示的加权和计算。

分析einsum字符串

"ij,j->i"

,索引

j

出现在两个输入中但不在输出中,因此对其求和。而索引

i

同时出现在输入和输出中,表明不对其求和。由于

i

在第一个输入的第一个轴上,可以推断输出是一个列向量,与图示一致。

简单转置

矩阵转置是线性代数中最基本的操作之一,它将矩阵的行和列互换。用

einsum()

实现如下:

图10:

einsum()

表示的矩阵转置操作。

这是一个没有收缩的单输入操作,因为所有索引都同时出现在输入和输出中,只是顺序发生了变化。

一步完成矩阵乘法和转置

einsum()

允许在一个操作中同时完成矩阵乘法和转置。以下示例展示了转置第二个输入的矩阵乘法:

图11:

einsum()

表示的带转置的矩阵乘法。

这里的收缩发生在索引

j

上,它出现在两个输入中但不在输出中。

图像中的轴交换(通道优先 vs. 通道置后)

广义上,转置只是轴交换的一种特例。在图像处理中,深度学习框架常用两种不同的数据格式:PyTorch使用通道优先格式(B,C,H,W),而TensorFlow使用通道置后格式(B,H,W,C)。在框架间转换时,可能需要调整图像的轴顺序:

图12:

einsum()

表示的图像数据格式转换。

这种操作的优势在于其明确性——einsum字符串直观地展示了维度的重排方式。由于所有索引都出现在输入和输出中,不执行任何收缩,仅改变轴的顺序。

值得注意的是,尽管看起来这些操作可能涉及大量数据复制,但

einsum()

通常会创建原始数组的视图,避免额外的内存分配,这是NumPy和

einsum()

的重要优化特性。

einsum的高级语法

基本语法已经展示了

einsum()

的强大功能,但这仅仅是开始。

einsum()

还提供了一些高级特性,使其在复杂计算场景中更加灵活高效。

本节将探讨

einsum()

的高级应用,特别是在深度学习中的应用,尤其是Transformer架构中的应用实例。

具体而言,我们将学习省略号(...)运算符的使用、如何处理多输入操作,以及如何通过

einsum_path()

优化

einsum()

的计算效率。

多头注意力机制概述

我们先简要回顾Transformer架构中的多头注意力机制。Transformer是一种序列到序列模型,由多个Transformer层堆叠而成,每层包含多头注意力和前馈神经网络(为简洁起见,此处省略了层归一化和残差连接)。其结构如下图所示:

图13:Transformer层的简化结构图。

注意多头注意力机制中,多个注意力头是并行计算的,每个头为每个批次计算独立的注意力分数。也就是说,注意力分数的计算在多个头和批次维度上进行广播。

用于批处理多头注意力的省略号 (…) 运算符

省略号运算符(...)是

einsum()

中一个强大功能,允许在不明确指定每个维度的情况下表示任意数量的维度。这在处理具有多个维度(如多头注意力中的头维度或批次维度)的操作时特别有用。

以一个简单例子说明:在多头注意力机制中,注意力分数通过计算查询矩阵Q和键矩阵K的矩阵乘积,并应用SoftMax函数得到。然后,这些注意力分数用于计算值矩阵V的加权和(为简化起见,此处省略了除以sqrt(d)的缩放步骤):

图14:注意力块内部结构(简化版)。

让我们关注图中绿色高亮的查询矩阵和键矩阵的乘法。常见表达式为

S=QK^T

,但这种表示方式有些简化,因为实际上我们需要为每个头和每个批次执行此计算,并且还需要转置键矩阵。使用

einsum()

,这些操作可以在一行代码中完成:

 Q=np.random.rand(batch_size, num_heads, head_dim, num_query_tokens)  
 K=np.random.rand(batch_size, num_heads, head_dim, num_key_tokens)  
   
 # 带转置的批处理多头矩阵乘法 (Batched Multi-Head Matrix Multiplication with transpose)
 S=np.einsum("bhdi,bhdj->bhij", Q, K) # (batch_size, num_heads, num_query_tokens, num_key_tokens)

这一操作不仅为每个头和批次执行矩阵乘法,还同时完成了键矩阵的转置。

在此示例中,我区分了

num_query_tokens

num_key_tokens

,因为虽然在自注意力机制中查询和键矩阵的token数量相同,但在交叉注意力场景中可能不同(如多模态系统中,文本提示可能只有少量token,而图像可能有成千上万个token)。

如果我们想将相同的函数应用于单头注意力或单个批次,可以使用省略号运算符:

 S=np.einsum("...di,...dj->...ij", Q, K)

这种表示方式不关心头和批次的具体数量,只关注最后两个维度的运算,非常灵活。我们可以用这个函数计算任意数量的头和批次的注意力分数:

图15:多头自注意力中任意批次和头数的矩阵乘法。

多输入与优化——推导线性 Transformer

到目前为止,我们主要使用了一个或两个输入的

einsum()

示例。但

einsum()

能够处理更多输入。为了展示多输入处理能力,我们将分析一个实际案例:推导Linear Transformer,这是一种计算复杂度与序列长度呈线性关系而非二次关系的Transformer变体。

我们需要理解为什么标准Transformer的复杂度与序列长度呈O(n²)关系。这主要源于注意力机制。回顾之前的例子,当我们计算查询矩阵和键矩阵的乘积时,得到一个形状为

(batch_size, num_heads, num_tokens, num_tokens)

的矩阵,即一个

num_tokens × num_tokens

的方阵。随着token数量的增加,计算量和存储需求呈二次增长。

这带来两个问题:计算复杂度高且内存需求大。线性Transformer采用核技巧来近似SoftMax函数:

图16:使用线性Transformer中的核技巧近似SoftMax函数。

对于本文而言,具体的

Φ(.)

函数定义不是重点,关键是理解它同时应用于查询和键矩阵。结构上,我们可以执行如下操作:

图17:为线性注意力重构的注意力块图。

这还不是最终结构,但它展示了如何使用

einsum()

在一步中计算包含三个输入的两个连续矩阵乘法:

图18:三矩阵乘法的Einsum表示法。

通过省略号运算符,该操作支持任意数量的头和批次,并且前两个输入还处理了必要的转置。但我们仍面临二次复杂度的问题。

解决方案核心在于利用矩阵乘法的结合律重新排列计算顺序:

图19:线性Transformer中应用矩阵乘法结合律。

仅通过改变矩阵乘法的顺序,计算复杂度从O(n²)降低到O(n)。这是线性Transformer的核心优化。但对于复杂的

einsum()

表达式,如何确定最优执行顺序?

这里可以使用

einsum_path()

函数,它能够为给定的einsum表达式确定最优收缩顺序。该函数支持"greedy"(贪婪)和"optimal"(最优)两种优化策略。贪婪策略计算速度快但可能不是最优,而最优策略计算较慢但保证最佳结果。

首先,定义输入并测量计算时间:

 importnumpyasnp  
B=8   # 批处理大小 (batch size)
H=16  # 头的数量 (number of heads)
D=64  # 特征维度 (feature dimension)
I=J=2048# 标记数量 (number of tokens)

Q_phi=np.random.rand(B, H, D, I) # 查询 (Query)
K_phi=np.random.rand(B, H, D, J) # 键 (Key)
V=np.random.rand(B, H, D, J) # 值 (Value)

%timeitnp.einsum("bhdi, bhdj, bhdj -> bhdi", Q_phi, K_phi, V)  
 # 每个循环 15.6 秒 ± 41.2 毫秒 (平均值 ± 标准差,7 次运行,每次 1 个循环) (15.6 s ± 41.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each))

使用

einsum_path()

寻找最优收缩顺序:

 path_standard=np.einsum_path("bhdi, bhdj, bhdj -> bhdi", Q_phi, K_phi, V, optimize=False)  
path_optimal=np.einsum_path("bhdi, bhdj, bhdj -> bhdi", Q_phi, K_phi, V, optimize="optimal")  

print(path_standard[0])  
# ['einsum_path', (0, 1, 2)]  

print(path_optimal[0])  
#['einsum_path', (1, 2), (0, 1)]  

print(path_optimal[1])  
#   Complete contraction:  bhdi,bhdj,bhdj->bhdi  
#          Naive scaling:  5  
#      Optimized scaling:  4  
#       Naive FLOP count:  1.031e+11  
#   Optimized FLOP count:  5.033e+07  
#    Theoretical speedup:  2048.000  
#   Largest intermediate:  1.678e+07 elements  
# --------------------------------------------------------------------------  
# scaling                  current                                remaining  
# --------------------------------------------------------------------------  
#    4              bhdj,bhdj->bhd                           bhdi,bhd->bhdi  
 #    4              bhd,bhdi->bhdi                               bhdi->bhdi

输出包含丰富信息,主要关注点如下:

  • 收缩顺序:标准收缩顺序(0,1,2)表示先将前两个矩阵相乘,再与第三个矩阵相乘;最优收缩顺序(1,2),(0,1)表示先将第二个和第三个矩阵相乘,再与第一个矩阵相乘。
  • 缩放因子:朴素缩放是einsum字符串中索引数量,优化缩放是最优收缩顺序中的索引数量。缩放值越低,计算越快。
  • FLOP计数:浮点运算次数,优化后的FLOP计数大幅减少。
  • 理论加速比:朴素FLOP计数与优化FLOP计数的比率,本例中为2048倍,恰好等于示例中的token数量,验证了线性Transformer的线性复杂度特性。

最后测量优化后的计算时间:

 %timeitnp.einsum("bhdi, bhdj, bhdj -> bhdi", Q_phi, K_phi, V, optimize=path_optimal[0])  
 # 每个循环 36.6 毫秒 ± 1.23 毫秒 (平均值 ± 标准差,7 次运行,每次 1 个循环) (36.6 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each))

计算时间从15.6秒降至36.6毫秒,加速比达426倍!实际应用中,矩阵规模更大时加速效果会更明显。

线性Transformer的更新结构如下:

图20:通过改变矩阵乘法收缩顺序实现线性Transformer。

需要注意的是,理论加速比在实际应用中并不总能完全实现,因为实际性能还受

einsum

具体实现和硬件配置影响。内存访问模式、缓存使用效率等因素也会影响性能。此外,某些计算模式可能允许使用优化的BLAS子程序,而其他模式则不行。因此,在应用优化前后务必进行性能测试,确认优化效果。

使用 einsum时的注意事项

隐式语法 (Implicit Syntax)

隐式语法是

einsum()

的简写形式,允许在不明确指定输出形状的情况下编写einsum字符串。虽然更简洁,但灵活性较低,且有一些需要注意的特殊情况,特别是索引字母顺序的重要性。

以下示例说明这一点:

 a=np.empty((10,20))  
   
 # 因为索引顺序为ba,形状从(10,20)变为(20,10)
 np.einsum("ba",a).shape  
 # (20, 10)

虽然在简单情况下这种表示方式较为直观,但在处理多输入操作时可能导致混淆:

 b=np.empty((20,30))  

# 矩阵乘法:
np.einsum("bc,cd",a,b).shape  
# (10, 30)  

# 将索引d改为a,会转置结果!
np.einsum("bc,ca",a,b).shape  
 # (30, 10)

除混淆风险外,隐式语法在功能上也存在局限。例如,无法提取矩阵对角线,因为重复索引会自动进行求和:

 c=np.ones((20,20))  
   
 np.einsum("ii",c)  
 # 输出: 20 (即对角线元素之和)

建议:坚持使用显式语法,它更具可读性且不易出错。

数据类型不会自动提升 (Data Types Not Being Promoted)

einsum()

不会自动提升数据类型,这在处理整数类型时尤其需要注意。如果计算结果超出数据类型范围,将导致溢出:

 a=np.ones(200, dtype=np.int8)  
   
 print(np.sum(a))  
 # 输出: 200
   
 print(np.einsum("i->",a))  
 # 输出: -56 (由于int8类型溢出)

由于

a

中元素和超过127(int8的最大值),结果发生溢出,返回负值。相比之下,

np.sum()

会自动将结果提升为更大的整数类型(此例中为np.int64)。

BLAS 子程序的使用 (Usage of BLAS Subroutines)

BLAS(基础线性代数子程序)是一系列高度优化的线性代数运算库。NumPy、PyTorch、TensorFlow等科学计算库通常使用这些库加速矩阵运算。

einsum()

在某些情况下可利用BLAS子程序优化计算,特别是矩阵乘法等标准线性代数运算。然而,并非所有

einsum

操作都能通过BLAS优化。例如,涉及多维广播的操作通常无法利用BLAS优化。

以下示例展示了BLAS优化的效果:

 B=1# 批处理 (batch)
D=512# 特征维度 (feature dimension)
I=J=4096# 标记数量 (number of tokens)

# 非批处理矩阵 (Non-batched matrices)
Q=  np.random.rand(D, I)  
K=  np.random.rand(D, J)  

# 批处理矩阵 (Batched matrices)
Qb=  np.random.rand(B, D, I)  
Kb=  np.random.rand(B, D, J)  

%timeitnp.einsum("di,dj->ij", Q, K, optimize=False) # -> 无BLAS优化
# 每个循环 5.22 秒 ± 56.7 毫秒 (平均值 ± 标准差,7 次运行,每次 1 个循环)

%timeitnp.einsum("di,dj->ij", Q, K, optimize=True) # -> 使用BLAS优化
# 每个循环 125 毫秒 ± 5.3 毫秒 (平均值 ± 标准差,7 次运行,每次 10 个循环)

%timeitnp.einsum("bdi,bdj->bij", Qb, Kb, optimize=True) # -> 批处理模式下无BLAS优化
 # 每个循环 5.28 秒 ± 72.3 毫秒 (平均值 ± 标准差,7 次运行,每次 1 个循环)

对于大型批处理矩阵,可以考虑使用GPU或支持批处理BLAS的库,如PyTorch:

 importtorch  

# PyTorch使用张量
Qbt=torch.tensor(Qb)  
Kbt=torch.tensor(Kb)  

%timeittorch.einsum("bdi,bdj->bij", Qbt, Kbt)  
 # 每个循环 76.3 毫秒 ± 2.15 毫秒 (平均值 ± 标准差,7 次运行,每次 10 个循环)

PyTorch的实现比NumPy中优化的

einsum()

快约两倍,这得益于其不同的实现方式。

不同框架中的不同行为 (Different Behavior in Different Frameworks)

如上一节所示,

einsum()

在不同框架中的行为可能存在差异。PyTorch和TensorFlow都有各自针对框架优化的

einsum()

实现,导致性能和某些行为细节的差异。

使用不同框架时,务必查阅相应文档了解具体实现细节。

另一个值得考虑的选择是einops库,它为张量操作提供了更一致、更用户友好的接口。该库在NumPy、PyTorch和TensorFlow之上构建,为三个框架提供统一接口。它支持与

einsum()

类似的语法,并提供额外功能,如模式匹配、打包和解包、多字母索引名等,还为PyTorch和TensorFlow实现了相应的层。

总结

恭喜您完成了这篇关于

einsum()

的全面指南!🎉 我们从基本语法入手,逐步探讨了高级应用,如线性Transformer的实现。我们还讨论了使用

einsum()

时的注意事项及其解决方案。

希望本文能帮助您更好地理解

einsum()

,并在实际工作中自信地应用这一强大工具。如有任何问题或建议,欢迎在评论区交流。您也可以分享自己的使用案例,促进社区共同学习!

https://avoid.overfit.cn/post/2c4838dc17614a43b57ac27b21747c37


deephub
125 声望111 粉丝