// Copyright 2017 Ninja Theory Ltd. All Rights Reserved.

#ifndef NTTENSORFLOW_NTTENSORFLOW_H_
#define NTTENSORFLOW_NTTENSORFLOW_H_

#ifdef _WIN32
#  ifdef NTTENSORFLOW_API_EXPORTS
#    define NTTENSORFLOW_EXPORT __declspec(dllexport)
#  else
#    define NTTENSORFLOW_EXPORT __declspec(dllimport)
#  endif
#else
#  define NTTENSORFLOW_EXPORT
#endif

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <iterator>
#include <numeric>
#include <string>
#include <utility>
#include <vector>

namespace NTTensorFlow
{

namespace internal
{
// Private implementations
struct TfModelData;
struct TfRunnerData;

/**
 * Support for complex data types.
 */
template<typename T>
struct ComplexTypeSupport
{
    typedef T complex_type;
    typedef T primitive_type;
    static constexpr bool needs_conversion = false;
    static inline T ToPrimitive(const T& value)
    {
        return value;
    }
    static inline T ToComplex(const T& value)
    {
        return value;
    }
};

template<>
struct ComplexTypeSupport<std::string>
{
    typedef std::string complex_type;
    typedef const std::string::value_type* primitive_type;
    static constexpr bool needs_conversion = true;
    static inline const std::string::value_type* ToPrimitive(const std::string& value)
    {
        return value.data();
    }
    static inline std::string ToComplex(const std::string::value_type* const& value)
    {
        return std::string(value);
    }
};
}

/**
 * A TensorFlow model.
 *
 * @see https://www.tensorflow.org
 */
class TfModel
{
private:

    /**
     * Private constructor.
     *
     * @param thisData TfRunner private data.
     */
    TfModel(internal::TfModelData* thisData);

public:

    // Noncopyable

    NTTENSORFLOW_EXPORT TfModel(const TfModel&) = delete;
    NTTENSORFLOW_EXPORT TfModel& operator=(const TfModel&) = delete;

    /** Move constructor. */
    NTTENSORFLOW_EXPORT TfModel(TfModel&& thisData);

    /** Move assignment. */
    NTTENSORFLOW_EXPORT TfModel& operator=(TfModel&& thisData);

    /** Destructor. */
    NTTENSORFLOW_EXPORT virtual ~TfModel();

    /**
     * Load a TensorFlow model from a file.
     *
     * @param filePath  Path of a TensorFlow GraphDef binary Protocol Buffers file.
     */
    static inline TfModel* Load(const std::string& filePath)
    {
        return Load(filePath.data());
    }

    /**
     * Load a TensorFlow model from a file.
     *
     * @param filePath  Path of a TensorFlow GraphDef binary Protocol Buffers file.
     */
    NTTENSORFLOW_EXPORT static TfModel* Load(const char* filePath);

    /**
     * Read a TensorFlow model from an external buffer.
     *
     * @param data  Data of a TensorFlow GraphDef binary Protocol Buffers serialization.
     * @param size  Size of the data.
     */
    NTTENSORFLOW_EXPORT static TfModel* Read(const void* data, int size);

    /**
     * Write the model to an external buffer.
     *
     * @param data  Buffer where the model is written.
     * @param size  Size of the given buffer.
     * @return True if the model could be correctly written, false otherwise.
     */
    NTTENSORFLOW_EXPORT bool Write(void* data, int size) const;

    /**
     * @return The byte size of the model.
     */
    NTTENSORFLOW_EXPORT int ByteSize() const;

private:

    /** Private data. */
    internal::TfModelData* m_this;

    template<typename InputTypeT, typename StateTypeT>
    friend class TfRunner;
};

/**
 * Runner for TensorFlow models.
 *
 * The runner allows to load a TensorFlow GraphDef from a binary Protocol
 * Buffers file, provide input(s) and retrieve the output(s). It also support
 * stateful models such as RNNs.
 *
 * Every input variable must have the same type, as well as every state
 * variable. However, arbitrary independent output data types are supported.
 *
 * @see https://www.tensorflow.org
 *
 * @tparam InputTypeT Input variables type.
 * @tparam StateTypeT State variables type.
 */
template<typename InputTypeT = float, typename StateTypeT = InputTypeT>
class TfRunner
{
private:

    /**
     * Private constructor.
     *
     * @param thisData TfRunner private data.
     */
    TfRunner(internal::TfRunnerData* thisData);

public:

    /** Input variables type. */
    typedef InputTypeT InputType;

    /** State variables type. */
    typedef StateTypeT StateType;

    // Noncopyable

    TfRunner(TfRunner<InputTypeT, StateTypeT>& other) = delete;
    TfRunner& operator=(TfRunner<InputTypeT, StateTypeT>& other) = delete;

    /**
     * Move constructor.
     */
    TfRunner(TfRunner<InputTypeT, StateTypeT>&& other);

    /** Move assignment. */
    TfRunner& operator=(TfRunner<InputTypeT, StateTypeT>&& other);

    /** Destructor. */
    virtual ~TfRunner();

    /**
     * Create a runner for a TensorFlow model.
     *
     * The input, output and state nodes must be specified. If the model has no
     * state states should be left empty.
     *
     * @param model         TensorFlow model.
     * @param inputs        Input node names.
     * @param outputs       Output node names.
     * @param states        State node names. Each pair in the vector must contain
     *                      the name of the state input first and the name of the
     *                      state output second.
     * @return New runner for the loaded model, or nullptr if the runner could not be created.
     */
    static TfRunner<InputTypeT, StateTypeT>* Create(
        const TfModel& model,
        const std::vector<std::string>& inputs,
        const std::vector<std::string>& outputs,
        const std::vector<std::pair<std::string, std::string>>& states =
            std::vector<std::pair<std::string, std::string>>())
    {
        auto stringGetData = [](const std::string& s) { return s.data(); };
        std::vector<const char*> inputsChar;
        std::transform(inputs.begin(), inputs.end(),
                       std::back_inserter(inputsChar), stringGetData);
        std::vector<const char*> outputsChar;
        std::transform(outputs.begin(), outputs.end(),
                       std::back_inserter(outputsChar), stringGetData);
        std::vector<const char*> stateInputsChar;
        std::vector<const char*> stateOutputsChar;
        for (auto& state : states)
        {
            stateInputsChar.push_back(state.first.data());
            stateOutputsChar.push_back(state.second.data());
        }
        return CreateImpl(model, inputsChar.data(), inputs.size(),
                          outputsChar.data(), outputs.size(),
                          stateInputsChar.data(), stateOutputsChar.data(), states.size());
    }

    /**
     * Create a runner for a TensorFlow model.
     *
     * The input, output and state nodes must be specified. If the model has no
     * state states should be left empty.
     *
     * The state nodes must be given as prefixes of names for the actual nodes.
     * The model must contain exactly the same number of nodes with each prefix,
     * and for each node prefixed by the input prefix there must be a node
     * prefixed by the output prefix with the same suffix representing the output
     * of that state variable. For example, for a state variable with input node
     * "StateIn_1" and output node "StateOut_1", using "StateIn" and "StateOut" as
     * state input and output prefixes would match them properly. Empty suffixes
     * are valid too. This is useful for models where the exact number of state
     * variables is not known in advance.
     *
     * @param model             TensorFlow model.
     * @param inputs            Input node names.
     * @param outputs           Output node names.
     * @param stateInputPrefix  State input prefix.
     * @param stateOutputPrefix State output prefix.
     * @return New runner for the loaded model, or nullptr if the runner could not be created.
     */
    static TfRunner<InputTypeT, StateTypeT>* Create(
        const TfModel& model,
        const std::vector<std::string>& inputs,
        const std::vector<std::string>& outputs,
        const std::string& stateInputPrefix,
        const std::string& stateOutputPrefix)
    {
        auto stringGetData = [](const std::string& s) { return s.data(); };
        std::vector<const char*> inputsChar;
        std::transform(inputs.begin(), inputs.end(),
                       std::back_inserter(inputsChar), stringGetData);
        std::vector<const char*> outputsChar;
        std::transform(outputs.begin(), outputs.end(),
                       std::back_inserter(outputsChar), stringGetData);
        return CreateImpl(model, inputsChar.data(), inputs.size(),
                          outputsChar.data(), outputs.size(),
                          stateInputPrefix.data(), stateOutputPrefix.data());
    }

    /**
     * Create a runner for a TensorFlow model.
     *
     * The input, output and state nodes must be specified. If the model has no
     * state stateAliases should be left empty.
     *
     * Instead of using actual placeholder or variable names this function
     * expects "aliases". An alias is the name of an arbitrary node with one
     * input connected to the output of the real node (for example, an identity
     * node). This makes it easier to work with models generated by frameworks
     * that do not allow to provide custom node names.
     *
     * @param model         TensorFlow model.
     * @param inputAliases  Aliases for the input nodes.
     * @param outputAliases Aliases for the output nodes.
     * @param stateAliases  Aliases for the state nodes. Each pair in the vector
     *                      must contain an alias of the state input first and an
     *                      alias of the state output second.
     * @return New runner for the loaded model, or nullptr if the runner could not be created.
     */
    static TfRunner<InputTypeT, StateTypeT>* CreateAliased(
        const TfModel& model,
        const std::vector<std::string>& inputAliases,
        const std::vector<std::string>& outputAliases,
        const std::vector<std::pair<std::string, std::string>>& stateAliases =
            std::vector<std::pair<std::string, std::string>>())
    {
        auto stringGetData = [](const std::string& s) { return s.data(); };
        std::vector<const char*> inputAliasesChar;
        std::transform(inputAliases.begin(), inputAliases.end(),
                       std::back_inserter(inputAliasesChar), stringGetData);
        std::vector<const char*> outputAliasesChar;
        std::transform(outputAliases.begin(), outputAliases.end(),
                       std::back_inserter(outputAliasesChar), stringGetData);
        std::vector<const char*> stateInputAliasesChar;
        std::vector<const char*> stateOutputAliasesChar;
        for (auto& stateAlias : stateAliases)
        {
            stateInputAliasesChar.push_back(stateAlias.first.data());
            stateOutputAliasesChar.push_back(stateAlias.second.data());
        }
        return CreateAliasedImpl(model, inputAliasesChar.data(), inputAliases.size(),
                                 outputAliasesChar.data(), outputAliases.size(),
                                 stateInputAliasesChar.data(), stateOutputAliasesChar.data(), stateAliases.size());
    }

    /**
     * Create a runner for a TensorFlow model.
     *
     * The input, output and state nodes must be specified. If the model has no
     * state stateInputAliasPrefix and stateOutputAliasPrefix should be empty
     * strings.
     *
     * Instead of using actual placeholder or variable names this function
     * expects "aliases". An alias is the name of an arbitrary node with one
     * input connected to the output of the real node (for example, an identity
     * node). This makes it easier to work with models generated by frameworks
     * that do not allow to provide custom node names.
     *
     * The state nodes must be given as prefixes of aliases for the actual nodes.
     * The model must contain exactly the same number of nodes with each prefix,
     * and for each node prefixed by the input prefix there must be a node
     * prefixed by the output prefix with the same suffix representing the output
     * of that state variable. For example, for a state variable with input alias
     * "StateIn_1" and output alias "StateOut_1", using "StateIn" and "StateOut" as
     * state input and output alias prefixes would match them properly. Empty
     * suffixes are valid too. This is useful for models where the exact number of
     * state variables is not known in advance.
     *
     * @param model                  TensorFlow model.
     * @param inputAliases           Aliases for the input nodes.
     * @param outputAliases          Aliases for the output nodes.
     * @param stateInputAliasPrefix  State input alias prefix.
     * @param stateOutputAliasPrefix State output alias prefix.
     * @return New runner for the loaded model, or nullptr if the runner could not be created.
     */
    static TfRunner<InputTypeT, StateTypeT>* CreateAliased(
        const TfModel& model,
        const std::vector<std::string>& inputAliases,
        const std::vector<std::string>& outputAliases,
        const std::string& stateInputAliasPrefix,
        const std::string& stateOutputAliasPrefix)
    {
        auto stringGetData = [](const std::string& s) { return s.data(); };
        std::vector<const char*> inputAliasesChar;
        std::transform(inputAliases.begin(), inputAliases.end(),
                       std::back_inserter(inputAliasesChar), stringGetData);
        std::vector<const char*> outputAliasesChar;
        std::transform(outputAliases.begin(), outputAliases.end(),
                       std::back_inserter(outputAliasesChar), stringGetData);
        std::vector<const char*> stateInputAliasesChar;
        std::vector<const char*> stateOutputAliasesChar;
        return CreateAliasedImpl(model, inputAliasesChar.data(), inputAliases.size(),
                                 outputAliasesChar.data(), outputAliases.size(),
                                 stateInputAliasPrefix.data(), stateOutputAliasPrefix.data());
    }

    /**
     * Reset the state of the model.
     *
     * For statefuls models, this method resets the value of the state nodes to
     * zero. For non-stateful models, it does nothing.
     */
    void Reset();

    /**
     * Feed an input to the model.
     *
     * Runs the model with the given input values and update its outputs. If the
     * model is stateful the state variables will get updated accordingly.
     *
     * The input values must be given as flattened vectors with the content of the
     * input tensors in the order specified when the model was loaded. Input vectors
     * must be flattened according to "row-major" storage, meaning that tensor
     * indices are iterated from the last to the first dimension.
     *
     * @param inputValues     Fed input values.
     * @return true if the input was fed correctly, false otherwise.
     */
    bool Feed(const std::vector<std::vector<InputType>>& inputValues)
    {
        typedef typename std::vector<InputType> InputVector;
        std::vector<const InputType*> inputValuePtrs;
        std::transform(inputValues.begin(), inputValues.end(),
                       std::back_inserter(inputValuePtrs),
                       [](const std::vector<InputType>& v) { return v.data(); });
        std::vector<size_t> inputValueSizes;
        std::transform(inputValues.begin(), inputValues.end(),
                       std::back_inserter(inputValueSizes),
                       [](const std::vector<InputType>& v) { return v.size(); });
        return Feed(inputValuePtrs.data(), inputValueSizes.data(), inputValues.size());
    }

    /**
     * Feed an input to the model.
     *
     * Runs the model with the given input values and update its outputs. If the
     * model is stateful the state variables will get updated accordingly.
     *
     * The input values must be given as flattened vectors with the content of the
     * input tensors in the order specified when the model was loaded. Input vectors
     * must be flattened according to "row-major" storage, meaning that tensor
     * indices are iterated from the last to the first dimension.
     *
     * @param inputValues      Fed input values.
     * @param inputValueSizes  Size of the input values.
     * @param numInputValues   Number of fed input values.
     * @return true if the input was fed correctly, false otherwise.
     */
    bool Feed(const InputType* const* inputValues, const size_t* inputValueSizes, uint32_t numInputValues)
    {
        if (numInputValues != NumInputs())
        {
            return false;
        }
        typedef typename internal::ComplexTypeSupport<InputType>::primitive_type PrimitiveType;
        if (internal::ComplexTypeSupport<InputType>::needs_conversion)
        {
          std::vector<std::vector<PrimitiveType>> inputPrimitiveValues;
          std::vector<const PrimitiveType*> inputPrimitiveValuesPtr;
          for (uint32_t i = 0; i < numInputValues; i++)
          {
              std::vector<PrimitiveType> convertedValues;
              std::transform(inputValues[i], inputValues[i] + inputValueSizes[i],
                             std::back_inserter(convertedValues), internal::ComplexTypeSupport<InputType>::ToPrimitive);
              inputPrimitiveValues.push_back(std::move(convertedValues));
              inputPrimitiveValuesPtr.push_back(inputPrimitiveValues.back().data());
          }
            return FeedImpl(inputPrimitiveValuesPtr.data(), inputValueSizes, numInputValues);
        }
        else
        {
            return FeedImpl(reinterpret_cast<const PrimitiveType* const*>(inputValues), inputValueSizes, numInputValues);
        }
    }

    /**
     * The number of inputs fed to the model after the latest reset of the runner,
     * or after its creation if the runner has not been reset.
     *
     * Outputs should not try to be retrieved while this value is zero.
     *
     * @return The number of inputs fed after the latest runner reset or creation.
     */
    uint32_t NumFeeds() const;

    /**
     * The number of inputs in the model.
     *
     * @return The number of inputs.
     */
    uint32_t NumInputs() const;

    /**
     * The number of outputs in the model.
     *
     * @return The number of outputs.
     */
    uint32_t NumOutputs() const;

    /**
     * Retrieve an output from the model.
     *
     * At least one input must have been fed to the model after creating it or
     * resetting it before calling this method.
     *
     * The return value is a vector with the contents of the output tensor
     * flattened according to "row-major" storage, meaning that tensor indices
     * are iterated from the last to the first dimension.
     *
     * @tparam OutputTypeT  Output variable type. Must match the type declared in the model.
     * @param outputIndex   Index of the desired output, as indicated by the position
     *                       of its name or alias in the outputs vector when the model was
     *                       loaded.
     * @return The value of the output as a flattened vector.
     */
    template<typename OutputTypeT>
    inline std::vector<OutputTypeT> GetOutput(uint32_t outputIndex) const
    {
        std::vector<OutputTypeT> output;
        GetOutput(outputIndex, output);
        return output;
    }

    /**
     * Retrieve an output from the model.
     *
     * At least one input must have been fed to the model after creating it or
     * resetting it before calling this method.
     *
     * The return value is a vector with the contents of the output tensor
     * flattened according to "row-major" storage, meaning that tensor indices
     * are iterated from the last to the first dimension.
     *
     * @tparam OutputTypeT  Output variable type. Must match the type declared in the model.
     * @param outputIndex   Index of the desired output, as indicated by the position
     *                       of its name or alias in the outputs vector when the model was
     *                       loaded.
     * @param[out] output   The value of the output as a flattened vector.
     */
    template<typename OutputTypeT>
    inline void GetOutput(uint32_t outputIndex, std::vector<OutputTypeT>& output) const
    {
        output.resize(GetOutputSize(outputIndex));
        auto size = output.size();
        auto ok = GetOutput(outputIndex, output.data(), size);
        assert(ok);
    }

    /**
     * Retrieve an output from the model.
     *
     * At least one input must have been fed to the model after creating it or
     * resetting it before calling this method.
     *
     * The return value is a vector with the contents of the output tensor
     * flattened according to "row-major" storage, meaning that tensor indices
     * are iterated from the last to the first dimension.
     *
     * @tparam OutputTypeT  Output variable type. Must match the type declared in the model.
     * @param outputIndex   Index of the desired output, as indicated by the position
     *                       of its name or alias in the outputs vector when the model was
     *                       loaded.
     * @param[out] output  The value of the output as a flattened vector.
     * @param[inout] size  The size of the pointed memory area on input, set to size of
     *                      the retrieved output on return.
     * @return  true if the output value could be copied to the memory area, false otherwise.
     */
    template<typename OutputTypeT>
    inline bool GetOutput(uint32_t outputIndex, OutputTypeT* output, size_t& size) const
    {
        typedef internal::ComplexTypeSupport<OutputTypeT>::primitive_type PrimitiveType;
        if (internal::ComplexTypeSupport<OutputTypeT>::needs_conversion)
        {
            std::vector<PrimitiveType> outputPrimitive(size);
            auto ok = GetOutputImpl<OutputTypeT>(outputIndex, outputPrimitive.data(), size);
            if (!ok) return false;
            std::transform(outputPrimitive.begin(), outputPrimitive.begin() + size,
                           output, internal::ComplexTypeSupport<OutputTypeT>::ToComplex);
            return true;
        }
        else
        {
            return GetOutputImpl<OutputTypeT>(outputIndex, reinterpret_cast<PrimitiveType*>(output), size);
        }
    }

    /**
     * The number of state variables in the model.
     *
     * This can be useful if the model was loaded using prefixes for the states
     * but state inspection is needed.
     *
     * @return The number of state variables.
     */
    uint32_t NumStates() const;

    /**
     * Retrieve an state variable from the model.
     *
     * This method can be called at any moment, although do note that the returned
     * value will be all zeros if called right after construction or a reset.
     *
     * The return value is a vector with the contents of the state tensor
     * flattened according to "row-major" storage, meaning that tensor indices
     * are iterated from the last to the first dimension.
     *
     * @param stateIndex Index of the desired state, as indicated by the position
     *                   of its name or alias in the states vector when the model
     *                   was loaded. If the model was loaded using state prefixes
     *                   make sure the index is smaller than the total number of
     *                   states.
     * @return The value of the output as a flattened vector.
     */
    inline std::vector<StateType> GetState(uint32_t stateIndex) const
    {
        std::vector<StateType> state;
        auto size = state.size();
        GetState(stateIndex, state, size);
        return state;
    }

    /**
     * Retrieve an state variable from the model.
     *
     * This method can be called at any moment, although do note that the returned
     * value will be all zeros if called right after construction or a reset.
     *
     * The return value is a vector with the contents of the state tensor
     * flattened according to "row-major" storage, meaning that tensor indices
     * are iterated from the last to the first dimension.
     *
     * @param stateIndex Index of the desired state, as indicated by the position
     *                   of its name or alias in the states vector when the model
     *                   was loaded. If the model was loaded using state prefixes
     *                   make sure the index is smaller than the total number of
     *                   states.
     * @param[out] state The value of the output as a flattened vector.
     */
    inline void GetState(uint32_t stateIndex, std::vector<StateType>& state) const
    {
        state.resize(GetStateSize(stateIndex));
        GetState(stateIndex, state.data());
    }

    /**
     * Retrieve an state variable from the model.
     *
     * This method can be called at any moment, although do note that the returned
     * value will be all zeros if called right after construction or a reset.
     *
     * The return value is a vector with the contents of the state tensor
     * flattened according to "row-major" storage, meaning that tensor indices
     * are iterated from the last to the first dimension.
     *
     * @param stateIndex Index of the desired state, as indicated by the position
     *                   of its name or alias in the states vector when the model
     *                   was loaded. If the model was loaded using state prefixes
     *                   make sure the index is smaller than the total number of
     *                   states.
     * @param[out] state The value of the output as a flattened vector.
     * @param[inout] size  The size of the pointed memory area on input, set to size of
     *                      the retrieved state on return.
     * @return  true if the state value could be copied to the memory area, false otherwise.
     */
    inline bool GetState(uint32_t stateIndex, StateType* state, size_t& size) const
    {
        typedef internal::ComplexTypeSupport<StateType>::primitive_type PrimitiveType;
        if (internal::ComplexTypeSupport<StateType>::needs_conversion)
        {
            std::vector<PrimitiveType> statePrimitive(size);
            auto ok = GetStateImpl(stateIndex, statePrimitive.data());
            if (!ok) return false;
            std::transform(statePrimitive.begin(), statePrimitive.begin() + size,
                           state, internal::ComplexTypeSupport<StateType>::ToComplex);
            return true;
        }
        else
        {
            return GetStateImpl(stateIndex, reinterpret_cast<PrimitiveType*>(state), size);
        }
    }

    /**
     * Get the rank (number of dimensions) of an input variable.
     *
     * This method can be called at any moment.
     *
     * @param inputIndex Index of the desired input rank, as indicated by the
     *                   position of its name or alias in the inputs vector
     *                   when the model was loaded.
     * @return The rank of the input variable.
     */
    uint32_t GetInputRank(uint32_t inputIndex) const;

    /**
     * Get the shape of an input variable.
     *
     * The shape of a variable can guide the construction of flattened input vectors.
     *
     * This method can be called at any moment.
     *
     * @param inputIndex Index of the desired input shape, as indicated by the
     *                   position of its name or alias in the inputs vector
     *                   when the model was loaded.
     * @return The shape of the input variable, expressed as a vector indicating
     *         the number of elements in each tensor dimension.
     */
    std::vector<size_t> GetInputShape(uint32_t inputIndex) const
    {
        std::vector<size_t> shape(GetInputRank(inputIndex));
        for (uint32_t i = 0; i < shape.size(); i++)
        {
            shape[i] = GeInputDimensionSize(i);
        }
        return shape;
    }

    /**
     * Get the size of an input variable.
     *
     * The size of a variable can guide the construction of flattened input vectors.
     *
     * This method can be called at any moment.
     *
     * @param inputIndex Index of the desired input size, as indicated by the
     *                   position of its name or alias in the inputs vector
     *                   when the model was loaded.
     * @return The size of the input variable.
     */
    size_t GetInputSize(uint32_t inputIndex) const;

    /**
     * Get the size of one dimension of an input variable.
     *
     * This method can be called at any moment.
     *
     * @param inputIndex  Index of the desired input size, as indicated by the
     *                     position of its name or alias in the outputs vector
     *                     when the model was loaded.
     * @param dimension   Dimension of which size is requested.
     * @return The size of the given dimension in the input variable.
     */
    size_t GetInputDimensionSize(uint32_t inputIndex, uint32_t dimension) const;

    /**
     * Get the rank (number of dimensions) of an state variable.
     *
     * This method can be called at any moment.
     *
     * @param stateIndex Index of the desired state rank, as indicated by the
     *                   position of its name or alias in the states vector
     *                   when the model was loaded. If the model was loaded using
     *                   state prefixes make sure the index is smaller than the
     *                   total number of states.
     * @return The rank of the state variable.
     */
    uint32_t GetStateRank(uint32_t stateIndex) const;

    /**
     * Get the shape of a state variable.
     *
     * The shape of a variable can guide the processing of flattened state vectors.
     *
     * This method can be called at any moment.
     *
     * @param stateIndex Index of the desired state shape, as indicated by the
     *                   position of its name or alias in the states vector
     *                   when the model was loaded. If the model was loaded using
     *                   state prefixes make sure the index is smaller than the
     *                   total number of states.
     * @return The shape of the state variable, expressed as a vector indicating
     *         the number of elements in each tensor dimension.
     */
    inline std::vector<size_t> GetStateShape(uint32_t stateIndex) const
    {
        std::vector<size_t> shape(GetStateRank(stateIndex));
        for (uint32_t i = 0; i < shape.size(); i++)
        {
            shape[i] = GetStateDimensionSize(i);
        }
        return shape;
    }

    /**
     * Get the size of a state variable.
     *
     * The size of a variable can guide the construction of flattened state vectors.
     *
     * This method can be called at any moment.
     *
     * @param stateIndex Index of the desired state size, as indicated by the
     *                   position of its name or alias in the states vector
     *                   when the model was loaded. If the model was loaded using
     *                   state prefixes make sure the index is smaller than the
     *                   total number of states.
     * @return The size of the state variable.
     */
    size_t GetStateSize(uint32_t stateIndex) const;

    /**
     * Get the size of one dimension of a state variable.
     *
     * This method can be called at any moment.
     *
     * @param stateIndex  Index of the desired state size, as indicated by the
     *                     position of its name or alias in the states vector
     *                     when the model was loaded. If the model was loaded using
     *                     state prefixes make sure the index is smaller than the
     *                     total number of states.
     * @param dimension   Dimension of which size is requested.
     * @return The size of the given dimension in the state variable.
     */
    size_t GetStateDimensionSize(uint32_t stateIndex, uint32_t dimension) const;

    /**
     * Get the rank (number of dimensions) of an output variable.
     *
     * At least one input must have been fed to the model after creating it or
     * resetting it before calling this method.
     *
     * @param outputIndex  Index of the desired output rank, as indicated by the
     *                      position of its name or alias in the outputs vector
     *                      when the model was loaded.
     * @return The rank of the output variable.
     */
    uint32_t GetOutputRank(uint32_t outputIndex) const;

    /**
     * Get the shape of an output variable.
     *
     * The shape of a variable can guide the processing of flattened output vectors.
     *
     * At least one input must have been fed to the model after creating it or
     * resetting it before calling this method.
     *
     * @param outputIndex  Index of the desired output shape, as indicated by the
     *                      position of its name or alias in the outputs vector
     *                      when the model was loaded.
     * @return The shape of the output variable, expressed as a vector indicating
     *         the number of elements in each tensor dimension.
     */
    std::vector<size_t> GetOutputShape(uint32_t outputIndex) const
    {
        std::vector<size_t> shape(GetOutputRank(outputIndex));
        for (uint32_t i = 0; i < shape.size(); i++)
        {
            shape[i] = GetOutputDimensionSize(i);
        }
        return shape;
    }

    /**
     * Get the size of an output variable.
     *
     * The size of a variable can guide the construction of flattened output vectors.
     *
     * At least one input must have been fed to the model after creating it or
     * resetting it before calling this method.
     *
     * @param outputIndex  Index of the desired output size, as indicated by the
     *                      position of its name or alias in the outputs vector
     *                      when the model was loaded.
     * @return The size of the output variable.
     */
    size_t GetOutputSize(uint32_t outputIndex) const;

    /**
     * Get the size of one dimension of an output variable.
     *
     * At least one input must have been fed to the model after creating it or
     * resetting it before calling this method.
     *
     * @param outputIndex  Index of the desired output size, as indicated by the
     *                      position of its name or alias in the outputs vector
     *                      when the model was loaded.
     * @param dimension    Dimension of which size is requested.
     * @return The size of the given dimension in the output variable.
     */
    size_t GetOutputDimensionSize(uint32_t outputIndex, uint32_t dimension) const;

private:

    /** Private data. */
    internal::TfRunnerData* m_this;

    /**
     * Create a runner for a TensorFlow model.
     *
     * The input, output and state nodes must be specified. If the model has no
     * state vtates should be left empty.
     *
     * @param model         TensorFlow model.
     * @param inputs        Input node names.
     * @param num_inputs    Number of inputs.
     * @param outputs       Output node names.
     * @param num_outputs   Number of outputs.
     * @param stateInputs   State input node names.
     * @param stateOutputs  State output node names.
     * @param num_states    Number of states.
     * @return New runner for the loaded model, or nullptr if the runner could not be created.
     */
    static TfRunner<InputTypeT, StateTypeT>* CreateImpl(
        const TfModel& model,
        const char* const* inputs, uint32_t numInputs,
        const char* const* outputs, uint32_t numOutputs,
        const char* const* stateInputs,
        const char* const* stateOutputs,
        uint32_t numStates);

    /**
     * Create a runner for a TensorFlow model.
     *
     * The input, output and state nodes must be specified. If the model has no
     * state states should be left empty.
     *
     * The state nodes must be given as prefixes of names for the actual nodes.
     * The model must contain exactly the same number of nodes with each prefix,
     * and for each node prefixed by the input prefix there must be a node
     * prefixed by the output prefix with the same suffix representing the output
     * of that state variable. For example, for a state variable with input node
     * "StateIn_1" and output node "StateOut_1", using "StateIn" and "StateOut" as
     * state input and output prefixes would match them properly. Empty suffixes
     * are valid too. This is useful for models where the exact number of state
     * variables is not known in advance.
     *
     * @param model              TensorFlow model.
     * @param inputs             Input node names.
     * @param num_inputs         Number of inputs.
     * @param outputs            Output node names.
     * @param num_outputs        Number of outputs.
     * @param stateInputPrefix   State input name prefix.
     * @param stateOutputPrefix  State output name prefix.
     * @return New runner for the loaded model, or nullptr if the runner could not be created.
     */
    static TfRunner<InputTypeT, StateTypeT>* CreateImpl(
        const TfModel& model,
        const char* const* inputs, uint32_t numInputs,
        const char* const* outputs, uint32_t numOutputs,
        const char*stateInputPrefix, const char* stateOutputPrefix);

    /**
     * Create a runner for a TensorFlow model.
     *
     * The input, output and state nodes must be specified. If the model has no
     * state stateAliases should be left empty.
     *
     * Instead of using actual placeholder or variable names this function
     * expects "aliases". An alias is the name of an arbitrary node with one
     * input connected to the output of the real node (for example, an identity
     * node). This makes it easier to work with models generated by frameworks
     * that do not allow to provide custom node names.
     *
     * @param model               TensorFlow model.
     * @param inputAliases        Aliases for the input nodes.
     * @param num_inputs          Number of inputs.
     * @param outputAliases       Aliases for the output nodes.
     * @param num_outputs         Number of outputs.
     * @param stateInputAliases   Aliases for the state input nodes.
     * @param stateOutputAliases  Aliases for the state output nodeS.
     * @param num_states          Number of states.
     * @return New runner for the loaded model, or nullptr if the runner could not be created.
     */
    static TfRunner<InputTypeT, StateTypeT>* CreateAliasedImpl(
        const TfModel& model,
        const char* const* inputAliases, uint32_t numInputs,
        const char* const* outputAliases, uint32_t numOutputs,
        const char* const* stateInputAliases,
        const char* const* stateOutputAliases,
        uint32_t numStates);

    /**
     * Create a runner for a TensorFlow model.
     *
     * The input, output and state nodes must be specified. If the model has no
     * state stateInputAliasPrefix and stateOutputAliasPrefix should be empty
     * strings.
     *
     * Instead of using actual placeholder or variable names this function
     * expects "aliases". An alias is the name of an arbitrary node with one
     * input connected to the output of the real node (for example, an identity
     * node). This makes it easier to work with models generated by frameworks
     * that do not allow to provide custom node names.
     *
     * The state nodes must be given as prefixes of aliases for the actual nodes.
     * The model must contain exactly the same number of nodes with each prefix,
     * and for each node prefixed by the input prefix there must be a node
     * prefixed by the output prefix with the same suffix representing the output
     * of that state variable. For example, for a state variable with input alias
     * "StateIn_1" and output alias "StateOut_1", using "StateIn" and "StateOut" as
     * state input and output alias prefixes would match them properly. Empty
     * suffixes are valid too. This is useful for models where the exact number of
     * state variables is not known in advance.
     *
     * @param model                   TensorFlow model.
     * @param inputAliases            Aliases for the input nodes.
     * @param num_inputs              Number of inputs.
     * @param outputAliases           Aliases for the output nodes.
     * @param num_outputs             Number of outputs.
     * @param stateInputAliasPrefix   State input alias prefix.
     * @param stateOutputAliasPrefix  State output alias prefix.
     * @return New runner for the loaded model, or nullptr if the runner could not be created.
     */
    static TfRunner<InputTypeT, StateTypeT>* CreateAliasedImpl(
        const TfModel& model,
        const char* const* inputAliases, uint32_t numInputs,
        const char* const* outputAliases, uint32_t numOutputs,
        const char* stateInputAliasPrefix,
        const char* stateOutputAliasPrefix);

    /**
     * Feed an input to the model.
     *
     * Runs the model with the given input values and update its outputs. If the
     * model is stateful the state variables will get updated accordingly.
     *
     * The input values must be given as flattened vectors with the content of the
     * input tensors in the order specified when the model was loaded. Input vectors
     * must be flattened according to "row-major" storage, meaning that tensor
     * indices are iterated from the last to the first dimension.
     *
     * @param inputValues      Fed input values.
     * @param inputValueSizes  Size of the input values.
     * @param numInputValues   Number of fed input values.
     * @return true if the input was fed correctly, false otherwise.
     */
    bool FeedImpl(const typename internal::ComplexTypeSupport<InputType>::primitive_type* const* inputValues, const size_t* inputValueSizes, uint32_t numInputValues);

    /**
     * Private implementation of GetState.
     *
     * @param stateIndex  Index of the desired state, as indicated by the position
     *                     of its name or alias in the states vector when the model
     *                     was loaded. If the model was loaded using state prefixes
     *                     make sure the index is smaller than the total number of
     *                     states.
     * @para[out] state  The value of the state as a flattened vector.
     * @param[inout] size  The size of the pointed memory area on input, set to size of
     *                      the retrieved state on return.
     * @return  true if the state value could be copied to the memory area, false otherwise.
     */
    bool GetStateImpl(uint32_t stateIndex, typename internal::ComplexTypeSupport<StateType>::primitive_type* state, size_t& size) const;

    /**
     * Private implementation of GetOutput
     *
     * @tparam OutputTypeT  Output variable type. Must match the type declared in the model.
     * @param outputIndex  Index of the desired output, as indicated by the position
     *                      of its name or alias in the outputs vector when the model was
     *                      loaded.
     * @para[out] output  The value of the output as a flattened vector.
     * @param[inout] size  The size of the pointed memory area on input, set to size of
     *                      the retrieved output on return.
     * @return  true if the output value could be copied to the memory area, false otherwise.
     */
    template<typename OutputTypeT>
    bool GetOutputImpl(uint32_t outputIndex, typename internal::ComplexTypeSupport<OutputTypeT>::primitive_type* output, size_t& size) const;
};

}

#ifndef NTTENSORFLOW_API_EXPORTS
#include "NTTensorFlow.inl"
#endif

#endif  // NTTENSORFLOW_NTTENSORFLOW_H_
