93 lines
3.5 KiB
Python
93 lines
3.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright (C) 2018-2025 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from typing import Optional
|
|
import torch
|
|
from openvino.frontend.pytorch import ModuleExtension, gptq
|
|
from openvino.frontend.pytorch.patch_model import patch_model, unpatch_model
|
|
|
|
|
|
def detect_quantized_model(model: torch.nn.Module) -> Optional[str]:
|
|
"""Detects the quantization method used in a given PyTorch model.
|
|
|
|
Args:
|
|
model (torch.nn.Module): The PyTorch model to check for quantization.
|
|
|
|
Returns:
|
|
str: The quantization method if available, otherwise None.
|
|
"""
|
|
if (model and getattr(model, "config", None)
|
|
and getattr(model.config, "quantization_config", None)):
|
|
return model.config.quantization_config.quant_method # type: ignore
|
|
if getattr(model, "model", None):
|
|
return detect_quantized_model(model.model) # type: ignore[arg-type]
|
|
return None
|
|
|
|
|
|
def patch_quantized(model: torch.nn.Module) -> None:
|
|
"""Patches a model based on its quantization type ("awq" or "gptq").
|
|
|
|
Args:
|
|
model (torch.nn.Module): The model to patch.
|
|
|
|
Raises:
|
|
RuntimeError: If the quantization type is unknown.
|
|
"""
|
|
def fp32_tensor(*shape: int) -> torch.Tensor:
|
|
return torch.full(shape, 0.5, dtype=torch.float32)
|
|
|
|
quant_type = detect_quantized_model(model)
|
|
extensions = {}
|
|
if quant_type == "awq":
|
|
try:
|
|
from awq.modules.linear import WQLinear_GEMM
|
|
extensions[WQLinear_GEMM] = ModuleExtension(
|
|
WQLinear_GEMM, "ov_ext::awq_gemm",
|
|
convert=lambda module, target_op, *args, **kwargs: target_op(
|
|
args[0], module.qweight, module.qzeros, module.scales,
|
|
torch.tensor(module.group_size),
|
|
torch.tensor(module.w_bit), module.bias),
|
|
evaluate=lambda module, *args, **kwargs: fp32_tensor(
|
|
*args[0].shape[:-1], module.out_features)) # type: ignore
|
|
except ImportError:
|
|
pass
|
|
elif quant_type == "bitnet":
|
|
try:
|
|
from transformers.integrations.bitnet import AutoBitLinear
|
|
extensions[AutoBitLinear] = ModuleExtension(
|
|
AutoBitLinear, "ov_ext::bit_linear",
|
|
convert=lambda module, target_op, *args, **kwargs: target_op(
|
|
module.rms_norm(
|
|
args[0]) if module.rms_norm is not None else args[0],
|
|
getattr(module, "original_weight", module.weight),
|
|
module.weight_scale,
|
|
module.bias),
|
|
evaluate=lambda module, *args, **kwargs: fp32_tensor(
|
|
*args[0].shape[:-1], module.out_features)) # type: ignore
|
|
except ImportError:
|
|
pass
|
|
elif quant_type == "gptq":
|
|
model._openvino_gptq_patched = True # type: ignore[assignment]
|
|
gptq.patch_model(model) # type: ignore
|
|
return
|
|
else:
|
|
raise RuntimeError(f"Unknown quantization type: {quant_type}.")
|
|
|
|
patch_model(model, extensions,
|
|
"_openvino_quantized_patch_orig_forward") # type: ignore
|
|
|
|
|
|
def unpatch_quantized(model: torch.nn.Module) -> None:
|
|
"""Reverts the patching applied to a quantized PyTorch model.
|
|
|
|
Args:
|
|
model (torch.nn.Module): The model to unpatch.
|
|
"""
|
|
if getattr(model, "_openvino_gptq_patched", False):
|
|
gptq.unpatch_model(model) # type: ignore
|
|
del model._openvino_gptq_patched
|
|
else:
|
|
unpatch_model(model,
|
|
"_openvino_quantized_patch_orig_forward") # type: ignore
|