84 lines
3.8 KiB
C++
84 lines
3.8 KiB
C++
// Copyright (C) 2018-2025 Intel Corporation
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
|
|
#pragma once
|
|
|
|
#include "openvino/core/node.hpp"
|
|
#include "openvino/pass/pattern/op/pattern.hpp"
|
|
|
|
namespace ov::pass::pattern {
|
|
namespace op {
|
|
/// Fails if the predicate returns false on the graph value.
|
|
///
|
|
/// The graph value is added to the matched values list. If the Label is already
|
|
/// associated with a value, the match succeeds if the value is the same as the graph
|
|
/// value. Otherwise, the label is associated with the graph value and the match
|
|
/// succeeds if the pattern input matches the graph value.
|
|
///
|
|
/// DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If
|
|
/// more than one inputs are given, an Or pattern of the inputs serves as the input.
|
|
class OPENVINO_API Label : public Pattern {
|
|
public:
|
|
OPENVINO_RTTI("patternLabel");
|
|
/// \brief creates a Label node containing a sub-pattern described by \sa type and
|
|
/// \sa shape.
|
|
///
|
|
/// this Label node can be bound only to the nodes in the input graph
|
|
/// that match the pattern specified by \sa wrapped_nodes
|
|
/// Example:
|
|
/// \code{.cpp}
|
|
/// auto add = a + b; // a and b are op::Parameter in this example
|
|
/// auto label = std::make_shared<pattern::op::Label>(element::f32,
|
|
/// PartialShape{2,2},
|
|
/// nullptr,
|
|
/// OutputVector{add});
|
|
/// \endcode
|
|
template <typename TPredicate, typename TArg = OutputVector>
|
|
Label(const element::Type& type,
|
|
const PartialShape& s,
|
|
const TPredicate& pred,
|
|
const TArg& wrapped_values = OutputVector{})
|
|
: Pattern(OutputVector{wrap_values(wrapped_values)}, Predicate(pred)) {
|
|
set_output_type(0, type, s);
|
|
}
|
|
|
|
/// \brief creates a Label node containing a sub-pattern described by the type and
|
|
/// shape of \sa node.
|
|
///
|
|
/// this Label node can be bound only to the nodes in the input graph
|
|
/// that match the pattern specified by \sa wrapped_values
|
|
/// Example:
|
|
/// \code{.cpp}
|
|
/// auto add = a + b; // a and b are op::Parameter in this example
|
|
/// auto label = std::make_shared<pattern::op::Label>(add,
|
|
/// nullptr,
|
|
/// OutputVector{add});
|
|
/// \endcode
|
|
template <typename TPredicate, typename TArg = OutputVector>
|
|
Label(const Output<Node>& value, const TPredicate& pred, const TArg& wrapped_values = OutputVector{})
|
|
: Label(value.get_element_type(), value.get_partial_shape(), Predicate(pred), wrapped_values) {}
|
|
|
|
explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic())
|
|
: Label(type, s, nullptr, OutputVector{}) {}
|
|
explicit Label(const Output<Node>& value)
|
|
: Label(value.get_element_type(), value.get_partial_shape(), nullptr, OutputVector{}) {}
|
|
|
|
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
|
|
|
|
protected:
|
|
static Output<Node> wrap_values(const OutputVector& wrapped_values);
|
|
static Output<Node> wrap_values(const NodeVector& wrapped_values);
|
|
};
|
|
} // namespace op
|
|
|
|
OPENVINO_API std::shared_ptr<Node> any_input(const Attributes& attrs = {});
|
|
|
|
template <typename TPredicate,
|
|
typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate> &&
|
|
!std::is_constructible_v<Attributes, TPredicate>>* = nullptr>
|
|
std::shared_ptr<Node> any_input(const TPredicate& pred) {
|
|
return std::make_shared<pattern::op::Label>(element::dynamic, PartialShape::dynamic(), op::Predicate(pred));
|
|
}
|
|
} // namespace ov::pass::pattern
|