# -*- coding: utf-8 -*- # Copyright (C) 2018-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # mypy: ignore-errors import inspect import logging import typing import torch from openvino.frontend.pytorch.py_pytorch_frontend import ( _FrontEndPytorchDecoder as Decoder, _Type as DecoderType ) from openvino import op, PartialShape, Type as OVType, OVAny from openvino.frontend.pytorch.utils import ( ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, prepare_example_inputs_and_model, convert_quantized_tensor, graph_has_ops, patch_none_example, ) from openvino import opset11 as ops from openvino.frontend.pytorch import quantized, patch_model from openvino.frontend.pytorch.module_extension import ModuleExtension from openvino.frontend.pytorch.patch_functions import FunctionsPatcher log = logging.getLogger(__name__) # A marker for a special type of conversion extension that is inlined in Trampoline class class InlineConversionExtension: pass class TorchScriptPythonDecoder(Decoder): def __init__( self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True, skip_freeze=False, constant_cache=None, module_extensions=None, trace_kwargs=None, ): super().__init__() # We store every decoder created by this decoder so that all them are # not deleted until the first decoder is deleted self.m_decoders = [] self._input_signature = None self._shared_memory = shared_memory self._input_is_list = False self.constant_cache = constant_cache if constant_cache is not None else dict() # noqa: C408 self.module_extensions = module_extensions self.config = None self.out_debug_name_overwrites = {} self.cached_out_types = [] if graph_element is None: if hasattr(pt_module, "config"): if isinstance(pt_module.config, dict): self.config = pt_module.config elif hasattr(pt_module.config, "to_dict"): self.config = pt_module.config.to_dict() try: pt_module = self._get_scripted_model( pt_module, example_input, skip_freeze, trace_kwargs) except Exception as e: if example_input is not None: msg = "tracing" help_msg = ("Please check correctness of provided 'example_input'. " "Sometimes models can be converted in scripted mode, please try running " "conversion without 'example_input'.\n") else: msg = "scripting" help_msg = ("Tracing sometimes provide better results, " "please provide valid 'example_input' argument.\n") raise RuntimeError( f"Couldn't get TorchScript module by {msg}.\nException:\n{e}\n" f"{help_msg} You can also provide TorchScript module that you obtained" " yourself, please refer to PyTorch documentation: " "https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html." ) from e self.graph_element = pt_module.inlined_graph self.alias_db = self.graph_element.alias_db() else: self.graph_element = graph_element self.alias_db = alias_db self.pt_module = pt_module self.raw_inputs = list(self.graph_element.inputs()) self.raw_outputs = list(self.graph_element.outputs()) if self._input_signature is not None: if "self" in self.raw_inputs[0].debugName(): self._input_signature.insert(0, "self") if 0 < len(self._input_signature) < len(self.raw_inputs): # last input is args input, we need to multiply that name by # number of extra inputs self._input_signature = self._input_signature[:-1] s_len = len(self._input_signature) for i in range(len(self.raw_inputs) - s_len): self._input_signature.append( self.raw_inputs[i + s_len].debugName()) if isinstance(self.graph_element, torch.Graph): self._transform_tensor_list_constants_to_listconstruct( self.graph_element) self._transform_optional_constants(self.graph_element) log.debug("Inlined graph:\n%s", self.graph_element) @staticmethod def _get_preserved_attributes(model) -> list: preserved_attributes = [] for name, module in model.named_modules(): compressed_types = [torch.int8, torch.uint8, torch.float16, torch.bfloat16] if hasattr(module, "weight") and getattr(module.weight, "dtype", None) in compressed_types: preserved_attributes.append(name) return preserved_attributes def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False, trace_kwargs=None): freeze_by_default = False if isinstance(pt_module, torch.nn.Module): pt_module.eval() input_signature = None input_parameters = None if isinstance(pt_module, torch.nn.Module) and not isinstance( pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule) ): # input params is dictionary contains input names and their # signature values (type hints and default values if any) input_params = inspect.signature(pt_module.forward if hasattr( pt_module, "forward") else pt_module.__call__).parameters input_signature = list(input_params) if example_inputs is None: if self.module_extensions: raise RuntimeError( "ModuleExtension is not supported for scripting. " "Please provide valid example_input argument to run tracing.") scripted = torch.jit.script(pt_module) freeze_by_default = True else: pt_module, example_inputs = patch_none_example(pt_module, example_inputs) input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model( example_inputs, input_params, pt_module) # name of attribute in a patched module where the # original forward method is kept orig_forward_name = "_openvino_module_extension_patch_orig_forward" if self.module_extensions: patch_model.patch_model( pt_module, self.module_extensions, orig_forward_name) patched = False if quantized.detect_quantized_model(pt_module) is not None: try: quantized.patch_quantized(pt_module) patched = True except Exception as error: log.warning( "Failed patching of AutoGPTQ model. Error message:\n" "Tracing of the model will likely be unsuccessful or incorrect", exc_info=error) quantized.unpatch_quantized(pt_module) patched = False if trace_kwargs is None: trace_kwargs = {} try: with FunctionsPatcher(): scripted = torch.jit.trace( pt_module, **input_parameters, strict=False, **trace_kwargs) finally: if patched: quantized.unpatch_quantized(pt_module) have_to_freeze_ops = ["prim::Uninitialized", "prim::unchecked_cast", "aten::append"] if not freeze_by_default and graph_has_ops(scripted.inlined_graph, have_to_freeze_ops): # freeze models with unsupported ops freeze_by_default = True quantized_hint_ops = ["quantized", "aten::as_strided"] if freeze_by_default and graph_has_ops(scripted.inlined_graph, quantized_hint_ops): # do not freeze quantized models and can't freeze for # aten::as_strided it will result in incorrect inference freeze_by_default = False if freeze_by_default and not skip_freeze: preserved_attrs = self._get_preserved_attributes(scripted) f_model = torch.jit.freeze( scripted, preserved_attrs=preserved_attrs) else: f_model = scripted self._example_input = input_parameters["example_inputs"] if input_parameters else None else: f_model = pt_module self._example_input = example_inputs self._input_signature = input_signature return f_model def inputs(self) -> list: return [x.unique() for x in self.raw_inputs] def get_input_debug_name(self, index: int) -> str: return self._raw_input(index).debugName() def get_input_signature_name(self, index: int) -> str: if self._input_signature is not None and index < len(self._input_signature): return self._input_signature[index] return self.get_input_debug_name(index) def get_input_shape(self, index: int): raw_input = self._raw_input(index) return self.get_shape_for_value(raw_input) def get_input_strides(self, index: int) -> list[int]: raw_input = self._raw_input(index) if isinstance(raw_input, torch.Value): inp_type = raw_input.type() if isinstance(inp_type, torch.TensorType): strides = inp_type.strides() if strides: return strides return [] def get_input_type(self, index: int): raw_input = self._raw_input(index) return self.get_type_for_value(raw_input) def get_output_debug_name(self, index: int) -> str: if index in self.out_debug_name_overwrites: return self.out_debug_name_overwrites[index] return self._raw_output(index).debugName() def get_output_shape(self, index: int): output = self._raw_output(index) return self.get_shape_for_value(output) def get_output_type(self, index: int): if index < len(self.cached_out_types): return self.cached_out_types[index] output = self._raw_output(index) return self.get_type_for_value(output) def _get_known_type_for_value(self, pt_type): """Returns known/unknown types wrapped as OVAny.""" # Check for simple scalar types first if pt_type is None: return OVAny(OVType.dynamic) # TODO: Don't use str, use native types if str(pt_type) in ["int", "float", "bool"]: return OVAny(DecoderType.PyScalar(OVAny(pt_to_ov_type_map[str(pt_type)]))) elif isinstance(pt_type, torch.dtype) and pt_type.is_complex: return OVAny(DecoderType.Complex(self._get_known_type_for_value(pt_type.to_real()))) elif str(pt_type) in pt_to_ov_type_map: return OVAny(pt_to_ov_type_map[str(pt_type)]) elif isinstance(pt_type, torch.ComplexType): # Tensor type, parse element type return OVAny(DecoderType.Tensor(OVAny(DecoderType.Complex(OVAny(OVType.dynamic))))) elif isinstance(pt_type, torch.TensorType): # Tensor type, parse element type return OVAny(DecoderType.Tensor(self._get_known_type_for_value(pt_type.dtype()))) elif isinstance(pt_type, torch.ListType): element_type = pt_type.getElementType() return OVAny(DecoderType.List(self._get_known_type_for_value(element_type))) elif isinstance(pt_type, (torch.StringType, torch.DeviceObjType)): return OVAny(DecoderType.Str()) elif isinstance(pt_type, torch.NoneType): return OVAny(DecoderType.PyNone()) else: # Not yet recognized return OVAny(OVType.dynamic) def get_shape_for_value(self, value: torch.Value): if value.isCompleteTensor(): # We avoid static shapes, they don't generalize on other inputs ps = PartialShape([-1] * len(value.type().sizes())) return ps else: # TODO: Recognize types that we can represent as a nested # constructs with objects from DecoderType If recognized, # return scalar instead of dynamic. Scalar means a single # value of that custom type. See get_type_for_value for reference pass return PartialShape.dynamic() def get_type_for_value(self, value: torch.Value): _type = value.type() if hasattr(value, "type") else type(value).__name__ full_type = self._get_known_type_for_value(_type) return full_type def get_subgraph_size(self) -> int: if isinstance(self.graph_element, torch.Node): return len(self.get_subgraphs()) else: return 1 def visit_subgraph(self, node_visitor) -> None: # make sure topological order is satisfied for node in self.graph_element.nodes(): decoder = TorchScriptPythonDecoder( self.pt_module, node, alias_db=self.alias_db, shared_memory=self._shared_memory, constant_cache=self.constant_cache, module_extensions=self.module_extensions, ) self.m_decoders.append(decoder) node_visitor(decoder) def decoder_type_name(self) -> str: return "ts" def get_subgraphs(self) -> list: if self.graph_element.kind() in ["prim::PythonOp", "prim::fork"]: if "Subgraph" in self.graph_element.attributeNames(): assert isinstance( self.graph_element, torch.Node), "Graph element must be of type torch.Node." subgraph = getattr(self.graph_element, self.graph_element.kindOf("Subgraph"))("Subgraph") torch._C._jit_pass_inline(subgraph) return [subgraph] else: # Attribute "Subgraph" is only available if Graph was # created using tracing. # TODO Find way to extract subgraph for scripted Graph. return [] return list(self.graph_element.blocks()) def get_subgraph_decoder(self, index: int): module = self.pt_module if self.graph_element.kind() == "prim::fork": in0 = self.raw_inputs[0] if in0.node().kind() == "prim::GetAttr": module, _ = get_value_from_getattr(in0.node(), self.pt_module) decoder = TorchScriptPythonDecoder(module, self.get_subgraphs()[index], alias_db=self.alias_db, shared_memory=self._shared_memory, module_extensions=self.module_extensions ) self.m_decoders.append(decoder) return decoder def get_op_extension(self): assert isinstance( self.graph_element, torch.Node), "Function can be called only when self.graph_element is of type torch.Node" if self.graph_element.kind() == "prim::PythonOp" and callable(getattr(self.graph_element, "pyobj", None)): pyobj = self.graph_element.pyobj() trampoline = getattr(pyobj, "__self__", None) return trampoline, getattr(trampoline, "target_extension", None) def get_op_type(self) -> str: if op_extension := self.get_op_extension(): trampoline, target_extension = op_extension if isinstance(target_extension, ModuleExtension): target_op = target_extension.target_op if callable(target_op): target = target_op(trampoline.original_module) elif isinstance(target_op, str): target = target_op # TODO: Support target as a callable that will play a role of # ConversionExtension for an entire module instead of a single # op. Without supporting target as a callable here, # ConversionExtension functionality is still possible to # implement by combining two extensions: ModuleExtension that # use temporary name as a target op and another extension of # type ConversionExtension that translates that particular # temporary name to custom graph. But providing conversion code # as a callable `target` is more convenient. return target return self.graph_element.kind() def get_schema(self) -> str: return self.graph_element.schema() def outputs(self) -> list: return [x.unique() for x in self.raw_outputs] def _raw_output(self, index: int): return self.raw_outputs[index] def _raw_input(self, index: int): return self.raw_inputs[index] def num_of_outputs(self): return len(self.raw_outputs) def output(self, index: int): return self.outputs()[index] def mark_node(self, node): name = self.get_op_type() if "FrameworkNode" not in node.get_type_name(): name += "/" + node.get_type_name() if self.graph_element.scopeName(): scope_name = self.graph_element.scopeName().split("/")[-1] node.set_friendly_name(scope_name + "/" + name) else: node.set_friendly_name(name) return node def _add_name_to_const_and_cache(self, outputs, name, dtype=None): if len(outputs) == 1: # set name corresponding to state_dict name outputs[0].get_node().set_friendly_name(name) self.out_debug_name_overwrites[0] = name self.constant_cache[name] = (outputs, dtype) def try_decode_get_attr(self): pt_value, name = get_value_from_getattr( self.graph_element, self.pt_module) assert pt_value is not None, "Couldn't retrieve value from prim::GetAttr" if isinstance(pt_value, torch.ScriptObject): # We assume this is __torch__.torch.classes.quantized.Conv2dPackedParamsBase or __torch__.torch.classes.quantized.LinearPackedParamsBase # TODO: but can be anything. Figure a better way to distinguish weight, bias = pt_value.unpack() w_name = name + ".weight" if w_name in self.constant_cache: res = self.constant_cache[w_name][0] else: res = convert_quantized_tensor(weight, self._shared_memory) self._add_name_to_const_and_cache(res, w_name) if isinstance(bias, torch.Tensor): b_name = name + ".bias" if b_name in self.constant_cache: res += self.constant_cache[b_name][0] else: b_res = ivalue_to_constant(bias) self._add_name_to_const_and_cache(b_res, b_name) res += b_res else: res += ops.convert_like(ivalue_to_constant(torch.zeros(1)) [0], res[0]).outputs() try: # these params exist only for conv params stride = pt_value.stride() padding = pt_value.padding() dilation = pt_value.dilation() groups = pt_value.groups() res += ivalue_to_constant(stride, shared_memory=self._shared_memory) res += ivalue_to_constant(padding, shared_memory=self._shared_memory) res += ivalue_to_constant(dilation, shared_memory=self._shared_memory) res += ivalue_to_constant(groups, shared_memory=self._shared_memory) except Exception as e: logging.debug("Failed to get conv params", exc_info=e) return res elif not isinstance(pt_value, (torch.jit.ScriptModule, torch.jit.TracedModule)): # this tensor can be used multiple times in the model, so we have to reuse constants if name in self.constant_cache: const, dtype = self.constant_cache[name] else: dtype = self.get_type_for_value(pt_value) if hasattr(pt_value, "dtype") and pt_value.dtype.is_complex: pt_value = torch.view_as_real(pt_value) const = ivalue_to_constant( pt_value, shared_memory=self._shared_memory) self._add_name_to_const_and_cache(const, name, dtype) if dtype is not None: self.cached_out_types = [dtype] return const else: return [] def as_constant(self): if not isinstance(self.graph_element, torch.Node): return None if not self.get_op_type() == "prim::Constant": return None pt_value = self._raw_output(0) pt_type = pt_value.type() if isinstance(pt_type, torch.TensorType): return ivalue_to_constant(pt_value.toIValue(), shared_memory=self._shared_memory) if isinstance(pt_type, torch.ListType): return self._as_constant_list(pt_value) if isinstance(pt_type, torch._C.Type) and pt_type.annotation_str == "Generator": gen = pt_value.toIValue() return ivalue_to_constant(gen.initial_seed(), shared_memory=self._shared_memory) const = ivalue_to_constant( pt_value.toIValue(), shared_memory=self._shared_memory) if len(const) > 0: # set name corresponding to state_dict name const[0].get_node().set_friendly_name( self.get_output_debug_name(0)) return const def as_string(self): if self.get_op_type() == "prim::Constant": pt_value = self._raw_output(0) if str(pt_value.type()) in ["torch.StringType", "str"]: return pt_value.toIValue() elif str(pt_value.type()) == "Device": return pt_value.toIValue().type elif self.get_op_type() == "prim::device": return self._get_device_string() return None @staticmethod def _as_constant_list(pt_value: torch.Value): # For now we treat a list as a 1D tensor; it is required by converters to avoid # need to massively rewrite them in that part where constant attributes are queried pt_element_type = str(pt_value.type().getElementType()) ivalue = pt_value.toIValue() is_known_type = pt_element_type in pt_to_ov_type_map if is_known_type: ovtype = pt_to_ov_type_map[pt_element_type] ovshape = PartialShape([len(ivalue)]) ov_const = op.Constant(ovtype, ovshape.get_shape(), ivalue) return ov_const.outputs() return [] def _get_device_string(self) -> str: assert self.graph_element.kind( ) == "prim::device", "This function can be called for prim::device node." value = self.raw_inputs[0] if value.type().isSubtypeOf(torch.TensorType.get()): tensor = typing.cast(torch.TensorType, value.type()) device = tensor.device() if device: return str(device) # Device cannot be statically determined. return "cpu" def input_is_none(self, index: int) -> bool: if index >= len(self.inputs()) or self._raw_input(index) is None: return True else: r_input = self._raw_input(index) if str(r_input.type()) in ["torch.NoneType", "NoneType"]: return True else: in_node = r_input.node() if in_node.kind() == "prim::GetAttr": pt_value, _ = get_value_from_getattr( in_node, self.pt_module) return pt_value is None return False def may_produce_alias(self, in_index: int, out_index: int) -> bool: if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::_convolution", "aten::matmul", "aten::clone"]: # AliasDB::may_contain_alias sometimes return True for tensors produced by convolution or matmul, we have to workaround that return False try: return self.alias_db.may_contain_alias(self._raw_input(in_index), self._raw_output(out_index)) except Exception as e: # Sometimes pytorch fails to get result with IndexError exception while these indexes exist in node logging.debug("Failed to get alias information", exc_info=e) return False def is_input_inlined(self, index): return False def get_inlined_input_decoder(self, index): return None def get_attribute(self, name): return OVAny(None) def get_named_input(self, name): raise RuntimeError("There is no named inputs in TS graph") def get_rt_info(self): rt_info = {} if self.config is not None and "quantization_config" in self.config and "sym" in self.config["quantization_config"]: rt_info["symmetric_quantization"] = OVAny( self.config["quantization_config"]["sym"]) return rt_info @staticmethod def _transform_tensor_list_constants_to_listconstruct(graph: torch.Graph): # Function replaces prim::Constant containing list of Tensors with # prim::ListConstruct containing prim::Constant Tensors. assert isinstance( graph, torch.Graph), "Function can be called only with parameters of type torch.Graph." for node in graph.nodes(): if node.kind() != "prim::Constant": continue output_type = node.output().type() allowed_types = [ output_type.isSubtypeOf(torch.ListType.ofTensors()), output_type.isSubtypeOf(torch.ListType( torch.OptionalType.ofTensor())), ] if not any(allowed_types): continue const_inputs = [] for value in node.output().toIValue(): const_input = graph.insertConstant(value) const_input.node().moveBefore(node) const_input.node().copyMetadata(node) const_inputs.append(const_input) replacement = graph.create("prim::ListConstruct", const_inputs) replacement.insertBefore(node) replacement.output().setType(torch.ListType.ofTensors()) replacement.copyMetadata(node) node.output().replaceAllUsesWith(replacement.output()) @staticmethod def _transform_optional_constants(graph: torch.Graph): # Function replaces prim::Constant containing torch.OptionalType with # prim::Constant containing torch.NoneType or type of IValue. assert isinstance( graph, torch.Graph), "Function can be called only with parameters of type torch.Graph." for node in graph.nodes(): if node.kind() != "prim::Constant": continue output_type = node.output().type() if not isinstance(output_type, torch.OptionalType): continue value = node.output().toIValue() const_input = graph.insertConstant(value) const_input.node().moveBefore(node) const_input.node().copyMetadata(node) node.output().replaceAllUsesWith(const_input) def has_converter(self): if op_extension := self.get_op_extension(): _, target_extension = op_extension return isinstance(target_extension, InlineConversionExtension) return False def convert(self, node_context): if op_extension := self.get_op_extension(): trampoline, target_extension = op_extension assert isinstance(target_extension, InlineConversionExtension) try: return trampoline.convert(node_context) except Exception as e: log.error("Exception happened during calling of custom " "converter for PyTorch operation. PyTorch Script " "code: %s", self.graph_element, exc_info=e) raise raise AssertionError("PyTorch FrontEnd Internal Error: `converter` " "method of TorchScriptPythonDecoder is called " "for node that has no custom converter")