135 lines
4.7 KiB
Python
135 lines
4.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright (C) 2018-2025 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
"""Helper functions for validating user input."""
|
|
|
|
import logging
|
|
from typing import Any, Optional
|
|
from collections.abc import Callable, Iterable
|
|
|
|
import numpy as np
|
|
|
|
from openvino.exceptions import UserInputError
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def assert_list_of_ints(value_list: Iterable[int], message: str) -> None:
|
|
"""Verify that the provided value is an iterable of integers."""
|
|
try:
|
|
for value in value_list:
|
|
if not isinstance(value, int):
|
|
raise TypeError
|
|
except TypeError:
|
|
log.warning(message)
|
|
raise UserInputError(message, value_list)
|
|
|
|
|
|
def _check_value(op_name, attr_key, value, val_type, cond=None):
|
|
# type: (str, str, Any, type, Optional[Callable[[Any], bool]]) -> bool
|
|
"""Check whether provided value satisfies specified criteria.
|
|
|
|
:param op_name: The operator name which attributes are checked.
|
|
:param attr_key: The attribute name.
|
|
:param value: The value to check.
|
|
:param val_type: Required value type.
|
|
:param cond: The optional function running additional checks.
|
|
|
|
:raises UserInputError:
|
|
|
|
returns: True if attribute satisfies all criterias. Otherwise False.
|
|
"""
|
|
if not np.issubdtype(type(value), val_type):
|
|
raise UserInputError(
|
|
f'{op_name} operator attribute "{attr_key}" value must by of type {val_type}.',
|
|
)
|
|
if cond is not None and not cond(value):
|
|
raise UserInputError(
|
|
f'{op_name} operator attribute "{attr_key}" value does not satisfy provided condition.',
|
|
)
|
|
return True
|
|
|
|
|
|
def check_valid_attribute(op_name, attr_dict, attr_key, val_type, cond=None, required=False):
|
|
# type: (str, dict, str, type, Optional[Callable[[Any], bool]], Optional[bool]) -> bool
|
|
"""Check whether specified attribute satisfies given criteria.
|
|
|
|
:param op_name: The operator name which attributes are checked.
|
|
:param attr_dict: Dictionary containing key-value attributes to check.
|
|
:param attr_key: Key value for validated attribute.
|
|
:param val_type: Value type for validated attribute.
|
|
:param cond: Any callable wich accept attribute value and returns True or False.
|
|
:param required: Whether provided attribute key is not required. This mean it may be missing
|
|
from provided dictionary.
|
|
|
|
:raises UserInputError:
|
|
|
|
returns True if attribute satisfies all criterias. Otherwise False.
|
|
"""
|
|
result = True
|
|
|
|
if required and attr_key not in attr_dict:
|
|
raise UserInputError(
|
|
f'Provided dictionary is missing {op_name} operator required attribute "{attr_key}"',
|
|
)
|
|
|
|
if attr_key not in attr_dict:
|
|
return result
|
|
|
|
attr_value = attr_dict[attr_key]
|
|
|
|
if np.isscalar(attr_value):
|
|
result = result and _check_value(op_name, attr_key, attr_value, val_type, cond)
|
|
else:
|
|
for value in attr_value:
|
|
result = result and _check_value(op_name, attr_key, value, val_type, cond)
|
|
|
|
return result
|
|
|
|
|
|
def check_valid_attributes(
|
|
op_name, # type: str
|
|
attributes, # type: dict[str, Any]
|
|
requirements, # type: list[tuple[str, bool, type, Optional[Callable]]]
|
|
):
|
|
# type: (...) -> bool
|
|
"""Perform attributes validation according to specified type, value criteria.
|
|
|
|
:param op_name: The operator name which attributes are checked.
|
|
:param attributes: The dictionary with user provided attributes to check.
|
|
:param requirements: The list of tuples describing attributes' requirements. The tuple should
|
|
contain following values:
|
|
(attr_name: str,
|
|
is_required: bool,
|
|
value_type: type,
|
|
value_condition: Callable)
|
|
|
|
:raises UserInputError:
|
|
|
|
:returns True if all attributes satisfies criterias. Otherwise False.
|
|
"""
|
|
for attr, required, val_type, cond in requirements:
|
|
check_valid_attribute(op_name, attributes, attr, val_type, cond, required)
|
|
return True
|
|
|
|
|
|
def is_positive_value(value): # type: (Any) -> bool
|
|
"""Determine whether the specified x is positive value.
|
|
|
|
:param value: The value to check.
|
|
|
|
returns True if the specified x is positive value, False otherwise.
|
|
"""
|
|
return value > 0
|
|
|
|
|
|
def is_non_negative_value(value): # type: (Any) -> bool
|
|
"""Determine whether the specified x is non-negative value.
|
|
|
|
:param value: The value to check.
|
|
|
|
returns True if the specified x is non-negative value, False otherwise.
|
|
"""
|
|
return value >= 0
|