Files
ANSLibs/OpenVINO/python/openvino/frontend/pytorch/torchdynamo/decompositions.py

319 lines
10 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# mypy: ignore-errors
import torch
from torch._decomp.decompositions import aten, pw_cast_for_opmath
from torch._decomp import register_decomposition, get_decompositions
@register_decomposition(aten.convolution_backward)
@pw_cast_for_opmath
def convolution_backward(
grad_output,
inp,
weight,
bias,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
):
if stride == [2, 2]:
output_padding = [1, 1]
# Compute the gradient of the input tensor
grad_input = torch.nn.functional.conv_transpose2d(
grad_output, weight, stride=stride, padding=padding, dilation=dilation, groups=groups, output_padding=output_padding
)
# Compute the gradient of the weight tensor
grad_weight = torch.nn.functional.conv_transpose2d(
inp, weight.transpose(0, 1), stride=stride, padding=padding, dilation=dilation, groups=groups, output_padding=output_padding
)
# Compute the gradient of the bias tensor
if bias is not None:
grad_bias = grad_output.sum([0, 2, 3], keepdim=True)
else:
grad_bias = None
return grad_input, grad_weight, grad_bias
if len(get_decompositions([aten._scaled_dot_product_flash_attention.default])) == 0:
@register_decomposition(aten._scaled_dot_product_flash_attention.default)
def scaled_dot_product_flash_attention(
query,
key,
value,
dropout_p=0.0,
is_causal=False,
*,
return_debug_mask=False,
scale=None,
):
batch_size, num_head, q_size, head_size = (
query.shape[0],
query.shape[1],
query.shape[2],
query.shape[3],
)
logsumexp = torch.empty([batch_size, q_size, num_head, head_size], dtype=torch.float)
cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
max_q, max_k = 0, 0
philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
debug_attn_mask = torch.empty(
[],
dtype=query.dtype,
device=query.device,
requires_grad=query.requires_grad,
)
output, _ = aten._scaled_dot_product_attention_math.default(
query, key, value, None, dropout_p, is_causal, None, scale=scale
)
scores = torch.matmul(query, key.transpose(-2, -1)) / (key.size(-1) ** 0.5)
logsumexp = torch.logsumexp(scores, dim=-1)
output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format)
return (
output.transpose(1, 2),
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
)
def get_aot_decomposition_list():
return [
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._softmax.default,
torch.ops.aten._softmax_backward_data.default,
torch.ops.aten.convolution_backward.default,
torch.ops.aten.gelu_backward.default,
torch.ops.aten.native_group_norm.default,
torch.ops.aten.native_group_norm_backward.default,
torch.ops.aten.native_layer_norm.default,
torch.ops.aten.native_layer_norm_backward.default,
torch.ops.aten.slice_backward.default,
]
def get_inf_decomposition_list():
return [torch.ops.aten.nll_loss_forward.default]
def get_export_decomposition_list():
# list of decompositions from torch._decomp.core_aten_decompositions
# removed _backward ops and ops supported without decomposition
decomp = [
torch.ops.aten.addcdiv,
torch.ops.aten.addcdiv_,
torch.ops.aten.addcmul,
torch.ops.aten.addcmul_,
torch.ops.aten.addr,
torch.ops.aten.affine_grid_generator,
torch.ops.aten.all,
torch.ops.aten.aminmax,
torch.ops.aten.arange.default,
torch.ops.aten.arange.start,
torch.ops.aten.baddbmm,
torch.ops.aten.binary_cross_entropy,
torch.ops.aten.binary_cross_entropy_with_logits,
torch.ops.aten.block_diag,
torch.ops.aten.celu,
torch.ops.aten.celu_,
torch.ops.aten.clamp_max,
torch.ops.aten.clamp_min,
torch.ops.aten.count_nonzero,
torch.ops.aten.linalg_cross,
torch.ops.aten.cudnn_batch_norm,
torch.ops.aten.deg2rad,
torch.ops.aten.deg2rad_,
torch.ops.aten.detach,
torch.ops.aten.diag_embed,
torch.ops.aten.dot,
torch.ops.aten.vdot,
torch.ops.aten.elu,
torch.ops.aten.elu_,
torch.ops.aten._embedding_bag,
torch.ops.aten.empty_like,
torch.ops.aten._euclidean_dist.default,
torch.ops.aten.expand_as,
torch.ops.aten.eye,
torch.ops.aten.fill,
torch.ops.aten.fill_,
torch.ops.aten.floor_divide,
torch.ops.aten.frac,
torch.ops.aten.frac_,
torch.ops.aten._fused_moving_avg_obs_fq_helper,
torch.ops.aten.gelu_,
torch.ops.aten.glu,
torch.ops.aten.hardshrink,
torch.ops.aten.hardsigmoid,
torch.ops.aten.hardsigmoid_,
torch.ops.aten.hardswish,
torch.ops.aten.hardswish_,
torch.ops.aten.hardtanh_,
torch.ops.aten.heaviside,
torch.ops.aten.heaviside_,
torch.ops.aten.huber_loss,
torch.ops.aten.im2col,
torch.ops.aten.index_add,
torch.ops.aten.index_add_,
torch.ops.aten.index_copy,
torch.ops.aten.index_copy_,
torch.ops.aten.index_fill,
torch.ops.aten.index_fill_,
torch.ops.aten.isin,
torch.ops.aten.isneginf,
torch.ops.aten.isposinf,
torch.ops.aten.l1_loss,
torch.ops.aten.leaky_relu_,
torch.ops.aten.lerp,
torch.ops.aten.lerp_,
torch.ops.aten.linspace,
torch.ops.aten.logaddexp,
torch.ops.aten.logaddexp2,
torch.ops.aten.logit,
torch.ops.aten.logit_,
torch.ops.aten.log_sigmoid_forward,
torch.ops.aten.logspace,
torch.ops.aten.logsumexp.default,
torch.ops.aten.masked_fill,
torch.ops.aten.masked_fill_,
torch.ops.aten.mish,
torch.ops.aten.mish_,
torch.ops.aten.mse_loss,
torch.ops.aten.multi_margin_loss,
torch.ops.aten.multilabel_margin_loss_forward,
torch.ops.aten.mv,
torch.ops.aten.mvlgamma,
torch.ops.aten.mvlgamma_,
torch.ops.aten.nansum,
torch.ops.aten.nan_to_num,
torch.ops.aten.nan_to_num_,
torch.ops.aten.narrow,
torch.ops.aten.new_empty,
torch.ops.aten.new_full,
torch.ops.aten.new_ones,
torch.ops.aten.new_zeros,
torch.ops.aten.nll_loss_forward,
torch.ops.aten.norm,
torch.ops.aten.ones,
torch.ops.aten.ones_like,
torch.ops.aten._prelu_kernel,
torch.ops.aten._reshape_alias,
torch.ops.aten.rad2deg,
torch.ops.aten.rad2deg_,
torch.ops.aten.reflection_pad1d,
torch.ops.aten.reflection_pad2d,
torch.ops.aten.reflection_pad3d,
torch.ops.aten.replication_pad1d,
torch.ops.aten.replication_pad2d,
torch.ops.aten.replication_pad3d,
torch.ops.aten.renorm,
torch.ops.aten.renorm_,
torch.ops.aten.resize_as,
torch.ops.aten.roll,
torch.ops.aten.rot90,
torch.ops.aten.rrelu_with_noise,
torch.ops.aten.rrelu_with_noise_,
torch.ops.aten.rsub,
torch.ops.aten.select_scatter,
torch.ops.aten.sgn,
torch.ops.aten.sgn_,
torch.ops.aten.silu,
torch.ops.aten.silu_,
torch.ops.aten.sinc,
torch.ops.aten.sinc_,
torch.ops.aten.smooth_l1_loss,
torch.ops.aten.soft_margin_loss,
torch.ops.aten.softplus,
torch.ops.aten.softshrink,
torch.ops.aten.special_entr,
torch.ops.aten.special_log_ndtr,
torch.ops.aten.special_xlog1py,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes_copy,
torch.ops.aten.squeeze.default,
torch.ops.aten.squeeze.dim,
torch.ops.aten.std,
torch.ops.aten.std_mean,
torch.ops.aten.stack,
torch.ops.aten.sum.default,
torch.ops.aten.sum.out,
torch.ops.aten.t,
torch.ops.aten.take,
torch.ops.aten.threshold,
torch.ops.aten.threshold_,
torch.ops.aten.trace,
torch.ops.aten.transpose.int,
torch.ops.aten.tril,
torch.ops.aten.tril_,
torch.ops.aten.triu,
torch.ops.aten.triu_,
torch.ops.aten.unbind,
torch.ops.aten.unfold_copy,
torch.ops.aten._unsafe_index,
torch.ops.aten.unsafe_split.Tensor,
torch.ops.aten.unsafe_split_with_sizes,
torch.ops.aten._unsafe_view,
torch.ops.aten.view_as_complex,
torch.ops.aten.xlogy,
torch.ops.aten.xlogy_,
torch.ops.aten.zero,
torch.ops.aten.zero_,
torch.ops.aten.zeros,
torch.ops.aten.zeros_like,
torch.ops.aten._weight_norm_interface,
]
try:
from packaging import version
if version.parse(torch.__version__) >= version.parse("2.3"):
decomp += [
torch.ops.aten._lazy_clone,
torch.ops.aten._test_parallel_materialize,
torch.ops.aten._chunk_cat,
]
except ImportError:
pass
return decomp
def ops_to_not_decompose():
# list of operations that shouldn't be decomposed
return [
torch.ops.aten.col2im.default,
torch.ops.aten.linear.default,
torch.ops.aten.rms_norm.default,
torch.ops.aten.upsample_nearest1d.default,
torch.ops.aten.upsample_nearest1d.vec,
torch.ops.aten.upsample_nearest2d.default,
torch.ops.aten.upsample_nearest2d.vec,
torch.ops.aten.upsample_nearest3d.default,
torch.ops.aten.upsample_nearest3d.vec,
torch.ops.aten.upsample_linear1d.vec,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_trilinear3d.vec,
torch.ops.aten.upsample_bicubic2d.vec,
torch.ops.aten.scaled_dot_product_attention.default,
]