107 lines
3.5 KiB
C++
107 lines
3.5 KiB
C++
// Copyright (C) 2018-2025 Intel Corporation
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
|
|
#pragma once
|
|
|
|
#include <list>
|
|
#include <memory>
|
|
#include <typeinfo>
|
|
#include <vector>
|
|
|
|
#include "openvino/pass/pass.hpp"
|
|
#include "openvino/pass/validate.hpp"
|
|
|
|
namespace ov {
|
|
namespace pass {
|
|
/**
|
|
* @brief Manager class allows to manage transformation passes
|
|
* @ingroup ov_pass_cpp_api
|
|
*/
|
|
class OPENVINO_API Manager {
|
|
public:
|
|
Manager();
|
|
virtual ~Manager();
|
|
|
|
//// \brief Construct Manager with a provided name.
|
|
explicit Manager(std::string name);
|
|
|
|
//// \brief Construct Manager with shared PassConfig instance
|
|
explicit Manager(std::shared_ptr<PassConfig> pass_config, std::string name = "UnnamedManager");
|
|
|
|
/// \brief Register given transformation class type to execution list
|
|
/// Example below show the basic usage of pass::Manager
|
|
///
|
|
/// pass::Manager manager;
|
|
/// manager.register_pass<MyTransformation>(/* transformation constructor args */);
|
|
/// manager.run_passes(f);
|
|
///
|
|
/// For some purposes transformation can be registered and disabled by default.
|
|
///
|
|
/// manager.register_pass<MyTransformation, false>();
|
|
///
|
|
/// \return shared_ptr to the transformation instance
|
|
template <typename T, bool Enable = true, class... Args>
|
|
std::shared_ptr<T> register_pass(Args&&... args) {
|
|
auto rc = push_pass<T>(std::forward<Args>(args)...);
|
|
rc->set_pass_config(m_pass_config);
|
|
if (m_per_pass_validation) {
|
|
push_pass<Validate>();
|
|
}
|
|
if (!Enable && !m_pass_config->is_enabled<T>()) {
|
|
m_pass_config->disable<T>();
|
|
}
|
|
return rc;
|
|
}
|
|
|
|
std::shared_ptr<PassBase> register_pass_instance(std::shared_ptr<PassBase> pass) {
|
|
pass->set_pass_config(m_pass_config);
|
|
m_pass_list.push_back(pass);
|
|
if (m_per_pass_validation) {
|
|
push_pass<Validate>();
|
|
}
|
|
return pass;
|
|
}
|
|
|
|
/// \brief Runs registered transformations on a given model
|
|
///
|
|
/// \param model Input model
|
|
///
|
|
/// \return Returns true if the model was changed by transformations,
|
|
/// false otherwise.
|
|
bool run_passes(const std::shared_ptr<Model>& model);
|
|
|
|
/// \brief Set flag to enable/disable running Validate pass after executing
|
|
/// each registered pass
|
|
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
|
|
void set_per_pass_validation(bool new_state);
|
|
|
|
/// \return PassConfig shared object. This object is used for transformations pipeline
|
|
/// configuration.
|
|
/// This object allows to disable/enable transformations execution, set callback to
|
|
/// particular
|
|
/// transformation. For more details see PassConfig class.
|
|
std::shared_ptr<PassConfig> get_pass_config() {
|
|
return m_pass_config;
|
|
}
|
|
|
|
protected:
|
|
template <typename T, class... Args>
|
|
std::shared_ptr<T> push_pass(Args&&... args) {
|
|
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
|
|
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
|
m_pass_list.push_back(pass);
|
|
return pass;
|
|
}
|
|
|
|
std::shared_ptr<PassConfig> m_pass_config;
|
|
std::vector<std::shared_ptr<PassBase>> m_pass_list;
|
|
bool m_per_pass_validation = true;
|
|
std::string m_name = "UnnamedManager";
|
|
|
|
private:
|
|
bool run_pass(const std::shared_ptr<PassBase>& pass, const std::shared_ptr<Model>& model, bool needs_validate);
|
|
};
|
|
} // namespace pass
|
|
} // namespace ov
|