# -*- coding: utf-8 -*- # Copyright (C) 2018-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from typing import Optional import torch from openvino.frontend.pytorch import ModuleExtension, gptq from openvino.frontend.pytorch.patch_model import patch_model, unpatch_model def detect_quantized_model(model: torch.nn.Module) -> Optional[str]: """Detects the quantization method used in a given PyTorch model. Args: model (torch.nn.Module): The PyTorch model to check for quantization. Returns: str: The quantization method if available, otherwise None. """ if (model and getattr(model, "config", None) and getattr(model.config, "quantization_config", None)): return model.config.quantization_config.quant_method # type: ignore if getattr(model, "model", None): return detect_quantized_model(model.model) # type: ignore[arg-type] return None def patch_quantized(model: torch.nn.Module) -> None: """Patches a model based on its quantization type ("awq" or "gptq"). Args: model (torch.nn.Module): The model to patch. Raises: RuntimeError: If the quantization type is unknown. """ def fp32_tensor(*shape: int) -> torch.Tensor: return torch.full(shape, 0.5, dtype=torch.float32) quant_type = detect_quantized_model(model) extensions = {} if quant_type == "awq": try: from awq.modules.linear import WQLinear_GEMM extensions[WQLinear_GEMM] = ModuleExtension( WQLinear_GEMM, "ov_ext::awq_gemm", convert=lambda module, target_op, *args, **kwargs: target_op( args[0], module.qweight, module.qzeros, module.scales, torch.tensor(module.group_size), torch.tensor(module.w_bit), module.bias), evaluate=lambda module, *args, **kwargs: fp32_tensor( *args[0].shape[:-1], module.out_features)) # type: ignore except ImportError: pass elif quant_type == "bitnet": try: from transformers.integrations.bitnet import AutoBitLinear extensions[AutoBitLinear] = ModuleExtension( AutoBitLinear, "ov_ext::bit_linear", convert=lambda module, target_op, *args, **kwargs: target_op( module.rms_norm( args[0]) if module.rms_norm is not None else args[0], getattr(module, "original_weight", module.weight), module.weight_scale, module.bias), evaluate=lambda module, *args, **kwargs: fp32_tensor( *args[0].shape[:-1], module.out_features)) # type: ignore except ImportError: pass elif quant_type == "gptq": model._openvino_gptq_patched = True # type: ignore[assignment] gptq.patch_model(model) # type: ignore return else: raise RuntimeError(f"Unknown quantization type: {quant_type}.") patch_model(model, extensions, "_openvino_quantized_patch_orig_forward") # type: ignore def unpatch_quantized(model: torch.nn.Module) -> None: """Reverts the patching applied to a quantized PyTorch model. Args: model (torch.nn.Module): The model to unpatch. """ if getattr(model, "_openvino_gptq_patched", False): gptq.unpatch_model(model) # type: ignore del model._openvino_gptq_patched else: unpatch_model(model, "_openvino_quantized_patch_orig_forward") # type: ignore