zainulabedin949's picture
Update app.py
6646cc2 verified
raw
history blame
3.26 kB
import gradio as gr
import pandas as pd
import numpy as np
from momentfm import MOMENTPipeline
import matplotlib.pyplot as plt
from io import StringIO
# Initialize model
model = MOMENTPipeline.from_pretrained(
"AutonLab/MOMENT-1-large",
model_kwargs={"task_name": "reconstruction"},
)
model.init()
def detect_anomalies(data_input, threshold=0.1):
try:
# Read data
if isinstance(data_input, str):
df = pd.read_csv(StringIO(data_input))
else:
return "Error: Please provide CSV data"
# Validate columns
if 'timestamp' not in df.columns or 'value' not in df.columns:
return "Error: CSV must contain 'timestamp' and 'value' columns", None, None
# Convert timestamp and sort
df['timestamp'] = pd.to_datetime(df['timestamp'])
df = df.sort_values('timestamp')
# Get values as numpy array
values = df['value'].values.astype(float)
# Detect anomalies
reconstruction = model.reconstruct(values)
errors = np.abs(values - reconstruction)
# Apply threshold (using relative error)
threshold_value = threshold * np.max(errors)
df['anomaly_score'] = errors
df['is_anomaly'] = errors > threshold_value
# Create plot
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(df['timestamp'], df['value'], label='Value', color='blue')
ax.scatter(
df.loc[df['is_anomaly'], 'timestamp'],
df.loc[df['is_anomaly'], 'value'],
color='red', label='Anomaly'
)
ax.set_title('Sensor Data with Anomalies')
ax.legend()
# Prepare results
stats = {
"total_points": len(df),
"anomalies_detected": sum(df['is_anomaly']),
"max_anomaly_score": float(np.max(errors)),
"threshold_used": float(threshold_value)
}
return fig, stats, df.to_dict('records')
except Exception as e:
return f"Error: {str(e)}", None, None
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## ๐Ÿ› ๏ธ Equipment Anomaly Detection")
with gr.Row():
with gr.Column():
data_input = gr.Textbox(
label="Paste CSV data (timestamp,value)",
value="""timestamp,value
2025-04-01 00:00:00,100
2025-04-01 01:00:00,102
2025-04-01 02:00:00,98
2025-04-01 03:00:00,105
2025-04-01 04:00:00,103
2025-04-01 05:00:00,107
2025-04-01 06:00:00,200
2025-04-01 07:00:00,108
2025-04-01 08:00:00,110
2025-04-01 09:00:00,98
2025-04-01 10:00:00,99
2025-04-01 11:00:00,102
2025-04-01 12:00:00,101""",
lines=10
)
threshold = gr.Slider(0.01, 0.5, value=0.1, label="Anomaly Threshold")
submit_btn = gr.Button("Detect Anomalies")
with gr.Column():
plot_output = gr.Plot()
stats_output = gr.JSON(label="Statistics")
data_output = gr.JSON(label="Detailed Results")
submit_btn.click(
detect_anomalies,
inputs=[data_input, threshold],
outputs=[plot_output, stats_output, data_output]
)
demo.launch()