Files
ANSLibs/OpenVINO/python/openvino/utils/broadcasting.py

42 lines
1.3 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Optional
from openvino import AxisSet
from openvino.utils.types import (
TensorShape,
)
log = logging.getLogger(__name__)
def get_broadcast_axes(
output_shape: TensorShape,
input_shape: TensorShape,
axis: Optional[int] = None,
) -> AxisSet:
"""Generate a list of broadcast axes for openvino broadcast.
Informally, a broadcast "adds" axes to the input tensor,
replicating elements from the input tensor as needed to fill the new dimensions.
Function calculate which of the output axes are added in this way.
:param output_shape: The new shape for the output tensor.
:param input_shape: The shape of input tensor.
:param axis: The axis along which we want to replicate elements.
returns: The indices of added axes.
"""
axes_indexes = list(range(0, len(output_shape)))
if axis is None:
output_begin = len(output_shape) - len(input_shape)
else:
output_begin = axis
right_axes_indexes = list(range(output_begin, output_begin + len(input_shape)))
for index in reversed(right_axes_indexes):
del axes_indexes[index]
return AxisSet(set(axes_indexes))