nabil-tazi's picture
Update app.py
6fa680b verified
import gradio as gr
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model variable
model = None
def load_model():
"""Load the sentence transformer model"""
global model
if model is None:
try:
logger.info("Loading sentence transformer model...")
model = SentenceTransformer('nabil-tazi/autotrain-d19rl-a8u4f')
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise e
return model
def classify_ambiance(user_input):
"""Classify lighting ambiance from user input"""
if not user_input or not user_input.strip():
return "❌ Please enter some text", {}, ""
try:
# Load model if not already loaded
current_model = load_model()
# Your three reference ambiances
references = ["bright", "cozy", "natural"]
# Get embeddings
user_embedding = current_model.encode([user_input.strip()])
ref_embeddings = current_model.encode(references)
# Calculate similarities
similarities = cos_sim(user_embedding, ref_embeddings)[0]
# Get best match
best_idx = similarities.argmax()
best_ambiance = references[best_idx]
confidence = float(similarities[best_idx])
# Format all scores for debugging
all_scores = {ref: round(float(sim), 4) for ref, sim in zip(references, similarities)}
# Create result with emoji
emoji_map = {"bright": "β˜€οΈ", "cozy": "πŸ•―οΈ", "natural": "🌿"}
result_text = f"## {emoji_map.get(best_ambiance, 'πŸ’‘')} **{best_ambiance.upper()}**\n**Confidence:** {confidence:.3f}"
# Create confidence bar
confidence_bar = f"**Confidence Level:** {'β–ˆ' * int(confidence * 20)}{'β–‘' * (20 - int(confidence * 20))} {confidence:.1%}"
logger.info(f"Classified '{user_input}' as '{best_ambiance}' with confidence {confidence:.3f}")
return result_text, all_scores, confidence_bar
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
logger.error(f"Classification error: {e}")
return error_msg, {}, ""
# Create Gradio interface
with gr.Blocks(
title="🏠 Lighting Ambiance Classifier",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 800px !important;
margin: auto !important;
}
.result-box {
background: linear-gradient(45deg, #f0f0f0, #ffffff);
border-radius: 10px;
padding: 20px;
}
"""
) as demo:
# Header
gr.Markdown(
"""
# 🏠 Lighting Ambiance Classifier
**Classify your lighting preferences into three categories:**
- β˜€οΈ **Bright**: Well-lit, luminous, clear lighting
- πŸ•―οΈ **Cozy**: Warm, dim, soft, ambient lighting
- 🌿 **Natural**: Daylight, sunlight, organic lighting
**Supports both English and Japanese!** πŸ‡ΊπŸ‡ΈπŸ‡―πŸ‡΅
"""
)
with gr.Row():
with gr.Column(scale=2):
# Input section
gr.Markdown("### πŸ’¬ Enter your lighting preference:")
input_text = gr.Textbox(
label="Your lighting preference",
placeholder="e.g., 'not bright', 'ζ˜Žγ‚‹γγͺい', 'cozy lighting', 'θ‡ͺη„Άγͺε…‰γŒζ¬²γ—γ„'",
lines=3,
max_lines=5
)
with gr.Row():
submit_btn = gr.Button("πŸ” Classify", variant="primary", size="lg")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
with gr.Column(scale=2):
# Output section
gr.Markdown("### 🎯 Classification Result:")
result = gr.Markdown(value="Enter text and click classify!", elem_classes=["result-box"])
confidence_bar = gr.Markdown(value="")
# Detailed scores
with gr.Row():
scores = gr.JSON(label="πŸ“Š Detailed Similarity Scores", visible=True)
# Example inputs
gr.Markdown("### πŸ’‘ Try these examples:")
with gr.Row():
examples = gr.Examples(
examples=[
["not bright"],
["ζ˜Žγ‚‹γγͺい"],
["I want cozy lighting"],
["θ‡ͺη„Άγͺε…‰γŒζ¬²γ—γ„"],
["make it brighter"],
["ζš—γγ—γŸγ„"],
["romantic atmosphere"],
["作ζ₯­γ—γ‚„γ™γ„ζ˜Žγ‚‹γ•"],
["candle light"],
["ε€ͺι™½ε…‰γΏγŸγ„"],
["harsh fluorescent"],
["ε„ͺγ—γ„η…§ζ˜Ž"]
],
inputs=input_text,
examples_per_page=6
)
# Footer
gr.Markdown(
"""
---
**Model:** Fine-tuned multilingual sentence transformer trained on English-Japanese lighting preference pairs.
**How it works:** The model compares your input text with the three ambiance categories and returns the most similar one with a confidence score.
"""
)
# Event handlers
def clear_all():
return "", "Enter text and click classify!", {}, ""
submit_btn.click(
fn=classify_ambiance,
inputs=input_text,
outputs=[result, scores, confidence_bar]
)
input_text.submit(
fn=classify_ambiance,
inputs=input_text,
outputs=[result, scores, confidence_bar]
)
clear_btn.click(
fn=clear_all,
outputs=[input_text, result, scores, confidence_bar]
)
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)