66 lines
1.7 KiB
Python
66 lines
1.7 KiB
Python
# Copyright (C) 2018-2025 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# flake8: noqa
|
|
# mypy: ignore-errors
|
|
|
|
from enum import Enum
|
|
from jax.lax import ConvDimensionNumbers
|
|
|
|
def enum_values_pass(value):
|
|
if isinstance(value, Enum):
|
|
return value.value
|
|
return value
|
|
|
|
|
|
def conv_dimension_numbers_pass(value):
|
|
if isinstance(value, ConvDimensionNumbers):
|
|
return [
|
|
list(value.lhs_spec),
|
|
list(value.rhs_spec),
|
|
list(value.out_spec)
|
|
]
|
|
return value
|
|
|
|
|
|
def filter_element(value):
|
|
passes = [enum_values_pass]
|
|
for pass_ in passes:
|
|
value = pass_(value)
|
|
return value
|
|
|
|
|
|
def filter_ivalue(value):
|
|
passes = [conv_dimension_numbers_pass]
|
|
for pass_ in passes:
|
|
value = pass_(value)
|
|
return value
|
|
|
|
|
|
def dot_general_param_pass(param_name: str, jax_eqn):
|
|
param = jax_eqn.params[param_name]
|
|
res = {}
|
|
if param_name == 'dimension_numbers':
|
|
contract_dimensions = param[0]
|
|
assert len(contract_dimensions) == 2
|
|
res['contract_dimensions'] = [list(contract_dimensions[0]), list(contract_dimensions[1])]
|
|
|
|
batch_dimensions = param[1]
|
|
assert len(batch_dimensions) == 2
|
|
lhs_length = len(batch_dimensions[0])
|
|
rhs_length = len(batch_dimensions[1])
|
|
assert lhs_length == rhs_length
|
|
if lhs_length > 0:
|
|
res['batch_dimensions'] = [list(batch_dimensions[0]), list(batch_dimensions[1])]
|
|
return res
|
|
|
|
# mapping from primitive to pass
|
|
param_passes = {
|
|
'dot_general': dot_general_param_pass,
|
|
}
|
|
|
|
def filter_param(primitive: str, param_name: str, jax_eqn):
|
|
if primitive in param_passes:
|
|
return param_passes[primitive](param_name, jax_eqn)
|
|
return {param_name: jax_eqn.params[param_name]}
|