42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright (C) 2018-2025 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import logging
|
|
from typing import Optional
|
|
|
|
from openvino import AxisSet
|
|
from openvino.utils.types import (
|
|
TensorShape,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def get_broadcast_axes(
|
|
output_shape: TensorShape,
|
|
input_shape: TensorShape,
|
|
axis: Optional[int] = None,
|
|
) -> AxisSet:
|
|
"""Generate a list of broadcast axes for openvino broadcast.
|
|
|
|
Informally, a broadcast "adds" axes to the input tensor,
|
|
replicating elements from the input tensor as needed to fill the new dimensions.
|
|
Function calculate which of the output axes are added in this way.
|
|
|
|
:param output_shape: The new shape for the output tensor.
|
|
:param input_shape: The shape of input tensor.
|
|
:param axis: The axis along which we want to replicate elements.
|
|
|
|
returns: The indices of added axes.
|
|
"""
|
|
axes_indexes = list(range(0, len(output_shape)))
|
|
if axis is None:
|
|
output_begin = len(output_shape) - len(input_shape)
|
|
else:
|
|
output_begin = axis
|
|
right_axes_indexes = list(range(output_begin, output_begin + len(input_shape)))
|
|
for index in reversed(right_axes_indexes):
|
|
del axes_indexes[index]
|
|
return AxisSet(set(axes_indexes))
|