Files

169 lines
6.6 KiB
C++

// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include "openvino/op/constant.hpp"
#include "openvino/op/op.hpp"
#include "openvino/op/util/topk_base.hpp"
namespace ov {
namespace op {
namespace v1 {
/// \brief Computes indices and values of the k maximum/minimum values
/// for each slice along specified axis.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API TopK : public util::TopKBase {
public:
OPENVINO_OP("TopK", "opset1", op::util::TopKBase);
using SortType = TopKSortType;
using Mode = TopKMode;
/// \brief Constructs a TopK operation
TopK() = default;
/// \brief Constructs a TopK operation with two outputs: values and indices.
/// By default the indices output is described by i32 data type.
///
/// \param data The input tensor
/// \param k Specifies how many maximum/minimum elements should be computed
/// (note: scalar input tensor)
/// \param axis The axis along which to compute top k indices
/// \param mode Specifies which operation (min or max) is used to select
/// the biggest element of two.
/// \param sort Specifies order of output elements and/or indices
/// Accepted values: none, index, value
/// \param index_element_type Specifies type of produced indices
TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type = element::i32);
TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const Mode mode,
const SortType sort,
const element::Type& index_element_type = element::i32);
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
protected:
virtual void k_type_check(const element::Type& k_element_type) const override;
};
} // namespace v1
namespace v3 {
/// \brief Computes indices and values of the k maximum/minimum values
/// for each slice along specified axis.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API TopK : public util::TopKBase {
public:
OPENVINO_OP("TopK", "opset3", op::util::TopKBase);
/// \brief Constructs a TopK operation
TopK() = default;
/// \brief Constructs a TopK operation with two outputs: values and indices.
/// By default the indices output is described by i32 data type.
///
/// \param data The input tensor
/// \param k Specifies how many maximum/minimum elements should be computed
/// (note: scalar input tensor)
/// \param axis The axis along which to compute top k indices
/// \param mode Specifies which operation (min or max) is used to select
/// the biggest element of two.
/// \param sort Specifies order of output elements and/or indices
/// Accepted values: none, index, value
/// \param index_element_type Specifies type of produced indices
TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type = element::i32);
TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type = element::i32);
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v3
namespace v11 {
/// \brief Computes the top K elements of a given tensor along the specified axis.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API TopK : public util::TopKBase {
public:
OPENVINO_OP("TopK", "opset11", op::util::TopKBase);
/// \brief Constructs a TopK operation
TopK() = default;
/// \brief Constructs a TopK operation with two outputs: values and indices.
///
/// \param data The input tensor
/// \param k Specifies how many maximum/minimum elements should be computed
/// \param axis The axis along which the TopK operation should be executed
/// \param mode Specifies whether TopK selects the largest or the smallest elements from each slice
/// \param sort Specifies the order of corresponding elements of the output tensor
/// \param index_element_type Specifies the data type of the elements in the 'indices' output tensor.
/// \param stable Specifies whether the equivalent elements should maintain their relative order
/// from the input tensor during sorting.
TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type = element::i32,
const bool stable = false);
/// \brief Constructs a TopK operation with two outputs: values and indices.
///
/// \param data The input tensor
/// \param k Specifies how many maximum/minimum elements should be computed
/// \param axis The axis along which the TopK operation should be executed
/// \param mode Specifies whether TopK selects the largest or the smallest elements from each slice
/// \param sort Specifies the order of corresponding elements of the output tensor
/// \param index_element_type Specifies the data type of the elements in the 'indices' output tensor.
/// \param stable Specifies whether the equivalent elements should maintain their relative order
/// from the input tensor during sorting.
TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type = element::i32,
const bool stable = false);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
bool get_stable() const {
return m_stable;
}
void set_stable(const bool stable) {
m_stable = stable;
}
private:
bool m_stable = false;
};
} // namespace v11
} // namespace op
} // namespace ov