import os
import torch
from gptqmodel import GPTQModel, QuantizeConfig
from datasets import load_dataset

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

model_path = "G:/AI_Study/Ai_model/HyperCLOVAX-SEED-Think-14B"
output_path = "G:/AI_Study/Ai_model/HyperCLOVAX-SEED-Think-14B-GPTQ"

quant_config = QuantizeConfig(bits=4, group_size=128, desc_act=False)

print("모델 로드 시작...")
model = GPTQModel.from_pretrained(
    model_path, quantize_config=quant_config, device_map="auto", trust_remote_code=True
)

# [로그 강화] 모든 선형 레이어의 차원을 강제로 출력하여 검증
print("\n--- [로그] 모든 레이어 차원 전수 검사 시작 ---")
for name, module in model.model.named_modules():
    if hasattr(module, "in_features"):
        before = module.in_features
        # 어떤 차원인지 모든 레이어의 이름을 찍어 확인합니다.
        print(f"[검사] 레이어: {name} | 감지된 in_features: {before}")

        if before == 14334:
            print(f"  >>> [보정 대상 발견] {name} 레이어를 14336으로 수정합니다.")
            module.in_features = 14336

            if hasattr(module, "g_idx") and module.g_idx is not None:
                padding = torch.full(
                    (2,),
                    module.g_idx[-1],
                    device=module.g_idx.device,
                    dtype=module.g_idx.dtype,
                )
                module.g_idx = torch.cat([module.g_idx, padding])
                print(f"    └─ 성공: {name} g_idx 보정 완료.")

print("--- [로그] 전수 검사 종료 ---\n")

print("보정 데이터셋 준비 중...")
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
dataset = [text for text in traindata["text"] if len(text) > 0][:300]

print("양자화 시작...")
try:
    model.quantize(dataset, batch_size=1)
    print("[성공] 양자화 연산이 차원 오류 없이 완료되었습니다.")
except Exception as e:
    print(f"[오류] 양자화 도중 차원 정렬 실패: {e}")
    raise e

print("모델 저장 중...")
model.save_quantized(output_path)
print("양자화 완료!")
