tfrun

An easy-to-use C++ wrapper over the stable C API of TensorFlow
git clone https://0xff.ir/g/tfrun.git
Log | Files | Refs | README | LICENSE

tfrun.cpp (3868B)


      1 // Copyright 2019-2021 Mohammad-Reza Nabipoor
      2 // SPDX-License-Identifier: Apache-2.0
      3 
      4 #include "tfrun.hpp"
      5 
      6 #include <cstdlib>
      7 #include <cstring>
      8 #include <memory>
      9 
     10 #include "mio/mio.hpp"
     11 #include "scope_guard.hpp"
     12 
     13 using namespace std;
     14 
     15 namespace tfrun {
     16 inline namespace v1 {
     17 
     18 siso::siso(const string& model, const string& in, const string& out)
     19 {
     20   auto status = TF_NewStatus();
     21 
     22   if (status == nullptr)
     23     throw runtime_error{ "TF_NewStatus() failed" };
     24 
     25   SCOPE_EXIT { TF_DeleteStatus(status); };
     26 
     27   auto graph = TF_NewGraph();
     28 
     29   if (graph == nullptr)
     30     throw runtime_error{ "TF_NewGraph() failed" };
     31 
     32   SCOPE_EXIT { TF_DeleteGraph(graph); };
     33 
     34   // Graph
     35   {
     36     // load model
     37     auto buf = [&] {
     38       mio::mmap_source m{ model };
     39       auto f = m.cbegin();
     40       auto l = m.cend();
     41       auto len = distance(f, l);
     42 
     43       return TF_NewBufferFromString(&*f, len);
     44     }();
     45 
     46     if (buf == nullptr)
     47       throw runtime_error{ "TF_NewBuffer() failed" };
     48 
     49     SCOPE_EXIT { TF_DeleteBuffer(buf); };
     50 
     51     auto opts = TF_NewImportGraphDefOptions();
     52 
     53     if (opts == nullptr)
     54       throw runtime_error{ "TF_NewImportGraphDefOptions() failed" };
     55 
     56     SCOPE_EXIT { TF_DeleteImportGraphDefOptions(opts); };
     57 
     58     TF_GraphImportGraphDef(graph, buf, opts, status);
     59 
     60     if (TF_GetCode(status) != TF_OK)
     61       throw runtime_error{
     62         // NOTE because of `+` operator, the message of `status` will be copied
     63         "TF_GraphImportGraphDef() failed: "s + string{ TF_Message(status) }
     64       };
     65   }
     66 
     67   // Session
     68   {
     69     TF_SessionOptions* opts = TF_NewSessionOptions();
     70 
     71     SCOPE_EXIT { TF_DeleteSessionOptions(opts); };
     72 
     73     sess_ = TF_NewSession(graph, opts, status);
     74     if (sess_ == nullptr)
     75       throw runtime_error{
     76         // NOTE because of `+` operator, the message of `status` will be copied
     77         "TF_NewSession() failed: "s + string{ TF_Message(status) }
     78       };
     79   }
     80 
     81   SCOPE_FAIL { session_dtor(); };
     82 
     83   // Input/Output
     84   {
     85     in_ = TF_Output{ TF_GraphOperationByName(graph, in.c_str()), 0 };
     86 
     87     if (in_.oper == nullptr)
     88       throw runtime_error{ "Input operation '"s + in + "' not found"s };
     89 
     90     out_ = TF_Output{ TF_GraphOperationByName(graph, out.c_str()), 0 };
     91 
     92     if (out_.oper == nullptr)
     93       throw runtime_error{ "Output operation '"s + out + "' not found"s };
     94   }
     95 }
     96 
     97 siso::output
     98 siso::run(const int64_t* dims, int num_dims, float* data, std::size_t len)
     99 {
    100   if (sess_ == nullptr)
    101     throw logic_error{ "invalid use" };
    102 
    103   auto status = TF_NewStatus();
    104   TF_Tensor* tout{ nullptr };
    105   auto tin = TF_NewTensor(
    106     TF_FLOAT, dims, num_dims, data, len, [](void*, size_t, void*) {}, nullptr);
    107 
    108   SCOPE_EXIT { TF_DeleteTensor(tin); };
    109 
    110   if (tin == nullptr)
    111     throw runtime_error{ "TF_NewTensor() failed" };
    112 
    113   TF_SessionRun(sess_,
    114                 nullptr,
    115                 &in_,
    116                 &tin,
    117                 1,
    118                 &out_,
    119                 &tout,
    120                 1,
    121                 nullptr,
    122                 0,
    123                 nullptr,
    124                 status);
    125 
    126   if (TF_GetCode(status) != TF_OK)
    127     throw runtime_error{ "TF_Session() failed: "s +
    128                          string{ TF_Message(status) } };
    129 
    130   SCOPE_EXIT { TF_DeleteTensor(tout); };
    131 
    132   auto d{ static_cast<float*>(TF_TensorData(tout)) };
    133   output o{
    134     vector<float>{ d, d + TF_TensorByteSize(tout) / sizeof(float) },
    135     vector<int64_t>(TF_NumDims(tout)),
    136   };
    137 
    138   {
    139     const auto sz = o.dims.size();
    140 
    141     for (auto i = 0u; i < sz; i++)
    142       o.dims[i] = TF_Dim(tout, i);
    143   }
    144 
    145   return o;
    146 }
    147 
    148 void
    149 siso::session_dtor() noexcept
    150 {
    151   if (sess_ == nullptr)
    152     return;
    153 
    154   auto status = TF_NewStatus();
    155 
    156   // assert(status != nullptr);
    157 
    158   SCOPE_EXIT { TF_DeleteStatus(status); };
    159 
    160   TF_CloseSession(sess_, status);
    161   TF_DeleteSession(sess_, status);
    162 
    163   sess_ = nullptr;
    164 }
    165 
    166 siso::~siso()
    167 {
    168   session_dtor();
    169 }
    170 
    171 } // inline namespace v1
    172 } // namespace siso