問題描述
要將經過訓練的網絡導入 C++,您需要導出網絡才能執行此操作.在搜索了很多并且幾乎沒有找到關于它的信息之后,澄清我們應該使用 freeze_graph() 能夠做到這一點.
For importing your trained network to the C++ you need to export your network to be able to do so. After searching a lot and finding almost no information about it, it was clarified that we should use freeze_graph() to be able to do it.
感謝 Tensorflow 的新 0.7 版本,他們添加了 文檔.
Thanks to the new 0.7 version of Tensorflow, they added documentation of it.
查看文檔后發現類似的方法很少,你能說一下freeze_graph()
和:tf.train.export_meta_graph
因為它有類似的參數,但它似乎也可以用于將模型導入 C++(我只是猜測不同之處在于,對于使用這種方法輸出的文件,您只能使用 import_graph_def()
還是別的什么?)
After looking into documentations, I found that there are few similar methods, can you tell what is the difference between freeze_graph()
and:
tf.train.export_meta_graph
as it has similar parameters, but it seems it can also be used for importing models to C++ (I just guess the difference is that for using the file output by this method you can only use import_graph_def()
or it's something else?)
還有一個關于如何使用 write_graph()
的問題:在文檔中,graph_def
由 sess.graph_def
給出,但在 freeze_graph()
的例子中,它是 sess.graph.as_graph_def()代碼>.這兩者有什么區別?
Also one question about how to use write_graph()
:
In documentations the graph_def
is given by sess.graph_def
but in examples in freeze_graph()
it is sess.graph.as_graph_def()
. What is the difference between these two?
這個問題與這個問題有關.
謝謝!
推薦答案
這是我利用 TF 0.12 中引入的 V2 檢查點的解決方案.
Here's my solution utilizing the V2 checkpoints introduced in TF 0.12.
無需將所有變量轉換為常量或凍結圖表.
There's no need to convert all variables to constants or freeze the graph.
為了清楚起見,我的目錄 models
中的 V2 檢查點如下所示:
Just for clarity, a V2 checkpoint looks like this in my directory models
:
checkpoint # some information on the name of the files in the checkpoint
my-model.data-00000-of-00001 # the saved weights
my-model.index # probably definition of data layout in the previous file
my-model.meta # protobuf of the graph (nodes and topology info)
Python 部分(保存)
with tf.Session() as sess:
tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model')
如果您使用 tf.trainable_variables()
創建 Saver
,您可以節省一些頭痛和存儲空間.但也許一些更復雜的模型需要保存所有數據,然后將此參數刪除到 Saver
,只需確保您正在創建 Saver
after> 您的圖表已創建.給所有變量/層賦予唯一的名稱也是非常明智的,否則你可能會遇到不同的問題.
If you create the Saver
with tf.trainable_variables()
, you can save yourself some headache and storage space. But maybe some more complicated models need all data to be saved, then remove this argument to Saver
, just make sure you're creating the Saver
after your graph is created. It is also very wise to give all variables/layers unique names, otherwise you can run in different problems.
Python 部分(推理)
with tf.Session() as sess:
saver = tf.train.import_meta_graph('models/my-model.meta')
saver.restore(sess, tf.train.latest_checkpoint('models/'))
outputTensors = sess.run(outputOps, feed_dict=feedDict)
C++ 部分(推理)
請注意,checkpointPath
不是任何現有文件的路徑,只是它們的公共前綴.如果您錯誤地放置了 .index
文件的路徑,TF 不會告訴您這是錯誤的,但是由于未初始化的變量,它會在推理過程中死亡.
Note that checkpointPath
isn't a path to any of the existing files, just their common prefix. If you mistakenly put there path to the .index
file, TF won't tell you that was wrong, but it will die during inference due to uninitialized variables.
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
using namespace std;
using namespace tensorflow;
...
// set up your input paths
const string pathToGraph = "models/my-model.meta"
const string checkpointPath = "models/my-model";
...
auto session = NewSession(SessionOptions());
if (session == nullptr) {
throw runtime_error("Could not create Tensorflow session.");
}
Status status;
// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok()) {
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}
// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok()) {
throw runtime_error("Error creating graph: " + status.ToString());
}
// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run(
{{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
{},
{graph_def.saver_def().restore_op_name()},
nullptr);
if (!status.ok()) {
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}
// and run the inference to your liking
auto feedDict = ...
auto outputOps = ...
std::vector<tensorflow::Tensor> outputTensors;
status = session->Run(feedDict, outputOps, {}, &outputTensors);
這篇關于Tensorflow 在 C++ 中導出和運行圖的不同方式的文章就介紹到這了,希望我們推薦的答案對大家有所幫助,也希望大家多多支持html5模板網!