tfrun.hpp (870B)
1 // Copyright 2019-2021 Mohammad-Reza Nabipoor 2 // SPDX-License-Identifier: Apache-2.0 3 4 #pragma once 5 6 #include <cstdint> 7 #include <string> 8 #include <vector> 9 10 #include <tensorflow/c/c_api.h> 11 12 namespace tfrun { 13 inline namespace v1 { 14 15 #define TFRUN_API __attribute__((visibility("default"))) 16 17 // single input single output 18 class TFRUN_API siso 19 { 20 public: 21 siso() = default; 22 siso(const std::string& model_pbfile, 23 const std::string& input_name, 24 const std::string& output_name); 25 ~siso(); 26 27 struct output 28 { 29 std::vector<float> data; 30 std::vector<int64_t> dims; 31 }; 32 33 output run(const int64_t* dims, int num_dims, float* data, std::size_t len); 34 35 private: 36 TF_Session* sess_{ nullptr }; 37 TF_Output in_{ nullptr, 0 }; 38 TF_Output out_{ nullptr, 0 }; 39 40 void session_dtor() noexcept; 41 }; 42 43 #undef TFRUN_API 44 45 } // inline namespace v1 46 } // namespace tfrun