# Copyright (C) 2018-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # flake8: noqa # mypy: ignore-errors import jax from packaging import version if version.parse(jax.__version__) < version.parse("0.6.0"): import jax as jex import jax.core else: import jax.extend as jex import jax.numpy as jnp import numpy as np from openvino.frontend.jax.passes import filter_element, filter_ivalue, filter_param from openvino import op, Type as OVType, Shape, OVAny numpy_to_ov_type_map = { np.float32: OVType.f32, bool: OVType.boolean, jax.dtypes.bfloat16: OVType.bf16, # TODO: check this np.float16: OVType.f16, np.float32: OVType.f32, np.float64: OVType.f64, np.uint8: OVType.u8, np.int8: OVType.i8, np.uint16: OVType.u16, np.int16: OVType.i16, np.uint32: OVType.u32, np.int32: OVType.i32, np.uint64: OVType.u64, np.int64: OVType.i64, } jax_to_ov_type_map = { jnp.float32: OVType.f32, jnp.bfloat16: OVType.bf16, # TODO: check this jnp.float16: OVType.f16, jnp.float64: OVType.f64, jnp.uint8: OVType.u8, jnp.int8: OVType.i8, jnp.uint16: OVType.u16, jnp.int16: OVType.i16, jnp.uint32: OVType.u32, jnp.int32: OVType.i32, jnp.uint64: OVType.u64, jnp.int64: OVType.i64, } try: jax_to_ov_type_map[jnp.bool] = OVType.boolean except: pass basic_to_ov_type_map = { int: OVType.i64, float: OVType.f32, bool: OVType.boolean, } ov_type_to_int_map = { OVType.u8: 0, OVType.i8: 1, OVType.i16: 2, OVType.i32: 3, OVType.i64: 4, OVType.f16: 5, OVType.f32: 6, OVType.f64: 7, OVType.u16: 8, OVType.u32: 9, OVType.u64: 10, OVType.boolean: 11, OVType.bf16: 15, } def get_type_from_py_type(value): if isinstance(value, float): return OVType.f32 if isinstance(value, bool): return OVType.boolean if isinstance(value, int): return OVType.i64 return OVType.dynamic def get_type_from_np_type(value): for np_dtype, ov_type in numpy_to_ov_type_map.items(): if isinstance(value, np_dtype): return ov_type return None def _get_ov_type_from_value(value): ov_type = get_type_from_np_type(value) if ov_type is None: ov_type = get_type_from_py_type(value) return ov_type def get_ov_type_for_value(value): if isinstance(value, (jex.core.Var, jex.core.Literal)): if value.aval.dtype in jax_to_ov_type_map: return OVAny(jax_to_ov_type_map[value.aval.dtype]) for k, v in numpy_to_ov_type_map.items(): if value.aval.dtype == k: return OVAny(v) for k, v in basic_to_ov_type_map.items(): if isinstance(value.aval.dtype, k): return OVAny(v) elif isinstance(value, (int, float, bool)): return OVAny(jax_to_ov_type_map[type(value)]) else: raise NotImplementedError(f"dtype for {value} of type {type(value)} has not been supported yet.") def get_ov_type_from_jax_type(dtype): if dtype in jax_to_ov_type_map: return OVAny(jax_to_ov_type_map[dtype]) for k, v in numpy_to_ov_type_map.items(): if dtype == k: return OVAny(v) for k, v in basic_to_ov_type_map.items(): if isinstance(dtype, k): return OVAny(v) return None def jax_array_to_ov_const(arr: np.ndarray, shared_memory=True): # TODO: deal with bfloat16 dtype here. if isinstance(arr, np.ndarray): return op.Constant(arr, shared_memory=shared_memory) elif isinstance(arr, jax.Array): return op.Constant(np.array(jax.device_get(arr)), shared_memory=shared_memory) else: raise ValueError(f"Constant is expected to be a numpy array or jax array but got {type(arr)}") def ivalue_to_constant(ivalue, shared_memory=True): ''' Convert a python object to an openvino constant. ''' # print('ivalue = ', ivalue) ivalue = filter_ivalue(ivalue) ov_type = _get_ov_type_from_value(ivalue) if ov_type.is_static(): return op.Constant(ov_type, Shape([]), [ivalue]).outputs() if isinstance(ivalue, (list, tuple)): # TODO 150596: remove this workaround if len(ivalue) == 0: return op.Constant(OVType.i64, Shape([0]), []).outputs() assert len(ivalue) > 0, "Can't deduce type for empty list" try: if isinstance(ivalue[0], (list, tuple)): second_len = len(ivalue[0]) flattened_ivalue = [] for value in ivalue: assert isinstance(value, (list, tuple)), "Can't deduce type for a list with both list and basic types." assert len(value) == second_len or len( value) == 0, "Can't deduce type for nested list with different lengths." flattened_ivalue.extend([filter_element(item) for item in value]) flattened_ivalue = [item for sublist in ivalue for item in sublist] ov_type = _get_ov_type_from_value(flattened_ivalue[0]) assert ov_type.is_static(), f"Can't deduce type {flattened_ivalue[0].__class__} for list" return op.Constant(ov_type, Shape([len(ivalue), second_len]), flattened_ivalue).outputs() ivalue = [filter_element(item) for item in ivalue] ov_type = _get_ov_type_from_value(ivalue[0]) assert ov_type.is_static(), f"Can't deduce type {ivalue[0].__class__} for list" except: # TODO 150596: remove this workaround ivalue = [0] ov_type = OVType.f32 return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs() if isinstance(ivalue, (jax.Array, np.ndarray)): return jax_array_to_ov_const(ivalue, shared_memory=shared_memory).outputs() ov_dtype_value = get_ov_type_from_jax_type(ivalue) if ov_dtype_value is not None: return op.Constant(OVType.i64, Shape([]), [ov_type_to_int_map[ov_dtype_value]]).outputs() return None def param_to_constants(primitive: str, param_name: str, jaxpr, shared_memory=True): processed_params = filter_param(primitive, param_name, jaxpr) for k, v in processed_params.items(): processed_params[k] = ivalue_to_constant(v, shared_memory=shared_memory) return processed_params