zainulabedin949's picture
Update app.py
6f850cd verified
raw
history blame
4.68 kB
import gradio as gr
import numpy as np
import pandas as pd
from momentfm import MOMENTPipeline
import matplotlib.pyplot as plt
from io import StringIO
# Initialize the MOMENT model
model = MOMENTPipeline.from_pretrained(
"AutonLab/MOMENT-1-large",
model_kwargs={"task_name": "reconstruction"},
)
model.init()
def detect_anomalies(data_input, threshold=0.05):
"""
Process time-series data and detect anomalies using MOMENT model
"""
try:
# Handle different input types
if isinstance(data_input, str):
# Try to read as CSV
try:
df = pd.read_csv(StringIO(data_input))
except:
# Try to read as JSON
try:
df = pd.read_json(StringIO(data_input))
except:
return "Error: Could not parse input data. Please provide valid CSV or JSON."
elif isinstance(data_input, dict):
df = pd.DataFrame(data_input)
else:
return "Error: Unsupported input format"
# Check for required columns
if 'timestamp' not in df.columns or 'value' not in df.columns:
return "Error: Data must contain 'timestamp' and 'value' columns"
# Convert timestamp to datetime if needed
df['timestamp'] = pd.to_datetime(df['timestamp'])
df = df.sort_values('timestamp')
# Prepare data for MOMENT model
time_series = df['value'].values.astype(float)
# Get reconstruction from the model
reconstruction = model.reconstruct(time_series)
# Calculate reconstruction error
error = np.abs(time_series - reconstruction)
# Detect anomalies based on threshold
df['anomaly_score'] = error
df['is_anomaly'] = error > threshold * np.max(error)
# Create plot
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(df['timestamp'], df['value'], label='Original', color='blue')
ax.scatter(
df[df['is_anomaly']]['timestamp'],
df[df['is_anomaly']]['value'],
color='red',
label='Anomaly'
)
ax.set_title('Time Series with Anomalies Detected')
ax.set_xlabel('Timestamp')
ax.set_ylabel('Value')
ax.legend()
ax.grid(True)
# Prepare results
anomalies = df[df['is_anomaly']]
stats = {
"total_points": len(df),
"anomalies_detected": len(anomalies),
"anomaly_percentage": f"{100 * len(anomalies)/len(df):.2f}%",
"max_anomaly_score": np.max(error),
"threshold_used": threshold
}
return fig, stats, df.to_dict(orient='records')
except Exception as e:
return f"Error processing data: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Equipment Anomaly Detection") as demo:
gr.Markdown("# ๐Ÿ› ๏ธ Equipment Sensor Anomaly Detection")
gr.Markdown("""
**Detect anomalies in equipment sensor data using the MOMENT-1-large model**
- Upload CSV/JSON data with 'timestamp' and 'value' columns
- Adjust the sensitivity threshold as needed
- Get visual and statistical results
""")
with gr.Row():
with gr.Column():
input_data = gr.Textbox(
label="Paste your time-series data (CSV/JSON)",
placeholder="timestamp,value\n2023-01-01,1.2\n2023-01-02,1.5...",
lines=5
)
file_upload = gr.File(label="Or upload a file")
threshold = gr.Slider(
minimum=0.01,
maximum=0.2,
value=0.05,
step=0.01,
label="Anomaly Detection Sensitivity (lower = more sensitive)"
)
submit_btn = gr.Button("Detect Anomalies", variant="primary")
with gr.Column():
plot_output = gr.Plot(label="Anomaly Detection Results")
stats_output = gr.JSON(label="Detection Statistics")
data_output = gr.JSON(label="Processed Data with Anomaly Scores")
# Handle file upload
def process_file(file):
if file:
with open(file.name, 'r') as f:
return f.read()
return ""
file_upload.change(process_file, inputs=file_upload, outputs=input_data)
submit_btn.click(
detect_anomalies,
inputs=[input_data, threshold],
outputs=[plot_output, stats_output, data_output]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)