136 lines
4.5 KiB
C++
136 lines
4.5 KiB
C++
// Copyright (C) 2018-2025 Intel Corporation
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
|
|
#pragma once
|
|
|
|
#include "openvino/core/node_output.hpp"
|
|
#include "openvino/core/type/element_type.hpp"
|
|
#include "openvino/frontend/variable.hpp"
|
|
#include "openvino/frontend/visibility.hpp"
|
|
|
|
namespace ov {
|
|
namespace frontend {
|
|
|
|
/// \brief HashTable is a special type of Variable that has a complex value including keys and values.
|
|
/// Keys and values are represented with two separate graph at each time step
|
|
class FRONTEND_API HashTable : public Variable {
|
|
public:
|
|
using Ptr = std::shared_ptr<HashTable>;
|
|
OPENVINO_OP("HashTable", "ov::frontend", Variable);
|
|
|
|
HashTable(const std::string& name,
|
|
const ov::element::Type& key_type,
|
|
const ov::element::Type& value_type,
|
|
const std::shared_ptr<DecoderBase>& decoder = nullptr)
|
|
: Variable(name, decoder),
|
|
m_key_type(key_type),
|
|
m_value_type(value_type) {
|
|
validate_and_infer_types();
|
|
}
|
|
|
|
HashTable(const HashTable& other, const ov::Output<ov::Node>& keys, const ov::Output<ov::Node>& values)
|
|
: HashTable(other) {
|
|
m_keys = keys;
|
|
m_values = values;
|
|
m_is_initialized = true;
|
|
++m_init_counter;
|
|
}
|
|
|
|
// it must be used only for cloning
|
|
// other ways are illegal
|
|
HashTable(const std::string& name,
|
|
const ov::element::Type& key_type,
|
|
const ov::element::Type& value_type,
|
|
const ov::Output<ov::Node>& keys,
|
|
const ov::Output<ov::Node>& values,
|
|
bool is_initialized,
|
|
uint64_t init_counter,
|
|
const std::shared_ptr<DecoderBase>& decoder = nullptr)
|
|
: Variable(name, decoder),
|
|
m_key_type(key_type),
|
|
m_value_type(value_type),
|
|
m_keys(keys),
|
|
m_values(values) {
|
|
m_init_counter = init_counter;
|
|
m_is_initialized = is_initialized;
|
|
validate_and_infer_types();
|
|
}
|
|
|
|
void validate_and_infer_types() override {
|
|
// this is a type of resource so its shape and type is not applicable
|
|
// its output serves to store a reference to a resource
|
|
set_output_type(0, ov::element::dynamic, ov::PartialShape::dynamic());
|
|
// these two outputs serves to store keys and values of a resource
|
|
// keys and values are 1D tensors
|
|
set_output_type(1, m_key_type, ov::PartialShape::dynamic(1));
|
|
set_output_type(2, m_value_type, ov::PartialShape::dynamic(1));
|
|
}
|
|
|
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override {
|
|
auto hash_table_node = std::make_shared<HashTable>(m_name,
|
|
m_key_type,
|
|
m_value_type,
|
|
m_keys,
|
|
m_values,
|
|
m_is_initialized,
|
|
m_init_counter,
|
|
m_decoder);
|
|
hash_table_node->set_attrs(get_attrs());
|
|
return hash_table_node;
|
|
}
|
|
|
|
ov::Output<ov::Node> get_value() override {
|
|
return output(0);
|
|
}
|
|
|
|
/// \brief Returns a value corresponding keys of hash table
|
|
ov::Output<ov::Node> get_keys() {
|
|
if (m_is_initialized) {
|
|
return m_keys;
|
|
} else if (m_other_keys.size() > 0) {
|
|
return *(m_other_keys.begin());
|
|
}
|
|
|
|
return output(1);
|
|
}
|
|
|
|
/// \brief Returns a value corresponding values of hash table
|
|
ov::Output<ov::Node> get_values() {
|
|
if (m_is_initialized) {
|
|
return m_values;
|
|
} else if (m_other_values.size() > 0) {
|
|
return *(m_other_values.begin());
|
|
}
|
|
|
|
return output(2);
|
|
}
|
|
|
|
ov::element::Type get_key_type() const {
|
|
return m_key_type;
|
|
}
|
|
|
|
ov::element::Type get_value_type() const {
|
|
return m_value_type;
|
|
}
|
|
|
|
void add_other_keys_values(const ov::Output<ov::Node>& other_key, const ov::Output<ov::Node>& other_value) {
|
|
m_other_keys.insert(other_key);
|
|
m_other_values.insert(other_value);
|
|
}
|
|
|
|
virtual ~HashTable();
|
|
|
|
private:
|
|
ov::element::Type m_key_type;
|
|
ov::element::Type m_value_type;
|
|
ov::Output<ov::Node> m_keys;
|
|
ov::Output<ov::Node> m_values;
|
|
|
|
std::set<ov::Output<ov::Node>> m_other_keys;
|
|
std::set<ov::Output<ov::Node>> m_other_values;
|
|
};
|
|
|
|
} // namespace frontend
|
|
} // namespace ov
|