如何使用多笔transpose代替单笔transpose?

新手上路,请多包涵

请教下各位大佬,如何使用多笔transpose来代替单笔transpose?
有这个需求是因为dma本身的硬件限制。

比如说针对下面这个case
一个5d的tensor, shape = [4, 8, 16, 32, 64], permutation = [4, 3, 0, 2, 1].
通过numpy 可以一次转置成 [64, 32, 4, 16, 8], 也可以分成三次转置来变成 [64, 32, 4, 16, 8]。

这两种最终得到的结果是等价的,但是后者通过三次转置,每次转置的维度不超过三维,更容易被硬件接受。
所以有什么办法可以找到 针对任意permutation 来拆分成多次transpose的 通用解决办法呢?
sample code:

import numpy as np

in_data = np.arange(0, 4 * 8 * 16 * 32 * 64).reshape(4, 8, 16, 32, 64)

perm = [4, 3, 0, 2, 1]
gold = np.transpose(in_data, perm)

test = in_data
test = np.transpose(test, [3, 4, 0, 1, 2])
test = np.transpose(test, [0, 1, 2, 4, 3])
test = np.transpose(test, [1, 0, 2, 3, 4])

print(gold.shape)
print(test.shape)
np.testing.assert_array_equal(test, gold)

阅读 2k
1 个回答

重写写了一个你再试试:

import numpy as np

def apply_transpose(seq, transpose):
    return seq[:transpose[0]] + seq[transpose[0]:transpose[1]][::-1] + seq[transpose[1]:]

def search_transpose_seq_rec(target_perm, cur_perm, cur_seq, max_transpose_dim):
    if cur_perm == target_perm:
        return cur_seq

    n = len(target_perm)
    for i in range(n):
        for j in range(i + 1, min(i + max_transpose_dim + 1, n + 1)):
            new_perm = apply_transpose(cur_perm, (i, j))
            new_seq = cur_seq + [(i, j)]
            result = search_transpose_seq_rec(target_perm, new_perm, new_seq, max_transpose_dim)
            if result:
                return result
    return None

def search_transpose_seq(target_perm, max_transpose_dim=3):
    initial_perm = list(range(len(target_perm)))
    return search_transpose_seq_rec(target_perm, initial_perm, [], max_transpose_dim)

src_shape = [4, 8, 16, 32, 64]
permutation = [4, 3, 0, 2, 1]

transpose_seq = search_transpose_seq(permutation)
print(transpose_seq)  # 输出: [(0, 2), (1, 5), (2, 4)]

in_data = np.arange(0, 4 * 8 * 16 * 32 * 64).reshape(4, 8, 16, 32, 64)

gold = np.transpose(in_data, permutation)

test = in_data
for t in transpose_seq:
    test = np.transpose(test, list(range(t[0])) + list(range(t[1], t[0] - 1, -1)) + list(range(t[1], len(permutation))))

print(gold.shape)  # 输出: (64, 32, 4, 16, 8)
print(test.shape)  # 输出: (64, 32, 4, 16, 8)
np.testing.assert_array_equal(test, gold)
撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题