// Copyright (C) 2018-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #pragma once #include #include #include #include "openvino/core/any.hpp" #include "openvino/core/core_visibility.hpp" #include "openvino/core/node.hpp" #include "openvino/core/symbol.hpp" namespace ov::pass::pattern { /// \brief Wrapper to uniformly store and access Matcher symbol information. class OPENVINO_API PatternSymbolValue { public: PatternSymbolValue(); PatternSymbolValue(const std::shared_ptr& s); PatternSymbolValue(int64_t i); PatternSymbolValue(double d); PatternSymbolValue(const std::vector& g); bool is_dynamic() const; bool is_static() const; bool is_group() const; bool is_integer() const; bool is_double() const; int64_t i() const; double d() const; std::shared_ptr s() const; const std::vector& g() const; bool operator==(const PatternSymbolValue& other) const; bool operator!=(const PatternSymbolValue& other) const; template >* = nullptr> static std::vector make_value_vector(const std::vector& v) { return {v.begin(), v.end()}; } private: bool is_valid() const; ov::Any m_value; }; using PatternSymbolMap = std::unordered_map; namespace op { using NodePredicate = std::function)>; using ValuePredicate = std::function&)>; /// \brief Wrapper over different types of predicates. It is used to add restrictions to the match /// Predicate types: /// - Value Predicate -- function)> // most popular version of predicate /// - Node Predicate -- function)> // legacy version, should be used with care /// - Symbol Predicate -- function)> // new version, collects / checks symbols /// class OPENVINO_API Predicate { public: Predicate(); Predicate(std::nullptr_t); template &>>* = nullptr> explicit Predicate(const TPredicate& predicate) { m_pred = [=](PatternSymbolMap&, const Output& out) { return predicate(out); }; } template &> && !std::is_invocable_r_v&>>* = nullptr> explicit Predicate(const TPredicate& predicate) { m_pred = [=](PatternSymbolMap&, const Output& out) { return predicate(out.get_node_shared_ptr()); }; } template < typename TPredicate, typename std::enable_if_t>>* = nullptr> explicit Predicate(const TPredicate& predicate) { m_pred = predicate; m_requires_map = true; } template explicit Predicate(const TPredicate& predicate, std::string name) : Predicate(predicate) { if (!name.empty()) m_name = std::move(name); } bool operator()(Matcher* m, const Output& output) const; bool operator()(const std::shared_ptr& node) const; bool operator()(const Output& output) const; template Predicate operator||(const TPredicate& other) const { return *this || Predicate(other); } template Predicate operator&&(const TPredicate& other) const { return *this && Predicate(other); } Predicate operator||(const Predicate& other) const { auto result = Predicate( [pred = m_pred, other_pred = other.m_pred](PatternSymbolMap& m, const Output& out) -> bool { return pred(m, out) || other_pred(m, out); }, m_name + " || " + other.m_name); result.m_requires_map = m_requires_map || other.m_requires_map; return result; } Predicate operator&&(const Predicate& other) const { auto result = Predicate( [pred = m_pred, other_pred = other.m_pred](PatternSymbolMap& m, const Output& out) -> bool { return pred(m, out) && other_pred(m, out); }, m_name + " && " + other.m_name); result.m_requires_map = m_requires_map || other.m_requires_map; return result; } private: bool m_requires_map = false; std::string m_name = "no_name"; std::function&)> m_pred; }; template && !std::is_same_v>* = nullptr> Predicate operator&&(const TPredicate& lhs, const Predicate& rhs) { return Predicate(lhs) && rhs; } template && !std::is_same_v>* = nullptr> Predicate operator||(const TPredicate& lhs, const Predicate& rhs) { return Predicate(lhs) || rhs; } } // namespace op } // namespace ov::pass::pattern