Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import joblib | |
import logfire | |
import numpy as np | |
import pandas as pd | |
from openai import OpenAI | |
from pydantic import BaseModel | |
# Configure logging | |
logfire.configure(token=os.getenv("LOGFIRE_API_KEY")) | |
logfire.instrument_pydantic() | |
# Load pre-trained model and label names | |
model_data = joblib.load("model.joblib") | |
model = model_data["model"] | |
label_names = model_data["label_names"] | |
class Results(BaseModel): | |
text: str | |
hateful: float | |
insults: float | |
sexual: float | |
violence: float | |
self_harm: float | |
aom: float | |
# Initialize OpenAI client | |
client = OpenAI() | |
def get_embedding(text: str, embedding_model: str = "text-embedding-3-large") -> np.ndarray: | |
""" | |
Get embedding for the input text from OpenAI. | |
Replaces newlines with spaces before calling the API. | |
""" | |
text = text.replace("\n", " ") | |
response = client.embeddings.create(input=[text], model=embedding_model) | |
embedding = response.data[0].embedding | |
return np.array(embedding) | |
def classify_text(text: str): | |
""" | |
Get the OpenAI embedding for the provided text, classify it using your model, | |
and return a DataFrame with the rounded probabilities and binary predictions. | |
""" | |
embedding = get_embedding(text) | |
X = embedding.reshape(1, -1) | |
probabilities = model.predict(X) | |
rounded_probs = np.round(probabilities[0], 4) | |
# Optionally log the results (this doesn't affect the output) | |
Results( | |
text=text, | |
hateful=rounded_probs[0], | |
insults=rounded_probs[1], | |
sexual=rounded_probs[2], | |
violence=rounded_probs[3], | |
self_harm=rounded_probs[4], | |
aom=rounded_probs[5], | |
) | |
# Create DataFrame with rounded probabilities and binary predictions | |
df = pd.DataFrame({ | |
"Label": label_names, | |
"Probability": rounded_probs, | |
"Prediction": (rounded_probs > 0.5).astype(int) | |
}) | |
return gr.update(value=df, visible=True) | |
with gr.Blocks(title="Zoo Entry 001") as iface: | |
input_text = gr.Textbox(lines=5, label="Input Text") | |
submit_btn = gr.Button("Submit") | |
output_table = gr.DataFrame(label="Classification Results", visible=False) | |
submit_btn.click(fn=classify_text, inputs=input_text, outputs=output_table) | |
if __name__ == "__main__": | |
iface.launch() | |