tfrun

An easy-to-use C++ wrapper over the stable C API of TensorFlow
git clone https://0xff.ir/g/tfrun.git
Log | Files | Refs | README | LICENSE

tfrun.test.cpp (1944B)


      1 // Copyright 2019-2021 Mohammad-Reza Nabipoor
      2 // SPDX-License-Identifier: Apache-2.0
      3 
      4 #include "tfrun.hpp"
      5 
      6 #define CATCH_CONFIG_MAIN
      7 #include "catch2/catch.hpp"
      8 
      9 #include <array>
     10 #include <cassert>
     11 
     12 #include "bmpread.hpp"
     13 
     14 std::vector<float>
     15 image(const std::vector<uint8_t>& in)
     16 {
     17   std::vector<float> out(in.size());
     18 
     19   for (auto i = 0u; i < in.size(); i++)
     20     out[i] = in[i] / 128.0f - 1.0f;
     21   return out;
     22 }
     23 
     24 TEST_CASE("default ctor")
     25 {
     26   tfrun::siso net;
     27 
     28   REQUIRE_THROWS(net.run(nullptr, 0, nullptr, 0));
     29 }
     30 
     31 TEST_CASE("MobileNet v2")
     32 {
     33   const char* mpath = getenv("MODEL_FILE");
     34 
     35   REQUIRE(mpath != nullptr);
     36 
     37   tfrun::siso net{ mpath, "input", "MobilenetV2/Predictions/Reshape_1" };
     38   enum
     39   {
     40     ndims = 4,
     41     nclasses = 1000 + 1 /*background*/,
     42     nbatch = 1,
     43   };
     44   std::array<int64_t, ndims> dims{ nbatch, 224, 224, 3 };
     45   int width;
     46   int height;
     47   int chan;
     48   std::vector<float> img;
     49 
     50   //--- beagle
     51 
     52   img = image(bmpread("dog.bmp", &width, &height, &chan));
     53   assert(img.size() == nbatch * 224 * 224 * 3);
     54 
     55   auto out =
     56     net.run(dims.data(), ndims, img.data(), img.size() * sizeof(img[0]));
     57   auto it = std::max_element(out.data.begin(), out.data.end());
     58   auto classIdx = it - out.data.begin();
     59 
     60   REQUIRE(out.dims == std::vector<int64_t>{
     61                         nbatch,
     62                         nclasses,
     63                       });
     64   REQUIRE(classIdx == 1 + 162 /*beagle*/);
     65 
     66   //--- panda
     67 
     68   img = image(bmpread("panda.bmp", &width, &height, &chan));
     69   assert(img.size() == nbatch * 224 * 224 * 3);
     70 
     71   out = net.run(dims.data(), ndims, img.data(), img.size() * sizeof(img[0]));
     72   it = std::max_element(out.data.begin(), out.data.end());
     73   classIdx = it - out.data.begin();
     74 
     75   REQUIRE(out.dims == std::vector<int64_t>{
     76                         nbatch,
     77                         nclasses,
     78                       });
     79   REQUIRE(classIdx == 1 +
     80     388 /*giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca*/);
     81 }