gabrielchua's picture
Update app.py
71b879a verified
import os
import gradio as gr
import joblib
import logfire
import numpy as np
import pandas as pd
from openai import OpenAI
from pydantic import BaseModel
# Configure logging
logfire.configure(token=os.getenv("LOGFIRE_API_KEY"))
logfire.instrument_pydantic()
# Load pre-trained model and label names
model_data = joblib.load("model.joblib")
model = model_data["model"]
label_names = model_data["label_names"]
class Results(BaseModel):
text: str
hateful: float
insults: float
sexual: float
violence: float
self_harm: float
aom: float
# Initialize OpenAI client
client = OpenAI()
def get_embedding(text: str, embedding_model: str = "text-embedding-3-large") -> np.ndarray:
"""
Get embedding for the input text from OpenAI.
Replaces newlines with spaces before calling the API.
"""
text = text.replace("\n", " ")
response = client.embeddings.create(input=[text], model=embedding_model)
embedding = response.data[0].embedding
return np.array(embedding)
def classify_text(text: str):
"""
Get the OpenAI embedding for the provided text, classify it using your model,
and return a DataFrame with the rounded probabilities and binary predictions.
"""
embedding = get_embedding(text)
X = embedding.reshape(1, -1)
probabilities = model.predict(X)
rounded_probs = np.round(probabilities[0], 4)
# Optionally log the results (this doesn't affect the output)
Results(
text=text,
hateful=rounded_probs[0],
insults=rounded_probs[1],
sexual=rounded_probs[2],
violence=rounded_probs[3],
self_harm=rounded_probs[4],
aom=rounded_probs[5],
)
# Create DataFrame with rounded probabilities and binary predictions
df = pd.DataFrame({
"Label": label_names,
"Probability": rounded_probs,
"Prediction": (rounded_probs > 0.5).astype(int)
})
return gr.update(value=df, visible=True)
with gr.Blocks(title="Zoo Entry 001") as iface:
input_text = gr.Textbox(lines=5, label="Input Text")
submit_btn = gr.Button("Submit")
output_table = gr.DataFrame(label="Classification Results", visible=False)
submit_btn.click(fn=classify_text, inputs=input_text, outputs=output_table)
if __name__ == "__main__":
iface.launch()