![]() | tfrunAn easy-to-use C++ wrapper over the stable C API of TensorFlow |
git clone https://0xff.ir/g/tfrun.git | |
Log | Files | Refs | README | LICENSE |
tfrun
An easy-to-use C++ wrapper over the stable C API of TensorFlow.
void
tfrun_example(const char* model_path)
{
// `model_path` is the path of a pre-trained MobileNetV2 network
tfrun::siso net{ model_path, "input", "MobilenetV2/Predictions/Reshape_1" };
enum
{
ndims = 4,
nclasses = 1000 + 1 /*background*/,
nbatch = 1,
};
std::array<int64_t, ndims> dims{ nbatch, 224, 224, 3 };
int width;
int height;
int chan;
std::vector<float> img;
img = image(bmpread("dog.bmp", &width, &height, &chan));
assert(img.size() == nbatch * 224 * 224 * 3);
auto out =
net.run(dims.data(), ndims, img.data(), img.size() * sizeof(img[0]));
auto it = std::max_element(out.data.begin(), out.data.end());
auto classIdx = it - out.data.begin();
assert(out.dims == std::vector<int64_t>{
nbatch,
nclasses,
});
assert(classIdx == 1 + 162 /*beagle*/);
}
Please download the pre-trained model and set the environment variable
MODEL_PATH
to the path of downloaded model and run the test/tfrun.test
program.