Varibales

tensorflow的variable可以通过2种方式获得:

w1 = tf.Variable(init_value, name = "w1", dtype = tf.float32)
w2 = tf.get_variable("w", initializer = initializer, shape = (2,3), dtype = tf.float32)

这两种方式的区别是,w1是声明变量,如果有重名的,则tf会自动给你编号。
w2则本质是从一个全局dict里取value的过程,它以name为key,去全局dict中查找指定的变量,如果有,则复用;没有则通过initializer创建。

需要注意的是,复用变量需要通过如下方式显式指定:

with tf.variable_scope("vs") as scope:
    w1 = tf.get_variable("w", shape = (2,3))
    scope.reuse_variables() # set reuse mode
    w2 = tf.get_variable("w", shape = (2,3))

另外,2种方式在使用name_scope和variable_scope时也有轻微不同:
tf.Variable总是会根据scope增加前缀。但是tf.get_variable只会在variable_scope的时候增加前缀


flybywind
1.1k 声望38 粉丝

引用和评论

0 条评论