gabrielchua's picture
Update app.py
aa264ad verified
raw
history blame
2.28 kB
import os
import gradio as gr
import joblib
import numpy as np
import pandas as pd
from openai import OpenAI
from huggingface_hub import login
from huggingface_hub import hf_hub_download
import logfire
logfire.configure(token=os.getenv("LOGFIRE_API_KEY"))
# Load your pre-trained model and label names
model_path = hf_hub_download(repo_id="govtech/zoo-entry-001", filename="model.joblib", use_auth_token=True)
model_data = joblib.load(model_path)
model = model_data['model']
label_names = model_data['label_names']
# Initialize OpenAI client
client = OpenAI()
def get_embedding(text, embedding_model="text-embedding-3-large"):
"""
Get embedding for the input text from OpenAI.
Replace newlines in the text, then call the API.
"""
text = text.replace("\n", " ")
response = client.embeddings.create(
input=[text],
model=embedding_model
)
# Extract embedding vector from response
embedding = response.data[0].embedding
return np.array(embedding)
def classify_text(text):
"""
Get the OpenAI embedding for the provided text, classify it using your model,
and return an updated DataFrame component with the predictions and probabilities.
"""
embedding = get_embedding(text)
# Add batch dimension
X = np.array(embedding)[None, :]
# Get probabilities from the model
probabilities = model.predict(X)
# Create a DataFrame with probabilities, labels, and binary predictions
df = pd.DataFrame({
'Label': label_names,
'Probability': probabilities[0],
'Prediction': (probabilities[0] > 0.5).astype(int)
})
# Return an update to the DataFrame component to make it visible with the results
logfire.info(f"{text} ({probabilities[0]})")
return gr.update(value=df, visible=True)
with gr.Blocks(title="Zoo Entry 001") as iface:
with gr.Row():
input_text = gr.Textbox(lines=5, label="Input Text")
with gr.Row():
submit_btn = gr.Button("Submit")
# Initialize the table as hidden
with gr.Row():
output_table = gr.DataFrame(label="Classification Results", visible=False)
submit_btn.click(fn=classify_text, inputs=input_text, outputs=output_table)
if __name__ == "__main__":
iface.launch()