181 lines
7.4 KiB
Python
181 lines
7.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright (C) 2018-2025 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# mypy: ignore-errors
|
|
|
|
from functools import partial
|
|
import logging
|
|
import torch
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# Wraps a single tensor to a module to prevent it from jit.freezing
|
|
# It depends on a tensor dtype whether it will be preserved from freezing.
|
|
# Refer to the decoder code to learn which types will be preserved.
|
|
class KeepWeight(torch.nn.Module):
|
|
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
|
|
def forward(self):
|
|
return self.weight
|
|
|
|
|
|
# Produces a pattern that can be captured later and represented as a single u4 constant node
|
|
def decompression_pattern(weights):
|
|
mask = torch.tensor(15, dtype=torch.uint8).to(weights.device)
|
|
return torch.stack((torch.bitwise_and(weights, mask), torch.bitwise_right_shift(weights, 4)), dim=-1)
|
|
|
|
|
|
def patched_forward(self, *args, **kwargs):
|
|
if hasattr(self, "_hf_hook"):
|
|
args, kwargs = self._hf_hook.pre_forward(self, *args, **kwargs)
|
|
|
|
arg = args[0]
|
|
dtype = arg.dtype
|
|
outshape = arg.shape[:-1] + (self.width,)
|
|
arg = arg.contiguous().view(-1, arg.shape[-1])
|
|
groups = self.qzeros.shape[0]
|
|
height = self.qweight.shape[0]
|
|
|
|
unpacked_weights = decompression_pattern(
|
|
self._openvino_u4_compression_submodule_qweights()).contiguous().view(height, -1, 8)
|
|
unpacked_weights = torch.transpose(
|
|
unpacked_weights, 1, 2).contiguous().view(-1, self.group_size, self.width)
|
|
unpacked_zp = decompression_pattern(
|
|
self._openvino_u4_compression_submodule_qzeros()).contiguous().view(groups, 1, -1)
|
|
|
|
unpacked_weights = (unpacked_weights.to(dtype) - unpacked_zp) * self.scales
|
|
unpacked_weights = unpacked_weights.view(-1, self.width)
|
|
|
|
out = arg @ unpacked_weights
|
|
|
|
out = out.view(outshape)
|
|
if self.bias is not None:
|
|
out.add_(self.bias)
|
|
|
|
if hasattr(self, "_hf_hook"):
|
|
out = self._hf_hook.post_forward(self, out)
|
|
return out
|
|
|
|
|
|
def patched_forward_sym(self, *args, **kwargs):
|
|
if hasattr(self, "_hf_hook"):
|
|
args, kwargs = self._hf_hook.pre_forward(self, *args, **kwargs)
|
|
|
|
arg = args[0]
|
|
dtype = arg.dtype
|
|
outshape = arg.shape[:-1] + (self.width,)
|
|
arg = arg.contiguous().view(-1, arg.shape[-1])
|
|
height = self.qweight.shape[0]
|
|
|
|
unpacked_weights = decompression_pattern(
|
|
self._openvino_u4_compression_submodule_qweights()).contiguous().view(height, -1, 8)
|
|
unpacked_weights = torch.transpose(
|
|
unpacked_weights, 1, 2).contiguous().view(-1, self.group_size, self.width)
|
|
|
|
# all zp is 8 for symmetrical, will repack to i4 in pt fe transformation
|
|
unpacked_weights = (unpacked_weights.to(torch.int8) - torch.tensor(8, dtype=torch.int8))
|
|
unpacked_weights = unpacked_weights.to(dtype) * self.scales
|
|
unpacked_weights = unpacked_weights.view(-1, self.width)
|
|
|
|
out = arg @ unpacked_weights
|
|
|
|
out = out.view(outshape)
|
|
if self.bias is not None:
|
|
out.add_(self.bias)
|
|
|
|
if hasattr(self, "_hf_hook"):
|
|
out = self._hf_hook.post_forward(self, out)
|
|
return out
|
|
|
|
|
|
# All the following AutoGPTQ"s quant types are supposed to have the same weights packing schema
|
|
supported_quant_types = ["triton", "exllama", "exllamav2", "cuda-old"]
|
|
|
|
|
|
def patch_model(model):
|
|
is_symmetrical = False
|
|
config = None
|
|
if hasattr(model, "config"):
|
|
config = model.config
|
|
elif hasattr(model, "model") and hasattr(model.model, "config"):
|
|
# original model was wrapped
|
|
config = model.model.config
|
|
if config is not None and hasattr(config, "quantization_config") and hasattr(config.quantization_config, "sym"):
|
|
is_symmetrical = config.quantization_config.sym
|
|
for name, module in model.named_modules():
|
|
if hasattr(module, "_openvino_patch_orig_forward"):
|
|
# already patched, skipping
|
|
continue
|
|
# TODO: Check module type
|
|
is_quantized = getattr(module, "is_quantized", None)
|
|
if is_quantized is not None:
|
|
module.is_quantized = False
|
|
module.float() # enables tracing on CPU, applied for all modules
|
|
if hasattr(module, "QUANT_TYPE"):
|
|
if module.QUANT_TYPE not in supported_quant_types:
|
|
raise ValueError(f"Unsupported QUANT_TYPE == {module.QUANT_TYPE} is discovered for "
|
|
"AutoGPTQ model, only the following types are supported: "
|
|
f"{supported_quant_types}")
|
|
if module.bits != 4:
|
|
raise ValueError(f"Unsupported bits == {module.bits} is discovered in module {name} "
|
|
"in AutoGPTQ model, only bits == 4 is supported.")
|
|
|
|
int4_in_int32 = 8
|
|
groups = module.qzeros.shape[0]
|
|
module.width = module.qweight.shape[1]
|
|
assert module.group_size == module.qweight.shape[0] * int4_in_int32 // groups
|
|
|
|
module._openvino_patch_orig_forward = module.forward
|
|
if is_symmetrical:
|
|
module.forward = partial(patched_forward_sym, module)
|
|
else:
|
|
module.forward = partial(patched_forward, module)
|
|
|
|
# Keep original field properties to be used when model is returned back to its original state
|
|
module._openvino_patch_orig_qweights_type = module.qweight.dtype
|
|
module._openvino_patch_orig_qzeros_type = module.qzeros.dtype
|
|
module._openvino_patch_orig_scale_shape = module.scales.shape
|
|
|
|
module.qweight = module.qweight.view(dtype=torch.uint8)
|
|
module.qzeros = module.qzeros.view(dtype=torch.uint8)
|
|
|
|
# TODO: Redundant tensor copy? Try to remove module.qweight and module.qzeros after keeping modified values as submodules
|
|
module.add_module(
|
|
"_openvino_u4_compression_submodule_qweights", KeepWeight(module.qweight))
|
|
# Adding 17 to move zp+1 step from after unpacking to before to have correct decompression pattern. Can it overflow?
|
|
module.add_module("_openvino_u4_compression_submodule_qzeros",
|
|
KeepWeight(module.qzeros + torch.tensor(17, dtype=torch.uint8)))
|
|
|
|
module.scales = module.scales.view(-1, 1, module.width)
|
|
|
|
|
|
def unpatch_model(model):
|
|
for _, module in model.named_modules():
|
|
if hasattr(module, "_openvino_patch_orig_forward"):
|
|
try:
|
|
module.forward = module._openvino_patch_orig_forward
|
|
del module._openvino_patch_orig_forward
|
|
|
|
module.qweight = module.qweight.view(
|
|
dtype=module._openvino_patch_orig_qweights_type)
|
|
del module._openvino_patch_orig_qweights_type
|
|
|
|
module.qzeros = module.qzeros.view(
|
|
dtype=module._openvino_patch_orig_qzeros_type)
|
|
del module._openvino_patch_orig_qzeros_type
|
|
|
|
module.scales = module.scales.view(module._openvino_patch_orig_scale_shape)
|
|
del module._openvino_patch_orig_scale_shape
|
|
|
|
del module._openvino_u4_compression_submodule_qweights
|
|
del module._openvino_u4_compression_submodule_qzeros
|
|
except Exception as error:
|
|
log.warning("Exception raised during GPTQ model unpatching. "
|
|
"Depending on the exact issue it may lead to broken "
|
|
"original model.\n%s", error)
|