tfrun.test.cpp (1944B)
1 // Copyright 2019-2021 Mohammad-Reza Nabipoor 2 // SPDX-License-Identifier: Apache-2.0 3 4 #include "tfrun.hpp" 5 6 #define CATCH_CONFIG_MAIN 7 #include "catch2/catch.hpp" 8 9 #include <array> 10 #include <cassert> 11 12 #include "bmpread.hpp" 13 14 std::vector<float> 15 image(const std::vector<uint8_t>& in) 16 { 17 std::vector<float> out(in.size()); 18 19 for (auto i = 0u; i < in.size(); i++) 20 out[i] = in[i] / 128.0f - 1.0f; 21 return out; 22 } 23 24 TEST_CASE("default ctor") 25 { 26 tfrun::siso net; 27 28 REQUIRE_THROWS(net.run(nullptr, 0, nullptr, 0)); 29 } 30 31 TEST_CASE("MobileNet v2") 32 { 33 const char* mpath = getenv("MODEL_FILE"); 34 35 REQUIRE(mpath != nullptr); 36 37 tfrun::siso net{ mpath, "input", "MobilenetV2/Predictions/Reshape_1" }; 38 enum 39 { 40 ndims = 4, 41 nclasses = 1000 + 1 /*background*/, 42 nbatch = 1, 43 }; 44 std::array<int64_t, ndims> dims{ nbatch, 224, 224, 3 }; 45 int width; 46 int height; 47 int chan; 48 std::vector<float> img; 49 50 //--- beagle 51 52 img = image(bmpread("dog.bmp", &width, &height, &chan)); 53 assert(img.size() == nbatch * 224 * 224 * 3); 54 55 auto out = 56 net.run(dims.data(), ndims, img.data(), img.size() * sizeof(img[0])); 57 auto it = std::max_element(out.data.begin(), out.data.end()); 58 auto classIdx = it - out.data.begin(); 59 60 REQUIRE(out.dims == std::vector<int64_t>{ 61 nbatch, 62 nclasses, 63 }); 64 REQUIRE(classIdx == 1 + 162 /*beagle*/); 65 66 //--- panda 67 68 img = image(bmpread("panda.bmp", &width, &height, &chan)); 69 assert(img.size() == nbatch * 224 * 224 * 3); 70 71 out = net.run(dims.data(), ndims, img.data(), img.size() * sizeof(img[0])); 72 it = std::max_element(out.data.begin(), out.data.end()); 73 classIdx = it - out.data.begin(); 74 75 REQUIRE(out.dims == std::vector<int64_t>{ 76 nbatch, 77 nclasses, 78 }); 79 REQUIRE(classIdx == 1 + 80 388 /*giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca*/); 81 }