413 lines
15 KiB
Python
413 lines
15 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright (C) 2018-2025 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# mypy: ignore-errors
|
|
|
|
import inspect
|
|
import logging
|
|
import torch
|
|
import numpy as np
|
|
from contextlib import contextmanager
|
|
|
|
from openvino import op, Type as OVType, Shape, Tensor, OVAny
|
|
from openvino import opset11 as ops
|
|
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def make_constant(*args, **kwargs):
|
|
return op.Constant(*args, **kwargs)
|
|
|
|
|
|
def fetch_attr(self_module, target: str):
|
|
"""Fetch an attribute from the `Module` hierarchy of `self.module`.
|
|
|
|
Args:
|
|
self_module (torch.nn.Module): The module to fetch the attribute from
|
|
target (str): The fully-qualified name of the attribute to fetch
|
|
|
|
Returns:
|
|
Any: The value of the attribute.
|
|
"""
|
|
target_atoms = target.split(".")
|
|
attr_itr = self_module
|
|
for i, atom in enumerate(target_atoms):
|
|
if not hasattr(attr_itr, atom):
|
|
raise RuntimeError(
|
|
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
|
|
attr_itr = getattr(attr_itr, atom)
|
|
return attr_itr
|
|
|
|
|
|
def get_type_from_py_type(value):
|
|
if isinstance(value, float):
|
|
return OVType.f32
|
|
if isinstance(value, bool):
|
|
return OVType.boolean
|
|
if isinstance(value, int):
|
|
return OVType.i64
|
|
if isinstance(value, complex):
|
|
return OVType.f32
|
|
return OVType.dynamic
|
|
|
|
|
|
F8_DTYPE_MAP = {
|
|
torch.float8_e4m3fn: OVType.f8e4m3,
|
|
torch.float8_e5m2: OVType.f8e5m2,
|
|
}
|
|
|
|
|
|
def torch_tensor_to_ov_const(torch_t: torch.Tensor, shared_memory=True):
|
|
try:
|
|
from torch._prims import FakeTensor
|
|
if isinstance(torch_t, FakeTensor):
|
|
raise AssertionError("`FakeTensor` detected. Infer the "
|
|
"model before exporting to avoid this.")
|
|
except ImportError:
|
|
log.debug("Failed to import FakeTensor")
|
|
|
|
dtype = torch_t.dtype
|
|
torch_t = torch_t.contiguous()
|
|
if dtype == torch.bfloat16:
|
|
# reinterpret bfloat16 data as float16 to allow conversion to numpy
|
|
torch_t = torch_t.view(torch.float16)
|
|
narr = torch_t.numpy(force=True)
|
|
tensor = Tensor(narr, torch_t.shape, OVType.bf16)
|
|
ov_const = op.Constant(tensor, shared_memory=shared_memory)
|
|
elif dtype in F8_DTYPE_MAP:
|
|
# reinterpret f8 data as u8 to allow conversion to numpy
|
|
torch_t = torch_t.view(torch.uint8)
|
|
narr = torch_t.numpy(force=True)
|
|
tensor = Tensor(narr, torch_t.shape, F8_DTYPE_MAP[dtype])
|
|
ov_const = op.Constant(tensor, shared_memory=shared_memory)
|
|
elif torch_t.is_complex():
|
|
narr = torch.view_as_real(torch_t).numpy(force=True)
|
|
# we rely on frontend to mark the constant as complex internally
|
|
ov_const = op.Constant(narr, shared_memory=shared_memory)
|
|
else:
|
|
narr = torch_t.numpy(force=True)
|
|
ov_const = op.Constant(narr, shared_memory=shared_memory)
|
|
return ov_const
|
|
|
|
|
|
def ivalue_to_constant(ivalue, shared_memory=True):
|
|
ov_type = get_type_from_py_type(ivalue)
|
|
if ov_type.is_static():
|
|
if isinstance(ivalue, complex):
|
|
return op.Constant(ov_type, Shape([2]), [ivalue.real, ivalue.imag]).outputs()
|
|
else:
|
|
return op.Constant(ov_type, Shape([]), [ivalue]).outputs()
|
|
|
|
if isinstance(ivalue, (list, tuple)):
|
|
assert len(ivalue) > 0, "Can't deduce type for empty list"
|
|
ov_type = get_type_from_py_type(ivalue[0])
|
|
assert ov_type.is_static(), "Can't deduce type for list"
|
|
return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs()
|
|
|
|
if isinstance(ivalue, torch.Tensor):
|
|
return torch_tensor_to_ov_const(ivalue, shared_memory=shared_memory).outputs()
|
|
return None
|
|
|
|
|
|
def get_value_from_getattr(getattr_node, self_module):
|
|
assert getattr_node.kind() == "prim::GetAttr", "Got node of kind not equal to prim::GetAttr"
|
|
# GetAttr nodes can be nested
|
|
stack = []
|
|
while getattr_node.kind() == "prim::GetAttr":
|
|
stack.append(getattr_node)
|
|
inputs = list(getattr_node.inputs())
|
|
if len(inputs) == 0:
|
|
break
|
|
getattr_node = inputs[0].node()
|
|
module = self_module
|
|
path_name = "self"
|
|
while len(stack) > 0:
|
|
node = stack.pop()
|
|
attr_name = node.s("name")
|
|
assert hasattr(
|
|
module, attr_name), f'No attribute with name "{attr_name}" found in module.'
|
|
path_name = ".".join([path_name, attr_name])
|
|
module = getattr(module, attr_name)
|
|
return module, path_name
|
|
|
|
|
|
def graph_has_ops(graph, op_types: list) -> bool:
|
|
res = False
|
|
for node in graph.nodes():
|
|
if any(kind in node.kind() for kind in op_types):
|
|
return True
|
|
for block in node.blocks():
|
|
res = graph_has_ops(block, op_types)
|
|
if res:
|
|
return res
|
|
return res
|
|
|
|
|
|
pt_to_ov_type_map = {
|
|
"float": OVType.f32,
|
|
"int": OVType.i64,
|
|
"bool": OVType.boolean,
|
|
"torch.float8_e4m3fn": OVType.f8e4m3,
|
|
"torch.float8_e5m2": OVType.f8e5m2,
|
|
"torch.bfloat16": OVType.bf16,
|
|
"torch.float16": OVType.f16,
|
|
"torch.float32": OVType.f32,
|
|
"torch.float64": OVType.f64,
|
|
"torch.complex32": DecoderType.Complex(OVAny(OVType.f16)),
|
|
"torch.complex64": DecoderType.Complex(OVAny(OVType.f32)),
|
|
"torch.complex128": DecoderType.Complex(OVAny(OVType.f64)),
|
|
"torch.uint8": OVType.u8,
|
|
"torch.int8": OVType.i8,
|
|
"torch.int16": OVType.i16,
|
|
"torch.int32": OVType.i32,
|
|
"torch.int64": OVType.i64,
|
|
"torch.bool": OVType.boolean,
|
|
"torch.DoubleTensor": OVType.f64,
|
|
"torch.FloatTensor": OVType.f32,
|
|
"torch.HalfTensor": OVType.f16,
|
|
"torch.BFloat16Tensor": OVType.bf16,
|
|
"torch.IntTensor": OVType.i32,
|
|
"torch.LongTensor": OVType.i64,
|
|
"torch.ShortTensor": OVType.i16,
|
|
"torch.CharTensor": OVType.i8,
|
|
"torch.ByteTensor": OVType.u8,
|
|
"torch.BoolTensor": OVType.boolean,
|
|
"torch.ComplexHalfTensor": DecoderType.Complex(OVAny(OVType.f16)),
|
|
"torch.ComplexFloatTensor": DecoderType.Complex(OVAny(OVType.f32)),
|
|
"torch.ComplexDoubleTensor": DecoderType.Complex(OVAny(OVType.f64)),
|
|
"torch.quint8": OVType.u8,
|
|
"torch.qint8": OVType.i8,
|
|
"torch.qint32": OVType.i32,
|
|
}
|
|
|
|
|
|
wrapper_template = """
|
|
import torch
|
|
from typing import *
|
|
|
|
class ModelWrapper(torch.nn.Module):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.model = model
|
|
|
|
def forward(self, {input_sign}):
|
|
return self.model({example_input})
|
|
"""
|
|
|
|
|
|
def build_wrapper(template, model):
|
|
"""Builds a wrapper around the given model using the provided template."""
|
|
result = {}
|
|
try:
|
|
exec(template, result)
|
|
|
|
wrapped_model = result["ModelWrapper"](model)
|
|
wrapped_model.eval()
|
|
# if wrapping failed, it is better to return original model for avoid user confusion regarding error message
|
|
except Exception:
|
|
log.error("Failed to build model wrapper.")
|
|
wrapped_model = model
|
|
return wrapped_model
|
|
|
|
|
|
def process_dict_inputs(inputs, input_params, model):
|
|
ordered_inputs = []
|
|
for input_name in input_params:
|
|
if input_name in inputs:
|
|
ordered_inputs.append(input_name)
|
|
|
|
input_signature = list(input_params)
|
|
if ordered_inputs == input_signature[: len(ordered_inputs)]:
|
|
example_inputs = [inputs[input_name] for input_name in ordered_inputs]
|
|
if all(isinstance(inp, torch.Tensor) for inp in example_inputs):
|
|
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, model
|
|
return {"example_inputs": example_inputs}, ordered_inputs, model
|
|
|
|
# PyTorch has some difficulties to trace models with named unordered parameters:
|
|
# torch < 2.0.0 supports only positional arguments for tracing
|
|
# pytorch == 2.0.0 supports input kwargs tracing,
|
|
# but does not support complex nested objects (e. g. tuple of tuples of tensors)
|
|
# We will use wrapper for making them positional as workaround.
|
|
|
|
input_sign_str = []
|
|
input_params_str = []
|
|
|
|
for input_name in ordered_inputs:
|
|
if str(input_params[input_name].annotation).startswith("typing.Union"):
|
|
filter_custom_args = []
|
|
for arg in input_params[input_name].annotation.__args__:
|
|
str_arg = str(arg)
|
|
is_typing = str_arg.startswith("typing.")
|
|
is_torch = "torch." in str_arg
|
|
is_builten = str_arg in (str(int), str(float), str(type(None)))
|
|
if not (is_typing or is_torch or is_builten):
|
|
continue
|
|
filter_custom_args.append(arg)
|
|
input_params[input_name].annotation.__args__ = tuple(
|
|
filter_custom_args)
|
|
input_sign_str.append(
|
|
str(input_params[input_name]).replace("NoneType", "None"))
|
|
input_params_str.append(f"{input_name}={input_name}")
|
|
|
|
wrapper_class = wrapper_template.format(input_sign=", ".join(
|
|
input_sign_str), example_input=", ".join(input_params_str))
|
|
|
|
wrapped_model = build_wrapper(wrapper_class, model)
|
|
|
|
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
|
|
|
|
|
|
def prepare_example_inputs_and_model(inputs, input_params, model):
|
|
input_is_list = False
|
|
input_signature = list(input_params)
|
|
if isinstance(inputs, dict):
|
|
examples, ordered, wrapped = process_dict_inputs(
|
|
inputs, input_params, model)
|
|
return examples, ordered, wrapped, input_is_list
|
|
if isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
|
|
if "list" in str(input_params[input_signature[0]].annotation):
|
|
inputs = inputs[0].unsqueeze(0)
|
|
input_is_list = True
|
|
|
|
if isinstance(inputs, torch.Tensor):
|
|
inputs = [inputs]
|
|
input_signature = input_signature[: len(inputs)]
|
|
return {"example_inputs": inputs}, input_signature, model, input_is_list
|
|
|
|
|
|
def convert_quantized_tensor(qtensor: torch.Tensor, shared_memory: bool):
|
|
# represents torch quantized tensor as
|
|
# Constant(u8) -> Convert(f32) -> Subtract(zero_point) -> Multiply(scale)
|
|
qscheme = qtensor.qscheme()
|
|
if qscheme == torch.per_channel_affine:
|
|
int8_tensor = qtensor.int_repr()
|
|
scale = qtensor.q_per_channel_scales().numpy().astype(np.float32)
|
|
zero_point = qtensor.q_per_channel_zero_points().numpy().astype(np.float32)
|
|
axis = np.int32(qtensor.q_per_channel_axis())
|
|
|
|
new_shape = np.ones(len(int8_tensor.shape), dtype=np.int32)
|
|
new_shape[axis] = -1
|
|
zero_point_bc = np.reshape(zero_point, new_shape)
|
|
scale_bc = np.reshape(scale, new_shape)
|
|
|
|
int8_const = torch_tensor_to_ov_const(
|
|
int8_tensor, shared_memory=shared_memory)
|
|
convert = ops.convert(int8_const, np.float32)
|
|
sub = ops.subtract(convert, zero_point_bc)
|
|
return ops.multiply(sub, scale_bc).outputs()
|
|
elif qscheme == torch.per_tensor_affine:
|
|
int8_tensor = qtensor.int_repr()
|
|
scale = np.float32(qtensor.q_scale())
|
|
zero_point = np.float32(qtensor.q_zero_point())
|
|
|
|
int8_const = torch_tensor_to_ov_const(
|
|
int8_tensor, shared_memory=shared_memory)
|
|
convert = ops.convert(int8_const, np.float32)
|
|
sub = ops.subtract(convert, zero_point)
|
|
return ops.multiply(sub, scale).outputs()
|
|
raise AssertionError(f"Unsupported qscheme: {qscheme}")
|
|
|
|
|
|
def process_individual_input(arg, arg_name):
|
|
"""Generate signature, param string, example, and wrap flag from input.
|
|
|
|
Args:
|
|
arg: The input value to process.
|
|
arg_name: The name of the input.
|
|
|
|
Returns:
|
|
tuple: (signature, param string, example entry, wrap flag).
|
|
"""
|
|
sign = None
|
|
param = None
|
|
example_entry = None
|
|
to_wrap = False
|
|
if isinstance(arg, tuple):
|
|
internal_input = []
|
|
new_tuple = []
|
|
index = 0
|
|
for value in arg:
|
|
if value is None:
|
|
to_wrap = True
|
|
internal_input.append("None")
|
|
else:
|
|
internal_input.append(f"{arg_name}[{index}]")
|
|
new_tuple.append(value)
|
|
index += 1
|
|
param = f"({', '.join(internal_input)},)"
|
|
if len(new_tuple) > 0:
|
|
example_entry = tuple(new_tuple)
|
|
sign = arg_name
|
|
elif arg is None:
|
|
to_wrap = True
|
|
param = "None"
|
|
else:
|
|
sign = arg_name
|
|
param = arg_name
|
|
example_entry = arg
|
|
return sign, param, example_entry, to_wrap
|
|
|
|
|
|
def patch_none_example(model: torch.nn.Module, example):
|
|
"""Patch a PyTorch model to handle None values in the input example."""
|
|
callable_func = getattr(model, "forward", model.__call__)
|
|
input_params = inspect.signature(callable_func).parameters
|
|
input_signature = list(input_params)
|
|
input_sign_str = []
|
|
input_params_str = []
|
|
to_wrap = False
|
|
if isinstance(example, tuple) and len(input_signature) >= len(example):
|
|
new_example = []
|
|
for i, arg in enumerate(example):
|
|
arg_name = input_signature[i]
|
|
sign, param, example_entry, _to_wrap = process_individual_input(arg, arg_name)
|
|
to_wrap = to_wrap or _to_wrap
|
|
if sign is not None:
|
|
input_sign_str.append(str(input_params[sign]))
|
|
input_params_str.append(param)
|
|
if example_entry is not None:
|
|
new_example.append(example_entry)
|
|
if to_wrap:
|
|
wrapper_class = wrapper_template.format(input_sign=", ".join(input_sign_str),
|
|
example_input=", ".join(input_params_str))
|
|
wrapped_model = build_wrapper(wrapper_class, model)
|
|
log.warning("Model has None in the example input. The input "
|
|
"with None will be removed from the resulting model.")
|
|
return wrapped_model, tuple(new_example)
|
|
elif isinstance(example, dict) and len(input_signature) >= len(example):
|
|
new_example = {}
|
|
input_signature = [s for s in input_signature if s in example]
|
|
for arg_name in input_signature:
|
|
arg = example[arg_name]
|
|
sign, param, example_entry, _to_wrap = process_individual_input(arg, arg_name)
|
|
to_wrap = to_wrap or _to_wrap
|
|
if sign is not None:
|
|
input_sign_str.append(str(input_params[sign]))
|
|
input_params_str.append(f"{arg_name}={param}")
|
|
if example_entry is not None:
|
|
new_example[arg_name] = example_entry
|
|
if to_wrap:
|
|
wrapper_class = wrapper_template.format(input_sign=", ".join(input_sign_str),
|
|
example_input=", ".join(input_params_str))
|
|
wrapped_model = build_wrapper(wrapper_class, model)
|
|
log.warning("Model has None in the example input. The input "
|
|
"with None will be removed from the resulting model.")
|
|
return wrapped_model, new_example
|
|
return model, example
|
|
|
|
|
|
@contextmanager
|
|
def no_jit_trace():
|
|
"""Context manager to disable JIT tracing.
|
|
|
|
Note: using this function on large models consume a lot of memory.
|
|
"""
|
|
state = torch._C._get_tracing_state()
|
|
torch._C._set_tracing_state(None)
|
|
try:
|
|
yield
|
|
finally:
|
|
torch._C._set_tracing_state(state)
|