Files
ANSLibs/OpenVINO/python/openvino/frontend/pytorch/gptq.py

181 lines
7.4 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# mypy: ignore-errors
from functools import partial
import logging
import torch
log = logging.getLogger(__name__)
# Wraps a single tensor to a module to prevent it from jit.freezing
# It depends on a tensor dtype whether it will be preserved from freezing.
# Refer to the decoder code to learn which types will be preserved.
class KeepWeight(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = torch.nn.Parameter(weight, requires_grad=False)
def forward(self):
return self.weight
# Produces a pattern that can be captured later and represented as a single u4 constant node
def decompression_pattern(weights):
mask = torch.tensor(15, dtype=torch.uint8).to(weights.device)
return torch.stack((torch.bitwise_and(weights, mask), torch.bitwise_right_shift(weights, 4)), dim=-1)
def patched_forward(self, *args, **kwargs):
if hasattr(self, "_hf_hook"):
args, kwargs = self._hf_hook.pre_forward(self, *args, **kwargs)
arg = args[0]
dtype = arg.dtype
outshape = arg.shape[:-1] + (self.width,)
arg = arg.contiguous().view(-1, arg.shape[-1])
groups = self.qzeros.shape[0]
height = self.qweight.shape[0]
unpacked_weights = decompression_pattern(
self._openvino_u4_compression_submodule_qweights()).contiguous().view(height, -1, 8)
unpacked_weights = torch.transpose(
unpacked_weights, 1, 2).contiguous().view(-1, self.group_size, self.width)
unpacked_zp = decompression_pattern(
self._openvino_u4_compression_submodule_qzeros()).contiguous().view(groups, 1, -1)
unpacked_weights = (unpacked_weights.to(dtype) - unpacked_zp) * self.scales
unpacked_weights = unpacked_weights.view(-1, self.width)
out = arg @ unpacked_weights
out = out.view(outshape)
if self.bias is not None:
out.add_(self.bias)
if hasattr(self, "_hf_hook"):
out = self._hf_hook.post_forward(self, out)
return out
def patched_forward_sym(self, *args, **kwargs):
if hasattr(self, "_hf_hook"):
args, kwargs = self._hf_hook.pre_forward(self, *args, **kwargs)
arg = args[0]
dtype = arg.dtype
outshape = arg.shape[:-1] + (self.width,)
arg = arg.contiguous().view(-1, arg.shape[-1])
height = self.qweight.shape[0]
unpacked_weights = decompression_pattern(
self._openvino_u4_compression_submodule_qweights()).contiguous().view(height, -1, 8)
unpacked_weights = torch.transpose(
unpacked_weights, 1, 2).contiguous().view(-1, self.group_size, self.width)
# all zp is 8 for symmetrical, will repack to i4 in pt fe transformation
unpacked_weights = (unpacked_weights.to(torch.int8) - torch.tensor(8, dtype=torch.int8))
unpacked_weights = unpacked_weights.to(dtype) * self.scales
unpacked_weights = unpacked_weights.view(-1, self.width)
out = arg @ unpacked_weights
out = out.view(outshape)
if self.bias is not None:
out.add_(self.bias)
if hasattr(self, "_hf_hook"):
out = self._hf_hook.post_forward(self, out)
return out
# All the following AutoGPTQ"s quant types are supposed to have the same weights packing schema
supported_quant_types = ["triton", "exllama", "exllamav2", "cuda-old"]
def patch_model(model):
is_symmetrical = False
config = None
if hasattr(model, "config"):
config = model.config
elif hasattr(model, "model") and hasattr(model.model, "config"):
# original model was wrapped
config = model.model.config
if config is not None and hasattr(config, "quantization_config") and hasattr(config.quantization_config, "sym"):
is_symmetrical = config.quantization_config.sym
for name, module in model.named_modules():
if hasattr(module, "_openvino_patch_orig_forward"):
# already patched, skipping
continue
# TODO: Check module type
is_quantized = getattr(module, "is_quantized", None)
if is_quantized is not None:
module.is_quantized = False
module.float() # enables tracing on CPU, applied for all modules
if hasattr(module, "QUANT_TYPE"):
if module.QUANT_TYPE not in supported_quant_types:
raise ValueError(f"Unsupported QUANT_TYPE == {module.QUANT_TYPE} is discovered for "
"AutoGPTQ model, only the following types are supported: "
f"{supported_quant_types}")
if module.bits != 4:
raise ValueError(f"Unsupported bits == {module.bits} is discovered in module {name} "
"in AutoGPTQ model, only bits == 4 is supported.")
int4_in_int32 = 8
groups = module.qzeros.shape[0]
module.width = module.qweight.shape[1]
assert module.group_size == module.qweight.shape[0] * int4_in_int32 // groups
module._openvino_patch_orig_forward = module.forward
if is_symmetrical:
module.forward = partial(patched_forward_sym, module)
else:
module.forward = partial(patched_forward, module)
# Keep original field properties to be used when model is returned back to its original state
module._openvino_patch_orig_qweights_type = module.qweight.dtype
module._openvino_patch_orig_qzeros_type = module.qzeros.dtype
module._openvino_patch_orig_scale_shape = module.scales.shape
module.qweight = module.qweight.view(dtype=torch.uint8)
module.qzeros = module.qzeros.view(dtype=torch.uint8)
# TODO: Redundant tensor copy? Try to remove module.qweight and module.qzeros after keeping modified values as submodules
module.add_module(
"_openvino_u4_compression_submodule_qweights", KeepWeight(module.qweight))
# Adding 17 to move zp+1 step from after unpacking to before to have correct decompression pattern. Can it overflow?
module.add_module("_openvino_u4_compression_submodule_qzeros",
KeepWeight(module.qzeros + torch.tensor(17, dtype=torch.uint8)))
module.scales = module.scales.view(-1, 1, module.width)
def unpatch_model(model):
for _, module in model.named_modules():
if hasattr(module, "_openvino_patch_orig_forward"):
try:
module.forward = module._openvino_patch_orig_forward
del module._openvino_patch_orig_forward
module.qweight = module.qweight.view(
dtype=module._openvino_patch_orig_qweights_type)
del module._openvino_patch_orig_qweights_type
module.qzeros = module.qzeros.view(
dtype=module._openvino_patch_orig_qzeros_type)
del module._openvino_patch_orig_qzeros_type
module.scales = module.scales.view(module._openvino_patch_orig_scale_shape)
del module._openvino_patch_orig_scale_shape
del module._openvino_u4_compression_submodule_qweights
del module._openvino_u4_compression_submodule_qzeros
except Exception as error:
log.warning("Exception raised during GPTQ model unpatching. "
"Depending on the exact issue it may lead to broken "
"original model.\n%s", error)