Files
ANSLibs/OpenVINO/python/openvino/utils/decorators.py

162 lines
6.2 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from functools import wraps
from inspect import signature
from typing import Any, Optional, Union, get_origin, get_args
from collections.abc import Callable
from openvino import Node, Output
from openvino.utils.types import NodeInput, as_node, as_nodes
def _get_name(**kwargs: Any) -> Node:
if "name" in kwargs:
return kwargs["name"]
return None
def _set_node_friendly_name(node: Node, *, name: Optional[str] = None) -> Node:
if name is not None:
node.friendly_name = name
return node
def nameable_op(node_factory_function: Callable) -> Callable:
"""Set the name to the openvino operator returned by the wrapped function."""
@wraps(node_factory_function)
def wrapper(*args: Any, **kwargs: Any) -> Node:
node = node_factory_function(*args, **kwargs)
node = _set_node_friendly_name(node, name=_get_name(**kwargs))
return node
return wrapper
def unary_op(node_factory_function: Callable) -> Callable:
"""Convert the first input value to a Constant Node if a numeric value is detected."""
@wraps(node_factory_function)
def wrapper(input_value: NodeInput, *args: Any, **kwargs: Any) -> Node:
input_node = as_node(input_value, name=_get_name(**kwargs))
node = node_factory_function(input_node, *args, **kwargs)
node = _set_node_friendly_name(node, name=_get_name(**kwargs))
return node
return wrapper
def binary_op(node_factory_function: Callable) -> Callable:
"""Convert the first two input values to Constant Nodes if numeric values are detected."""
@wraps(node_factory_function)
def wrapper(left: NodeInput, right: NodeInput, *args: Any, **kwargs: Any) -> Node:
left, right = as_nodes(left, right, name=_get_name(**kwargs))
node = node_factory_function(left, right, *args, **kwargs)
node = _set_node_friendly_name(node, name=_get_name(**kwargs))
return node
return wrapper
def custom_preprocess_function(custom_function: Callable) -> Callable:
"""Convert Node returned from custom_function to Output."""
@wraps(custom_function)
def wrapper(node: Node) -> Output:
return Output._from_node(custom_function(node))
return wrapper
class MultiMethod(object):
def __init__(self, name: str):
self.name = name
self.typemap: dict[tuple, Callable] = {}
# Checks if actual_type is a subclass of any type in the union
def matches_union(self, union_type, actual_type) -> bool: # type: ignore
for type_arg in get_args(union_type):
origin = get_origin(type_arg)
if origin is not None:
type_arg = origin
if isinstance(type_arg, type) and issubclass(actual_type, type_arg):
return True
elif get_origin(type_arg) == list:
if issubclass(actual_type, list):
return True
return False
def matches_optional(self, optional_type, actual_type) -> bool: # type: ignore
return actual_type is None or self.matches_union(optional_type, actual_type)
# Checks whether there is overloading which matches invoked argument types
def check_invoked_types_in_overloaded_funcs(self, tuple_to_check: tuple, key_structure: tuple) -> bool:
for actual_type, expected_type in zip(tuple_to_check, key_structure):
origin = get_origin(expected_type)
if origin is Union:
if not self.matches_union(expected_type, actual_type):
return False
elif origin is Optional:
if not self.matches_optional(expected_type, actual_type):
return False
elif not issubclass(actual_type, expected_type):
return False
return True
def __call__(self, *args, **kwargs) -> Any: # type: ignore
arg_types = tuple(arg.__class__ for arg in args)
kwarg_types = {key: type(value) for key, value in kwargs.items()}
key_matched = None
if len(kwarg_types) == 0 and len(arg_types) != 0:
for key in self.typemap.keys():
# compare types of called function with overloads
if self.check_invoked_types_in_overloaded_funcs(arg_types, key):
key_matched = key
break
elif len(arg_types) == 0 and len(kwarg_types) != 0:
for key, func in self.typemap.items():
func_signature = {arg_name: types.annotation for arg_name, types in signature(func).parameters.items()}
# if kwargs of called function are subset of overloaded function, we use this overload
if kwarg_types.keys() <= func_signature.keys():
key_matched = key
break
elif len(arg_types) != 0 and len(kwarg_types) != 0:
for key, func in self.typemap.items():
func_signature = {arg_name: types.annotation for arg_name, types in signature(func).parameters.items()}
# compare types of called function with overloads
if self.check_invoked_types_in_overloaded_funcs(arg_types, tuple(func_signature.values())):
# if kwargs of called function are subset of overloaded function, we use this overload
if kwarg_types.keys() <= func_signature.keys():
key_matched = key
break
if key_matched is None:
raise TypeError(f"The necessary overload for {self.name} was not found")
function = self.typemap.get(key_matched)
return function(*args, **kwargs) # type: ignore
def register(self, types: tuple, function: Callable) -> None:
if types in self.typemap:
raise TypeError("duplicate registration")
self.typemap[types] = function
registry: dict[str, MultiMethod] = {}
def overloading(*types: tuple) -> Callable:
def register(function: Callable) -> MultiMethod:
name = function.__name__
mm = registry.get(name)
if mm is None:
mm = registry[name] = MultiMethod(name)
mm.register(types, function)
return mm
return register