Files
ANSLibs/OpenVINO/python/openvino/utils/node_factory.py

135 lines
5.1 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from functools import singledispatchmethod
from typing import Any, Optional, Union
from pathlib import Path
from openvino._pyopenvino import NodeFactory as _NodeFactory
from openvino import Node, Output, Extension
from openvino.exceptions import UserInputError
DEFAULT_OPSET = "opset13"
class NodeFactory(object):
"""Factory front-end to create node objects."""
def __init__(self, opset_version: str = DEFAULT_OPSET) -> None:
"""Create the NodeFactory object.
:param opset_version: The opset version the factory will use to produce ops from.
"""
self.factory = _NodeFactory(opset_version)
def create(
self,
op_type_name: str,
arguments: Optional[list[Union[Node, Output]]] = None,
attributes: Optional[dict[str, Any]] = None,
) -> Node:
"""Create node object from provided description.
The user does not have to provide all node's attributes, but only required ones.
:param op_type_name: The operator type name.
:param arguments: The operator arguments.
:param attributes: The operator attributes.
:return: Node object representing requested operator with attributes set.
"""
if arguments is None and attributes is None:
node = self.factory.create(op_type_name)
return node
if arguments is None and attributes is not None:
raise UserInputError(f'Error: cannot create "{op_type_name}" op without arguments.')
if attributes is None:
attributes = {}
assert arguments is not None
arguments = self._arguments_as_outputs(arguments)
node = self.factory.create(op_type_name, arguments, attributes)
return node
@singledispatchmethod
def add_extension(self, extension: Union[Path, str, Extension, list[Extension]]) -> None:
raise TypeError(f"Unknown argument type: {type(extension)}")
@add_extension.register(Path)
@add_extension.register(str)
def _(self, lib_path: Union[Path, str]) -> None:
"""Add custom operations from an extension.
Extends operation types available for creation by operations
loaded from prebuilt C++ library. Enables instantiation of custom
operations exposed in that library without direct use of
operation classes. Other types of extensions, e.g. conversion
extensions, if they are exposed in the library, are ignored.
In case if an extension operation type from the extension match
one of existing operations registered before (from the standard
OpenVINO opset or from another extension loaded earlier), a new
operation overrides an old operation.
Version of an operation is ignored: an operation with a given type and
a given version/opset will override operation with the same type but
different version/opset in the same NodeFactory instance.
Use separate libraries and NodeFactory instances to differentiate
versions/opsets.
:param lib_path: A path to the library with extension.
"""
self.factory.add_extension(lib_path)
@add_extension.register(Extension)
@add_extension.register(list)
def _(self, extension: Union[Extension, list[Extension]]) -> None:
"""Add custom operations from extension library.
Extends operation types available for creation by operations
loaded from prebuilt C++ library. Enables instantiation of custom
operations exposed in that library without direct use of
operation classes. Other types of extensions, e.g. conversion
extensions, if they are exposed in the library, are ignored.
In case if an extension operation type from a library match
one of existing operations registered before (from the standard
OpenVINO opset or from another extension loaded earlier), a new
operation overrides an old operation.
Version of an operation is ignored: an operation with a given type and
a given version/opset will override operation with the same type but
different version/opset in the same NodeFactory instance.
Use separate libraries and NodeFactory instances to differentiate
versions/opsets.
:param extension: A single Extension or list of Extensions.
"""
self.factory.add_extension(extension)
@staticmethod
def _arguments_as_outputs(arguments: list[Union[Node, Output]]) -> list[Output]:
outputs = []
for argument in arguments:
if issubclass(type(argument), Output):
outputs.append(argument)
else:
outputs.extend(argument.outputs())
return outputs
def _get_node_factory(opset_version: Optional[str] = None) -> NodeFactory:
"""Return NodeFactory configured to create operators from specified opset version."""
if opset_version:
return NodeFactory(opset_version)
else:
return NodeFactory()