File size: 3,444 Bytes
6f850cd 6646cc2 6f850cd 9458d26 309ef52 6f850cd 309ef52 6f850cd 309ef52 6f850cd 309ef52 9458d26 309ef52 9458d26 309ef52 9458d26 309ef52 9458d26 309ef52 6f850cd 309ef52 6f850cd 309ef52 6646cc2 6f850cd 9458d26 6f850cd 6646cc2 9458d26 6f850cd 9458d26 6f850cd 9458d26 6f850cd 9458d26 309ef52 9458d26 6f850cd 6646cc2 6f850cd 9458d26 6f850cd 309ef52 9458d26 6f850cd 6646cc2 309ef52 6646cc2 9458d26 6f850cd 309ef52 9458d26 6646cc2 6f850cd 309ef52 6f850cd 6646cc2 6f850cd 9458d26 309ef52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from io import StringIO
from momentfm import MOMENTPipeline
from datetime import datetime
# Initialize model correctly
model = MOMENTPipeline.from_pretrained(
"AutonLab/MOMENT-1-large",
model_kwargs={"task_name": "anomaly_detection"}, # Changed task name
)
model.init()
def detect_anomalies(data_input, threshold=3.0): # Changed default threshold
try:
# Process input data
df = pd.read_csv(StringIO(data_input))
# Validate data
if 'timestamp' not in df.columns or 'value' not in df.columns:
return None, {"error": "CSV must contain 'timestamp' and 'value' columns"}, None
df['timestamp'] = pd.to_datetime(df['timestamp'])
df = df.sort_values('timestamp')
# Prepare input for MOMENT (must be 3D array: [samples, timesteps, features])
values = df['value'].values.astype(np.float32)
values_3d = values.reshape(1, -1, 1) # Reshape to 3D
# Correct reconstruction call
reconstruction = model.reconstruct(X=values_3d) # Using named parameter
# Calculate errors (flatten back to 1D)
errors = np.abs(values - reconstruction[0,:,0])
# Dynamic threshold (using z-score)
threshold_value = np.mean(errors) + threshold * np.std(errors)
df['anomaly_score'] = errors
df['is_anomaly'] = errors > threshold_value
# Create plot
fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(df['timestamp'], df['value'], 'b-', label='Value')
ax.scatter(
df.loc[df['is_anomaly'], 'timestamp'],
df.loc[df['is_anomaly'], 'value'],
color='red', s=100, label='Anomaly'
)
ax.set_title(f'Anomaly Detection (Threshold: {threshold_value:.2f})')
ax.legend()
# Prepare outputs
stats = {
"data_points": len(df),
"anomalies": int(df['is_anomaly'].sum()),
"threshold": float(threshold_value),
"max_score": float(np.max(errors))
}
return fig, stats, df.to_dict('records')
except Exception as e:
return None, {"error": str(e)}, None
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# 🚨 Time-Series Anomaly Detection")
with gr.Row():
with gr.Column():
data_input = gr.Textbox(
label="Paste CSV Data",
value="""timestamp,value
2025-04-01 00:00:00,100
2025-04-01 01:00:00,102
2025-04-01 02:00:00,98
2025-04-01 03:00:00,105
2025-04-01 04:00:00,103
2025-04-01 05:00:00,107
2025-04-01 06:00:00,200
2025-04-01 07:00:00,108
2025-04-01 08:00:00,110
2025-04-01 09:00:00,98
2025-04-01 10:00:00,99
2025-04-01 11:00:00,102
2025-04-01 12:00:00,101""",
lines=15
)
threshold = gr.Slider(1.0, 5.0, value=3.0, step=0.1, label="Z-Score Threshold")
submit_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
plot_output = gr.Plot()
stats_output = gr.JSON()
data_output = gr.JSON()
submit_btn.click(
detect_anomalies,
inputs=[data_input, threshold],
outputs=[plot_output, stats_output, data_output]
)
if __name__ == "__main__":
demo.launch()
|