# -*- coding: utf-8 -*- # Copyright (C) 2018-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # mypy: ignore-errors import logging import inspect import torch from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType from openvino import PartialShape, Type as OVType, OVAny, Shape from openvino.frontend.pytorch.utils import ( make_constant, fetch_attr, pt_to_ov_type_map, torch_tensor_to_ov_const) logger = logging.getLogger(__name__) class BaseFXDecoder(Decoder): """Extends Decoder to handle FX graph decoding in PyTorch. Provides a common interface for all FX decoders. """ def __init__(self, mark_node_callback=None) -> None: Decoder.__init__(self) self.mark_node_callback = mark_node_callback # We store every decoder created by this decoder so that # all them are not deleted until the first decoder is deleted self.m_decoders = [] self._inputs = [] self._outputs = [] @staticmethod def unpack_containers(arg): if isinstance(arg, (tuple, list)): res = [] for element in arg: res.extend(BaseFXDecoder.unpack_containers(element)) return res elif isinstance(arg, dict): res = [] for key, element in arg.items(): unpacked = BaseFXDecoder.unpack_containers(element) if len(unpacked) == 1: unpacked[0] = (key, unpacked[0][1]) res.extend(unpacked) return res else: return [("", arg)] @staticmethod def arg_to_constant(arg): if isinstance(arg, list): if len(arg) > 0: return make_constant(pt_to_ov_type_map[type( arg[0]).__name__], Shape([len(arg)]), arg) else: # TODO: which type should we use if list is empty? Need a signaling value here return make_constant(OVType.i32, Shape([0]), []) elif isinstance(arg, bool): return make_constant(OVType.boolean, Shape([]), [arg]) elif isinstance(arg, int): return make_constant(OVType.i64, Shape([]), [arg]) elif isinstance(arg, float): return make_constant(OVType.f32, Shape([]), [arg]) elif isinstance(arg, str): u8_tensor = torch.frombuffer(str.encode(arg), dtype=torch.uint8) return torch_tensor_to_ov_const(u8_tensor, shared_memory=True) return None @staticmethod def get_type_for_value(value): if issubclass(type(value), torch.fx.Node): if ("tensor_meta" in value.meta.keys()): if value.meta["tensor_meta"] and isinstance(value.meta["tensor_meta"], torch.Tensor): pt_type = value.meta["tensor_meta"].dtype if str(pt_type) in pt_to_ov_type_map: ov_type = pt_to_ov_type_map[str(pt_type)] return OVAny(ov_type) return OVAny(OVType.dynamic) elif isinstance(value, int): return OVAny(DecoderType.PyScalar(OVAny(OVType.i64))) elif isinstance(value, float): return OVAny(DecoderType.PyScalar(OVAny(OVType.f32))) elif isinstance(value, bool): return OVAny(DecoderType.PyScalar(OVAny(OVType.boolean))) elif isinstance(value, list): if len(value) > 0: return OVAny(DecoderType.List(BaseFXDecoder.get_type_for_value(value[0]))) else: return OVAny(DecoderType.List(OVAny(OVType.i32))) return OVAny(OVType.dynamic) def inputs(self): # Consider 0 a special case which may mean the input is inlined, but not guaranteed return [x if not isinstance(x, InlinedInput) else 0 for x in self._inputs] def output(self, index): return self.outputs()[index] def get_input_debug_name(self, index): return "input" + str(index) def is_input_inlined(self, index): return isinstance(self._inputs[index], InlinedInput) def get_inlined_input_decoder(self, index): target = self._inputs[index] assert isinstance(target, InlinedInput), "Requested non-inlined input" in_decoder = InlinedInputDecoder( target, self._nodes, self.mark_node_callback) self.m_decoders.append(in_decoder) return in_decoder def get_input_shape(self, index): return PartialShape.dynamic() def get_input_type(self, index): return OVAny(OVType.dynamic) def get_output_type(self, index): return OVAny(OVType.dynamic) def input_is_none(self, index): if index < len(self._inputs) and isinstance(self._inputs[index], InlinedInput): return self._inputs[index].data is None return False def decoder_type_name(self) -> str: return "fx" def get_schema(self): return "NONE" def mark_node(self, node): if self.mark_node_callback is not None: self.mark_node_callback(self, node) return node def get_subgraphs(self): return [] def get_subgraph_size(self): return len(self.get_subgraphs()) def as_string(self): return None def may_produce_alias(self, in_index: int, out_index: int) -> bool: return False def get_rt_info(self): rt_info = {} return rt_info class TorchFXPythonDecoder (BaseFXDecoder): """Decoder for PyTorch FX GraphModule and Node objects to OpenVINO IR.""" _decomp_table = None def __init__(self, pt_module, fx_gm=None, nodes=None, mark_node_callback=None, input_shapes=None, input_types=None, dynamic_shapes=False): super().__init__(mark_node_callback) self.pt_module = pt_module self.fx_gm = fx_gm if fx_gm is not None else pt_module self.input_types = input_types or [] self.input_types = [OVAny(pt_to_ov_type_map[str(t)]) for t in self.input_types] self.input_shapes = input_shapes or [] self._input_signature = [] self._example_input = None if isinstance(pt_module, torch.fx.graph_module.GraphModule): self._input_is_list = None self._nodes = list(pt_module.graph.nodes) found_types = [] found_shapes = [] for i, value in enumerate(self._nodes): if value.op == "placeholder": self._inputs.append(i) self._input_signature.append(value.name) found_shapes.append(self.get_found_shape(value)) found_types.append(self.get_found_dtype(value)) if found_shapes[-1] is not None: new_shape = [] for dim in found_shapes[-1]: if (dynamic_shapes or type(dim).__name__ == "SymInt"): new_shape.append(-1) else: new_shape.append(dim) found_shapes[-1] = torch.Size(new_shape) elif value.op == "output": # Instead of putting output index, refer to its target uargs = self.unpack_containers(value.args) self._outputs = [(arg[0], self._nodes.index(arg[1])) for arg in uargs if arg[1] is not None] if not input_shapes or len(input_shapes) == 0: self.input_shapes = found_shapes if not input_types or len(input_types) == 0: self.input_types = found_types if hasattr(self.pt_module, "forward"): input_params = inspect.signature(self.pt_module.forward).parameters self._input_signature = list(input_params) elif isinstance(pt_module, torch.fx.Node): self._nodes = nodes # passed from outer context # FIXME: Quadratic complexity nodes*nodes considering the outer loop over all nodes self._outputs = [("", self._nodes.index(pt_module))] self.input_types = [] for arg in pt_module.args: if isinstance(arg, torch.fx.Node): self._inputs.append(self._nodes.index(arg)) else: # Not a node, consider it inlined self._inputs.append(InlinedInput(arg)) self.input_types.append( BaseFXDecoder.get_type_for_value(arg)) @classmethod def from_exported_program(cls, exported_program: torch.export.ExportedProgram) -> "TorchFXPythonDecoder": """Create a TorchFXPythonDecoder instance from an exported PyTorch program.""" from packaging import version if version.parse(torch.__version__) >= version.parse("2.6"): if cls._decomp_table is None: from torch.export.decomp_utils import CustomDecompTable from openvino.frontend.pytorch.torchdynamo.decompositions import ops_to_not_decompose cls._decomp_table = CustomDecompTable() for op in ops_to_not_decompose(): try: cls._decomp_table.pop(op) except KeyError as e: logging.warning("Operation %s not found in decomp table", op, exc_info=e) exported_program = exported_program.run_decompositions(cls._decomp_table) elif version.parse(torch.__version__) >= version.parse("2.2"): from torch._decomp import get_decompositions from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list decomp = get_decompositions(get_export_decomposition_list()) exported_program = exported_program.run_decompositions(decomp_table=decomp) gm = exported_program.module() logger.debug(gm.code) return cls(gm, dynamic_shapes=True) @staticmethod def get_found_shape(value) -> str: # If input is a tensor, read the shape from meta data if hasattr(value, "meta"): if ("tensor_meta" in value.meta.keys()) and value.meta["tensor_meta"]: return value.meta["tensor_meta"].shape if ("val" in value.meta.keys()) and isinstance(value.meta["val"], torch.Tensor): return value.meta["val"].shape return None @staticmethod def get_found_dtype(value) -> str: # If input is a tensor, read the data type from meta data if hasattr(value, "meta") and ("tensor_meta" in value.meta.keys()) and value.meta["tensor_meta"]: return OVAny(pt_to_ov_type_map[str(value.meta["tensor_meta"].dtype)]) return None def get_input_signature_name(self, index: int) -> str: if self._input_signature is not None and index < len(self._input_signature): return self._input_signature[index] return self.get_input_debug_name(index) def get_input_shape(self, index): if index < len(self.input_shapes) and self.input_shapes[index] is not None: return PartialShape(self.input_shapes[index]) _input = self._raw_input(index) return self.get_shape_for_value(_input) def get_input_strides(self, index: int) -> list: raw_input = self._raw_input(index) if isinstance(raw_input, torch.fx.node.Node) and hasattr(raw_input, "meta"): meta = raw_input.meta if "tensor_meta" in meta and hasattr(meta["tensor_meta"], "stride"): strides = list(meta["tensor_meta"].stride) if strides: return strides return [] def get_input_type(self, index): if index < len(self.input_types) and self.input_types[index] is not None: return self.input_types[index] _input = self._raw_input(index) return self.get_type_for_value(_input) def get_output_debug_name(self, index): if self._outputs is not None and index < len(self._outputs) and self._outputs[index][0]: return self._outputs[index][0] name = getattr(self.pt_module, "name", "output") return name + ":" + str(index) def get_output_shape(self, index): output = self._raw_output(index) return self.get_shape_for_value(output) def get_output_type(self, index): output = self._raw_output(index) return self.get_type_for_value(output) def get_shape_for_value(self, value): if value and hasattr(value, "meta") and ("tensor_meta" in value.meta.keys()): if value.meta["tensor_meta"]: return PartialShape(len(value.meta["tensor_meta"].shape) * [-1]) return PartialShape.dynamic() def get_attribute(self, name): if name in self.pt_module.kwargs: attr = self.pt_module.kwargs[name] if isinstance(attr, torch.dtype): return OVAny(pt_to_ov_type_map[str(attr)]) if isinstance(attr, torch.device): return OVAny(attr.type) if isinstance(attr, str): return OVAny(attr) # Numeric attrs convert to Constant constant = self.arg_to_constant(attr) if constant is not None: return OVAny(constant.output(0)) # so that has_attribute return True if attribute exist return OVAny(DecoderType.PyNone()) return OVAny(None) def get_named_input(self, name): """Returns id of kwargs input. Such input can be Node or a constant value, this function is only used for to return node index. If the input is constant, get_attribute should be used. """ if name in self.pt_module.kwargs: arg = self.pt_module.kwargs[name] if isinstance(arg, torch.fx.Node): return self._nodes.index(arg) raise RuntimeError("This input is not a Node") def visit_subgraph(self, node_visitor): # make sure topological order is satisfied for node in self._nodes: if node.op in {"placeholder", "output"}: continue # skipping non-operational nodes if node.op == "call_function" and str(node.target) in ["aten._assert_async.msg"]: continue decoder = TorchFXPythonDecoder( node, self.fx_gm, self._nodes, mark_node_callback=self.mark_node_callback) self.m_decoders.append(decoder) node_visitor(decoder) def get_subgraph_decoder(self, index): decoder = TorchFXPythonDecoder(self.get_subgraphs()[index], self.fx_gm, mark_node_callback=self.mark_node_callback) self.m_decoders.append(decoder) return decoder def get_op_type(self): if self.pt_module.op == "call_function": if type(self.pt_module.target).__name__ == "EdgeOpOverload": return self.pt_module.target.__name__ return str(self.pt_module.target) elif self.pt_module.op == "get_attr": return "get_attr" # FIXME should be aligned with get_attr from TS implementation else: return "UNKNOWN_TYPE_" + str(self.pt_module.op) def outputs(self): return [o[1] for o in self._outputs] def _raw_outputs(self): return [self._nodes[x[1]] for x in self._outputs] def _raw_output(self, index): return self._raw_outputs()[index] def _raw_inputs(self): return [self._nodes[x] if not isinstance(x, InlinedInput) and x < len(self._nodes) else x.data for x in self._inputs] def _raw_input(self, index): return self._raw_inputs()[index] def num_of_outputs(self): return len(self.outputs()) def output_list_size(self): max_out_id = -1 for user in self.pt_module.users: if "" == str(user.target) and max_out_id < user.args[1]: max_out_id = user.args[1] return max_out_id + 1 def mark_node(self, node): name = self.get_op_type() if "FrameworkNode" not in node.get_type_name(): name += "/" + node.get_type_name() node.set_friendly_name(self.pt_module.name + "/" + name) super().mark_node(node) return node def as_constant(self): assert self.pt_module.op == "get_attr", "Only get_attr is supported" # Extract Constant from FX module field ret = fetch_attr(self.fx_gm, self.pt_module.target) ov_const = torch_tensor_to_ov_const(ret, shared_memory=True) return ov_const.outputs() def input_is_none(self, index): if index >= len(self._inputs) or (isinstance(self._inputs[index], InlinedInput) and self._inputs[index].data is None): return True else: r_input = self._raw_input(index) return str(type(r_input)) in ["torch.NoneType", "NoneType"] def debug(self): self.pt_module.print() class InlinedInput: """Represents an inlined input. This is a special case where the input is not a node, but a constant value. """ def __init__(self, data) -> None: self.data = data class InlinedInputDecoder (BaseFXDecoder): """Decoder for inlined inputs in PyTorch FX graphs.""" def __init__(self, inlined_input: InlinedInput, nodes=None, mark_node_callback=None) -> None: super().__init__(mark_node_callback) self.inlined_input = inlined_input self._nodes = nodes self.is_const = not (isinstance(inlined_input.data, (list, tuple)) and any( isinstance(a, torch.fx.Node) for a in inlined_input.data)) if not self.is_const: self._inputs = [nodes.index(x) if isinstance( x, torch.fx.Node) else InlinedInput(x) for x in inlined_input.data] def get_op_type(self): # return specific type for inlined inputs if not self.is_const: return "prim::ListConstruct" return "inlined.constant.default" def outputs(self): return [0] def num_of_outputs(self): return 1 def get_input_shape(self, index): return PartialShape.dynamic() def get_input_type(self, index): return OVAny(OVType.dynamic) def get_output_type(self, index): return OVAny(OVType.dynamic) def input_is_none(self, index): if index < len(self._inputs) and isinstance(self._inputs[index], InlinedInput): return self._inputs[index].data is None return False def as_constant(self): arg = self.inlined_input.data constant = BaseFXDecoder.arg_to_constant(arg) if constant is not None: return constant.outputs() return [] def mark_node(self, node): name = self.get_op_type() if "FrameworkNode" not in node.get_type_name(): name += "/" + node.get_type_name() node.set_friendly_name(name) super().mark_node(node) return node