Vaibhav-Singh commited on
Commit
25b680f
·
1 Parent(s): 495ee6e
Files changed (1) hide show
  1. app.py +94 -41
app.py CHANGED
@@ -1,63 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from typing import List
5
  import torch
6
 
7
- app = FastAPI(title="Language Model API")
8
-
9
- # Model configuration
10
- CHECKPOINT = "HuggingFaceTB/SmolLM2-135M-Instruct"
11
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- # Initialize model and tokenizer
14
- try:
15
- tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
16
- model = AutoModelForCausalLM.from_pretrained(CHECKPOINT).to(DEVICE)
17
- except Exception as e:
18
- raise RuntimeError(f"Failed to load model: {str(e)}")
 
 
19
 
20
- class ChatMessage(BaseModel):
21
  role: str
22
  content: str
23
 
24
  class ChatRequest(BaseModel):
25
- messages: List[ChatMessage]
26
- max_new_tokens: int = 50
27
- temperature: float = 0.2
28
- top_p: float = 0.9
 
29
 
30
- @app.post("/generate")
31
- async def generate_response(request: ChatRequest):
32
  try:
33
- # Convert messages to the format expected by the model
34
- messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
35
-
36
- # Prepare input
37
- input_text = tokenizer.apply_chat_template(messages, tokenize=False)
38
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
39
-
40
- # Generate response
41
- outputs = model.generate(
42
- inputs,
43
  max_new_tokens=request.max_new_tokens,
44
- temperature=request.temperature,
45
- top_p=request.top_p,
46
- do_sample=True
47
  )
48
 
49
- # Decode and return response
50
- response_text = tokenizer.decode(outputs[0])
51
 
52
- return {
53
- "generated_text": response_text
54
- }
55
 
56
  except Exception as e:
57
  raise HTTPException(status_code=500, detail=str(e))
58
 
 
 
 
 
 
59
  if __name__ == "__main__":
60
  import uvicorn
61
- uvicorn.run(app, host="0.0.0.0", port=7860)
62
-
63
-
 
1
+ # from fastapi import FastAPI, HTTPException
2
+ # from pydantic import BaseModel
3
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ # from typing import List
5
+ # import torch
6
+
7
+ # app = FastAPI(title="Language Model API")
8
+
9
+ # # Model configuration
10
+ # CHECKPOINT = "HuggingFaceTB/SmolLM2-135M-Instruct"
11
+ # DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # # Initialize model and tokenizer
14
+ # try:
15
+ # tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
16
+ # model = AutoModelForCausalLM.from_pretrained(CHECKPOINT).to(DEVICE)
17
+ # except Exception as e:
18
+ # raise RuntimeError(f"Failed to load model: {str(e)}")
19
+
20
+ # class ChatMessage(BaseModel):
21
+ # role: str
22
+ # content: str
23
+
24
+ # class ChatRequest(BaseModel):
25
+ # messages: List[ChatMessage]
26
+ # max_new_tokens: int = 50
27
+ # temperature: float = 0.2
28
+ # top_p: float = 0.9
29
+
30
+ # @app.post("/generate")
31
+ # async def generate_response(request: ChatRequest):
32
+ # try:
33
+ # # Convert messages to the format expected by the model
34
+ # messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
35
+
36
+ # # Prepare input
37
+ # input_text = tokenizer.apply_chat_template(messages, tokenize=False)
38
+ # inputs = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
39
+
40
+ # # Generate response
41
+ # outputs = model.generate(
42
+ # inputs,
43
+ # max_new_tokens=request.max_new_tokens,
44
+ # temperature=request.temperature,
45
+ # top_p=request.top_p,
46
+ # do_sample=True
47
+ # )
48
+
49
+ # # Decode and return response
50
+ # response_text = tokenizer.decode(outputs[0])
51
+
52
+ # return {
53
+ # "generated_text": response_text
54
+ # }
55
+
56
+ # except Exception as e:
57
+ # raise HTTPException(status_code=500, detail=str(e))
58
+
59
+ # if __name__ == "__main__":
60
+ # import uvicorn
61
+ # uvicorn.run(app, host="0.0.0.0", port=7860)
62
+
63
+
64
+
65
  from fastapi import FastAPI, HTTPException
66
  from pydantic import BaseModel
67
+ from typing import List, Dict
68
+ import transformers
69
  import torch
70
 
71
+ app = FastAPI(title="LLaMA API")
 
 
 
 
72
 
73
+ # Initialize the model and pipeline at startup
74
+ model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
75
+ pipeline = transformers.pipeline(
76
+ "text-generation",
77
+ model=model_id,
78
+ model_kwargs={"torch_dtype": torch.bfloat16},
79
+ device_map="auto",
80
+ )
81
 
82
+ class Message(BaseModel):
83
  role: str
84
  content: str
85
 
86
  class ChatRequest(BaseModel):
87
+ messages: List[Message]
88
+ max_new_tokens: int = 256
89
+
90
+ class ChatResponse(BaseModel):
91
+ generated_text: str
92
 
93
+ @app.post("/generate", response_model=ChatResponse)
94
+ async def chat(request: ChatRequest):
95
  try:
96
+ outputs = pipeline(
97
+ [{"role": msg.role, "content": msg.content} for msg in request.messages],
 
 
 
 
 
 
 
 
98
  max_new_tokens=request.max_new_tokens,
 
 
 
99
  )
100
 
101
+ # Extract the last generated message
102
+ generated_text = outputs[0]["generated_text"][-1]
103
 
104
+ return ChatResponse(generated_text=generated_text)
 
 
105
 
106
  except Exception as e:
107
  raise HTTPException(status_code=500, detail=str(e))
108
 
109
+ # Health check endpoint
110
+ @app.get("/")
111
+ async def health_check():
112
+ return {"status": "healthy"}
113
+
114
  if __name__ == "__main__":
115
  import uvicorn
116
+ uvicorn.run(app, host="0.0.0.0", port=8000)