58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
# Copyright (C) 2022-2023 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import argparse
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from openvino.tools.ovc.error import Error
|
|
|
|
|
|
def default_path():
|
|
EXT_DIR_NAME = '.'
|
|
return os.path.abspath(os.getcwd().join(EXT_DIR_NAME))
|
|
|
|
|
|
def any_extensions_used(argv: argparse.Namespace):
|
|
# Checks that extensions are provided.
|
|
# Allowed types are string containing path to legacy extension directory
|
|
# or path to new extension .so file, or classes inherited from BaseExtension.
|
|
if not hasattr(argv, 'extension') or argv.extension is None:
|
|
return False
|
|
if not isinstance(argv.extension, (list, tuple)):
|
|
argv.extension = [argv.extension]
|
|
|
|
if isinstance(argv.extension, (list, tuple)) and len(argv.extension) > 0:
|
|
has_non_default_path = False
|
|
has_non_str_objects = False
|
|
for ext in argv.extension:
|
|
if not isinstance(ext, str):
|
|
has_non_str_objects = True
|
|
continue
|
|
if len(ext) == 0 or ext == default_path():
|
|
continue
|
|
has_non_default_path = True
|
|
|
|
return has_non_default_path or has_non_str_objects
|
|
|
|
raise Exception("Expected list of extensions, got {}.".format(type(argv.extension)))
|
|
|
|
|
|
def get_transformations_config_path(argv: argparse.Namespace) -> Path:
|
|
if hasattr(argv, 'transformations_config') \
|
|
and argv.transformations_config is not None and len(argv.transformations_config):
|
|
if isinstance(argv.transformations_config, str):
|
|
path = Path(argv.transformations_config)
|
|
if path.is_file():
|
|
return path
|
|
return None
|
|
|
|
|
|
def legacy_transformations_config_used(argv: argparse.Namespace):
|
|
return get_transformations_config_path(argv) != None
|
|
|
|
|
|
def tensorflow_custom_operations_config_update_used(argv: argparse.Namespace):
|
|
return hasattr(argv, 'tensorflow_custom_operations_config_update') and \
|
|
argv.tensorflow_custom_operations_config_update is not None
|