tfrun

An easy-to-use C++ wrapper over the stable C API of TensorFlow
git clone https://0xff.ir/g/tfrun.git
Log | Files | Refs | README | LICENSE

tfrun.hpp (870B)


      1 // Copyright 2019-2021 Mohammad-Reza Nabipoor
      2 // SPDX-License-Identifier: Apache-2.0
      3 
      4 #pragma once
      5 
      6 #include <cstdint>
      7 #include <string>
      8 #include <vector>
      9 
     10 #include <tensorflow/c/c_api.h>
     11 
     12 namespace tfrun {
     13 inline namespace v1 {
     14 
     15 #define TFRUN_API __attribute__((visibility("default")))
     16 
     17 // single input single output
     18 class TFRUN_API siso
     19 {
     20 public:
     21   siso() = default;
     22   siso(const std::string& model_pbfile,
     23        const std::string& input_name,
     24        const std::string& output_name);
     25   ~siso();
     26 
     27   struct output
     28   {
     29     std::vector<float> data;
     30     std::vector<int64_t> dims;
     31   };
     32 
     33   output run(const int64_t* dims, int num_dims, float* data, std::size_t len);
     34 
     35 private:
     36   TF_Session* sess_{ nullptr };
     37   TF_Output in_{ nullptr, 0 };
     38   TF_Output out_{ nullptr, 0 };
     39 
     40   void session_dtor() noexcept;
     41 };
     42 
     43 #undef TFRUN_API
     44 
     45 } // inline namespace v1
     46 } // namespace tfrun