Files
ANSLibs/OpenVINO/python/openvino/frontend/pytorch/inlined_extension.py

262 lines
10 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Union, Optional
from collections.abc import Callable
import torch
import numpy as np
import openvino as ov
from openvino.frontend.pytorch.ts_decoder import InlineConversionExtension
from openvino.frontend.pytorch.utils import pt_to_ov_type_map, no_jit_trace
class ConstWrap:
"""Wraps a constant value for comparison and representation."""
def __init__(self, value: Any) -> None:
self.value = value
def __eq__(self, other: Any) -> bool:
return self.value == other
def __repr__(self) -> str:
return f"ConstWrap({str(self.value)})"
def unpack(
packed: Any,
types: Union[type, tuple[type, ...]],
index: int = 0
) -> tuple[tuple, Any, int]:
"""Unpacks nested structures into a flat tuple."""
unpacked_result: tuple = ()
packer_result: Any = None
if isinstance(packed, tuple):
packer_result = ()
for el in packed:
unpacked, packer, index = unpack(el, types, index)
unpacked_result += unpacked
packer_result += (packer,)
elif isinstance(packed, list):
packer_result = []
for el in packed:
unpacked, packer, index = unpack(el, types, index)
unpacked_result += unpacked
packer_result.append(packer)
elif isinstance(packed, dict):
packer_result = {}
for key, el in packed.items():
unpacked, packer, index = unpack(el, types, index)
unpacked_result += unpacked
packer_result[key] = packer
elif isinstance(packed, types):
unpacked_result = (packed,)
packer_result = index
index += 1
else:
packer_result = ConstWrap(packed)
return unpacked_result, packer_result, index
def pack(unpacked: tuple, packer: Any) -> Any:
"""Packs unpacked elements back into their original structure."""
if isinstance(packer, (tuple, list)):
return type(packer)(pack(unpacked, el) for el in packer)
if isinstance(packer, dict):
return {k: pack(unpacked, v) for k, v in packer.items()}
if isinstance(packer, ConstWrap):
return packer.value
return unpacked[packer]
global_counter_id = 0
def make_custom_op_class(func: Callable,
input_signature: tuple[tuple[Any, Any], ...],
output_signature: tuple[tuple[Any, Any], ...],
input_packer: Any,
output_packer: Any) -> type:
"""Creates a custom operation class from a function and signatures."""
global global_counter_id
class InlinedCustomOp(ov.Op): # type: ignore
class_type_info = ov.runtime.DiscreteTypeInfo(
"InlinedCustomOp", "extension")
def __init__(self, *args: Any) -> None:
# TODO: What about attributes?
super().__init__(self, args)
# `id` attribute distinguishes different instances of the same
# class, we need it because different instances may have different
# behavior
self.attrs = {"id": global_counter_id}
self.constructor_validate_and_infer_types()
def evaluate(self, outputs: Any, inputs: Any) -> bool:
# TODO: Check memory sharing
inputs_torch = tuple(torch.from_numpy(input.data)
for input in inputs)
args, kwargs = pack(inputs_torch, input_packer)
result = func(*args, **kwargs)
result, result_packer, _ = unpack(result, torch.Tensor)
assert result_packer == output_packer
for i, tensor in enumerate(result):
if isinstance(tensor, torch.Tensor):
ov_t = ov.Tensor(tensor.numpy(
force=True), shared_memory=True)
else:
ov_t = ov.Tensor(np.array(tensor), shared_memory=True)
# TODO: set the output tensor directly without copying
ov_t.copy_to(outputs[i])
return True
def has_evaluate(self, *args: Any) -> bool:
return True
def visit_attributes(self, visitor: Any) -> bool:
visitor.on_attributes(self.attrs)
return True
def validate_and_infer_types(self) -> None:
# TODO: Validate input signature
assert output_signature != (), (
"Operation does not produce any output. "
"It will be removed from the graph."
)
for i, output in enumerate(output_signature):
self.set_output_type(i, output[0], output[1])
global_counter_id += 1
return InlinedCustomOp
def make_signature(args: Any) -> tuple[tuple[Any, Any], ...]:
"""Generate a signature tuple for PyTorch tensors.
Maps PyTorch tensor types to OpenVINO types and partial shapes, setting
all dimensions dynamic preserving rank. Currently assumes that all input
arguments are torch.Tensor.
"""
# TODO: Extend beyond just tensors
return tuple(
(pt_to_ov_type_map.get(str(arg.dtype)), ov.PartialShape.dynamic(arg.ndim))
for arg in args
)
def make_input_signature(args: Any) -> tuple[tuple[Any, Any], ...]:
"""Generates an input signature for the provided arguments."""
return make_signature(args)
def make_output_signature(args: Any) -> tuple[tuple[Any, Any], ...]:
"""Generates an output signature for the provided arguments."""
if args is None:
args = ()
if not isinstance(args, tuple):
args = (args,)
return make_signature(args)
def make_trampoline_class(func: Callable,
op: Optional[type],
op_attrs: Any) -> type[torch.autograd.Function]:
"""Creates a trampoline class for a function."""
class Trampoline(torch.autograd.Function):
# this is a marker for this type of extension
target_extension = InlineConversionExtension()
# This function defines how the operation behaves when called as a
# part of PyTorch model code in eager execution or while jit.trace
@staticmethod
def forward(ctx: Any, *call_args: Any) -> Any:
cls = ctx._forward_cls
if not op:
input_signature = make_input_signature(call_args)
# TODO: Try to trace `func` with the hope to obtain traceable
# shapes to build more precise `validate_and_infer_types`
# automatically (unlikely possible)
packed_args, packed_kwargs = pack(
call_args, cls.input_packer)
with no_jit_trace():
result = func(*packed_args, **packed_kwargs)
result, cls.output_packer, _ = unpack(result, torch.Tensor)
if not op:
output_signature = make_output_signature(result)
cls.op = make_custom_op_class(func, input_signature,
output_signature,
cls.input_packer,
cls.output_packer)
else:
cls.op = op
return result
@classmethod
def unpack_inputs(cls, args: Any, kwargs: Any) -> tuple:
"""Unpacks inputs for the operation.
Unpack each element that is a tuple, a list, or a dict to a
tuple of their values and concatenate together. Build `packer`
function to pack the result back to the original nested data
types, save this `packer` to class member to use it later in
`forward`.
"""
unpacked, cls.input_packer, _ = unpack(
(args, kwargs), torch.Tensor)
return unpacked
@classmethod
def pack_outputs(cls, result: Any) -> Any:
"""Packs outputs for the operation."""
return pack(result, cls.output_packer)
@classmethod
def convert(cls, node_context: Any) -> tuple:
"""Defines how the operation is represented in graph."""
num_inps = node_context.get_input_size()
inputs = [node_context.get_input(i) for i in range(num_inps)]
node = cls.op(*inputs, **op_attrs)
# to trigger prim::TupleUnpack bypass in FrontEnd transformation
node.get_rt_info()["__torch_tuple_unpackable__"] = True
return node.outputs()
return Trampoline
def inlined_extension(*args: Any, **op_attrs: Any) -> Callable:
"""Creates an inlined extension for a function."""
def make_trampoline(func: Callable,
op: Optional[type] = None) -> Callable:
def trampoline(*args: Any, **kwargs: Any) -> Any:
# Keep trampoline class creation at the point when the function is
# called to make each time a new trampoline. It is required
# because `func` is fused inside Trampoline class and can have
# different behavior from call to call in PyTorch world even if
# the same op is specified to wrap multiple different functions.
# It happens due to capturing a global context and values kept in
# non-traceable part of inputs.
trampoline = make_trampoline_class(func, op, op_attrs)
# flattening inputs and unflattening outputs should be visible for
# PyTorch jit tracer, so it should be called outside of
# `trampoline.apply` to have clean Tensor-only signature for the
# custom op. Non-traceable part of inputs (and outputs) is captured
# as `trampoline` class fields.
args = trampoline.unpack_inputs(args, kwargs)
result = trampoline.apply(*args) # type: ignore
# pack to the expected nested data types here
return trampoline.pack_outputs(result)
return trampoline
if (len(args) == 1 and callable(args[0])
and not (isinstance(args[0], type)
and issubclass(args[0], ov.Op))):
func = args[0]
return make_trampoline(func)
op = args[0]
return lambda func: make_trampoline(func, op)