Files

95 lines
3.6 KiB
C++

// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <functional>
#include "openvino/core/node.hpp"
#include "openvino/pass/pattern/op/predicate.hpp"
namespace ov::pass::pattern {
namespace op {
class Label;
}
class Matcher;
class MatcherState;
using RPatternValueMap = std::map<std::shared_ptr<Node>, OutputVector>;
using PatternValueMap = std::map<std::shared_ptr<Node>, Output<Node>>;
using PatternValueMaps = std::vector<PatternValueMap>;
using PatternMap = std::map<std::shared_ptr<Node>, std::shared_ptr<Node>>;
using Attributes = std::unordered_map<std::string, ov::Any>;
PatternMap as_pattern_map(const PatternValueMap& pattern_value_map);
PatternValueMap as_pattern_value_map(const PatternMap& pattern_map);
template <typename T>
op::Predicate has_class() {
// GCC treats an empty lambda (without captures / internal state) differently in different ABI versions
bool gcc_abi_compatibility = true;
auto pred = [gcc_abi_compatibility](std::shared_ptr<Node> node) -> bool {
return gcc_abi_compatibility && ov::is_type<T>(std::move(node));
};
return op::Predicate(pred, "has_class<" + std::string(typeid(T).name()) + ">()");
}
template <typename T>
op::Predicate class_other_than() {
// GCC treats an empty lambda (without captures / internal state) differently in different ABI versions
bool gcc_abi_compatibility = true;
auto pred = [gcc_abi_compatibility](std::shared_ptr<Node> node) -> bool {
return gcc_abi_compatibility && !ov::is_type<T>(std::move(node));
};
return op::Predicate(pred, "class_other_than<" + std::string(typeid(T).name()) + ">()");
}
OPENVINO_API op::Predicate consumers_count(size_t n);
OPENVINO_API op::Predicate consumers_more_than(size_t n);
OPENVINO_API op::Predicate has_static_dim(size_t pos);
OPENVINO_API op::Predicate has_static_dims(const std::vector<size_t>& dims);
OPENVINO_API op::Predicate has_static_shape();
OPENVINO_API op::Predicate has_static_rank();
OPENVINO_API op::Predicate rank_equals(const Dimension& expected_rank);
OPENVINO_API op::Predicate rank_more_than(const Dimension& expected_rank);
OPENVINO_API op::Predicate type_matches(const element::Type& type);
OPENVINO_API op::Predicate type_matches_any(const std::vector<element::Type>& types);
OPENVINO_API op::Predicate all_of(const std::vector<std::function<bool(Output<Node>)>>& predicates);
OPENVINO_API op::Predicate attrs_match(const Attributes& expected_attrs);
OPENVINO_API op::Predicate shape_matches(const std::string& shape_notation);
OPENVINO_API op::Predicate value_matches(const std::string& value_notation);
namespace op {
OPENVINO_DEPRECATED("This method is deprecated. Use constructor of ov::pass::pattern::Predicate instead")
OPENVINO_API Predicate as_value_predicate(NodePredicate pred);
class OPENVINO_API Pattern : public Node {
public:
Pattern();
explicit Pattern(const OutputVector& patterns);
explicit Pattern(const NodeVector& patterns);
/// \brief A base class for all the utility operators used to describe a pattern to match
Pattern(const OutputVector& patterns, const Predicate& pred);
Pattern(const NodeVector& patterns, const Predicate& pred);
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& /* new_args */) const override {
OPENVINO_THROW("Uncopyable");
}
std::ostream& write_description(std::ostream& out, uint32_t depth) const override;
virtual std::ostream& write_type_description(std::ostream& out) const;
protected:
Predicate m_predicate;
};
} // namespace op
} // namespace ov::pass::pattern