如何将tensorflow中通过py_func自定义的操作用到keras中?

通过keras自定义层的时候,如果使用了tensorflow的py_func函数自定义的操作后,会导致在创建模型时提示AttributeError: 'NoneType' object has no attribute '_inbound_nodes'

测试代码是:

import tensorflow as tf
from keras import layers
from keras.layers import Layer
from keras.models import Model


class T(Layer):
    def __init__(self, **kwargs):
        super(T, self).__init__(**kwargs)

    def call(self, inputs, **kwargs):
        return tf.zeros(shape=(1, 1))

    def compute_output_shape(self, input_shape):
        return 1, 1


def direct_return(tensor1, tensor2):
    return tensor1, tensor2


def main():
    input1 = layers.Input(shape=(2, 2))
    input2 = layers.Input(shape=(2, 2))

    ret1, ret2 = tf.py_func(direct_return, [input1, input2], [tf.float32, tf.float32])
    ret1.set_shape((2, 2))
    ret2.set_shape((2, 2))

    t1 = T()(ret1)
    t2 = T()(ret2)

    model = Model(inputs=[input1, input2], outputs=[t1, t2])


main()

报错为:

Using TensorFlow backend.
Traceback (most recent call last):
  File "XXX/test.py", line 36, in <module>
    main()
  File "XXX/test.py", line 33, in main
    model = Model(inputs=[input1, input2], outputs=[t1, t2])
  File "……\conda\envs\tensorflow\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 93, in __init__
    self._init_graph_network(*args, **kwargs)
  File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 237, in _init_graph_network
    self.inputs, self.outputs)
  File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 1353, in _map_graph_network
    tensor_index=tensor_index)
  File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 1340, in build_map
    node_index, tensor_index)
  File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 1312, in build_map
    node = layer._inbound_nodes[node_index]
AttributeError: 'NoneType' object has no attribute '_inbound_nodes'

如果将t1t2的两行修改为:

    t1 = T()(input1)
    t2 = T()(input2)

则一切正常。

那么请问应该如何使keras自定义的层包含自定义的操作呢?

阅读 6k
1 个回答
新手上路,请多包涵

同样遇到这个问题,想在keras调用tf的py_func 如果解决了,能加我q775301251 交流下吗?谢谢您!

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题