Files

650 lines
29 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# mypy: ignore-errors
import inspect
import logging
import typing
import torch
from openvino.frontend.pytorch.py_pytorch_frontend import (
_FrontEndPytorchDecoder as Decoder,
_Type as DecoderType
)
from openvino import op, PartialShape, Type as OVType, OVAny
from openvino.frontend.pytorch.utils import (
ivalue_to_constant,
get_value_from_getattr,
pt_to_ov_type_map,
prepare_example_inputs_and_model,
convert_quantized_tensor,
graph_has_ops,
patch_none_example,
)
from openvino import opset11 as ops
from openvino.frontend.pytorch import quantized, patch_model
from openvino.frontend.pytorch.module_extension import ModuleExtension
from openvino.frontend.pytorch.patch_functions import FunctionsPatcher
log = logging.getLogger(__name__)
# A marker for a special type of conversion extension that is inlined in Trampoline class
class InlineConversionExtension:
pass
class TorchScriptPythonDecoder(Decoder):
def __init__(
self,
pt_module,
graph_element=None,
example_input=None,
alias_db=None,
shared_memory=True,
skip_freeze=False,
constant_cache=None,
module_extensions=None,
trace_kwargs=None,
):
super().__init__()
# We store every decoder created by this decoder so that all them are
# not deleted until the first decoder is deleted
self.m_decoders = []
self._input_signature = None
self._shared_memory = shared_memory
self._input_is_list = False
self.constant_cache = constant_cache if constant_cache is not None else dict() # noqa: C408
self.module_extensions = module_extensions
self.config = None
self.out_debug_name_overwrites = {}
self.cached_out_types = []
if graph_element is None:
if hasattr(pt_module, "config"):
if isinstance(pt_module.config, dict):
self.config = pt_module.config
elif hasattr(pt_module.config, "to_dict"):
self.config = pt_module.config.to_dict()
try:
pt_module = self._get_scripted_model(
pt_module, example_input, skip_freeze, trace_kwargs)
except Exception as e:
if example_input is not None:
msg = "tracing"
help_msg = ("Please check correctness of provided 'example_input'. "
"Sometimes models can be converted in scripted mode, please try running "
"conversion without 'example_input'.\n")
else:
msg = "scripting"
help_msg = ("Tracing sometimes provide better results, "
"please provide valid 'example_input' argument.\n")
raise RuntimeError(
f"Couldn't get TorchScript module by {msg}.\nException:\n{e}\n"
f"{help_msg} You can also provide TorchScript module that you obtained"
" yourself, please refer to PyTorch documentation: "
"https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html."
) from e
self.graph_element = pt_module.inlined_graph
self.alias_db = self.graph_element.alias_db()
else:
self.graph_element = graph_element
self.alias_db = alias_db
self.pt_module = pt_module
self.raw_inputs = list(self.graph_element.inputs())
self.raw_outputs = list(self.graph_element.outputs())
if self._input_signature is not None:
if "self" in self.raw_inputs[0].debugName():
self._input_signature.insert(0, "self")
if 0 < len(self._input_signature) < len(self.raw_inputs):
# last input is args input, we need to multiply that name by
# number of extra inputs
self._input_signature = self._input_signature[:-1]
s_len = len(self._input_signature)
for i in range(len(self.raw_inputs) - s_len):
self._input_signature.append(
self.raw_inputs[i + s_len].debugName())
if isinstance(self.graph_element, torch.Graph):
self._transform_tensor_list_constants_to_listconstruct(
self.graph_element)
self._transform_optional_constants(self.graph_element)
log.debug("Inlined graph:\n%s", self.graph_element)
@staticmethod
def _get_preserved_attributes(model) -> list:
preserved_attributes = []
for name, module in model.named_modules():
compressed_types = [torch.int8, torch.uint8,
torch.float16, torch.bfloat16]
if hasattr(module, "weight") and getattr(module.weight, "dtype", None) in compressed_types:
preserved_attributes.append(name)
return preserved_attributes
def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False, trace_kwargs=None):
freeze_by_default = False
if isinstance(pt_module, torch.nn.Module):
pt_module.eval()
input_signature = None
input_parameters = None
if isinstance(pt_module, torch.nn.Module) and not isinstance(
pt_module, (torch.jit._trace.TopLevelTracedModule,
torch.jit._script.RecursiveScriptModule)
):
# input params is dictionary contains input names and their
# signature values (type hints and default values if any)
input_params = inspect.signature(pt_module.forward if hasattr(
pt_module, "forward") else pt_module.__call__).parameters
input_signature = list(input_params)
if example_inputs is None:
if self.module_extensions:
raise RuntimeError(
"ModuleExtension is not supported for scripting. "
"Please provide valid example_input argument to run tracing.")
scripted = torch.jit.script(pt_module)
freeze_by_default = True
else:
pt_module, example_inputs = patch_none_example(pt_module, example_inputs)
input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model(
example_inputs, input_params, pt_module)
# name of attribute in a patched module where the
# original forward method is kept
orig_forward_name = "_openvino_module_extension_patch_orig_forward"
if self.module_extensions:
patch_model.patch_model(
pt_module, self.module_extensions, orig_forward_name)
patched = False
if quantized.detect_quantized_model(pt_module) is not None:
try:
quantized.patch_quantized(pt_module)
patched = True
except Exception as error:
log.warning(
"Failed patching of AutoGPTQ model. Error message:\n"
"Tracing of the model will likely be unsuccessful or incorrect",
exc_info=error)
quantized.unpatch_quantized(pt_module)
patched = False
if trace_kwargs is None:
trace_kwargs = {}
try:
with FunctionsPatcher():
scripted = torch.jit.trace(
pt_module, **input_parameters, strict=False, **trace_kwargs)
finally:
if patched:
quantized.unpatch_quantized(pt_module)
have_to_freeze_ops = ["prim::Uninitialized",
"prim::unchecked_cast", "aten::append"]
if not freeze_by_default and graph_has_ops(scripted.inlined_graph, have_to_freeze_ops):
# freeze models with unsupported ops
freeze_by_default = True
quantized_hint_ops = ["quantized", "aten::as_strided"]
if freeze_by_default and graph_has_ops(scripted.inlined_graph, quantized_hint_ops):
# do not freeze quantized models and can't freeze for
# aten::as_strided it will result in incorrect inference
freeze_by_default = False
if freeze_by_default and not skip_freeze:
preserved_attrs = self._get_preserved_attributes(scripted)
f_model = torch.jit.freeze(
scripted, preserved_attrs=preserved_attrs)
else:
f_model = scripted
self._example_input = input_parameters["example_inputs"] if input_parameters else None
else:
f_model = pt_module
self._example_input = example_inputs
self._input_signature = input_signature
return f_model
def inputs(self) -> list:
return [x.unique() for x in self.raw_inputs]
def get_input_debug_name(self, index: int) -> str:
return self._raw_input(index).debugName()
def get_input_signature_name(self, index: int) -> str:
if self._input_signature is not None and index < len(self._input_signature):
return self._input_signature[index]
return self.get_input_debug_name(index)
def get_input_shape(self, index: int):
raw_input = self._raw_input(index)
return self.get_shape_for_value(raw_input)
def get_input_strides(self, index: int) -> list[int]:
raw_input = self._raw_input(index)
if isinstance(raw_input, torch.Value):
inp_type = raw_input.type()
if isinstance(inp_type, torch.TensorType):
strides = inp_type.strides()
if strides:
return strides
return []
def get_input_type(self, index: int):
raw_input = self._raw_input(index)
return self.get_type_for_value(raw_input)
def get_output_debug_name(self, index: int) -> str:
if index in self.out_debug_name_overwrites:
return self.out_debug_name_overwrites[index]
return self._raw_output(index).debugName()
def get_output_shape(self, index: int):
output = self._raw_output(index)
return self.get_shape_for_value(output)
def get_output_type(self, index: int):
if index < len(self.cached_out_types):
return self.cached_out_types[index]
output = self._raw_output(index)
return self.get_type_for_value(output)
def _get_known_type_for_value(self, pt_type):
"""Returns known/unknown types wrapped as OVAny."""
# Check for simple scalar types first
if pt_type is None:
return OVAny(OVType.dynamic)
# TODO: Don't use str, use native types
if str(pt_type) in ["int", "float", "bool"]:
return OVAny(DecoderType.PyScalar(OVAny(pt_to_ov_type_map[str(pt_type)])))
elif isinstance(pt_type, torch.dtype) and pt_type.is_complex:
return OVAny(DecoderType.Complex(self._get_known_type_for_value(pt_type.to_real())))
elif str(pt_type) in pt_to_ov_type_map:
return OVAny(pt_to_ov_type_map[str(pt_type)])
elif isinstance(pt_type, torch.ComplexType):
# Tensor type, parse element type
return OVAny(DecoderType.Tensor(OVAny(DecoderType.Complex(OVAny(OVType.dynamic)))))
elif isinstance(pt_type, torch.TensorType):
# Tensor type, parse element type
return OVAny(DecoderType.Tensor(self._get_known_type_for_value(pt_type.dtype())))
elif isinstance(pt_type, torch.ListType):
element_type = pt_type.getElementType()
return OVAny(DecoderType.List(self._get_known_type_for_value(element_type)))
elif isinstance(pt_type, (torch.StringType, torch.DeviceObjType)):
return OVAny(DecoderType.Str())
elif isinstance(pt_type, torch.NoneType):
return OVAny(DecoderType.PyNone())
else:
# Not yet recognized
return OVAny(OVType.dynamic)
def get_shape_for_value(self, value: torch.Value):
if value.isCompleteTensor():
# We avoid static shapes, they don't generalize on other inputs
ps = PartialShape([-1] * len(value.type().sizes()))
return ps
else:
# TODO: Recognize types that we can represent as a nested
# constructs with objects from DecoderType If recognized,
# return scalar instead of dynamic. Scalar means a single
# value of that custom type. See get_type_for_value for reference
pass
return PartialShape.dynamic()
def get_type_for_value(self, value: torch.Value):
_type = value.type() if hasattr(value, "type") else type(value).__name__
full_type = self._get_known_type_for_value(_type)
return full_type
def get_subgraph_size(self) -> int:
if isinstance(self.graph_element, torch.Node):
return len(self.get_subgraphs())
else:
return 1
def visit_subgraph(self, node_visitor) -> None:
# make sure topological order is satisfied
for node in self.graph_element.nodes():
decoder = TorchScriptPythonDecoder(
self.pt_module,
node,
alias_db=self.alias_db,
shared_memory=self._shared_memory,
constant_cache=self.constant_cache,
module_extensions=self.module_extensions,
)
self.m_decoders.append(decoder)
node_visitor(decoder)
def decoder_type_name(self) -> str:
return "ts"
def get_subgraphs(self) -> list:
if self.graph_element.kind() in ["prim::PythonOp", "prim::fork"]:
if "Subgraph" in self.graph_element.attributeNames():
assert isinstance(
self.graph_element, torch.Node), "Graph element must be of type torch.Node."
subgraph = getattr(self.graph_element, self.graph_element.kindOf("Subgraph"))("Subgraph")
torch._C._jit_pass_inline(subgraph)
return [subgraph]
else:
# Attribute "Subgraph" is only available if Graph was
# created using tracing.
# TODO Find way to extract subgraph for scripted Graph.
return []
return list(self.graph_element.blocks())
def get_subgraph_decoder(self, index: int):
module = self.pt_module
if self.graph_element.kind() == "prim::fork":
in0 = self.raw_inputs[0]
if in0.node().kind() == "prim::GetAttr":
module, _ = get_value_from_getattr(in0.node(), self.pt_module)
decoder = TorchScriptPythonDecoder(module,
self.get_subgraphs()[index],
alias_db=self.alias_db,
shared_memory=self._shared_memory,
module_extensions=self.module_extensions
)
self.m_decoders.append(decoder)
return decoder
def get_op_extension(self):
assert isinstance(
self.graph_element, torch.Node), "Function can be called only when self.graph_element is of type torch.Node"
if self.graph_element.kind() == "prim::PythonOp" and callable(getattr(self.graph_element, "pyobj", None)):
pyobj = self.graph_element.pyobj()
trampoline = getattr(pyobj, "__self__", None)
return trampoline, getattr(trampoline, "target_extension", None)
def get_op_type(self) -> str:
if op_extension := self.get_op_extension():
trampoline, target_extension = op_extension
if isinstance(target_extension, ModuleExtension):
target_op = target_extension.target_op
if callable(target_op):
target = target_op(trampoline.original_module)
elif isinstance(target_op, str):
target = target_op
# TODO: Support target as a callable that will play a role of
# ConversionExtension for an entire module instead of a single
# op. Without supporting target as a callable here,
# ConversionExtension functionality is still possible to
# implement by combining two extensions: ModuleExtension that
# use temporary name as a target op and another extension of
# type ConversionExtension that translates that particular
# temporary name to custom graph. But providing conversion code
# as a callable `target` is more convenient.
return target
return self.graph_element.kind()
def get_schema(self) -> str:
return self.graph_element.schema()
def outputs(self) -> list:
return [x.unique() for x in self.raw_outputs]
def _raw_output(self, index: int):
return self.raw_outputs[index]
def _raw_input(self, index: int):
return self.raw_inputs[index]
def num_of_outputs(self):
return len(self.raw_outputs)
def output(self, index: int):
return self.outputs()[index]
def mark_node(self, node):
name = self.get_op_type()
if "FrameworkNode" not in node.get_type_name():
name += "/" + node.get_type_name()
if self.graph_element.scopeName():
scope_name = self.graph_element.scopeName().split("/")[-1]
node.set_friendly_name(scope_name + "/" + name)
else:
node.set_friendly_name(name)
return node
def _add_name_to_const_and_cache(self, outputs, name, dtype=None):
if len(outputs) == 1:
# set name corresponding to state_dict name
outputs[0].get_node().set_friendly_name(name)
self.out_debug_name_overwrites[0] = name
self.constant_cache[name] = (outputs, dtype)
def try_decode_get_attr(self):
pt_value, name = get_value_from_getattr(
self.graph_element, self.pt_module)
assert pt_value is not None, "Couldn't retrieve value from prim::GetAttr"
if isinstance(pt_value, torch.ScriptObject):
# We assume this is __torch__.torch.classes.quantized.Conv2dPackedParamsBase or __torch__.torch.classes.quantized.LinearPackedParamsBase
# TODO: but can be anything. Figure a better way to distinguish
weight, bias = pt_value.unpack()
w_name = name + ".weight"
if w_name in self.constant_cache:
res = self.constant_cache[w_name][0]
else:
res = convert_quantized_tensor(weight, self._shared_memory)
self._add_name_to_const_and_cache(res, w_name)
if isinstance(bias, torch.Tensor):
b_name = name + ".bias"
if b_name in self.constant_cache:
res += self.constant_cache[b_name][0]
else:
b_res = ivalue_to_constant(bias)
self._add_name_to_const_and_cache(b_res, b_name)
res += b_res
else:
res += ops.convert_like(ivalue_to_constant(torch.zeros(1))
[0], res[0]).outputs()
try:
# these params exist only for conv params
stride = pt_value.stride()
padding = pt_value.padding()
dilation = pt_value.dilation()
groups = pt_value.groups()
res += ivalue_to_constant(stride,
shared_memory=self._shared_memory)
res += ivalue_to_constant(padding,
shared_memory=self._shared_memory)
res += ivalue_to_constant(dilation,
shared_memory=self._shared_memory)
res += ivalue_to_constant(groups,
shared_memory=self._shared_memory)
except Exception as e:
logging.debug("Failed to get conv params", exc_info=e)
return res
elif not isinstance(pt_value, (torch.jit.ScriptModule, torch.jit.TracedModule)):
# this tensor can be used multiple times in the model, so we have to reuse constants
if name in self.constant_cache:
const, dtype = self.constant_cache[name]
else:
dtype = self.get_type_for_value(pt_value)
if hasattr(pt_value, "dtype") and pt_value.dtype.is_complex:
pt_value = torch.view_as_real(pt_value)
const = ivalue_to_constant(
pt_value, shared_memory=self._shared_memory)
self._add_name_to_const_and_cache(const, name, dtype)
if dtype is not None:
self.cached_out_types = [dtype]
return const
else:
return []
def as_constant(self):
if not isinstance(self.graph_element, torch.Node):
return None
if not self.get_op_type() == "prim::Constant":
return None
pt_value = self._raw_output(0)
pt_type = pt_value.type()
if isinstance(pt_type, torch.TensorType):
return ivalue_to_constant(pt_value.toIValue(), shared_memory=self._shared_memory)
if isinstance(pt_type, torch.ListType):
return self._as_constant_list(pt_value)
if isinstance(pt_type, torch._C.Type) and pt_type.annotation_str == "Generator":
gen = pt_value.toIValue()
return ivalue_to_constant(gen.initial_seed(), shared_memory=self._shared_memory)
const = ivalue_to_constant(
pt_value.toIValue(), shared_memory=self._shared_memory)
if len(const) > 0:
# set name corresponding to state_dict name
const[0].get_node().set_friendly_name(
self.get_output_debug_name(0))
return const
def as_string(self):
if self.get_op_type() == "prim::Constant":
pt_value = self._raw_output(0)
if str(pt_value.type()) in ["torch.StringType", "str"]:
return pt_value.toIValue()
elif str(pt_value.type()) == "Device":
return pt_value.toIValue().type
elif self.get_op_type() == "prim::device":
return self._get_device_string()
return None
@staticmethod
def _as_constant_list(pt_value: torch.Value):
# For now we treat a list as a 1D tensor; it is required by converters to avoid
# need to massively rewrite them in that part where constant attributes are queried
pt_element_type = str(pt_value.type().getElementType())
ivalue = pt_value.toIValue()
is_known_type = pt_element_type in pt_to_ov_type_map
if is_known_type:
ovtype = pt_to_ov_type_map[pt_element_type]
ovshape = PartialShape([len(ivalue)])
ov_const = op.Constant(ovtype, ovshape.get_shape(), ivalue)
return ov_const.outputs()
return []
def _get_device_string(self) -> str:
assert self.graph_element.kind(
) == "prim::device", "This function can be called for prim::device node."
value = self.raw_inputs[0]
if value.type().isSubtypeOf(torch.TensorType.get()):
tensor = typing.cast(torch.TensorType, value.type())
device = tensor.device()
if device:
return str(device)
# Device cannot be statically determined.
return "cpu"
def input_is_none(self, index: int) -> bool:
if index >= len(self.inputs()) or self._raw_input(index) is None:
return True
else:
r_input = self._raw_input(index)
if str(r_input.type()) in ["torch.NoneType", "NoneType"]:
return True
else:
in_node = r_input.node()
if in_node.kind() == "prim::GetAttr":
pt_value, _ = get_value_from_getattr(
in_node, self.pt_module)
return pt_value is None
return False
def may_produce_alias(self, in_index: int, out_index: int) -> bool:
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::_convolution", "aten::matmul", "aten::clone"]:
# AliasDB::may_contain_alias sometimes return True for tensors produced by convolution or matmul, we have to workaround that
return False
try:
return self.alias_db.may_contain_alias(self._raw_input(in_index), self._raw_output(out_index))
except Exception as e:
# Sometimes pytorch fails to get result with IndexError exception while these indexes exist in node
logging.debug("Failed to get alias information", exc_info=e)
return False
def is_input_inlined(self, index):
return False
def get_inlined_input_decoder(self, index):
return None
def get_attribute(self, name):
return OVAny(None)
def get_named_input(self, name):
raise RuntimeError("There is no named inputs in TS graph")
def get_rt_info(self):
rt_info = {}
if self.config is not None and "quantization_config" in self.config and "sym" in self.config["quantization_config"]:
rt_info["symmetric_quantization"] = OVAny(
self.config["quantization_config"]["sym"])
return rt_info
@staticmethod
def _transform_tensor_list_constants_to_listconstruct(graph: torch.Graph):
# Function replaces prim::Constant containing list of Tensors with
# prim::ListConstruct containing prim::Constant Tensors.
assert isinstance(
graph, torch.Graph), "Function can be called only with parameters of type torch.Graph."
for node in graph.nodes():
if node.kind() != "prim::Constant":
continue
output_type = node.output().type()
allowed_types = [
output_type.isSubtypeOf(torch.ListType.ofTensors()),
output_type.isSubtypeOf(torch.ListType(
torch.OptionalType.ofTensor())),
]
if not any(allowed_types):
continue
const_inputs = []
for value in node.output().toIValue():
const_input = graph.insertConstant(value)
const_input.node().moveBefore(node)
const_input.node().copyMetadata(node)
const_inputs.append(const_input)
replacement = graph.create("prim::ListConstruct", const_inputs)
replacement.insertBefore(node)
replacement.output().setType(torch.ListType.ofTensors())
replacement.copyMetadata(node)
node.output().replaceAllUsesWith(replacement.output())
@staticmethod
def _transform_optional_constants(graph: torch.Graph):
# Function replaces prim::Constant containing torch.OptionalType with
# prim::Constant containing torch.NoneType or type of IValue.
assert isinstance(
graph, torch.Graph), "Function can be called only with parameters of type torch.Graph."
for node in graph.nodes():
if node.kind() != "prim::Constant":
continue
output_type = node.output().type()
if not isinstance(output_type, torch.OptionalType):
continue
value = node.output().toIValue()
const_input = graph.insertConstant(value)
const_input.node().moveBefore(node)
const_input.node().copyMetadata(node)
node.output().replaceAllUsesWith(const_input)
def has_converter(self):
if op_extension := self.get_op_extension():
_, target_extension = op_extension
return isinstance(target_extension, InlineConversionExtension)
return False
def convert(self, node_context):
if op_extension := self.get_op_extension():
trampoline, target_extension = op_extension
assert isinstance(target_extension, InlineConversionExtension)
try:
return trampoline.convert(node_context)
except Exception as e:
log.error("Exception happened during calling of custom "
"converter for PyTorch operation. PyTorch Script "
"code: %s", self.graph_element, exc_info=e)
raise
raise AssertionError("PyTorch FrontEnd Internal Error: `converter` "
"method of TorchScriptPythonDecoder is called "
"for node that has no custom converter")