zainulabedin949's picture
Update app.py
309ef52 verified
raw
history blame
3.44 kB
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from io import StringIO
from momentfm import MOMENTPipeline
from datetime import datetime
# Initialize model correctly
model = MOMENTPipeline.from_pretrained(
"AutonLab/MOMENT-1-large",
model_kwargs={"task_name": "anomaly_detection"}, # Changed task name
)
model.init()
def detect_anomalies(data_input, threshold=3.0): # Changed default threshold
try:
# Process input data
df = pd.read_csv(StringIO(data_input))
# Validate data
if 'timestamp' not in df.columns or 'value' not in df.columns:
return None, {"error": "CSV must contain 'timestamp' and 'value' columns"}, None
df['timestamp'] = pd.to_datetime(df['timestamp'])
df = df.sort_values('timestamp')
# Prepare input for MOMENT (must be 3D array: [samples, timesteps, features])
values = df['value'].values.astype(np.float32)
values_3d = values.reshape(1, -1, 1) # Reshape to 3D
# Correct reconstruction call
reconstruction = model.reconstruct(X=values_3d) # Using named parameter
# Calculate errors (flatten back to 1D)
errors = np.abs(values - reconstruction[0,:,0])
# Dynamic threshold (using z-score)
threshold_value = np.mean(errors) + threshold * np.std(errors)
df['anomaly_score'] = errors
df['is_anomaly'] = errors > threshold_value
# Create plot
fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(df['timestamp'], df['value'], 'b-', label='Value')
ax.scatter(
df.loc[df['is_anomaly'], 'timestamp'],
df.loc[df['is_anomaly'], 'value'],
color='red', s=100, label='Anomaly'
)
ax.set_title(f'Anomaly Detection (Threshold: {threshold_value:.2f})')
ax.legend()
# Prepare outputs
stats = {
"data_points": len(df),
"anomalies": int(df['is_anomaly'].sum()),
"threshold": float(threshold_value),
"max_score": float(np.max(errors))
}
return fig, stats, df.to_dict('records')
except Exception as e:
return None, {"error": str(e)}, None
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# 🚨 Time-Series Anomaly Detection")
with gr.Row():
with gr.Column():
data_input = gr.Textbox(
label="Paste CSV Data",
value="""timestamp,value
2025-04-01 00:00:00,100
2025-04-01 01:00:00,102
2025-04-01 02:00:00,98
2025-04-01 03:00:00,105
2025-04-01 04:00:00,103
2025-04-01 05:00:00,107
2025-04-01 06:00:00,200
2025-04-01 07:00:00,108
2025-04-01 08:00:00,110
2025-04-01 09:00:00,98
2025-04-01 10:00:00,99
2025-04-01 11:00:00,102
2025-04-01 12:00:00,101""",
lines=15
)
threshold = gr.Slider(1.0, 5.0, value=3.0, step=0.1, label="Z-Score Threshold")
submit_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
plot_output = gr.Plot()
stats_output = gr.JSON()
data_output = gr.JSON()
submit_btn.click(
detect_anomalies,
inputs=[data_input, threshold],
outputs=[plot_output, stats_output, data_output]
)
if __name__ == "__main__":
demo.launch()