本文首发于个人博客https://kezunlin.me/post/5898412/,欢迎阅读最新内容!
load model from file and stream for caffe and pytorch
<!--more-->
Guide
caffe
load from file
enum caffe::Phase phase = caffe::Phase::TEST;
std::string proto_filepath = "yolov3.prototxt";
std::string weight_filepath = "yolov3.caffemodel";
caffe::Net<float> net = caffe::Net<float>(proto_filepath, phase));
net.CopyTrainedLayersFrom(weight_filepath);
load from stream
no caffe method to load directly from stream.
we can overrideReadProtoFromTextFile
andReadProtoFromBinaryFile
insrc/caffe/util/io.cpp
to implement this fuction.
Replace
bool ReadProtoFromTextFile(const char* filename, Message* proto) {
Encryption encryption;
int fd = open(filename, O_RDONLY);
CHECK_NE(fd, -1) << "File not found: " << filename;
FileInputStream* input = new FileInputStream(fd);
bool success = google::protobuf::TextFormat::Parse(input, proto);
delete input;
close(fd);
return success;
}
bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
int fd = open(filename, O_RDONLY);
CHECK_NE(fd, -1) << "File not found: " << filename;
ZeroCopyInputStream* raw_input = new FileInputStream(fd);
CodedInputStream* coded_input = new CodedInputStream(raw_input);
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
bool success = proto->ParseFromCodedStream(coded_input);
delete coded_input;
delete raw_input;
close(fd);
return success;
}
load fromdemo.prototxt
anddemo.caffemodel
with
bool ReadProtoFromTextFile(const char *filename, Message *proto) {
Encryption encryption;
string res = encryption.decryptTextFile(filename); // demo.prototxt
istringstream ss(res);
IstreamInputStream *input = new IstreamInputStream(&ss);
bool success = google::protobuf::TextFormat::Parse(input, proto);
delete input;
return success;
}
bool ReadProtoFromBinaryFile(const char *filename, Message *proto) {
Encryption encryption;
string res = encryption.decryptModelFile(filename); // demo.caffemodel
istringstream ss(res);
IstreamInputStream *input = new IstreamInputStream(&ss);
CodedInputStream *coded_input = new CodedInputStream(input);
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
bool success = proto->ParseFromCodedStream(coded_input);
delete coded_input;
delete input;
return success;
}
load fromdemo_encrypt.prototxt
anddemo_encrypt.caffemodel
pytorch
torch::jit::script::Module load(const std::string& filename,...);
torch::jit::script::Module load(const std::istream& in,...);
load from file
std::string model_path = "model.libpt";
torch::jit::script::Module net = torch::jit::load(model_path);
assert(net != nullptr);
load from stream
std::string model_content = ""; // read from file
std::istringstream ss(model_content);
torch::jit::script::Module net = torch::jit::load(ss);
assert(net != nullptr);
Reference
History
- 20191014: created.
Copyright
- Post author: kezunlin
- Post link: https://kezunlin.me/post/5898412/
- Copyright Notice: All articles in this blog are licensed under CC BY-NC-SA 3.0 unless stating additionally.
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。