Files
ANSLibs/OpenVINO/python/openvino/frontend/jax/jaxpr_decoder.py

304 lines
12 KiB
Python

# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
# mypy: ignore-errors
import jax
from packaging import version
if version.parse(jax.__version__) < version.parse("0.6.0"):
import jax as jex
import jax.core
else:
import jax.extend as jex
from openvino.frontend.jax.py_jax_frontend import _FrontEndJaxDecoder as Decoder
from openvino import PartialShape, Type as OVType, OVAny
from openvino.frontend.jax.utils import jax_array_to_ov_const, get_ov_type_for_value, \
ivalue_to_constant, param_to_constants
import numpy as np
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
class JaxprPythonDecoder(Decoder):
'''
The jaxpr decoder uses Jaxpr to get graph information from a jax module.
It takes use of the following parts.
- `ClosedJaxpr`: the jaxpr object that contains the jaxpr and literals.
- `Jaxpr`: the jaxpr object that contains the invars, outvars, and eqns.
- `JaxEqns`: A list of jaxpr equations, which contains the information of the operation.
- `Primitive`: the operation that is used in the equation.
- `invars`: the input variables of the equation.
- `aval`: the abstract value.
- `outvars`: the output variables of the equation.
- `aval`: the abstract value.
- `params`: the named params of this equation.
- `invars`: the inputs of the model (traced graph).
- `aval`: the abstract value.
- `outvars`: the outputs of the model (traced graph).
- `aval`: the abstract value.
- `constvars`: the constant variables used in this model.
- `aval`: the abstract value.
- `Literal`: the literal object that contains the value of the constants.
'''
def __init__(self, jaxpr, name=None, literals=None):
'''
Inputs:
- jaxpr: for users, `ClosedJaxpr` is expected here. See https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L197
- name: the name for the model.
- literals: the literals (constants) that are used in the model.
'''
Decoder.__init__(self)
if isinstance(jaxpr, (jex.core.JaxprEqn, jex.core.Jaxpr)):
self.jaxpr = jaxpr
elif isinstance(jaxpr, jex.core.ClosedJaxpr):
# Take the `Jaxpr` from `ClosedJaxpr`, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L85
self.jaxpr = jaxpr.jaxpr
# Literal should be a `Jax.core.Var`, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L85
self.literals = jaxpr.literals
else:
raise ValueError(f"Unexpected type of jaxpr: {type(jaxpr)}")
self.name = name
if self.name is None:
self.name = "jax_module"
if literals is not None:
self.literals = literals
self.params = {}
if hasattr(self.jaxpr, 'params') and isinstance(self.jaxpr.params, dict):
for k in self.jaxpr.params.keys():
converted = self.convert_param_to_constant_node(self.jaxpr, k)
if converted is not None:
self.params.update(converted)
# TODO: this implementation may lead to memory increasing. Any better solution?
self.m_decoders = []
def inputs(self) -> list[int]:
if isinstance(self.jaxpr, jex.core.JaxprEqn):
idx = 0
res = []
for inp in self.jaxpr.invars:
if isinstance(inp, jex.core.Literal):
res.append(self.literals[idx].output(0))
idx += 1
else:
res.append(id(inp))
return res
else:
return [id(v) for v in self.jaxpr.invars]
def input(self, idx: int) -> int:
return id(self.jaxpr.invars[idx])
def get_input_shape(self, index):
return PartialShape(self.jaxpr.invars[index].aval.shape)
def get_input_signature_name(self, index) -> str:
return "jaxpr_invar_" + str(index)
def get_input_type(self, index) -> OVType:
return get_ov_type_for_value(self.jaxpr.invars[index])
def get_named_param(self, name):
'''
Get the object id of the named parameter by the name.
'''
return self.params[name].output(0)
def get_named_param_as_constant(self, name):
'''
The named parameter in JAX is a python object but we want to use its value in cpp.
Therefore this API is used to get the named parameter as a constant, which can be used
to extract the value of it in cpp-level.
'''
return self.params[name].as_constant()
def get_param_names(self):
'''
In JAX, the named parameters may exist in `params` attribute of `JaxEqn`.
For example, the `jax.lax.cat` operation has a named parameter `dim`,
which is used to indicate the dimension to concatenate the tensors.
Here we return the names of all the named params that appear in the model for the current `JaxEqn`.
'''
return list(self.params.keys())
def get_output_type(self, index) -> OVType:
return get_ov_type_for_value(self.jaxpr.outvars[index])
def get_output_name(self, index) -> str:
return "jaxpr_outvar_" + str(index)
def get_output_shape(self, index):
return PartialShape(self.jaxpr.outvars[index].aval.shape)
def visit_subgraph(self, node_visitor) -> None:
if isinstance(self.jaxpr, jex.core.JaxprEqn):
return
for _, decoder in self.params.items():
self.m_decoders.append(decoder)
node_visitor(decoder)
for idx, node in enumerate(self.jaxpr.constvars):
decoder = self.convert_literal_to_constant_node(
literal=self.literals[idx],
name=self.name + "/" + f"const({id(node)})",
output_id=id(node)
)
self.m_decoders.append(decoder)
node_visitor(decoder)
# Visit every `JaxEqn` in the jaxpr, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L285
for node in self.jaxpr.eqns:
literal_decoders = []
for inp in node.invars:
if isinstance(inp, jex.core.Literal):
literal_decoder = self.convert_literal_to_constant_node(inp)
literal_decoders.append(literal_decoder)
node_visitor(literal_decoder)
decoder = JaxprPythonDecoder(node, name=self.name + "/" + node.primitive.name, literals=literal_decoders)
self.m_decoders.append(decoder)
node_visitor(decoder)
def get_op_type(self) -> str:
if isinstance(self.jaxpr, jex.core.JaxprEqn):
return self.jaxpr.primitive.name
else:
return "root"
def outputs(self) -> list[int]:
return [id(v) for v in self.jaxpr.outvars]
def output(self, idx: int) -> int:
return id(self.jaxpr.outvars[idx])
def num_inputs(self) -> int:
return len(self.jaxpr.invars)
def num_outputs(self) -> int:
return len(self.jaxpr.outvars)
def as_constant(self):
if self.get_op_type() == 'constant':
value = self.literals
# TODO: dig out how to share the memory.
# Currently, using shared_memory will raise `ValueError: array is not writeable``
ov_const = jax_array_to_ov_const(value, shared_memory=False)
return ov_const.outputs()
else:
raise ValueError("This is not a constant node so it cannot be converted to a constant.")
@staticmethod
def convert_param_to_constant_node(jaxpr, param) -> dict:
assert hasattr(jaxpr, 'params'), "The jaxpr does not have params."
if hasattr(jaxpr, 'primitive'):
param_map = param_to_constants(jaxpr.primitive.name, param, jaxpr, shared_memory=False)
res = {}
for name, constant in param_map.items():
if constant is not None:
res[name] = _JaxprPythonConstantDecoder(constant=constant)
else:
constant = ivalue_to_constant(jaxpr.params[param], shared_memory=False)
res = {param: _JaxprPythonConstantDecoder(constant=constant)} if constant is not None else {}
return res
@staticmethod
def convert_literal_to_constant_node(literal, name=None, output_id=None):
if isinstance(literal, jex.core.Literal):
constant = ivalue_to_constant(literal.val, shared_memory=False)
elif isinstance(literal, (jax.Array, np.ndarray)):
constant = ivalue_to_constant(literal, shared_memory=False)
else:
raise TypeError(f"The input should be a literal or jax array, but got {type(literal)}.")
return _JaxprPythonConstantDecoder(constant=constant, name=name, output_id=output_id)
class _JaxprPythonConstantDecoder(Decoder):
def __init__(self, name=None, constant=None, output_id=None):
'''
A decoder specially for constants and named parameters.
Inputs:
- name: the name for the model.
- literals: the literals (constants) that are used in the model.
- output_id: the id specified for this decoder's output. If none, use `id(self.constant)`.
'''
Decoder.__init__(self)
self.name = name
self.constant = constant
self.output_id = id(self.constant) if output_id is None else output_id
def inputs(self) -> list[int]:
return []
def input(self, idx: int) -> int:
raise ValueError("This is a constant node so it does not have input.")
def get_input_shape(self, index):
raise ValueError("This is a constant node so it does not have input shape.")
def get_input_signature_name(self, index) -> str:
raise ValueError("This is a constant node so it does not have input signature name.")
def get_input_type(self, index) -> OVType:
raise ValueError("This is a constant node so it does not have input type.")
def get_named_param(self, name):
raise ValueError("This is a constant node so it does not have named param.")
def get_named_param_as_constant(self, name):
raise ValueError("This is a constant node so it does not have named param.")
def get_param_names(self):
'''
In JAX, the named parameters may exist in `params` attribute of `JaxEqn`.
For example, the `jax.lax.cat` operation has a named parameter `dim`,
which is used to indicate the dimension to concatenate the tensors.
However, `_JaxprPythonConstantDecoder` is already a named param or a constant.
So it will never have a named param.
'''
return []
def get_output_type(self, index) -> OVType:
assert len(self.constant) == 1
return OVAny(self.constant[0].element_type)
def get_output_name(self, index) -> str:
return "jaxpr_outvar_" + str(index)
def get_output_shape(self, index):
assert len(self.constant) == 1
return PartialShape(self.constant[0].shape)
def visit_subgraph(self, node_visitor) -> None:
return
def get_op_type(self) -> str:
return "constant"
def outputs(self) -> list[int]:
return [self.output_id]
def output(self, idx: int) -> int:
return self.output_id
def num_inputs(self) -> int:
return 0
def num_outputs(self) -> int:
return 1
def as_constant(self):
return self.constant