import tensorflow as tf
x = [0.8, 0.4]
y = [0.5, 0.5]
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
compare_results = tf.greater(x, y).eval()
results = [1. if compare_result==True else 0. for compare_result in compare_results]
# 转化为tensorflow张亮
tf_results = tf.cast(results, tf.float32)
print(sess.run(tf_results))
目前需求如下,x是我的模型输出向量,y是一个阈值,对于x向量中的每一个元素,如果大于阈值(此处为0.5)则置为1 否则置0 但是tensorflow好像没有对应的遍历列表的功能,所以我用上述代码实现了,不知道各位大牛有没有原生的优雅解决方法呢