tfrun.cpp (3868B)
1 // Copyright 2019-2021 Mohammad-Reza Nabipoor 2 // SPDX-License-Identifier: Apache-2.0 3 4 #include "tfrun.hpp" 5 6 #include <cstdlib> 7 #include <cstring> 8 #include <memory> 9 10 #include "mio/mio.hpp" 11 #include "scope_guard.hpp" 12 13 using namespace std; 14 15 namespace tfrun { 16 inline namespace v1 { 17 18 siso::siso(const string& model, const string& in, const string& out) 19 { 20 auto status = TF_NewStatus(); 21 22 if (status == nullptr) 23 throw runtime_error{ "TF_NewStatus() failed" }; 24 25 SCOPE_EXIT { TF_DeleteStatus(status); }; 26 27 auto graph = TF_NewGraph(); 28 29 if (graph == nullptr) 30 throw runtime_error{ "TF_NewGraph() failed" }; 31 32 SCOPE_EXIT { TF_DeleteGraph(graph); }; 33 34 // Graph 35 { 36 // load model 37 auto buf = [&] { 38 mio::mmap_source m{ model }; 39 auto f = m.cbegin(); 40 auto l = m.cend(); 41 auto len = distance(f, l); 42 43 return TF_NewBufferFromString(&*f, len); 44 }(); 45 46 if (buf == nullptr) 47 throw runtime_error{ "TF_NewBuffer() failed" }; 48 49 SCOPE_EXIT { TF_DeleteBuffer(buf); }; 50 51 auto opts = TF_NewImportGraphDefOptions(); 52 53 if (opts == nullptr) 54 throw runtime_error{ "TF_NewImportGraphDefOptions() failed" }; 55 56 SCOPE_EXIT { TF_DeleteImportGraphDefOptions(opts); }; 57 58 TF_GraphImportGraphDef(graph, buf, opts, status); 59 60 if (TF_GetCode(status) != TF_OK) 61 throw runtime_error{ 62 // NOTE because of `+` operator, the message of `status` will be copied 63 "TF_GraphImportGraphDef() failed: "s + string{ TF_Message(status) } 64 }; 65 } 66 67 // Session 68 { 69 TF_SessionOptions* opts = TF_NewSessionOptions(); 70 71 SCOPE_EXIT { TF_DeleteSessionOptions(opts); }; 72 73 sess_ = TF_NewSession(graph, opts, status); 74 if (sess_ == nullptr) 75 throw runtime_error{ 76 // NOTE because of `+` operator, the message of `status` will be copied 77 "TF_NewSession() failed: "s + string{ TF_Message(status) } 78 }; 79 } 80 81 SCOPE_FAIL { session_dtor(); }; 82 83 // Input/Output 84 { 85 in_ = TF_Output{ TF_GraphOperationByName(graph, in.c_str()), 0 }; 86 87 if (in_.oper == nullptr) 88 throw runtime_error{ "Input operation '"s + in + "' not found"s }; 89 90 out_ = TF_Output{ TF_GraphOperationByName(graph, out.c_str()), 0 }; 91 92 if (out_.oper == nullptr) 93 throw runtime_error{ "Output operation '"s + out + "' not found"s }; 94 } 95 } 96 97 siso::output 98 siso::run(const int64_t* dims, int num_dims, float* data, std::size_t len) 99 { 100 if (sess_ == nullptr) 101 throw logic_error{ "invalid use" }; 102 103 auto status = TF_NewStatus(); 104 TF_Tensor* tout{ nullptr }; 105 auto tin = TF_NewTensor( 106 TF_FLOAT, dims, num_dims, data, len, [](void*, size_t, void*) {}, nullptr); 107 108 SCOPE_EXIT { TF_DeleteTensor(tin); }; 109 110 if (tin == nullptr) 111 throw runtime_error{ "TF_NewTensor() failed" }; 112 113 TF_SessionRun(sess_, 114 nullptr, 115 &in_, 116 &tin, 117 1, 118 &out_, 119 &tout, 120 1, 121 nullptr, 122 0, 123 nullptr, 124 status); 125 126 if (TF_GetCode(status) != TF_OK) 127 throw runtime_error{ "TF_Session() failed: "s + 128 string{ TF_Message(status) } }; 129 130 SCOPE_EXIT { TF_DeleteTensor(tout); }; 131 132 auto d{ static_cast<float*>(TF_TensorData(tout)) }; 133 output o{ 134 vector<float>{ d, d + TF_TensorByteSize(tout) / sizeof(float) }, 135 vector<int64_t>(TF_NumDims(tout)), 136 }; 137 138 { 139 const auto sz = o.dims.size(); 140 141 for (auto i = 0u; i < sz; i++) 142 o.dims[i] = TF_Dim(tout, i); 143 } 144 145 return o; 146 } 147 148 void 149 siso::session_dtor() noexcept 150 { 151 if (sess_ == nullptr) 152 return; 153 154 auto status = TF_NewStatus(); 155 156 // assert(status != nullptr); 157 158 SCOPE_EXIT { TF_DeleteStatus(status); }; 159 160 TF_CloseSession(sess_, status); 161 TF_DeleteSession(sess_, status); 162 163 sess_ = nullptr; 164 } 165 166 siso::~siso() 167 { 168 session_dtor(); 169 } 170 171 } // inline namespace v1 172 } // namespace siso