import os import gradio as gr import joblib import numpy as np import pandas as pd from openai import OpenAI from huggingface_hub import login from huggingface_hub import hf_hub_download import logfire logfire.configure(token=os.getenv("LOGFIRE_API_KEY")) # Load your pre-trained model and label names model_path = hf_hub_download(repo_id="govtech/zoo-entry-001", filename="model.joblib", use_auth_token=True) model_data = joblib.load(model_path) model = model_data['model'] label_names = model_data['label_names'] # Initialize OpenAI client client = OpenAI() def get_embedding(text, embedding_model="text-embedding-3-large"): """ Get embedding for the input text from OpenAI. Replace newlines in the text, then call the API. """ text = text.replace("\n", " ") response = client.embeddings.create( input=[text], model=embedding_model ) # Extract embedding vector from response embedding = response.data[0].embedding return np.array(embedding) def classify_text(text): """ Get the OpenAI embedding for the provided text, classify it using your model, and return an updated DataFrame component with the predictions and probabilities. """ embedding = get_embedding(text) # Add batch dimension X = np.array(embedding)[None, :] # Get probabilities from the model probabilities = model.predict(X) # Create a DataFrame with probabilities, labels, and binary predictions df = pd.DataFrame({ 'Label': label_names, 'Probability': probabilities[0], 'Prediction': (probabilities[0] > 0.5).astype(int) }) # Return an update to the DataFrame component to make it visible with the results logfire.info(f"{text} ({probabilities[0]})") return gr.update(value=df, visible=True) with gr.Blocks(title="Zoo Entry 001") as iface: with gr.Row(): input_text = gr.Textbox(lines=5, label="Input Text") with gr.Row(): submit_btn = gr.Button("Submit") # Initialize the table as hidden with gr.Row(): output_table = gr.DataFrame(label="Classification Results", visible=False) submit_btn.click(fn=classify_text, inputs=input_text, outputs=output_table) if __name__ == "__main__": iface.launch()