/* // Copyright (C) 2020-2024 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include struct InferenceResult; struct InputData; struct InternalModelData; struct ResultBase; class ModelBase { public: ModelBase() {}; ModelBase(const std::string& modelFileName, const std::string& layout = "") : modelFileName(modelFileName), inputsLayouts(parseLayoutString(layout)) {} virtual ~ModelBase() {} virtual std::shared_ptr preprocess(const InputData& inputData, ov::InferRequest& request) = 0; // Virtual method to be overridden in derived classes virtual void updateImageSize(const cv::Size& inputImgSize) { // Optionally leave empty or add any base logic here } virtual ov::CompiledModel compileModel(const ModelConfig& config, ov::Core& core); virtual void onLoadCompleted(const std::vector& requests) {} virtual std::unique_ptr postprocess(InferenceResult& infResult) = 0; void Initilise(const std::string& _modelFileName, const std::string& _layout = "") { modelFileName = _modelFileName; inputsLayouts = parseLayoutString(_layout); } const std::vector& getOutputsNames() const { return outputsNames; } const std::vector& getInputsNames() const { return inputsNames; } std::string getModelFileName() { return modelFileName; } void setInputsPreprocessing(bool reverseInputChannels, const std::string& meanValues, const std::string& scaleValues) { this->inputTransform = InputTransform(reverseInputChannels, meanValues, scaleValues); } protected: virtual void prepareInputsOutputs(std::shared_ptr& model) = 0; virtual void setBatch(std::shared_ptr& model); std::shared_ptr prepareModel(ov::Core& core); InputTransform inputTransform = InputTransform(); std::vector inputsNames; std::vector outputsNames; ov::CompiledModel compiledModel; std::string modelFileName; ModelConfig config = {}; std::map inputsLayouts; ov::Layout getInputLayout(const ov::Output& input); };