73 lines
3.0 KiB
Python
73 lines
3.0 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
||
|
|
# Copyright (C) 2018-2025 Intel Corporation
|
||
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
|
||
|
|
|
||
|
|
from typing import Optional, cast, overload
|
||
|
|
from collections.abc import Callable
|
||
|
|
from openvino import Op, Type, Shape, Tensor, PartialShape, TensorVector
|
||
|
|
|
||
|
|
|
||
|
|
class PostponedConstant(Op):
|
||
|
|
"""Postponed Constant is a way to materialize a big constant only when it is going to be serialized to IR and then immediately dispose."""
|
||
|
|
@overload
|
||
|
|
def __init__(self, element_type: Type, shape: Shape, maker: Callable[[], Tensor], name: Optional[str] = None) -> None:
|
||
|
|
...
|
||
|
|
|
||
|
|
@overload
|
||
|
|
def __init__(self, element_type: Type, shape: Shape, maker: Callable[[Tensor], None], name: Optional[str] = None) -> None:
|
||
|
|
...
|
||
|
|
|
||
|
|
def __init__(self, element_type: Type, shape: Shape, maker: Callable, name: Optional[str] = None) -> None:
|
||
|
|
"""Creates a PostponedConstant.
|
||
|
|
|
||
|
|
:param element_type: Element type of the constant.
|
||
|
|
:type element_type: openvino.Type
|
||
|
|
:param shape: Shape of the constant.
|
||
|
|
:type shape: openvino.Shape
|
||
|
|
:param maker: A callable that returns a Tensor or modifies the provided Tensor to represent the constant.
|
||
|
|
Note: It's recommended to use a callable without arguments (returns Tensor) to avoid unnecessary tensor data copies.
|
||
|
|
:type maker: Union[Callable[[], Tensor], Callable[[Tensor], None]]
|
||
|
|
:param name: Optional name for the constant.
|
||
|
|
:type name: Optional[str]
|
||
|
|
|
||
|
|
:Example of a maker that returns a Tensor:
|
||
|
|
|
||
|
|
.. code-block:: python
|
||
|
|
|
||
|
|
class Maker:
|
||
|
|
def __call__(self) -> ov.Tensor:
|
||
|
|
tensor_data = np.array([2, 2, 2, 2], dtype=np.float32)
|
||
|
|
return ov.Tensor(tensor_data)
|
||
|
|
"""
|
||
|
|
super().__init__(self)
|
||
|
|
self.get_rt_info()["postponed_constant"] = True # value doesn't matter
|
||
|
|
self.m_element_type = element_type
|
||
|
|
self.m_shape = shape
|
||
|
|
self.m_maker = maker
|
||
|
|
if name is not None:
|
||
|
|
self.friendly_name = name
|
||
|
|
self.constructor_validate_and_infer_types()
|
||
|
|
|
||
|
|
def evaluate(self, outputs: TensorVector, _: list[Tensor]) -> bool: # type: ignore
|
||
|
|
num_args = self.m_maker.__call__.__code__.co_argcount
|
||
|
|
if num_args == 1:
|
||
|
|
outputs[0] = cast(Callable[[], Tensor], self.m_maker)()
|
||
|
|
else:
|
||
|
|
cast(Callable[[Tensor], None], self.m_maker)(outputs[0])
|
||
|
|
return True
|
||
|
|
|
||
|
|
def validate_and_infer_types(self) -> None:
|
||
|
|
self.set_output_type(0, self.m_element_type, PartialShape(self.m_shape))
|
||
|
|
|
||
|
|
def clone_with_new_inputs(self, new_inputs: list[Tensor]) -> Op:
|
||
|
|
return PostponedConstant(self.m_element_type, self.m_shape, self.m_maker, self.friendly_name)
|
||
|
|
|
||
|
|
def has_evaluate(self) -> bool:
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
# `maker` is a function that returns ov.Tensor that represents a target Constant
|
||
|
|
def make_postponed_constant(element_type: Type, shape: Shape, maker: Callable, name: Optional[str] = None) -> Op:
|
||
|
|
return PostponedConstant(element_type, shape, maker, name)
|