Files
ANSLibs/OpenVINO/python/openvino/frontend/pytorch/torchdynamo/partition.py

146 lines
6.5 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# mypy: ignore-errors
import torch
from torch.nn import Module
from torch.fx import GraphModule, Node
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.experimental.proxy_tensor import DecompositionInterpreter
from torch._decomp import decomposition_table
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from openvino.frontend.pytorch.torchdynamo.op_support import OperatorSupport
from openvino.frontend.pytorch.torchdynamo.backend_utils import _is_testing
import typing as t
import logging
logger = logging.getLogger(__name__)
class PatternNode:
op_types = {}
def __init__(self):
self.op_types = {}
class Partitioner:
def __init__(self, options):
self.supported_ops = OperatorSupport(options)
def fx_serialize(self, graph_module: GraphModule, *args, **kwargs):
fx_gm = make_fx(graph_module)(*args)
return fx_gm
def add_get_attr_inputs(self, partitions: list[Partition]):
# TODO: Find a more efficient way to include input
# "get_attr" nodes to the partitions.
getattr_to_merge: dict[Node, Node] = {}
for partition in partitions:
for pnode in partition.nodes:
for pnode_input in pnode.all_input_nodes:
if pnode_input.op in ["get_attr"] and pnode_input.op not in getattr_to_merge:
getattr_to_merge[pnode_input] = partition
for getattr_node, getattr_part in getattr_to_merge.items():
getattr_part.add_node(getattr_node)
def check_fully_supported(self, graph_module: GraphModule) -> bool:
num_fused = 0
for node in graph_module.graph.nodes:
if node.op == "call_module" and "fused_" in node.name:
num_fused += 1
elif node.op != "placeholder" and node.op != "output":
return False
if num_fused == 1:
return True
return False
def check_pattern(self, node: torch.fx.Node, pattern: PatternNode, enabled_ops: list) -> bool:
if node.op == "call_function":
if ("call_function" + ":" + str(node.target)) in pattern.op_types:
pt_input_nodes = node.all_input_nodes
pattern_input_ops = pattern.op_types["call_function" + ":" + str(node.target)]
if pattern_input_ops is None:
enabled_ops.append(node)
return True
if len(pt_input_nodes) != len(pattern_input_ops):
return False
for i in range(len(pt_input_nodes)):
if not self.check_pattern(pt_input_nodes[i], pattern_input_ops[i], enabled_ops):
return False
enabled_ops.append(node)
return True
elif node.op == "get_attr":
if "get_attr" in pattern.op_types:
return True
else:
return False
return False
def capture_gptq_patterns(self, graph_module: GraphModule):
const_0_node = PatternNode
const_0_node.op_types["get_attr"] = None
unsqueeze_0_node = PatternNode
unsqueeze_0_node.op_types["call_function:aten.unsqueeze.default"] = [const_0_node]
expand_node = PatternNode
expand_node.op_types["call_function:aten.expand.default"] = [unsqueeze_0_node]
const_1_node = PatternNode
const_1_node.op_types["get_attr"] = None
unsqueeze_1_node = PatternNode
unsqueeze_1_node.op_types["call_function:aten.unsqueeze.default"] = [const_1_node]
bitwise_right_shift_node = PatternNode
bitwise_right_shift_node.op_types["call_function:aten.bitwise_right_shift.Tensor"] = [expand_node, unsqueeze_1_node]
to_copy_node = PatternNode
to_copy_node.op_types["call_function:aten._to_copy.default"] = [bitwise_right_shift_node]
add_or_to_copy_node = PatternNode
add_or_to_copy_node.op_types["call_function:aten._to_copy.default"] = [bitwise_right_shift_node]
add_or_to_copy_node.op_types["call_function:aten.add.Tensor"] = [to_copy_node]
bitwise_and_node = PatternNode
bitwise_and_node.op_types["call_function:aten.bitwise_and.Scalar"] = [add_or_to_copy_node]
for node in graph_module.graph.nodes:
if str(node.op) == "call_function" and str(node.target) == "aten.bitwise_and.Scalar":
enabled_ops = []
pattern_match = self.check_pattern(node, bitwise_and_node, enabled_ops)
if pattern_match:
for pattern_op in enabled_ops:
self.supported_ops.enable_by_name(pattern_op)
def capture_nncf_patterns(self, graph_module: GraphModule):
const_node = PatternNode
const_node.op_types["get_attr"] = None
bitwise_right_shift_node = PatternNode
bitwise_right_shift_node.op_types["call_function:aten.bitwise_right_shift.Tensor_Scalar"] = [const_node]
bitwise_and_node = PatternNode
bitwise_and_node.op_types["call_function:aten.bitwise_and.Scalar"] = [const_node]
stack_node = PatternNode
stack_node.op_types["call_function:aten.stack.default"] = [bitwise_and_node, bitwise_right_shift_node]
for node in graph_module.graph.nodes:
if str(node.op) == "call_function" and str(node.target) == "aten.stack.default":
enabled_ops = []
pattern_match = self.check_pattern(node, bitwise_and_node, enabled_ops)
if pattern_match:
for pattern_op in enabled_ops:
self.supported_ops.enable_by_name(pattern_op)
def make_partitions(self, graph_module: GraphModule, options) -> GraphModule:
logger.debug(f"Graph module before partitioning {graph_module}")
allow_single_node_partition = _is_testing(options)
self.capture_gptq_patterns(graph_module)
self.capture_nncf_patterns(graph_module)
partitioner = CapabilityBasedPartitioner(
graph_module, self.supported_ops, allows_single_node_partition=allow_single_node_partition)
partitions = partitioner.propose_partitions()
self.add_get_attr_inputs(partitions)
fused_graph_module = partitioner.fuse_partitions(partitions)
logger.debug(f"Graph module after partitioning {fused_graph_module}")
return fused_graph_module