48 lines
1.5 KiB
Python
48 lines
1.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright (C) 2018-2025 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import builtins
|
|
import torch
|
|
from typing import Any, Optional
|
|
|
|
originals_map = {
|
|
"divmod": builtins.divmod
|
|
}
|
|
|
|
|
|
def patched_divmod(lhs: Any, rhs: Any) -> tuple[Any, Any]:
|
|
# TorchScript tracing outputs torch.Tensor for `x.shape[-1]` instead of scalar.
|
|
# For example, it leads to TypeError issue during tracing of `divmod(x.shape[-1], 5)`
|
|
# since builtin `divmod()` expects both scalar inputs.
|
|
# So the patch for `divmod` is required to adjust operands to scalar type
|
|
if isinstance(lhs, torch.Tensor):
|
|
lhs = lhs.item()
|
|
if isinstance(rhs, torch.Tensor):
|
|
rhs = rhs.item()
|
|
return originals_map["divmod"](lhs, rhs)
|
|
|
|
|
|
patches_map = {
|
|
"divmod": patched_divmod,
|
|
}
|
|
|
|
|
|
class FunctionsPatcher:
|
|
"""Patch different Python functions including built-in routines such as divmod().
|
|
|
|
For example, TorchScript is unable to trace divmod() with torch.Tensor input type.
|
|
"""
|
|
def __enter__(self) -> None:
|
|
for name, new_fn in patches_map.items():
|
|
setattr(builtins, name, new_fn)
|
|
|
|
def __exit__(self,
|
|
exc_type: Optional[type[BaseException]],
|
|
exc_val: Optional[BaseException],
|
|
exc_tb: Optional[Any]
|
|
) -> None:
|
|
for name, _ in patches_map.items():
|
|
original_fn = originals_map[name]
|
|
setattr(builtins, name, original_fn)
|