Rúben Almeida commited on
Commit
0735f93
·
1 Parent(s): 6af49e3

Add exception handling for incompatible models

Browse files
.vscode/settings.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "python.testing.pytestArgs": [
3
+ "tests"
4
+ ],
5
+ "python.testing.unittestEnabled": false,
6
+ "python.testing.pytestEnabled": true
7
+ }
main.py CHANGED
@@ -1,4 +1,5 @@
1
  import zipfile
 
2
  from typing import Optional, Union
3
  from awq import AutoAWQForCausalLM
4
  from pydantic import BaseModel, Field
@@ -17,37 +18,61 @@ app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.
17
  ### -------
18
 
19
  ### DTO Definitions
20
- class QuantizationConfig(BaseModel):
21
- zero_point: Optional[bool] = Field(True, description="Use zero point quantization")
22
- q_group_size: Optional[int] = Field(128, description="Quantization group size")
23
- w_bit: Optional[int] = Field(4, description="Weight bit")
24
- version: Optional[str] = Field("GEMM", description="Quantization version")
25
-
26
- class ConvertRequest(BaseModel):
27
  hf_model_name: str
28
  hf_tokenizer_name: Optional[str] = Field(None, description="Hugging Face tokenizer name. Defaults to hf_model_name")
29
  hf_token: Optional[str] = Field(None, description="Hugging Face token for private models")
30
  hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.")
31
- quantization_config: QuantizationConfig = Field(QuantizationConfig(), description="Quantization configuration")
32
  ### -------
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  @app.get("/", include_in_schema=False)
36
  def redirect_to_docs():
37
  return RedirectResponse(url='/docs')
38
 
39
  ### FastAPI Endpoints
40
- @app.get("/health")
41
- def read_root():
42
- return {"status": "ok"}
43
-
44
- @app.post("/convert", response_model=None)
45
- def convert(request: ConvertRequest)->Union[FileResponse, dict]:
46
  model = AutoAWQForCausalLM.from_pretrained(request.hf_model_name)
47
  tokenizer = AutoTokenizer.from_pretrained(request.hf_tokenizer_name or request.hf_model_name, trust_remote_code=True)
48
-
49
- model.quantize(tokenizer, quant_config=request.quantization_config.model_dump())
50
-
 
 
 
51
  if request.hf_push_repo:
52
  model.save_quantized(request.hf_push_repo)
53
  tokenizer.save_pretrained(request.hf_push_repo)
@@ -72,4 +97,16 @@ def convert(request: ConvertRequest)->Union[FileResponse, dict]:
72
  )
73
 
74
 
75
- raise HTTPException(status_code=500, detail="Failed to convert model")
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import zipfile
2
+ from abc import ABC
3
  from typing import Optional, Union
4
  from awq import AutoAWQForCausalLM
5
  from pydantic import BaseModel, Field
 
18
  ### -------
19
 
20
  ### DTO Definitions
21
+ class QuantizationConfig(ABC, BaseModel):
22
+ pass
23
+ class ConvertRequest(ABC, BaseModel):
 
 
 
 
24
  hf_model_name: str
25
  hf_tokenizer_name: Optional[str] = Field(None, description="Hugging Face tokenizer name. Defaults to hf_model_name")
26
  hf_token: Optional[str] = Field(None, description="Hugging Face token for private models")
27
  hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.")
 
28
  ### -------
29
 
30
+ ### Quantization Configurations
31
+ class AWQQuantizationConfig(QuantizationConfig):
32
+ zero_point: Optional[bool] = Field(True, description="Use zero point quantization")
33
+ q_group_size: Optional[int] = Field(128, description="Quantization group size")
34
+ w_bit: Optional[int] = Field(4, description="Weight bit")
35
+ version: Optional[str] = Field("GEMM", description="Quantization version")
36
+
37
+ class GPTQQuantizationConfig(QuantizationConfig):
38
+ pass
39
+
40
+ class GGUFQuantizationConfig(QuantizationConfig):
41
+ pass
42
+ class AWQConvertionRequest(ConvertRequest):
43
+ quantization_config: Optional[AWQQuantizationConfig] = Field(
44
+ default_factory=lambda: AWQQuantizationConfig(),
45
+ description="AWQ quantization configuration"
46
+ )
47
+
48
+ class GPTQConvertionRequest(ConvertRequest):
49
+ quantization_config: Optional[GPTQQuantizationConfig] = Field(
50
+ default_factory=lambda: GPTQQuantizationConfig(),
51
+ description="GPTQ quantization configuration"
52
+ )
53
+
54
+ class GGUFConvertionRequest(ConvertRequest):
55
+ quantization_config: Optional[GGUFQuantizationConfig] = Field(
56
+ default_factory=lambda: GGUFQuantizationConfig(),
57
+ description="GGUF quantization configuration"
58
+ )
59
+ ### -------
60
 
61
  @app.get("/", include_in_schema=False)
62
  def redirect_to_docs():
63
  return RedirectResponse(url='/docs')
64
 
65
  ### FastAPI Endpoints
66
+ @app.post("/convert_awq", response_model=None)
67
+ def convert(request: AWQConvertionRequest)->Union[FileResponse, dict]:
 
 
 
 
68
  model = AutoAWQForCausalLM.from_pretrained(request.hf_model_name)
69
  tokenizer = AutoTokenizer.from_pretrained(request.hf_tokenizer_name or request.hf_model_name, trust_remote_code=True)
70
+
71
+ try:
72
+ model.quantize(tokenizer, quant_config=request.quantization_config.model_dump())
73
+ except TypeError as e:
74
+ raise HTTPException(status_code=400, detail=f"Is this model supported by AWQ Quantization? Check:https://github.com/mit-han-lab/llm-awq?tab=readme-ov-file {e}")
75
+
76
  if request.hf_push_repo:
77
  model.save_quantized(request.hf_push_repo)
78
  tokenizer.save_pretrained(request.hf_push_repo)
 
97
  )
98
 
99
 
100
+ raise HTTPException(status_code=500, detail="Failed to convert model")
101
+
102
+ @app.post("/convert_gpt_q", response_model=None)
103
+ def convert_gpt_q(request: ConvertRequest)->Union[FileResponse, dict]:
104
+ raise HTTPException(status_code=501, detail="Not implemented yet")
105
+
106
+ @app.post("/convert_gguf", response_model=None)
107
+ def convert_gguf(request: ConvertRequest)->Union[FileResponse, dict]:
108
+ raise HTTPException(status_code=501, detail="Not implemented yet")
109
+
110
+ @app.get("/health")
111
+ def read_root():
112
+ return {"status": "ok"}
requirements.txt CHANGED
@@ -9,4 +9,7 @@ fastapi[standard]
9
  transformers
10
  huggingface_hub
11
  autoawq[kernels]
12
- starlette>=0.46.2
 
 
 
 
9
  transformers
10
  huggingface_hub
11
  autoawq[kernels]
12
+ starlette>=0.46.2
13
+ pytest
14
+ requests
15
+ environs
tests/.env.example ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ENDPOINT=
2
+ HF_TOKEN=
tests/__init__.py ADDED
File without changes
tests/test_convertion.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import requests
3
+ from environs import Env
4
+ from huggingface_hub import login
5
+
6
+ env = Env()
7
+ env.read_env(override=True)
8
+
9
+ @pytest.mark.parametrize("model_name", [
10
+ "gpt2",
11
+ ])
12
+ def test_convert_download(model_name):
13
+ if env.str("HF_TOKEN"):
14
+ login(token=env("HF_TOKEN"))
15
+
16
+ response = requests.post(
17
+ env.str("ENDPOINT"),
18
+ json={
19
+ "hf_model_name": model_name,
20
+ "hf_tokenizer_name": model_name,
21
+ "hf_push_repo": None,
22
+ }
23
+ )
24
+
25
+ response.raise_for_status()
26
+
27
+ assert response.content_type == 'application/zip'
28
+
29
+
30
+ def test_convert_push():
31
+ pass