Rúben Almeida
RedirectResponse import path not exists
af9aed3
raw
history blame
2.45 kB
import zipfile
from typing import Optional, Union
from awq import AutoAWQForCausalLM
from pydantic import BaseModel, Field
from transformers import AutoTokenizer
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse, FileResponse
### FastAPI Initialization
@asynccontextmanager
async def lifespan(app:FastAPI):
yield
app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.1.0", lifespan=lifespan)
### -------
### DTO Definitions
class QuantizationConfig(BaseModel):
zero_point: Optional[bool] = Field(True, description="Use zero point quantization")
q_group_size: Optional[int] = Field(128, description="Quantization group size")
w_bit: Optional[int] = Field(4, description="Weight bit")
version: Optional[str] = Field("GEMM", description="Quantization version")
class ConvertRequest(BaseModel):
hf_model_name: str
hf_tokenizer_name: Optional[str] = Field(None, description="Hugging Face tokenizer name. Defaults to hf_model_name")
hf_token: Optional[str] = Field(None, description="Hugging Face token for private models")
hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.")
quantization_config: QuantizationConfig = Field(QuantizationConfig(), description="Quantization configuration")
### -------
@app.get("/", include_in_schema=False)
def redirect_to_docs():
return RedirectResponse(url='/docs')
### FastAPI Endpoints
@app.get("/health")
def read_root():
return {"status": "ok"}
@app.post("/convert")
def convert(request: ConvertRequest)->Union[FileResponse, dict]:
model = AutoAWQForCausalLM.from_pretrained(request.hf_model_name)
tokenizer = AutoTokenizer.from_pretrained(request.hf_tokenizer_name or request.hf_model_name, trust_remote_code=True)
model.quantize(tokenizer, quant_config=quant_config)
if request.hf_push_repo:
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
return {
"status": "ok",
"message": f"Model saved to {quant_path}"
}
# Return a zip file with the converted model
raise HTTPException(status_code=501, detail="Not Implemented yet")
#return FileResponse(file_location, media_type='application/octet-stream',filename=file_name)