File size: 2,281 Bytes
2499f75 7246298 9eb8e79 7246298 9eb8e79 7246298 aa264ad 7246298 9eb8e79 7246298 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import os
import gradio as gr
import joblib
import numpy as np
import pandas as pd
from openai import OpenAI
from huggingface_hub import login
from huggingface_hub import hf_hub_download
import logfire
logfire.configure(token=os.getenv("LOGFIRE_API_KEY"))
# Load your pre-trained model and label names
model_path = hf_hub_download(repo_id="govtech/zoo-entry-001", filename="model.joblib", use_auth_token=True)
model_data = joblib.load(model_path)
model = model_data['model']
label_names = model_data['label_names']
# Initialize OpenAI client
client = OpenAI()
def get_embedding(text, embedding_model="text-embedding-3-large"):
"""
Get embedding for the input text from OpenAI.
Replace newlines in the text, then call the API.
"""
text = text.replace("\n", " ")
response = client.embeddings.create(
input=[text],
model=embedding_model
)
# Extract embedding vector from response
embedding = response.data[0].embedding
return np.array(embedding)
def classify_text(text):
"""
Get the OpenAI embedding for the provided text, classify it using your model,
and return an updated DataFrame component with the predictions and probabilities.
"""
embedding = get_embedding(text)
# Add batch dimension
X = np.array(embedding)[None, :]
# Get probabilities from the model
probabilities = model.predict(X)
# Create a DataFrame with probabilities, labels, and binary predictions
df = pd.DataFrame({
'Label': label_names,
'Probability': probabilities[0],
'Prediction': (probabilities[0] > 0.5).astype(int)
})
# Return an update to the DataFrame component to make it visible with the results
logfire.info(f"{text} ({probabilities[0]})")
return gr.update(value=df, visible=True)
with gr.Blocks(title="Zoo Entry 001") as iface:
with gr.Row():
input_text = gr.Textbox(lines=5, label="Input Text")
with gr.Row():
submit_btn = gr.Button("Submit")
# Initialize the table as hidden
with gr.Row():
output_table = gr.DataFrame(label="Classification Results", visible=False)
submit_btn.click(fn=classify_text, inputs=input_text, outputs=output_table)
if __name__ == "__main__":
iface.launch()
|