如何在 Tensorflow-2.0 中绘制 tf.keras 模型?

新手上路,请多包涵

我升级到 Tensorflow 2.0,没有 tf.summary.FileWriter("tf_graphs", sess.graph) 。我正在查看关于此的其他一些 StackOverflow 问题,他们说使用 tf.compat.v1.summary etc 。在 Tensorflow 版本 2 中肯定有一种方法可以绘制和可视化 tf.keras 模型。它是什么?我正在寻找如下所示的张量板输出。谢谢!

在此处输入图像描述

原文由 Colin Steidtmann 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 551
2 个回答

根据 文档,您可以在训练模型后使用 Tensorboard 可视化图形。

首先,定义您的模型并运行它。然后,打开 Tensorboard 并切换到 Graph 选项卡。


最小可编译示例

这个例子取自文档。首先,定义您的模型和数据。

 # Relevant imports.
%load_ext tensorboard

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
from packaging import version

import tensorflow as tf
from tensorflow import keras

# Define the model.
model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])

(train_images, train_labels), _ = keras.datasets.fashion_mnist.load_data()
train_images = train_images / 255.0

接下来,训练您的模型。在这里,您需要为 Tensorboard 定义回调以用于可视化统计数据和图表。

 # Define the Keras TensorBoard callback.
logdir="logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

# Train the model.
model.fit(
    train_images,
    train_labels,
    batch_size=64,
    epochs=5,
    callbacks=[tensorboard_callback])

训练结束后,在你的笔记本中运行

%tensorboard --logdir logs

并切换到导航栏中的图表选项卡:

在此处输入图像描述

你会看到一个看起来很像这样的图表:

在此处输入图像描述

原文由 cs95 发布,翻译遵循 CC BY-SA 4.0 许可协议

您可以可视化任何 tf.function 修饰函数的图形,但首先,您必须跟踪其执行。

可视化 Keras 模型的图形意味着可视化它的 call 方法。

默认情况下,此方法未 tf.function 修饰,因此您必须将模型调用包装在正确修饰的函数中并执行它。

 import tensorflow as tf

model = tf.keras.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(32, activation="relu"),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation="softmax"),
    ]
)

@tf.function
def traceme(x):
    return model(x)

logdir = "log"
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)
# Forward pass
traceme(tf.zeros((1, 28, 28, 1)))
with writer.as_default():
    tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=logdir)

原文由 nessuno 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进