|
import zipfile |
|
from abc import ABC |
|
from typing import Optional, Union |
|
from awq import AutoAWQForCausalLM |
|
from pydantic import BaseModel, Field |
|
from transformers import AutoTokenizer |
|
from tempfile import NamedTemporaryFile |
|
from contextlib import asynccontextmanager |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import RedirectResponse, FileResponse |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app:FastAPI): |
|
yield |
|
|
|
app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.1.0", lifespan=lifespan) |
|
|
|
|
|
|
|
class QuantizationConfig(ABC, BaseModel): |
|
pass |
|
class ConvertRequest(ABC, BaseModel): |
|
hf_model_name: str |
|
hf_tokenizer_name: Optional[str] = Field(None, description="Hugging Face tokenizer name. Defaults to hf_model_name") |
|
hf_token: Optional[str] = Field(None, description="Hugging Face token for private models") |
|
hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.") |
|
|
|
|
|
|
|
class AWQQuantizationConfig(QuantizationConfig): |
|
zero_point: Optional[bool] = Field(True, description="Use zero point quantization") |
|
q_group_size: Optional[int] = Field(128, description="Quantization group size") |
|
w_bit: Optional[int] = Field(4, description="Weight bit") |
|
version: Optional[str] = Field("GEMM", description="Quantization version") |
|
|
|
class GPTQQuantizationConfig(QuantizationConfig): |
|
pass |
|
|
|
class GGUFQuantizationConfig(QuantizationConfig): |
|
pass |
|
class AWQConvertionRequest(ConvertRequest): |
|
quantization_config: Optional[AWQQuantizationConfig] = Field( |
|
default_factory=lambda: AWQQuantizationConfig(), |
|
description="AWQ quantization configuration" |
|
) |
|
|
|
class GPTQConvertionRequest(ConvertRequest): |
|
quantization_config: Optional[GPTQQuantizationConfig] = Field( |
|
default_factory=lambda: GPTQQuantizationConfig(), |
|
description="GPTQ quantization configuration" |
|
) |
|
|
|
class GGUFConvertionRequest(ConvertRequest): |
|
quantization_config: Optional[GGUFQuantizationConfig] = Field( |
|
default_factory=lambda: GGUFQuantizationConfig(), |
|
description="GGUF quantization configuration" |
|
) |
|
|
|
|
|
@app.get("/", include_in_schema=False) |
|
def redirect_to_docs(): |
|
return RedirectResponse(url='/docs') |
|
|
|
|
|
@app.post("/convert_awq", response_model=None) |
|
def convert(request: AWQConvertionRequest)->Union[FileResponse, dict]: |
|
model = AutoAWQForCausalLM.from_pretrained(request.hf_model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(request.hf_tokenizer_name or request.hf_model_name, trust_remote_code=True) |
|
|
|
try: |
|
model.quantize(tokenizer, quant_config=request.quantization_config.model_dump()) |
|
except TypeError as e: |
|
raise HTTPException(status_code=400, detail=f"Is this model supported by AWQ Quantization? Check:https://github.com/mit-han-lab/llm-awq?tab=readme-ov-file {e}") |
|
|
|
if request.hf_push_repo: |
|
model.save_quantized(request.hf_push_repo) |
|
tokenizer.save_pretrained(request.hf_push_repo) |
|
|
|
return { |
|
"status": "ok", |
|
"message": f"Model saved to {request.hf_push_repo}", |
|
} |
|
|
|
|
|
with NamedTemporaryFile(suffix=".zip", delete=False) as temp_zip: |
|
zip_file_path = temp_zip.name |
|
with zipfile.ZipFile(zip_file_path, 'w') as zipf: |
|
|
|
model.save_quantized(zipf) |
|
tokenizer.save_pretrained(zipf) |
|
|
|
return FileResponse( |
|
zip_file_path, |
|
media_type='application/zip', |
|
filename=f"{request.hf_model_name}.zip" |
|
) |
|
|
|
|
|
raise HTTPException(status_code=500, detail="Failed to convert model") |
|
|
|
@app.post("/convert_gpt_q", response_model=None) |
|
def convert_gpt_q(request: ConvertRequest)->Union[FileResponse, dict]: |
|
raise HTTPException(status_code=501, detail="Not implemented yet") |
|
|
|
@app.post("/convert_gguf", response_model=None) |
|
def convert_gguf(request: ConvertRequest)->Union[FileResponse, dict]: |
|
raise HTTPException(status_code=501, detail="Not implemented yet") |
|
|
|
@app.get("/health") |
|
def read_root(): |
|
return {"status": "ok"} |