347 lines
13 KiB
Python
347 lines
13 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
||
|
|
# Copyright (C) 2018-2025 Intel Corporation
|
||
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
|
||
|
|
# mypy: disable-error-code="no-redef"
|
||
|
|
|
||
|
|
import numbers
|
||
|
|
import logging
|
||
|
|
import copy
|
||
|
|
import numpy as np
|
||
|
|
from abc import ABCMeta, abstractmethod
|
||
|
|
from functools import singledispatch
|
||
|
|
from typing import Any, Union
|
||
|
|
from collections.abc import Callable, Sequence
|
||
|
|
from PIL import Image
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torchvision.transforms as transforms
|
||
|
|
from torchvision.transforms import InterpolationMode
|
||
|
|
|
||
|
|
import openvino as ov
|
||
|
|
import openvino.opset11 as ops
|
||
|
|
from openvino import Layout, Type, Output, Model
|
||
|
|
from openvino.utils.decorators import custom_preprocess_function
|
||
|
|
from openvino.preprocess import PrePostProcessor, ResizeAlgorithm, ColorFormat
|
||
|
|
|
||
|
|
|
||
|
|
TORCHTYPE_TO_OVTYPE = {
|
||
|
|
float: ov.Type.f32,
|
||
|
|
int: ov.Type.i32,
|
||
|
|
bool: ov.Type.boolean,
|
||
|
|
torch.float16: ov.Type.f16,
|
||
|
|
torch.float32: ov.Type.f32,
|
||
|
|
torch.float64: ov.Type.f64,
|
||
|
|
torch.uint8: ov.Type.u8,
|
||
|
|
torch.int8: ov.Type.i8,
|
||
|
|
torch.int32: ov.Type.i32,
|
||
|
|
torch.int64: ov.Type.i64,
|
||
|
|
torch.bool: ov.Type.boolean,
|
||
|
|
torch.DoubleTensor: ov.Type.f64,
|
||
|
|
torch.FloatTensor: ov.Type.f32,
|
||
|
|
torch.IntTensor: ov.Type.i32,
|
||
|
|
torch.LongTensor: ov.Type.i64,
|
||
|
|
torch.BoolTensor: ov.Type.boolean,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@singledispatch
|
||
|
|
def _setup_size(size: Any, error_msg: str) -> Sequence[int]:
|
||
|
|
raise ValueError(error_msg)
|
||
|
|
|
||
|
|
|
||
|
|
@_setup_size.register
|
||
|
|
def _setup_size_number(size: numbers.Number, error_msg: str) -> Sequence[int]:
|
||
|
|
return int(size), int(size) # type: ignore
|
||
|
|
|
||
|
|
|
||
|
|
@_setup_size.register
|
||
|
|
def _setup_size_sequence(size: Sequence, error_msg: str) -> Sequence[int]:
|
||
|
|
if len(size) == 1:
|
||
|
|
return size[0], size[0]
|
||
|
|
elif len(size) == 2:
|
||
|
|
return size[0], size[1]
|
||
|
|
raise ValueError(error_msg)
|
||
|
|
|
||
|
|
|
||
|
|
def _NHWC_to_NCHW(input_shape: list) -> list: # noqa N802
|
||
|
|
new_shape = copy.deepcopy(input_shape)
|
||
|
|
new_shape[1] = input_shape[3]
|
||
|
|
new_shape[2] = input_shape[1]
|
||
|
|
new_shape[3] = input_shape[2]
|
||
|
|
return new_shape
|
||
|
|
|
||
|
|
|
||
|
|
@singledispatch
|
||
|
|
def _to_list(transform: Callable) -> list:
|
||
|
|
raise TypeError(f"Unsupported transform type: {type(transform)}")
|
||
|
|
|
||
|
|
|
||
|
|
@_to_list.register
|
||
|
|
def _to_list_torch_sequential(transform: torch.nn.Sequential) -> list:
|
||
|
|
return list(transform)
|
||
|
|
|
||
|
|
|
||
|
|
@_to_list.register
|
||
|
|
def _to_list_transforms_compose(transform: transforms.Compose) -> list:
|
||
|
|
return transform.transforms
|
||
|
|
|
||
|
|
|
||
|
|
def _get_shape_layout_from_data(input_example: Union[torch.Tensor, np.ndarray, Image.Image]) -> tuple[list, Layout]:
|
||
|
|
if isinstance(input_example, (torch.Tensor, np.ndarray, Image.Image)): # PyTorch, OpenCV, numpy, PILLOW
|
||
|
|
shape = list(np.array(input_example, copy=False).shape)
|
||
|
|
layout = Layout("NCHW") if isinstance(input_example, torch.Tensor) else Layout("NHWC")
|
||
|
|
else:
|
||
|
|
raise TypeError(f"Unsupported input type: {type(input_example)}")
|
||
|
|
|
||
|
|
if len(shape) == 3:
|
||
|
|
shape = [1] + shape
|
||
|
|
elif len(shape) != 4:
|
||
|
|
raise ValueError(f"Unsupported number of input dimensions: {len(shape)}")
|
||
|
|
|
||
|
|
return shape, layout
|
||
|
|
|
||
|
|
|
||
|
|
class TransformConverterBase(metaclass=ABCMeta):
|
||
|
|
|
||
|
|
def __init__(self, **kwargs: Any) -> None: # noqa B027
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: dict) -> None:
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class TransformConverterFactory:
|
||
|
|
|
||
|
|
registry: dict[str, Callable] = {}
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def register(cls: Callable, target_type: Union[Callable, None] = None) -> Callable:
|
||
|
|
def inner_wrapper(wrapped_class: TransformConverterBase) -> Callable:
|
||
|
|
registered_name = wrapped_class.__name__ if target_type is None else target_type.__name__
|
||
|
|
if registered_name in cls.registry:
|
||
|
|
logging.warning(f"Executor {registered_name} already exists. {wrapped_class.__name__} will replace it.")
|
||
|
|
cls.registry[registered_name] = wrapped_class
|
||
|
|
return wrapped_class # type: ignore
|
||
|
|
|
||
|
|
return inner_wrapper
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def convert(cls: Callable, converter_type: Callable, *args: Any, **kwargs: Any) -> Callable:
|
||
|
|
transform_name = converter_type.__name__
|
||
|
|
if transform_name not in cls.registry:
|
||
|
|
raise ValueError(f"{transform_name} is not supported.")
|
||
|
|
|
||
|
|
converter = cls.registry[transform_name]()
|
||
|
|
return converter.convert(*args, **kwargs)
|
||
|
|
|
||
|
|
|
||
|
|
@TransformConverterFactory.register(transforms.Normalize)
|
||
|
|
class _(TransformConverterBase):
|
||
|
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: dict) -> None:
|
||
|
|
if transform.inplace:
|
||
|
|
raise ValueError("Inplace Normaliziation is not supported.")
|
||
|
|
ppp.input(input_idx).preprocess().mean(transform.mean).scale(transform.std)
|
||
|
|
|
||
|
|
|
||
|
|
@TransformConverterFactory.register(transforms.ConvertImageDtype)
|
||
|
|
class _(TransformConverterBase):
|
||
|
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: dict) -> None:
|
||
|
|
ppp.input(input_idx).preprocess().convert_element_type(TORCHTYPE_TO_OVTYPE[transform.dtype])
|
||
|
|
|
||
|
|
|
||
|
|
@TransformConverterFactory.register(transforms.Grayscale)
|
||
|
|
class _(TransformConverterBase):
|
||
|
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: dict) -> None:
|
||
|
|
input_shape = meta["input_shape"]
|
||
|
|
layout = meta["layout"]
|
||
|
|
|
||
|
|
input_shape[layout.get_index_by_name("C")] = 1
|
||
|
|
|
||
|
|
ppp.input(input_idx).preprocess().convert_color(ColorFormat.GRAY)
|
||
|
|
if transform.num_output_channels != 1:
|
||
|
|
input_shape[layout.get_index_by_name("C")] = transform.num_output_channels
|
||
|
|
|
||
|
|
@custom_preprocess_function
|
||
|
|
def broadcast_node(output: Output) -> Callable: # type: ignore[name-defined]
|
||
|
|
return ops.broadcast( # type: ignore
|
||
|
|
data=output,
|
||
|
|
target_shape=input_shape,
|
||
|
|
)
|
||
|
|
ppp.input(input_idx).preprocess().custom(broadcast_node)
|
||
|
|
|
||
|
|
meta["input_shape"] = input_shape
|
||
|
|
|
||
|
|
|
||
|
|
@TransformConverterFactory.register(transforms.Pad)
|
||
|
|
class _(TransformConverterBase):
|
||
|
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: dict) -> None:
|
||
|
|
image_dimensions = list(meta["image_dimensions"])
|
||
|
|
layout = meta["layout"]
|
||
|
|
torch_padding = transform.padding
|
||
|
|
pad_mode = transform.padding_mode
|
||
|
|
|
||
|
|
if pad_mode == "constant":
|
||
|
|
if isinstance(transform.fill, tuple):
|
||
|
|
raise ValueError("Different fill values for R, G, B channels are not supported.")
|
||
|
|
|
||
|
|
pads_begin = [0 for _ in meta["input_shape"]]
|
||
|
|
pads_end = [0 for _ in meta["input_shape"]]
|
||
|
|
|
||
|
|
# padding equal on all sides
|
||
|
|
if isinstance(torch_padding, int):
|
||
|
|
image_dimensions[0] += 2 * torch_padding
|
||
|
|
image_dimensions[1] += 2 * torch_padding
|
||
|
|
|
||
|
|
pads_begin[layout.get_index_by_name("H")] = torch_padding
|
||
|
|
pads_begin[layout.get_index_by_name("W")] = torch_padding
|
||
|
|
pads_end[layout.get_index_by_name("H")] = torch_padding
|
||
|
|
pads_end[layout.get_index_by_name("W")] = torch_padding
|
||
|
|
|
||
|
|
# padding different in horizontal and vertical axis
|
||
|
|
elif len(torch_padding) == 2:
|
||
|
|
image_dimensions[0] += sum(torch_padding)
|
||
|
|
image_dimensions[1] += sum(torch_padding)
|
||
|
|
|
||
|
|
pads_begin[layout.get_index_by_name("H")] = torch_padding[1]
|
||
|
|
pads_begin[layout.get_index_by_name("W")] = torch_padding[0]
|
||
|
|
pads_end[layout.get_index_by_name("H")] = torch_padding[1]
|
||
|
|
pads_end[layout.get_index_by_name("W")] = torch_padding[0]
|
||
|
|
|
||
|
|
# padding different on top, bottom, left and right of image
|
||
|
|
else:
|
||
|
|
image_dimensions[0] += torch_padding[1] + torch_padding[3]
|
||
|
|
image_dimensions[1] += torch_padding[0] + torch_padding[2]
|
||
|
|
|
||
|
|
pads_begin[layout.get_index_by_name("H")] = torch_padding[1]
|
||
|
|
pads_begin[layout.get_index_by_name("W")] = torch_padding[0]
|
||
|
|
pads_end[layout.get_index_by_name("H")] = torch_padding[3]
|
||
|
|
pads_end[layout.get_index_by_name("W")] = torch_padding[2]
|
||
|
|
|
||
|
|
@custom_preprocess_function
|
||
|
|
def pad_node(output: Output) -> Callable:
|
||
|
|
return ops.pad( # type: ignore
|
||
|
|
output,
|
||
|
|
pad_mode=pad_mode,
|
||
|
|
pads_begin=pads_begin,
|
||
|
|
pads_end=pads_end,
|
||
|
|
arg_pad_value=np.array(transform.fill, dtype=np.uint8) if pad_mode == "constant" else None,
|
||
|
|
)
|
||
|
|
|
||
|
|
ppp.input(input_idx).preprocess().custom(pad_node)
|
||
|
|
meta["image_dimensions"] = tuple(image_dimensions)
|
||
|
|
|
||
|
|
|
||
|
|
@TransformConverterFactory.register(transforms.ToTensor)
|
||
|
|
class _(TransformConverterBase):
|
||
|
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: dict) -> None:
|
||
|
|
input_shape = meta["input_shape"]
|
||
|
|
layout = meta["layout"]
|
||
|
|
|
||
|
|
ppp.input(input_idx).tensor().set_element_type(Type.u8).set_layout(Layout("NHWC")).set_color_format(ColorFormat.RGB) # noqa ECE001
|
||
|
|
|
||
|
|
if layout == Layout("NHWC"):
|
||
|
|
input_shape = _NHWC_to_NCHW(input_shape)
|
||
|
|
layout = Layout("NCHW")
|
||
|
|
ppp.input(input_idx).preprocess().convert_layout(layout)
|
||
|
|
ppp.input(input_idx).preprocess().convert_element_type(Type.f32)
|
||
|
|
ppp.input(input_idx).preprocess().scale(255.0)
|
||
|
|
|
||
|
|
meta["input_shape"] = input_shape
|
||
|
|
meta["layout"] = layout
|
||
|
|
|
||
|
|
|
||
|
|
@TransformConverterFactory.register(transforms.CenterCrop)
|
||
|
|
class _(TransformConverterBase):
|
||
|
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: dict) -> None:
|
||
|
|
input_shape = meta["input_shape"]
|
||
|
|
source_size = meta["image_dimensions"]
|
||
|
|
target_size = _setup_size(transform.size, "Incorrect size type for CenterCrop operation")
|
||
|
|
|
||
|
|
if target_size[0] > source_size[0] or target_size[1] > source_size[1]:
|
||
|
|
ValueError(f"CenterCrop size={target_size} is greater than source_size={source_size}")
|
||
|
|
|
||
|
|
bottom_left = []
|
||
|
|
bottom_left.append(int((source_size[0] - target_size[0]) / 2))
|
||
|
|
bottom_left.append(int((source_size[1] - target_size[1]) / 2))
|
||
|
|
|
||
|
|
top_right = []
|
||
|
|
top_right.append(min(bottom_left[0] + target_size[0], source_size[0] - 1))
|
||
|
|
top_right.append(min(bottom_left[1] + target_size[1], source_size[1] - 1))
|
||
|
|
|
||
|
|
bottom_left = [0] * len(input_shape[:-2]) + bottom_left if meta["layout"] == Layout("NCHW") else [0] + bottom_left + [0] # noqa ECE001
|
||
|
|
top_right = input_shape[:-2] + top_right if meta["layout"] == Layout("NCHW") else input_shape[:1] + top_right + input_shape[-1:]
|
||
|
|
|
||
|
|
ppp.input(input_idx).preprocess().crop(bottom_left, top_right)
|
||
|
|
meta["image_dimensions"] = (target_size[-2], target_size[-1])
|
||
|
|
|
||
|
|
|
||
|
|
@TransformConverterFactory.register(transforms.Resize)
|
||
|
|
class _(TransformConverterBase):
|
||
|
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: dict) -> None:
|
||
|
|
resize_mode_map = {
|
||
|
|
InterpolationMode.NEAREST: ResizeAlgorithm.RESIZE_NEAREST,
|
||
|
|
InterpolationMode.BILINEAR: ResizeAlgorithm.RESIZE_BILINEAR_PILLOW,
|
||
|
|
InterpolationMode.BICUBIC: ResizeAlgorithm.RESIZE_BICUBIC_PILLOW,
|
||
|
|
}
|
||
|
|
if transform.max_size:
|
||
|
|
raise ValueError("Resize with max_size if not supported")
|
||
|
|
if transform.interpolation not in resize_mode_map.keys():
|
||
|
|
raise ValueError(f"Interpolation mode {transform.interpolation} is not supported.")
|
||
|
|
|
||
|
|
target_h, target_w = _setup_size(transform.size, "Incorrect size type for Resize operation")
|
||
|
|
|
||
|
|
if isinstance(transform.size, int):
|
||
|
|
# rescale the smaller image edge
|
||
|
|
current_h, current_w = meta["image_dimensions"]
|
||
|
|
if current_h > current_w:
|
||
|
|
target_h = int(transform.size * (current_h / current_w))
|
||
|
|
elif current_w > current_h:
|
||
|
|
target_w = int(transform.size * (current_w / current_h))
|
||
|
|
|
||
|
|
ppp.input(input_idx).tensor().set_layout(Layout("NCHW"))
|
||
|
|
|
||
|
|
input_shape = meta["input_shape"]
|
||
|
|
|
||
|
|
input_shape[meta["layout"].get_index_by_name("H")] = -1
|
||
|
|
input_shape[meta["layout"].get_index_by_name("W")] = -1
|
||
|
|
|
||
|
|
ppp.input(input_idx).tensor().set_shape(input_shape)
|
||
|
|
ppp.input(input_idx).preprocess().resize(resize_mode_map[transform.interpolation], target_h, target_w)
|
||
|
|
meta["input_shape"] = input_shape
|
||
|
|
meta["image_dimensions"] = (target_h, target_w)
|
||
|
|
|
||
|
|
|
||
|
|
def _from_torchvision(model: Model, transform: Callable, input_example: Any, input_name: Union[str, None] = None) -> Model:
|
||
|
|
|
||
|
|
if input_name is not None:
|
||
|
|
input_idx = next((i for i, p in enumerate(model.get_parameters()) if p.get_friendly_name() == input_name), None)
|
||
|
|
else:
|
||
|
|
if len(model.get_parameters()) == 1:
|
||
|
|
input_idx = 0
|
||
|
|
else:
|
||
|
|
raise ValueError("Model contains multiple inputs. Please specify the name of the input to which prepocessing is added.")
|
||
|
|
|
||
|
|
if input_idx is None:
|
||
|
|
raise ValueError(f"Input with name {input_name} is not found")
|
||
|
|
|
||
|
|
input_shape, layout = _get_shape_layout_from_data(input_example)
|
||
|
|
|
||
|
|
ppp = PrePostProcessor(model)
|
||
|
|
ppp.input(input_idx).tensor().set_layout(layout)
|
||
|
|
ppp.input(input_idx).tensor().set_shape(input_shape)
|
||
|
|
|
||
|
|
image_dimensions = [input_shape[layout.get_index_by_name("H")], input_shape[layout.get_index_by_name("W")]]
|
||
|
|
global_meta = {
|
||
|
|
"input_shape": input_shape,
|
||
|
|
"image_dimensions": image_dimensions,
|
||
|
|
"layout": layout,
|
||
|
|
}
|
||
|
|
|
||
|
|
for tm in _to_list(transform):
|
||
|
|
TransformConverterFactory.convert(type(tm), input_idx, ppp, tm, global_meta)
|
||
|
|
|
||
|
|
updated_model = ppp.build()
|
||
|
|
return updated_model
|