import os import gradio as gr import joblib import logfire import numpy as np import pandas as pd from openai import OpenAI from pydantic import BaseModel # Configure logging logfire.configure(token=os.getenv("LOGFIRE_API_KEY")) logfire.instrument_pydantic() # Load pre-trained model and label names model_data = joblib.load("model.joblib") model = model_data["model"] label_names = model_data["label_names"] class Results(BaseModel): text: str hateful: float insults: float sexual: float violence: float self_harm: float aom: float # Initialize OpenAI client client = OpenAI() def get_embedding(text: str, embedding_model: str = "text-embedding-3-large") -> np.ndarray: """ Get embedding for the input text from OpenAI. Replaces newlines with spaces before calling the API. """ text = text.replace("\n", " ") response = client.embeddings.create(input=[text], model=embedding_model) embedding = response.data[0].embedding return np.array(embedding) def classify_text(text: str): """ Get the OpenAI embedding for the provided text, classify it using your model, and return a DataFrame with the rounded probabilities and binary predictions. """ embedding = get_embedding(text) X = embedding.reshape(1, -1) probabilities = model.predict(X) rounded_probs = np.round(probabilities[0], 4) # Optionally log the results (this doesn't affect the output) Results( text=text, hateful=rounded_probs[0], insults=rounded_probs[1], sexual=rounded_probs[2], violence=rounded_probs[3], self_harm=rounded_probs[4], aom=rounded_probs[5], ) # Create DataFrame with rounded probabilities and binary predictions df = pd.DataFrame({ "Label": label_names, "Probability": rounded_probs, "Prediction": (rounded_probs > 0.5).astype(int) }) return gr.update(value=df, visible=True) with gr.Blocks(title="Zoo Entry 001") as iface: input_text = gr.Textbox(lines=5, label="Input Text") submit_btn = gr.Button("Submit") output_table = gr.DataFrame(label="Classification Results", visible=False) submit_btn.click(fn=classify_text, inputs=input_text, outputs=output_table) if __name__ == "__main__": iface.launch()