Files
ANSLibs/OpenVINO/python/openvino/frontend/pytorch/utils.py

413 lines
15 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# mypy: ignore-errors
import inspect
import logging
import torch
import numpy as np
from contextlib import contextmanager
from openvino import op, Type as OVType, Shape, Tensor, OVAny
from openvino import opset11 as ops
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
log = logging.getLogger(__name__)
def make_constant(*args, **kwargs):
return op.Constant(*args, **kwargs)
def fetch_attr(self_module, target: str):
"""Fetch an attribute from the `Module` hierarchy of `self.module`.
Args:
self_module (torch.nn.Module): The module to fetch the attribute from
target (str): The fully-qualified name of the attribute to fetch
Returns:
Any: The value of the attribute.
"""
target_atoms = target.split(".")
attr_itr = self_module
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
def get_type_from_py_type(value):
if isinstance(value, float):
return OVType.f32
if isinstance(value, bool):
return OVType.boolean
if isinstance(value, int):
return OVType.i64
if isinstance(value, complex):
return OVType.f32
return OVType.dynamic
F8_DTYPE_MAP = {
torch.float8_e4m3fn: OVType.f8e4m3,
torch.float8_e5m2: OVType.f8e5m2,
}
def torch_tensor_to_ov_const(torch_t: torch.Tensor, shared_memory=True):
try:
from torch._prims import FakeTensor
if isinstance(torch_t, FakeTensor):
raise AssertionError("`FakeTensor` detected. Infer the "
"model before exporting to avoid this.")
except ImportError:
log.debug("Failed to import FakeTensor")
dtype = torch_t.dtype
torch_t = torch_t.contiguous()
if dtype == torch.bfloat16:
# reinterpret bfloat16 data as float16 to allow conversion to numpy
torch_t = torch_t.view(torch.float16)
narr = torch_t.numpy(force=True)
tensor = Tensor(narr, torch_t.shape, OVType.bf16)
ov_const = op.Constant(tensor, shared_memory=shared_memory)
elif dtype in F8_DTYPE_MAP:
# reinterpret f8 data as u8 to allow conversion to numpy
torch_t = torch_t.view(torch.uint8)
narr = torch_t.numpy(force=True)
tensor = Tensor(narr, torch_t.shape, F8_DTYPE_MAP[dtype])
ov_const = op.Constant(tensor, shared_memory=shared_memory)
elif torch_t.is_complex():
narr = torch.view_as_real(torch_t).numpy(force=True)
# we rely on frontend to mark the constant as complex internally
ov_const = op.Constant(narr, shared_memory=shared_memory)
else:
narr = torch_t.numpy(force=True)
ov_const = op.Constant(narr, shared_memory=shared_memory)
return ov_const
def ivalue_to_constant(ivalue, shared_memory=True):
ov_type = get_type_from_py_type(ivalue)
if ov_type.is_static():
if isinstance(ivalue, complex):
return op.Constant(ov_type, Shape([2]), [ivalue.real, ivalue.imag]).outputs()
else:
return op.Constant(ov_type, Shape([]), [ivalue]).outputs()
if isinstance(ivalue, (list, tuple)):
assert len(ivalue) > 0, "Can't deduce type for empty list"
ov_type = get_type_from_py_type(ivalue[0])
assert ov_type.is_static(), "Can't deduce type for list"
return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs()
if isinstance(ivalue, torch.Tensor):
return torch_tensor_to_ov_const(ivalue, shared_memory=shared_memory).outputs()
return None
def get_value_from_getattr(getattr_node, self_module):
assert getattr_node.kind() == "prim::GetAttr", "Got node of kind not equal to prim::GetAttr"
# GetAttr nodes can be nested
stack = []
while getattr_node.kind() == "prim::GetAttr":
stack.append(getattr_node)
inputs = list(getattr_node.inputs())
if len(inputs) == 0:
break
getattr_node = inputs[0].node()
module = self_module
path_name = "self"
while len(stack) > 0:
node = stack.pop()
attr_name = node.s("name")
assert hasattr(
module, attr_name), f'No attribute with name "{attr_name}" found in module.'
path_name = ".".join([path_name, attr_name])
module = getattr(module, attr_name)
return module, path_name
def graph_has_ops(graph, op_types: list) -> bool:
res = False
for node in graph.nodes():
if any(kind in node.kind() for kind in op_types):
return True
for block in node.blocks():
res = graph_has_ops(block, op_types)
if res:
return res
return res
pt_to_ov_type_map = {
"float": OVType.f32,
"int": OVType.i64,
"bool": OVType.boolean,
"torch.float8_e4m3fn": OVType.f8e4m3,
"torch.float8_e5m2": OVType.f8e5m2,
"torch.bfloat16": OVType.bf16,
"torch.float16": OVType.f16,
"torch.float32": OVType.f32,
"torch.float64": OVType.f64,
"torch.complex32": DecoderType.Complex(OVAny(OVType.f16)),
"torch.complex64": DecoderType.Complex(OVAny(OVType.f32)),
"torch.complex128": DecoderType.Complex(OVAny(OVType.f64)),
"torch.uint8": OVType.u8,
"torch.int8": OVType.i8,
"torch.int16": OVType.i16,
"torch.int32": OVType.i32,
"torch.int64": OVType.i64,
"torch.bool": OVType.boolean,
"torch.DoubleTensor": OVType.f64,
"torch.FloatTensor": OVType.f32,
"torch.HalfTensor": OVType.f16,
"torch.BFloat16Tensor": OVType.bf16,
"torch.IntTensor": OVType.i32,
"torch.LongTensor": OVType.i64,
"torch.ShortTensor": OVType.i16,
"torch.CharTensor": OVType.i8,
"torch.ByteTensor": OVType.u8,
"torch.BoolTensor": OVType.boolean,
"torch.ComplexHalfTensor": DecoderType.Complex(OVAny(OVType.f16)),
"torch.ComplexFloatTensor": DecoderType.Complex(OVAny(OVType.f32)),
"torch.ComplexDoubleTensor": DecoderType.Complex(OVAny(OVType.f64)),
"torch.quint8": OVType.u8,
"torch.qint8": OVType.i8,
"torch.qint32": OVType.i32,
}
wrapper_template = """
import torch
from typing import *
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, {input_sign}):
return self.model({example_input})
"""
def build_wrapper(template, model):
"""Builds a wrapper around the given model using the provided template."""
result = {}
try:
exec(template, result)
wrapped_model = result["ModelWrapper"](model)
wrapped_model.eval()
# if wrapping failed, it is better to return original model for avoid user confusion regarding error message
except Exception:
log.error("Failed to build model wrapper.")
wrapped_model = model
return wrapped_model
def process_dict_inputs(inputs, input_params, model):
ordered_inputs = []
for input_name in input_params:
if input_name in inputs:
ordered_inputs.append(input_name)
input_signature = list(input_params)
if ordered_inputs == input_signature[: len(ordered_inputs)]:
example_inputs = [inputs[input_name] for input_name in ordered_inputs]
if all(isinstance(inp, torch.Tensor) for inp in example_inputs):
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, model
return {"example_inputs": example_inputs}, ordered_inputs, model
# PyTorch has some difficulties to trace models with named unordered parameters:
# torch < 2.0.0 supports only positional arguments for tracing
# pytorch == 2.0.0 supports input kwargs tracing,
# but does not support complex nested objects (e. g. tuple of tuples of tensors)
# We will use wrapper for making them positional as workaround.
input_sign_str = []
input_params_str = []
for input_name in ordered_inputs:
if str(input_params[input_name].annotation).startswith("typing.Union"):
filter_custom_args = []
for arg in input_params[input_name].annotation.__args__:
str_arg = str(arg)
is_typing = str_arg.startswith("typing.")
is_torch = "torch." in str_arg
is_builten = str_arg in (str(int), str(float), str(type(None)))
if not (is_typing or is_torch or is_builten):
continue
filter_custom_args.append(arg)
input_params[input_name].annotation.__args__ = tuple(
filter_custom_args)
input_sign_str.append(
str(input_params[input_name]).replace("NoneType", "None"))
input_params_str.append(f"{input_name}={input_name}")
wrapper_class = wrapper_template.format(input_sign=", ".join(
input_sign_str), example_input=", ".join(input_params_str))
wrapped_model = build_wrapper(wrapper_class, model)
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
def prepare_example_inputs_and_model(inputs, input_params, model):
input_is_list = False
input_signature = list(input_params)
if isinstance(inputs, dict):
examples, ordered, wrapped = process_dict_inputs(
inputs, input_params, model)
return examples, ordered, wrapped, input_is_list
if isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
if "list" in str(input_params[input_signature[0]].annotation):
inputs = inputs[0].unsqueeze(0)
input_is_list = True
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
input_signature = input_signature[: len(inputs)]
return {"example_inputs": inputs}, input_signature, model, input_is_list
def convert_quantized_tensor(qtensor: torch.Tensor, shared_memory: bool):
# represents torch quantized tensor as
# Constant(u8) -> Convert(f32) -> Subtract(zero_point) -> Multiply(scale)
qscheme = qtensor.qscheme()
if qscheme == torch.per_channel_affine:
int8_tensor = qtensor.int_repr()
scale = qtensor.q_per_channel_scales().numpy().astype(np.float32)
zero_point = qtensor.q_per_channel_zero_points().numpy().astype(np.float32)
axis = np.int32(qtensor.q_per_channel_axis())
new_shape = np.ones(len(int8_tensor.shape), dtype=np.int32)
new_shape[axis] = -1
zero_point_bc = np.reshape(zero_point, new_shape)
scale_bc = np.reshape(scale, new_shape)
int8_const = torch_tensor_to_ov_const(
int8_tensor, shared_memory=shared_memory)
convert = ops.convert(int8_const, np.float32)
sub = ops.subtract(convert, zero_point_bc)
return ops.multiply(sub, scale_bc).outputs()
elif qscheme == torch.per_tensor_affine:
int8_tensor = qtensor.int_repr()
scale = np.float32(qtensor.q_scale())
zero_point = np.float32(qtensor.q_zero_point())
int8_const = torch_tensor_to_ov_const(
int8_tensor, shared_memory=shared_memory)
convert = ops.convert(int8_const, np.float32)
sub = ops.subtract(convert, zero_point)
return ops.multiply(sub, scale).outputs()
raise AssertionError(f"Unsupported qscheme: {qscheme}")
def process_individual_input(arg, arg_name):
"""Generate signature, param string, example, and wrap flag from input.
Args:
arg: The input value to process.
arg_name: The name of the input.
Returns:
tuple: (signature, param string, example entry, wrap flag).
"""
sign = None
param = None
example_entry = None
to_wrap = False
if isinstance(arg, tuple):
internal_input = []
new_tuple = []
index = 0
for value in arg:
if value is None:
to_wrap = True
internal_input.append("None")
else:
internal_input.append(f"{arg_name}[{index}]")
new_tuple.append(value)
index += 1
param = f"({', '.join(internal_input)},)"
if len(new_tuple) > 0:
example_entry = tuple(new_tuple)
sign = arg_name
elif arg is None:
to_wrap = True
param = "None"
else:
sign = arg_name
param = arg_name
example_entry = arg
return sign, param, example_entry, to_wrap
def patch_none_example(model: torch.nn.Module, example):
"""Patch a PyTorch model to handle None values in the input example."""
callable_func = getattr(model, "forward", model.__call__)
input_params = inspect.signature(callable_func).parameters
input_signature = list(input_params)
input_sign_str = []
input_params_str = []
to_wrap = False
if isinstance(example, tuple) and len(input_signature) >= len(example):
new_example = []
for i, arg in enumerate(example):
arg_name = input_signature[i]
sign, param, example_entry, _to_wrap = process_individual_input(arg, arg_name)
to_wrap = to_wrap or _to_wrap
if sign is not None:
input_sign_str.append(str(input_params[sign]))
input_params_str.append(param)
if example_entry is not None:
new_example.append(example_entry)
if to_wrap:
wrapper_class = wrapper_template.format(input_sign=", ".join(input_sign_str),
example_input=", ".join(input_params_str))
wrapped_model = build_wrapper(wrapper_class, model)
log.warning("Model has None in the example input. The input "
"with None will be removed from the resulting model.")
return wrapped_model, tuple(new_example)
elif isinstance(example, dict) and len(input_signature) >= len(example):
new_example = {}
input_signature = [s for s in input_signature if s in example]
for arg_name in input_signature:
arg = example[arg_name]
sign, param, example_entry, _to_wrap = process_individual_input(arg, arg_name)
to_wrap = to_wrap or _to_wrap
if sign is not None:
input_sign_str.append(str(input_params[sign]))
input_params_str.append(f"{arg_name}={param}")
if example_entry is not None:
new_example[arg_name] = example_entry
if to_wrap:
wrapper_class = wrapper_template.format(input_sign=", ".join(input_sign_str),
example_input=", ".join(input_params_str))
wrapped_model = build_wrapper(wrapper_class, model)
log.warning("Model has None in the example input. The input "
"with None will be removed from the resulting model.")
return wrapped_model, new_example
return model, example
@contextmanager
def no_jit_trace():
"""Context manager to disable JIT tracing.
Note: using this function on large models consume a lot of memory.
"""
state = torch._C._get_tracing_state()
torch._C._set_tracing_state(None)
try:
yield
finally:
torch._C._set_tracing_state(state)