Luisgust commited on
Commit
3a2f68b
·
verified ·
1 Parent(s): c0885e3

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +67 -0
main.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from typing import List, Tuple
4
+ from fastapi import FastAPI, Form, HTTPException
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from fastapi.responses import JSONResponse
7
+ from huggingface_hub import InferenceClient
8
+
9
+ # Initialize FastAPI app
10
+ app = FastAPI()
11
+
12
+ # Allow CORS for your frontend application
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"], # Change this to your frontend's URL in production
16
+ allow_credentials=True,
17
+ allow_methods=["*"],
18
+ allow_headers=["*"],
19
+ )
20
+
21
+ # Initialize Hugging Face Inference Client
22
+ client = InferenceClient("mistralai/Mistral-Nemo-Instruct-2407")
23
+
24
+ def format_prompt(message: str, history: List[Tuple[str, str]]) -> str:
25
+ prompt = "<s>"
26
+ for user_prompt, bot_response in history:
27
+ prompt += f"[INST] {user_prompt} [/INST]"
28
+ prompt += f" {bot_response}</s> "
29
+ prompt += f"[INST] {message} [/INST]"
30
+ return prompt
31
+
32
+ @app.post("/generate/")
33
+ async def generate(
34
+ prompt: str = Form(...),
35
+ history: str = Form(...),
36
+ temperature: float = Form(0.9),
37
+ max_new_tokens: int = Form(512),
38
+ top_p: float = Form(0.95),
39
+ repetition_penalty: float = Form(1.0)
40
+ ):
41
+ try:
42
+ # Parse history from JSON string to list of tuples
43
+ chat_history = eval(history)
44
+
45
+ # Format the prompt
46
+ formatted_prompt = format_prompt(prompt, chat_history)
47
+
48
+ generate_kwargs = dict(
49
+ temperature=temperature,
50
+ max_new_tokens=max_new_tokens,
51
+ top_p=top_p,
52
+ repetition_penalty=repetition_penalty,
53
+ do_sample=True,
54
+ seed=random.randint(0, 10**7),
55
+ )
56
+
57
+ # Generate text using the model
58
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
59
+ output = ""
60
+
61
+ for response in stream:
62
+ output += response.token.text
63
+
64
+ return JSONResponse(content={"response": output})
65
+
66
+ except Exception as e:
67
+ raise HTTPException(status_code=500, detail=str(e))