Files

118 lines
5.7 KiB
C++
Raw Permalink Normal View History

// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/pattern/op/op.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
namespace ov::pass::pattern {
namespace op {
/// A submatch on the graph value which contains optional op types defined in constructor.
/// `Optional` pattern supports multi input operations. In this case the pattern checks
/// inputs with optional node type or 1st input.
/// The match is succeed in case of full graphs matching or extended by one of optional type graph or pattern.
/// Otherwise fails.
//
// +------+ +------+ +------+ +------+ +------+
// | op_0 | | op_1 | | op_0 | | op_1 | | op_0 |
// +------+ +------+ +------+ +------+ +------+
// | | | | |
// V V V V |
// +-------------------+ +---------------------+ |
// | optional<op_type> | =======>>> | wrap_type<op_types> | |
// +-------------------+ +---------------------+ |
// | | |
// V +------------------+ |
// +------+ | |
// | op_3 | V V
// +------+ +--------+
// | Or |
// +--------+
// |
// V
// +--------+
// | op_3 |
// +--------+
// Known limitations:
// 1. The pattern matching does not support operations with optional inputs.
// For example, ov::op::v5::NonMaxSupression can be created without some optional input nodes (like
// `max_output_boxes_per_class`) (In case we would not specify input in constructor, the node input won't be created
// by default as a constant). Arguments matching will be failed due to different number of pattern and graph input
// args. Issue: 139835
// 2. The optional nodes with cumulative inputs will be matched by 1st input.
// Issue: 139839
class OPENVINO_API Optional : public Pattern {
public:
OPENVINO_RTTI("patternOptional");
/// \brief creates an optional node matching one pattern. Add nodes to match list.
/// \param type_infos Optional operation types to exclude them from the matching
/// in case the following op types do not exist in a pattern to match.
/// \param patterns The pattern to match a graph.
Optional(const std::vector<DiscreteTypeInfo>& type_infos, const OutputVector& inputs = {})
: Pattern(inputs),
optional_types(type_infos) {};
template <typename TPredicate>
Optional(const std::vector<DiscreteTypeInfo>& type_infos, const OutputVector& inputs, const TPredicate& pred)
: Pattern(inputs, Predicate(pred)),
optional_types(type_infos){};
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
std::vector<DiscreteTypeInfo> get_optional_types() const;
protected:
std::vector<DiscreteTypeInfo> optional_types;
};
} // namespace op
template <class NodeType>
void collect_type_info(std::vector<DiscreteTypeInfo>& type_info_vec) {
type_info_vec.push_back(NodeType::get_type_info_static());
}
template <class NodeType,
class... NodeTypeArgs,
typename std::enable_if<sizeof...(NodeTypeArgs) != 0, bool>::type = true>
void collect_type_info(std::vector<DiscreteTypeInfo>& type_info_vec) {
collect_type_info<NodeType>(type_info_vec);
collect_type_info<NodeTypeArgs...>(type_info_vec);
}
template <class... NodeTypes,
typename TPredicate,
typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate>>* = nullptr>
std::shared_ptr<Node> optional(const PatternOps& inputs, const TPredicate& pred, const Attributes& attrs = {}) {
std::vector<DiscreteTypeInfo> optional_type_info_vec;
collect_type_info<NodeTypes...>(optional_type_info_vec);
return std::make_shared<op::Optional>(
optional_type_info_vec,
ov::OutputVector(inputs),
attrs.empty() ? op::Predicate(pred) : attrs_match(attrs) && op::Predicate(pred));
}
template <class... NodeTypes,
typename TPredicate,
typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate> &&
!std::is_constructible_v<std::vector<PatternOp>, TPredicate>>* = nullptr>
std::shared_ptr<Node> optional(const TPredicate& pred, const Attributes& attrs = {}) {
return optional<NodeTypes...>(OutputVector{}, op::Predicate(pred), attrs);
}
template <class... NodeTypes>
std::shared_ptr<Node> optional(const PatternOps& inputs = {}, const Attributes& attrs = {}) {
return optional<NodeTypes...>(inputs, attrs.empty() ? op::Predicate() : attrs_match(attrs));
}
template <class... NodeTypes>
std::shared_ptr<Node> optional(std::initializer_list<std::pair<const std::string, ov::Any>>&& attrs) {
return optional<NodeTypes...>(OutputVector{}, attrs);
}
} // namespace ov::pass::pattern