本文首发于个人博客https://kezunlin.me/post/5898412/,欢迎阅读最新内容!

load model from file and stream for caffe and pytorch
<!--more-->

Guide

caffe

load from file

enum caffe::Phase phase = caffe::Phase::TEST;

std::string proto_filepath = "yolov3.prototxt";
std::string weight_filepath = "yolov3.caffemodel";
caffe::Net<float> net = caffe::Net<float>(proto_filepath, phase));
net.CopyTrainedLayersFrom(weight_filepath);

load from stream

no caffe method to load directly from stream.
we can override ReadProtoFromTextFile and ReadProtoFromBinaryFile in src/caffe/util/io.cpp to implement this fuction.

Replace

bool ReadProtoFromTextFile(const char* filename, Message* proto) {
    Encryption encryption;
  int fd = open(filename, O_RDONLY);
  CHECK_NE(fd, -1) << "File not found: " << filename;
  FileInputStream* input = new FileInputStream(fd);
  bool success = google::protobuf::TextFormat::Parse(input, proto);
  delete input;
  close(fd);
  return success;
}

bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
  int fd = open(filename, O_RDONLY);
  CHECK_NE(fd, -1) << "File not found: " << filename;
  ZeroCopyInputStream* raw_input = new FileInputStream(fd);
  CodedInputStream* coded_input = new CodedInputStream(raw_input);
  coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);

  bool success = proto->ParseFromCodedStream(coded_input);

  delete coded_input;
  delete raw_input;
  close(fd);
  return success;
}
load from demo.prototxt and demo.caffemodel

with

bool ReadProtoFromTextFile(const char *filename, Message *proto) {
    Encryption encryption;
    string res = encryption.decryptTextFile(filename); // demo.prototxt
    istringstream ss(res);

    IstreamInputStream *input = new IstreamInputStream(&ss);

    bool success = google::protobuf::TextFormat::Parse(input, proto);
    delete input;
    return success;
}


bool ReadProtoFromBinaryFile(const char *filename, Message *proto) {
    Encryption encryption;
    string res = encryption.decryptModelFile(filename); // demo.caffemodel
    istringstream ss(res);
    
    IstreamInputStream *input = new IstreamInputStream(&ss);
    CodedInputStream *coded_input = new CodedInputStream(input);
    coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);

    bool success = proto->ParseFromCodedStream(coded_input);

    delete coded_input;
    delete input;
    return success;
}
load from demo_encrypt.prototxt and demo_encrypt.caffemodel

pytorch

  • torch::jit::script::Module load(const std::string& filename,...);
  • torch::jit::script::Module load(const std::istream& in,...);

load from file

std::string model_path = "model.libpt";
torch::jit::script::Module net = torch::jit::load(model_path);
assert(net != nullptr);

load from stream

std::string model_content = ""; // read from file
std::istringstream ss(model_content);
torch::jit::script::Module net = torch::jit::load(ss);
assert(net != nullptr);

Reference

History

  • 20191014: created.

Copyright


kezunlin
7 声望3 粉丝

C++,Python