请问为什么tf.py_function()中自定义的函数未被调用?

新手上路,请多包涵

TF2.0中,在最大池化层的源码当中使用了tf.py_function(),自定义了一个函数,想将Tensor转为Numpy矩阵从而进行一些操作,然而程序运行时,却将自定义的函数忽略掉了,自定义函数中的print()函数未作输出,返回值为<unknow>,如果py_function()中的参数写错了,系统还是会报错的,说明py_function()这个函数运行了,但是自定义的函数未运行,这是为什么呢,是缺少修饰器么。

def Myshow_all(self,inputs):
    def showTensor(inputs):
        a=inputs.numpy()
        print(a)
        return a
    y=tf.py_function(showTensor,[inputs],tf.float32)
    print(y.shape)
    return y
def _pooling_function(self,inputs,pool_size,strides,padding,data_format):
    output=K.pool2d(inputs,pool_size,strides,padding,data_format,pool_mode='max')
    a=self.Myshow_all(inputs)
    print(a)
    print(a.shape)
    return output
阅读 2.7k
1 个回答

感谢回答问题,问题已经解决。解决方法为在_pooling_function函数上加上@tf.funtion修饰器就可以用了,本人也是自己瞎尝试蒙出来的,个人猜测在静态图的情况下tf.py_function这个函数只能在,被@tf.function修饰器修饰的函数中调用,在Myshow_all上面加入@tf.function在这里是肯定不行的,因为showTensor是Myshow_all的内函数。如果自己搞学术研究的话还是用动态图吧

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进