File size: 2,329 Bytes
2499f75
7246298
 
71b879a
7246298
 
 
92f7782
9eb8e79
71b879a
9eb8e79
92f7782
99df3cb
71b879a
99df3cb
71b879a
 
7246298
71b879a
92f7782
 
 
 
 
 
 
 
7246298
 
 
71b879a
7246298
 
71b879a
7246298
 
71b879a
7246298
 
 
71b879a
7246298
 
71b879a
7246298
 
71b879a
7246298
71b879a
 
 
 
92f7782
71b879a
 
 
 
 
 
92f7782
71b879a
 
 
 
 
 
 
 
7246298
 
 
71b879a
 
 
7246298
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import gradio as gr
import joblib
import logfire
import numpy as np
import pandas as pd
from openai import OpenAI
from pydantic import BaseModel

# Configure logging
logfire.configure(token=os.getenv("LOGFIRE_API_KEY"))
logfire.instrument_pydantic()

# Load pre-trained model and label names
model_data = joblib.load("model.joblib")
model = model_data["model"]
label_names = model_data["label_names"]

class Results(BaseModel):
    text: str
    hateful: float
    insults: float
    sexual: float
    violence: float
    self_harm: float
    aom: float

# Initialize OpenAI client
client = OpenAI()

def get_embedding(text: str, embedding_model: str = "text-embedding-3-large") -> np.ndarray:
    """
    Get embedding for the input text from OpenAI.
    Replaces newlines with spaces before calling the API.
    """
    text = text.replace("\n", " ")
    response = client.embeddings.create(input=[text], model=embedding_model)
    embedding = response.data[0].embedding
    return np.array(embedding)

def classify_text(text: str):
    """
    Get the OpenAI embedding for the provided text, classify it using your model,
    and return a DataFrame with the rounded probabilities and binary predictions.
    """
    embedding = get_embedding(text)
    X = embedding.reshape(1, -1)
    probabilities = model.predict(X)
    rounded_probs = np.round(probabilities[0], 4)
    
    # Optionally log the results (this doesn't affect the output)
    Results(
        text=text,
        hateful=rounded_probs[0],
        insults=rounded_probs[1],
        sexual=rounded_probs[2],
        violence=rounded_probs[3],
        self_harm=rounded_probs[4],
        aom=rounded_probs[5],
    )
    
    # Create DataFrame with rounded probabilities and binary predictions
    df = pd.DataFrame({
        "Label": label_names,
        "Probability": rounded_probs,
        "Prediction": (rounded_probs > 0.5).astype(int)
    })
    
    return gr.update(value=df, visible=True)

with gr.Blocks(title="Zoo Entry 001") as iface:
    input_text = gr.Textbox(lines=5, label="Input Text")
    submit_btn = gr.Button("Submit")
    output_table = gr.DataFrame(label="Classification Results", visible=False)
    
    submit_btn.click(fn=classify_text, inputs=input_text, outputs=output_table)

if __name__ == "__main__":
    iface.launch()