Wauplin HF staff commited on
Commit
f6c0ef9
·
verified ·
1 Parent(s): a874d43

Use `huggingface_hub.InferenceClient` instead of `openai` to call Sambanova

Browse files

This is a suggestion to use the `huggingface_hub` client instead of openai's one to call Sambanova API. No need to provide the sambanova API endpoint anymore. Also, one can use the HF model id and easily switch between providers (currently 7 providers on https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct, user can select based on speed/cost trade-off).

More advanced suggestion would be to set a `HF_TOKEN` Space secret instead and instantiate the client like this:

```py
client = huggingface_hub.InferenceClient(
provider="sambanova",
)
```

This will provide HF routing => easier to switch between providers while keeping billing in a single place.

Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -5,7 +5,7 @@ from pathlib import Path
5
 
6
  import gradio as gr
7
  import numpy as np
8
- import openai
9
  from dotenv import load_dotenv
10
  from fastapi import FastAPI
11
  from fastapi.responses import HTMLResponse, StreamingResponse
@@ -25,9 +25,9 @@ load_dotenv()
25
  curr_dir = Path(__file__).parent
26
 
27
 
28
- client = openai.OpenAI(
29
  api_key=os.environ.get("SAMBANOVA_API_KEY"),
30
- base_url="https://api.sambanova.ai/v1",
31
  )
32
  stt_model = get_stt_model()
33
 
@@ -52,7 +52,7 @@ def response(
52
  raise WebRTCError("test")
53
 
54
  request = client.chat.completions.create(
55
- model="Meta-Llama-3.2-3B-Instruct",
56
  messages=conversation_state, # type: ignore
57
  temperature=0.1,
58
  top_p=0.1,
 
5
 
6
  import gradio as gr
7
  import numpy as np
8
+ import huggingface_hub
9
  from dotenv import load_dotenv
10
  from fastapi import FastAPI
11
  from fastapi.responses import HTMLResponse, StreamingResponse
 
25
  curr_dir = Path(__file__).parent
26
 
27
 
28
+ client = huggingface_hub.InferenceClient(
29
  api_key=os.environ.get("SAMBANOVA_API_KEY"),
30
+ provider="sambanova",
31
  )
32
  stt_model = get_stt_model()
33
 
 
52
  raise WebRTCError("test")
53
 
54
  request = client.chat.completions.create(
55
+ model="meta-llama/Llama-3.2-3B-Instruct",
56
  messages=conversation_state, # type: ignore
57
  temperature=0.1,
58
  top_p=0.1,