Files
ANSLibs/OpenVINO/python/openvino/frontend/pytorch/module_extension.py

69 lines
3.0 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Union
from collections.abc import Callable
import torch
class ModuleExtension:
"""An extension that replaces a PyTorch module with a single operation.
A module can be identified by its type (e.g., `torch.nn.Linear`), module
instance in the model, or module name.
"""
def __init__(self,
module: Union[str, torch.nn.Module, type[torch.nn.Module]],
target_op: str,
evaluate: Optional[Callable] = None,
convert: Optional[Callable] = None,
condition: Optional[Callable] = None):
"""Create an extension that replaces a PyTorch module with a single op.
This functionality works with PyTorch models only. A module can be
identified by its type (e.g., `torch.nn.Linear`), module instance in
the model, or module name.
Args:
module (str, torch.nn.Module, type(torch.nn.Module)): PyTorch
module to replace.
target_op (str): A target operation that will be used as a replacer
for the module. It could be the name of the extension operation
or an existing PyTorch operation (with `prim::` or `aten::`
prefix following TorchScript syntax).
evaluate (callable): A function with the signature
`evaluate(module, *args, **kwargs)`. It replaces the target
module in model execution and is responsible for producing
valid output for the module to allow correct model tracing. By
default, it calls the original module's forward method with
the same arguments. The provided code will not be part of the
final traced model; it is used only to produce valid results
during tracing.
convert (callable): A function with the signature
`convert(target_op, *args, **kwargs)`. It is traced and becomes
part of the final model instead of the target module. It
accepts `target_op` as the first parameter, which appears as a
single node in the graph, with the type of the node being the
`target_op` provided as another argument above.
condition (callable): A function with the signature
`condition(module)`. It returns a boolean indicating whether
the extension applies to the given module.
"""
self.module = module
self.target_op = target_op
self.evaluate = evaluate
if self.evaluate is None:
self.evaluate = lambda module, *args, **kwargs: module(*args, **kwargs)
self.convert = convert
if self.convert is None:
self.convert = lambda module, target_op, *args, **kwargs: target_op(*args, **kwargs)
self.condition = condition
if self.condition is None:
self.condition = lambda module: True