491 lines
19 KiB
Python
491 lines
19 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
||
|
|
# Copyright (C) 2018-2025 Intel Corporation
|
||
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
|
||
|
|
# mypy: ignore-errors
|
||
|
|
|
||
|
|
import logging
|
||
|
|
import inspect
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
|
||
|
|
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
|
||
|
|
from openvino import PartialShape, Type as OVType, OVAny, Shape
|
||
|
|
from openvino.frontend.pytorch.utils import (
|
||
|
|
make_constant, fetch_attr, pt_to_ov_type_map, torch_tensor_to_ov_const)
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class BaseFXDecoder(Decoder):
|
||
|
|
"""Extends Decoder to handle FX graph decoding in PyTorch.
|
||
|
|
|
||
|
|
Provides a common interface for all FX decoders.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, mark_node_callback=None) -> None:
|
||
|
|
Decoder.__init__(self)
|
||
|
|
self.mark_node_callback = mark_node_callback
|
||
|
|
# We store every decoder created by this decoder so that
|
||
|
|
# all them are not deleted until the first decoder is deleted
|
||
|
|
self.m_decoders = []
|
||
|
|
self._inputs = []
|
||
|
|
self._outputs = []
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def unpack_containers(arg):
|
||
|
|
if isinstance(arg, (tuple, list)):
|
||
|
|
res = []
|
||
|
|
for element in arg:
|
||
|
|
res.extend(BaseFXDecoder.unpack_containers(element))
|
||
|
|
return res
|
||
|
|
elif isinstance(arg, dict):
|
||
|
|
res = []
|
||
|
|
for key, element in arg.items():
|
||
|
|
unpacked = BaseFXDecoder.unpack_containers(element)
|
||
|
|
if len(unpacked) == 1:
|
||
|
|
unpacked[0] = (key, unpacked[0][1])
|
||
|
|
res.extend(unpacked)
|
||
|
|
return res
|
||
|
|
else:
|
||
|
|
return [("", arg)]
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def arg_to_constant(arg):
|
||
|
|
if isinstance(arg, list):
|
||
|
|
if len(arg) > 0:
|
||
|
|
return make_constant(pt_to_ov_type_map[type(
|
||
|
|
arg[0]).__name__], Shape([len(arg)]), arg)
|
||
|
|
else:
|
||
|
|
# TODO: which type should we use if list is empty? Need a signaling value here
|
||
|
|
return make_constant(OVType.i32, Shape([0]), [])
|
||
|
|
elif isinstance(arg, bool):
|
||
|
|
return make_constant(OVType.boolean, Shape([]), [arg])
|
||
|
|
elif isinstance(arg, int):
|
||
|
|
return make_constant(OVType.i64, Shape([]), [arg])
|
||
|
|
elif isinstance(arg, float):
|
||
|
|
return make_constant(OVType.f32, Shape([]), [arg])
|
||
|
|
elif isinstance(arg, str):
|
||
|
|
u8_tensor = torch.frombuffer(str.encode(arg), dtype=torch.uint8)
|
||
|
|
return torch_tensor_to_ov_const(u8_tensor, shared_memory=True)
|
||
|
|
return None
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_type_for_value(value):
|
||
|
|
if issubclass(type(value), torch.fx.Node):
|
||
|
|
if ("tensor_meta" in value.meta.keys()):
|
||
|
|
if value.meta["tensor_meta"] and isinstance(value.meta["tensor_meta"], torch.Tensor):
|
||
|
|
pt_type = value.meta["tensor_meta"].dtype
|
||
|
|
if str(pt_type) in pt_to_ov_type_map:
|
||
|
|
ov_type = pt_to_ov_type_map[str(pt_type)]
|
||
|
|
return OVAny(ov_type)
|
||
|
|
return OVAny(OVType.dynamic)
|
||
|
|
elif isinstance(value, int):
|
||
|
|
return OVAny(DecoderType.PyScalar(OVAny(OVType.i64)))
|
||
|
|
elif isinstance(value, float):
|
||
|
|
return OVAny(DecoderType.PyScalar(OVAny(OVType.f32)))
|
||
|
|
elif isinstance(value, bool):
|
||
|
|
return OVAny(DecoderType.PyScalar(OVAny(OVType.boolean)))
|
||
|
|
elif isinstance(value, list):
|
||
|
|
if len(value) > 0:
|
||
|
|
return OVAny(DecoderType.List(BaseFXDecoder.get_type_for_value(value[0])))
|
||
|
|
else:
|
||
|
|
return OVAny(DecoderType.List(OVAny(OVType.i32)))
|
||
|
|
return OVAny(OVType.dynamic)
|
||
|
|
|
||
|
|
def inputs(self):
|
||
|
|
# Consider 0 a special case which may mean the input is inlined, but not guaranteed
|
||
|
|
return [x if not isinstance(x, InlinedInput) else 0 for x in self._inputs]
|
||
|
|
|
||
|
|
def output(self, index):
|
||
|
|
return self.outputs()[index]
|
||
|
|
|
||
|
|
def get_input_debug_name(self, index):
|
||
|
|
return "input" + str(index)
|
||
|
|
|
||
|
|
def is_input_inlined(self, index):
|
||
|
|
return isinstance(self._inputs[index], InlinedInput)
|
||
|
|
|
||
|
|
def get_inlined_input_decoder(self, index):
|
||
|
|
target = self._inputs[index]
|
||
|
|
assert isinstance(target, InlinedInput), "Requested non-inlined input"
|
||
|
|
in_decoder = InlinedInputDecoder(
|
||
|
|
target, self._nodes, self.mark_node_callback)
|
||
|
|
self.m_decoders.append(in_decoder)
|
||
|
|
return in_decoder
|
||
|
|
|
||
|
|
def get_input_shape(self, index):
|
||
|
|
return PartialShape.dynamic()
|
||
|
|
|
||
|
|
def get_input_type(self, index):
|
||
|
|
return OVAny(OVType.dynamic)
|
||
|
|
|
||
|
|
def get_output_type(self, index):
|
||
|
|
return OVAny(OVType.dynamic)
|
||
|
|
|
||
|
|
def input_is_none(self, index):
|
||
|
|
if index < len(self._inputs) and isinstance(self._inputs[index], InlinedInput):
|
||
|
|
return self._inputs[index].data is None
|
||
|
|
return False
|
||
|
|
|
||
|
|
def decoder_type_name(self) -> str:
|
||
|
|
return "fx"
|
||
|
|
|
||
|
|
def get_schema(self):
|
||
|
|
return "NONE"
|
||
|
|
|
||
|
|
def mark_node(self, node):
|
||
|
|
if self.mark_node_callback is not None:
|
||
|
|
self.mark_node_callback(self, node)
|
||
|
|
return node
|
||
|
|
|
||
|
|
def get_subgraphs(self):
|
||
|
|
return []
|
||
|
|
|
||
|
|
def get_subgraph_size(self):
|
||
|
|
return len(self.get_subgraphs())
|
||
|
|
|
||
|
|
def as_string(self):
|
||
|
|
return None
|
||
|
|
|
||
|
|
def may_produce_alias(self, in_index: int, out_index: int) -> bool:
|
||
|
|
return False
|
||
|
|
|
||
|
|
def get_rt_info(self):
|
||
|
|
rt_info = {}
|
||
|
|
return rt_info
|
||
|
|
|
||
|
|
|
||
|
|
class TorchFXPythonDecoder (BaseFXDecoder):
|
||
|
|
"""Decoder for PyTorch FX GraphModule and Node objects to OpenVINO IR."""
|
||
|
|
|
||
|
|
_decomp_table = None
|
||
|
|
|
||
|
|
def __init__(self, pt_module, fx_gm=None, nodes=None,
|
||
|
|
mark_node_callback=None, input_shapes=None,
|
||
|
|
input_types=None, dynamic_shapes=False):
|
||
|
|
super().__init__(mark_node_callback)
|
||
|
|
self.pt_module = pt_module
|
||
|
|
self.fx_gm = fx_gm if fx_gm is not None else pt_module
|
||
|
|
self.input_types = input_types or []
|
||
|
|
self.input_types = [OVAny(pt_to_ov_type_map[str(t)])
|
||
|
|
for t in self.input_types]
|
||
|
|
self.input_shapes = input_shapes or []
|
||
|
|
|
||
|
|
self._input_signature = []
|
||
|
|
self._example_input = None
|
||
|
|
|
||
|
|
if isinstance(pt_module, torch.fx.graph_module.GraphModule):
|
||
|
|
self._input_is_list = None
|
||
|
|
self._nodes = list(pt_module.graph.nodes)
|
||
|
|
found_types = []
|
||
|
|
found_shapes = []
|
||
|
|
for i, value in enumerate(self._nodes):
|
||
|
|
if value.op == "placeholder":
|
||
|
|
self._inputs.append(i)
|
||
|
|
self._input_signature.append(value.name)
|
||
|
|
|
||
|
|
found_shapes.append(self.get_found_shape(value))
|
||
|
|
found_types.append(self.get_found_dtype(value))
|
||
|
|
if found_shapes[-1] is not None:
|
||
|
|
new_shape = []
|
||
|
|
for dim in found_shapes[-1]:
|
||
|
|
if (dynamic_shapes or type(dim).__name__ == "SymInt"):
|
||
|
|
new_shape.append(-1)
|
||
|
|
else:
|
||
|
|
new_shape.append(dim)
|
||
|
|
found_shapes[-1] = torch.Size(new_shape)
|
||
|
|
|
||
|
|
elif value.op == "output":
|
||
|
|
# Instead of putting output index, refer to its target
|
||
|
|
uargs = self.unpack_containers(value.args)
|
||
|
|
self._outputs = [(arg[0], self._nodes.index(arg[1]))
|
||
|
|
for arg in uargs if arg[1] is not None]
|
||
|
|
|
||
|
|
if not input_shapes or len(input_shapes) == 0:
|
||
|
|
self.input_shapes = found_shapes
|
||
|
|
if not input_types or len(input_types) == 0:
|
||
|
|
self.input_types = found_types
|
||
|
|
|
||
|
|
if hasattr(self.pt_module, "forward"):
|
||
|
|
input_params = inspect.signature(self.pt_module.forward).parameters
|
||
|
|
self._input_signature = list(input_params)
|
||
|
|
|
||
|
|
elif isinstance(pt_module, torch.fx.Node):
|
||
|
|
self._nodes = nodes # passed from outer context
|
||
|
|
|
||
|
|
# FIXME: Quadratic complexity nodes*nodes considering the outer loop over all nodes
|
||
|
|
self._outputs = [("", self._nodes.index(pt_module))]
|
||
|
|
|
||
|
|
self.input_types = []
|
||
|
|
for arg in pt_module.args:
|
||
|
|
if isinstance(arg, torch.fx.Node):
|
||
|
|
self._inputs.append(self._nodes.index(arg))
|
||
|
|
else:
|
||
|
|
# Not a node, consider it inlined
|
||
|
|
self._inputs.append(InlinedInput(arg))
|
||
|
|
self.input_types.append(
|
||
|
|
BaseFXDecoder.get_type_for_value(arg))
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_exported_program(cls, exported_program: torch.export.ExportedProgram) -> "TorchFXPythonDecoder":
|
||
|
|
"""Create a TorchFXPythonDecoder instance from an exported PyTorch program."""
|
||
|
|
from packaging import version
|
||
|
|
if version.parse(torch.__version__) >= version.parse("2.6"):
|
||
|
|
if cls._decomp_table is None:
|
||
|
|
from torch.export.decomp_utils import CustomDecompTable
|
||
|
|
from openvino.frontend.pytorch.torchdynamo.decompositions import ops_to_not_decompose
|
||
|
|
cls._decomp_table = CustomDecompTable()
|
||
|
|
for op in ops_to_not_decompose():
|
||
|
|
try:
|
||
|
|
cls._decomp_table.pop(op)
|
||
|
|
except KeyError as e:
|
||
|
|
logging.warning("Operation %s not found in decomp table", op, exc_info=e)
|
||
|
|
exported_program = exported_program.run_decompositions(cls._decomp_table)
|
||
|
|
elif version.parse(torch.__version__) >= version.parse("2.2"):
|
||
|
|
from torch._decomp import get_decompositions
|
||
|
|
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
|
||
|
|
decomp = get_decompositions(get_export_decomposition_list())
|
||
|
|
exported_program = exported_program.run_decompositions(decomp_table=decomp)
|
||
|
|
gm = exported_program.module()
|
||
|
|
logger.debug(gm.code)
|
||
|
|
return cls(gm, dynamic_shapes=True)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_found_shape(value) -> str:
|
||
|
|
# If input is a tensor, read the shape from meta data
|
||
|
|
if hasattr(value, "meta"):
|
||
|
|
if ("tensor_meta" in value.meta.keys()) and value.meta["tensor_meta"]:
|
||
|
|
return value.meta["tensor_meta"].shape
|
||
|
|
if ("val" in value.meta.keys()) and isinstance(value.meta["val"], torch.Tensor):
|
||
|
|
return value.meta["val"].shape
|
||
|
|
return None
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_found_dtype(value) -> str:
|
||
|
|
# If input is a tensor, read the data type from meta data
|
||
|
|
if hasattr(value, "meta") and ("tensor_meta" in value.meta.keys()) and value.meta["tensor_meta"]:
|
||
|
|
return OVAny(pt_to_ov_type_map[str(value.meta["tensor_meta"].dtype)])
|
||
|
|
return None
|
||
|
|
|
||
|
|
def get_input_signature_name(self, index: int) -> str:
|
||
|
|
if self._input_signature is not None and index < len(self._input_signature):
|
||
|
|
return self._input_signature[index]
|
||
|
|
return self.get_input_debug_name(index)
|
||
|
|
|
||
|
|
def get_input_shape(self, index):
|
||
|
|
if index < len(self.input_shapes) and self.input_shapes[index] is not None:
|
||
|
|
return PartialShape(self.input_shapes[index])
|
||
|
|
_input = self._raw_input(index)
|
||
|
|
return self.get_shape_for_value(_input)
|
||
|
|
|
||
|
|
def get_input_strides(self, index: int) -> list:
|
||
|
|
raw_input = self._raw_input(index)
|
||
|
|
if isinstance(raw_input, torch.fx.node.Node) and hasattr(raw_input, "meta"):
|
||
|
|
meta = raw_input.meta
|
||
|
|
if "tensor_meta" in meta and hasattr(meta["tensor_meta"], "stride"):
|
||
|
|
strides = list(meta["tensor_meta"].stride)
|
||
|
|
if strides:
|
||
|
|
return strides
|
||
|
|
return []
|
||
|
|
|
||
|
|
def get_input_type(self, index):
|
||
|
|
if index < len(self.input_types) and self.input_types[index] is not None:
|
||
|
|
return self.input_types[index]
|
||
|
|
_input = self._raw_input(index)
|
||
|
|
return self.get_type_for_value(_input)
|
||
|
|
|
||
|
|
def get_output_debug_name(self, index):
|
||
|
|
if self._outputs is not None and index < len(self._outputs) and self._outputs[index][0]:
|
||
|
|
return self._outputs[index][0]
|
||
|
|
name = getattr(self.pt_module, "name", "output")
|
||
|
|
return name + ":" + str(index)
|
||
|
|
|
||
|
|
def get_output_shape(self, index):
|
||
|
|
output = self._raw_output(index)
|
||
|
|
return self.get_shape_for_value(output)
|
||
|
|
|
||
|
|
def get_output_type(self, index):
|
||
|
|
output = self._raw_output(index)
|
||
|
|
return self.get_type_for_value(output)
|
||
|
|
|
||
|
|
def get_shape_for_value(self, value):
|
||
|
|
if value and hasattr(value, "meta") and ("tensor_meta" in value.meta.keys()):
|
||
|
|
if value.meta["tensor_meta"]:
|
||
|
|
return PartialShape(len(value.meta["tensor_meta"].shape) * [-1])
|
||
|
|
return PartialShape.dynamic()
|
||
|
|
|
||
|
|
def get_attribute(self, name):
|
||
|
|
if name in self.pt_module.kwargs:
|
||
|
|
attr = self.pt_module.kwargs[name]
|
||
|
|
if isinstance(attr, torch.dtype):
|
||
|
|
return OVAny(pt_to_ov_type_map[str(attr)])
|
||
|
|
if isinstance(attr, torch.device):
|
||
|
|
return OVAny(attr.type)
|
||
|
|
if isinstance(attr, str):
|
||
|
|
return OVAny(attr)
|
||
|
|
# Numeric attrs convert to Constant
|
||
|
|
constant = self.arg_to_constant(attr)
|
||
|
|
if constant is not None:
|
||
|
|
return OVAny(constant.output(0))
|
||
|
|
# so that has_attribute return True if attribute exist
|
||
|
|
return OVAny(DecoderType.PyNone())
|
||
|
|
return OVAny(None)
|
||
|
|
|
||
|
|
def get_named_input(self, name):
|
||
|
|
"""Returns id of kwargs input.
|
||
|
|
|
||
|
|
Such input can be Node or a constant value,
|
||
|
|
this function is only used for to return node index. If the input is
|
||
|
|
constant, get_attribute should be used.
|
||
|
|
"""
|
||
|
|
if name in self.pt_module.kwargs:
|
||
|
|
arg = self.pt_module.kwargs[name]
|
||
|
|
if isinstance(arg, torch.fx.Node):
|
||
|
|
return self._nodes.index(arg)
|
||
|
|
raise RuntimeError("This input is not a Node")
|
||
|
|
|
||
|
|
def visit_subgraph(self, node_visitor):
|
||
|
|
# make sure topological order is satisfied
|
||
|
|
for node in self._nodes:
|
||
|
|
if node.op in {"placeholder", "output"}:
|
||
|
|
continue # skipping non-operational nodes
|
||
|
|
if node.op == "call_function" and str(node.target) in ["aten._assert_async.msg"]:
|
||
|
|
continue
|
||
|
|
decoder = TorchFXPythonDecoder(
|
||
|
|
node, self.fx_gm, self._nodes, mark_node_callback=self.mark_node_callback)
|
||
|
|
self.m_decoders.append(decoder)
|
||
|
|
node_visitor(decoder)
|
||
|
|
|
||
|
|
def get_subgraph_decoder(self, index):
|
||
|
|
decoder = TorchFXPythonDecoder(self.get_subgraphs()[index],
|
||
|
|
self.fx_gm,
|
||
|
|
mark_node_callback=self.mark_node_callback)
|
||
|
|
self.m_decoders.append(decoder)
|
||
|
|
return decoder
|
||
|
|
|
||
|
|
def get_op_type(self):
|
||
|
|
if self.pt_module.op == "call_function":
|
||
|
|
if type(self.pt_module.target).__name__ == "EdgeOpOverload":
|
||
|
|
return self.pt_module.target.__name__
|
||
|
|
return str(self.pt_module.target)
|
||
|
|
elif self.pt_module.op == "get_attr":
|
||
|
|
return "get_attr" # FIXME should be aligned with get_attr from TS implementation
|
||
|
|
else:
|
||
|
|
return "UNKNOWN_TYPE_" + str(self.pt_module.op)
|
||
|
|
|
||
|
|
def outputs(self):
|
||
|
|
return [o[1] for o in self._outputs]
|
||
|
|
|
||
|
|
def _raw_outputs(self):
|
||
|
|
return [self._nodes[x[1]] for x in self._outputs]
|
||
|
|
|
||
|
|
def _raw_output(self, index):
|
||
|
|
return self._raw_outputs()[index]
|
||
|
|
|
||
|
|
def _raw_inputs(self):
|
||
|
|
return [self._nodes[x] if not isinstance(x, InlinedInput) and x < len(self._nodes) else x.data for x in self._inputs]
|
||
|
|
|
||
|
|
def _raw_input(self, index):
|
||
|
|
return self._raw_inputs()[index]
|
||
|
|
|
||
|
|
def num_of_outputs(self):
|
||
|
|
return len(self.outputs())
|
||
|
|
|
||
|
|
def output_list_size(self):
|
||
|
|
max_out_id = -1
|
||
|
|
for user in self.pt_module.users:
|
||
|
|
if "<built-in function getitem>" == str(user.target) and max_out_id < user.args[1]:
|
||
|
|
max_out_id = user.args[1]
|
||
|
|
return max_out_id + 1
|
||
|
|
|
||
|
|
def mark_node(self, node):
|
||
|
|
name = self.get_op_type()
|
||
|
|
if "FrameworkNode" not in node.get_type_name():
|
||
|
|
name += "/" + node.get_type_name()
|
||
|
|
node.set_friendly_name(self.pt_module.name + "/" + name)
|
||
|
|
super().mark_node(node)
|
||
|
|
return node
|
||
|
|
|
||
|
|
def as_constant(self):
|
||
|
|
assert self.pt_module.op == "get_attr", "Only get_attr is supported"
|
||
|
|
# Extract Constant from FX module field
|
||
|
|
ret = fetch_attr(self.fx_gm, self.pt_module.target)
|
||
|
|
ov_const = torch_tensor_to_ov_const(ret, shared_memory=True)
|
||
|
|
return ov_const.outputs()
|
||
|
|
|
||
|
|
def input_is_none(self, index):
|
||
|
|
if index >= len(self._inputs) or (isinstance(self._inputs[index], InlinedInput) and self._inputs[index].data is None):
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
r_input = self._raw_input(index)
|
||
|
|
return str(type(r_input)) in ["torch.NoneType", "NoneType"]
|
||
|
|
|
||
|
|
def debug(self):
|
||
|
|
self.pt_module.print()
|
||
|
|
|
||
|
|
|
||
|
|
class InlinedInput:
|
||
|
|
"""Represents an inlined input.
|
||
|
|
|
||
|
|
This is a special case where the input is not a node, but a constant value.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, data) -> None:
|
||
|
|
self.data = data
|
||
|
|
|
||
|
|
|
||
|
|
class InlinedInputDecoder (BaseFXDecoder):
|
||
|
|
"""Decoder for inlined inputs in PyTorch FX graphs."""
|
||
|
|
|
||
|
|
def __init__(self, inlined_input: InlinedInput, nodes=None, mark_node_callback=None) -> None:
|
||
|
|
super().__init__(mark_node_callback)
|
||
|
|
self.inlined_input = inlined_input
|
||
|
|
self._nodes = nodes
|
||
|
|
self.is_const = not (isinstance(inlined_input.data, (list, tuple)) and any(
|
||
|
|
isinstance(a, torch.fx.Node) for a in inlined_input.data))
|
||
|
|
if not self.is_const:
|
||
|
|
self._inputs = [nodes.index(x) if isinstance(
|
||
|
|
x, torch.fx.Node) else InlinedInput(x) for x in inlined_input.data]
|
||
|
|
|
||
|
|
def get_op_type(self):
|
||
|
|
# return specific type for inlined inputs
|
||
|
|
if not self.is_const:
|
||
|
|
return "prim::ListConstruct"
|
||
|
|
return "inlined.constant.default"
|
||
|
|
|
||
|
|
def outputs(self):
|
||
|
|
return [0]
|
||
|
|
|
||
|
|
def num_of_outputs(self):
|
||
|
|
return 1
|
||
|
|
|
||
|
|
def get_input_shape(self, index):
|
||
|
|
return PartialShape.dynamic()
|
||
|
|
|
||
|
|
def get_input_type(self, index):
|
||
|
|
return OVAny(OVType.dynamic)
|
||
|
|
|
||
|
|
def get_output_type(self, index):
|
||
|
|
return OVAny(OVType.dynamic)
|
||
|
|
|
||
|
|
def input_is_none(self, index):
|
||
|
|
if index < len(self._inputs) and isinstance(self._inputs[index], InlinedInput):
|
||
|
|
return self._inputs[index].data is None
|
||
|
|
return False
|
||
|
|
|
||
|
|
def as_constant(self):
|
||
|
|
arg = self.inlined_input.data
|
||
|
|
constant = BaseFXDecoder.arg_to_constant(arg)
|
||
|
|
if constant is not None:
|
||
|
|
return constant.outputs()
|
||
|
|
return []
|
||
|
|
|
||
|
|
def mark_node(self, node):
|
||
|
|
name = self.get_op_type()
|
||
|
|
if "FrameworkNode" not in node.get_type_name():
|
||
|
|
name += "/" + node.get_type_name()
|
||
|
|
node.set_friendly_name(name)
|
||
|
|
super().mark_node(node)
|
||
|
|
return node
|