// Copyright (C) 2018-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #pragma once #include #include #include #include "openvino/core/rtti.hpp" #include "openvino/pass/matcher_pass.hpp" #define _OPENVINO_GRAPH_REWRITE_RTTI_WITH_TYPE(TYPE_NAME) _OPENVINO_GRAPH_REWRITE_RTTI_WITH_TYPE_VERSION(TYPE_NAME, "0") #define _OPENVINO_GRAPH_REWRITE_RTTI_WITH_TYPE_VERSION(TYPE_NAME, VERSION_NAME) \ _OPENVINO_RTTI_WITH_TYPE_VERSION_PARENT(TYPE_NAME, VERSION_NAME, ::ov::pass::GraphRewrite) #define OPENVINO_GRAPH_REWRITE_RTTI(...) \ _OPENVINO_RTTI_EXPAND(_OPENVINO_RTTI_DEFINITION_SELECTOR_2(__VA_ARGS__, \ _OPENVINO_GRAPH_REWRITE_RTTI_WITH_TYPE_VERSION, \ _OPENVINO_GRAPH_REWRITE_RTTI_WITH_TYPE)(__VA_ARGS__)) namespace ov { namespace pass { /// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function /// in /// efficient way /// /// Graph rewrite pass is used for matcher passes execution on Function. /// To register MatcherPass use \sa add_matcher(args) method where T is a MatcherPass /// class. /// As a default algorithm graph rewrite pass traverse Function in topological order and /// applies /// registered matcher passes for each node. But if all registered matcher passes have type /// based /// root node in Matcher pattern then efficient mechanism is used to execute them. /// Matcher pattern root is type based if it's operation from opset or /// pattern::op::WrapType. /// Note: when implementing pattern for Matcher make sure that root node is an operation /// from opset /// or has ov::pattern::op::WrapType. That will help GraphRewrite to execute matcher /// passes more /// efficient. /// \ingroup ov_pass_cpp_api class OPENVINO_API GraphRewrite : public ModelPass { public: OPENVINO_MODEL_PASS_RTTI("ov::pass::GraphRewrite"); GraphRewrite() = default; explicit GraphRewrite(const std::shared_ptr& pass) : ModelPass() { m_matchers.push_back(pass); } /// \brief Register given transformation class type to GraphRewrite execution list /// All registered transformations will be executed in a single graph traversal. /// Example below show the basic usage of pass::GraphRewrite /// /// pass::Manager manager; /// auto anchor = manager.register_pass(); /// anchor->add_matcher(); /// anchor->add_matcher(); /// anchor->set_name("CommonMatchers"); /// manager.run_passes(f); /// /// For some purposes transformation can be registered and disabled by default. /// /// anchor->add_matcher(); /// /// \return shared_ptr to the transformation instance template ::value, bool>::type = true> std::shared_ptr add_matcher(Args&&... args) { static_assert(std::is_base_of::value, "pass not derived from MatcherPass"); auto pass = std::make_shared(std::forward(args)...); auto pass_config = get_pass_config(); pass->set_pass_config(pass_config); if (!Enabled && !pass_config->is_enabled()) { pass_config->disable(); } m_matchers.push_back(pass); return pass; } /// \brief Register passes from GraphRewrite class that contains sequence of matcher /// passes registered in its ctor. /// For example: /// /// class ov::pass::LinFusions: public ov::pass::GraphRewrite { /// public: /// OPENVINO_GRAPH_REWRITE_RTTI("LinFusion"); /// Fusions() { /// add_matcher(); /// add_matcher(); /// } /// }; /// /// pass::Manager manager; /// auto anchor = manager.register_pass(); /// anchor->add_matcher(); /// anchor->add_matcher(); /// anchor->set_name("CommonFusions"); /// manager.run_passes(f); /// /// In this case all matcher passes from LinFusions pass will be united with other /// registered matchers. template ::value, bool>::type = true> void add_matcher(Args&&... args) { static_assert(std::is_base_of::value, "pass not derived from GraphRewrite"); auto pass = std::make_shared(std::forward(args)...); auto pass_config = get_pass_config(); for (auto& matcher : pass->m_matchers) { pass->set_pass_config(pass_config); m_matchers.push_back(matcher); } } std::shared_ptr add_matcher(const std::shared_ptr& pass); bool run_on_model(const std::shared_ptr& m) override; void set_pass_config(const std::shared_ptr& pass_config) override; protected: bool apply_matcher_passes(std::shared_ptr f, std::deque> nodes_to_run); bool m_enable_shape_inference = false; std::vector> m_matchers; }; } // namespace pass } // namespace ov