|
import os |
|
import gradio as gr |
|
import joblib |
|
import logfire |
|
import numpy as np |
|
import pandas as pd |
|
from openai import OpenAI |
|
from pydantic import BaseModel |
|
|
|
|
|
logfire.configure(token=os.getenv("LOGFIRE_API_KEY")) |
|
logfire.instrument_pydantic() |
|
|
|
|
|
model_data = joblib.load("model.joblib") |
|
model = model_data["model"] |
|
label_names = model_data["label_names"] |
|
|
|
class Results(BaseModel): |
|
text: str |
|
hateful: float |
|
insults: float |
|
sexual: float |
|
violence: float |
|
self_harm: float |
|
aom: float |
|
|
|
|
|
client = OpenAI() |
|
|
|
def get_embedding(text: str, embedding_model: str = "text-embedding-3-large") -> np.ndarray: |
|
""" |
|
Get embedding for the input text from OpenAI. |
|
Replaces newlines with spaces before calling the API. |
|
""" |
|
text = text.replace("\n", " ") |
|
response = client.embeddings.create(input=[text], model=embedding_model) |
|
embedding = response.data[0].embedding |
|
return np.array(embedding) |
|
|
|
def classify_text(text: str): |
|
""" |
|
Get the OpenAI embedding for the provided text, classify it using your model, |
|
and return a DataFrame with the rounded probabilities and binary predictions. |
|
""" |
|
embedding = get_embedding(text) |
|
X = embedding.reshape(1, -1) |
|
probabilities = model.predict(X) |
|
rounded_probs = np.round(probabilities[0], 4) |
|
|
|
|
|
Results( |
|
text=text, |
|
hateful=rounded_probs[0], |
|
insults=rounded_probs[1], |
|
sexual=rounded_probs[2], |
|
violence=rounded_probs[3], |
|
self_harm=rounded_probs[4], |
|
aom=rounded_probs[5], |
|
) |
|
|
|
|
|
df = pd.DataFrame({ |
|
"Label": label_names, |
|
"Probability": rounded_probs, |
|
"Prediction": (rounded_probs > 0.5).astype(int) |
|
}) |
|
|
|
return gr.update(value=df, visible=True) |
|
|
|
with gr.Blocks(title="Zoo Entry 001") as iface: |
|
input_text = gr.Textbox(lines=5, label="Input Text") |
|
submit_btn = gr.Button("Submit") |
|
output_table = gr.DataFrame(label="Classification Results", visible=False) |
|
|
|
submit_btn.click(fn=classify_text, inputs=input_text, outputs=output_table) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|