zainulabedin949 commited on
Commit
bd1a142
·
verified ·
1 Parent(s): 3165f49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -40
app.py CHANGED
@@ -4,41 +4,69 @@ import numpy as np
4
  import matplotlib.pyplot as plt
5
  from io import StringIO
6
  from momentfm import MOMENTPipeline
7
- from datetime import datetime
8
 
9
- # Initialize model correctly
10
- model = MOMENTPipeline.from_pretrained(
11
- "AutonLab/MOMENT-1-large",
12
- model_kwargs={"task_name": "anomaly_detection"}, # Changed task name
13
- )
14
- model.init()
15
 
16
- def detect_anomalies(data_input, threshold=3.0): # Changed default threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  try:
18
- # Process input data
19
- df = pd.read_csv(StringIO(data_input))
 
 
20
 
21
- # Validate data
22
- if 'timestamp' not in df.columns or 'value' not in df.columns:
23
- return None, {"error": "CSV must contain 'timestamp' and 'value' columns"}, None
24
 
25
- df['timestamp'] = pd.to_datetime(df['timestamp'])
26
- df = df.sort_values('timestamp')
 
 
 
 
 
27
 
28
- # Prepare input for MOMENT (must be 3D array: [samples, timesteps, features])
 
 
 
 
 
 
 
 
 
29
  values = df['value'].values.astype(np.float32)
30
- values_3d = values.reshape(1, -1, 1) # Reshape to 3D
31
 
32
- # Correct reconstruction call
33
- reconstruction = model.reconstruct(X=values_3d) # Using named parameter
 
 
 
34
 
35
- # Calculate errors (flatten back to 1D)
36
- errors = np.abs(values - reconstruction[0,:,0])
37
 
38
- # Dynamic threshold (using z-score)
39
- threshold_value = np.mean(errors) + threshold * np.std(errors)
40
  df['anomaly_score'] = errors
41
- df['is_anomaly'] = errors > threshold_value
42
 
43
  # Create plot
44
  fig, ax = plt.subplots(figsize=(12, 5))
@@ -48,30 +76,31 @@ def detect_anomalies(data_input, threshold=3.0): # Changed default threshold
48
  df.loc[df['is_anomaly'], 'value'],
49
  color='red', s=100, label='Anomaly'
50
  )
51
- ax.set_title(f'Anomaly Detection (Threshold: {threshold_value:.2f})')
52
  ax.legend()
53
 
54
  # Prepare outputs
55
  stats = {
56
  "data_points": len(df),
57
- "anomalies": int(df['is_anomaly'].sum()),
58
- "threshold": float(threshold_value),
59
- "max_score": float(np.max(errors))
60
  }
61
 
62
  return fig, stats, df.to_dict('records')
63
-
64
  except Exception as e:
 
65
  return None, {"error": str(e)}, None
66
 
67
- # Gradio interface
68
- with gr.Blocks() as demo:
69
- gr.Markdown("# 🚨 Time-Series Anomaly Detection")
70
 
71
  with gr.Row():
72
  with gr.Column():
73
  data_input = gr.Textbox(
74
- label="Paste CSV Data",
75
  value="""timestamp,value
76
  2025-04-01 00:00:00,100
77
  2025-04-01 01:00:00,102
@@ -88,19 +117,28 @@ with gr.Blocks() as demo:
88
  2025-04-01 12:00:00,101""",
89
  lines=15
90
  )
91
- threshold = gr.Slider(1.0, 5.0, value=3.0, step=0.1, label="Z-Score Threshold")
92
- submit_btn = gr.Button("Analyze", variant="primary")
 
 
 
 
 
 
93
 
94
  with gr.Column():
95
- plot_output = gr.Plot()
96
- stats_output = gr.JSON()
97
- data_output = gr.JSON()
 
 
 
98
 
99
  submit_btn.click(
100
  detect_anomalies,
101
- inputs=[data_input, threshold],
102
  outputs=[plot_output, stats_output, data_output]
103
  )
104
 
105
  if __name__ == "__main__":
106
- demo.launch()
 
4
  import matplotlib.pyplot as plt
5
  from io import StringIO
6
  from momentfm import MOMENTPipeline
7
+ import logging
8
 
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
 
 
 
12
 
13
+ # Initialize model with reconstruction task
14
+ try:
15
+ model = MOMENTPipeline.from_pretrained(
16
+ "AutonLab/MOMENT-1-large",
17
+ model_kwargs={"task_name": "reconstruction"}, # Correct task name
18
+ )
19
+ model.init()
20
+ logger.info("Model loaded successfully")
21
+ except Exception as e:
22
+ logger.error(f"Model loading failed: {str(e)}")
23
+ raise
24
+
25
+ def validate_data(data_input):
26
+ """Validate and process input data"""
27
  try:
28
+ if isinstance(data_input, str):
29
+ df = pd.read_csv(StringIO(data_input))
30
+ else:
31
+ raise ValueError("Input must be CSV text")
32
 
33
+ # Validate columns
34
+ if not all(col in df.columns for col in ['timestamp', 'value']):
35
+ raise ValueError("CSV must contain 'timestamp' and 'value' columns")
36
 
37
+ # Convert timestamps
38
+ df['timestamp'] = pd.to_datetime(df['timestamp'], errors='coerce')
39
+ if df['timestamp'].isnull().any():
40
+ raise ValueError("Invalid timestamp format")
41
+
42
+ # Convert values to numeric
43
+ df['value'] = pd.to_numeric(df['value'], errors='raise')
44
 
45
+ return df.sort_values('timestamp')
46
+
47
+ except Exception as e:
48
+ logger.error(f"Data validation error: {str(e)}")
49
+ raise
50
+
51
+ def detect_anomalies(data_input, sensitivity=3.0):
52
+ """Perform reconstruction-based anomaly detection"""
53
+ try:
54
+ df = validate_data(data_input)
55
  values = df['value'].values.astype(np.float32)
 
56
 
57
+ # Reshape to 3D format (batch, sequence, features)
58
+ values_3d = values.reshape(1, -1, 1)
59
+
60
+ # Get reconstruction
61
+ reconstructed = model.reconstruct(values_3d)
62
 
63
+ # Calculate reconstruction error (MAE)
64
+ errors = np.abs(values - reconstructed[0,:,0])
65
 
66
+ # Dynamic threshold (z-score based)
67
+ threshold = np.mean(errors) + sensitivity * np.std(errors)
68
  df['anomaly_score'] = errors
69
+ df['is_anomaly'] = errors > threshold
70
 
71
  # Create plot
72
  fig, ax = plt.subplots(figsize=(12, 5))
 
76
  df.loc[df['is_anomaly'], 'value'],
77
  color='red', s=100, label='Anomaly'
78
  )
79
+ ax.set_title(f'Anomaly Detection (Threshold: {threshold:.2f})')
80
  ax.legend()
81
 
82
  # Prepare outputs
83
  stats = {
84
  "data_points": len(df),
85
+ "anomalous_points": int(df['is_anomaly'].sum()),
86
+ "detection_threshold": float(threshold),
87
+ "max_error": float(np.max(errors))
88
  }
89
 
90
  return fig, stats, df.to_dict('records')
91
+
92
  except Exception as e:
93
+ logger.error(f"Detection error: {str(e)}")
94
  return None, {"error": str(e)}, None
95
 
96
+ # Gradio Interface
97
+ with gr.Blocks(title="MOMENT Anomaly Detector") as demo:
98
+ gr.Markdown("## 🔍 Equipment Anomaly Detection using MOMENT")
99
 
100
  with gr.Row():
101
  with gr.Column():
102
  data_input = gr.Textbox(
103
+ label="Paste time-series data (CSV format)",
104
  value="""timestamp,value
105
  2025-04-01 00:00:00,100
106
  2025-04-01 01:00:00,102
 
117
  2025-04-01 12:00:00,101""",
118
  lines=15
119
  )
120
+ sensitivity = gr.Slider(
121
+ minimum=1.0,
122
+ maximum=5.0,
123
+ value=3.0,
124
+ step=0.1,
125
+ label="Detection Sensitivity (Z-Score)"
126
+ )
127
+ submit_btn = gr.Button("Analyze Data", variant="primary")
128
 
129
  with gr.Column():
130
+ plot_output = gr.Plot(label="Anomaly Detection Results")
131
+ stats_output = gr.JSON(label="Detection Statistics")
132
+ data_output = gr.JSON(
133
+ label="Processed Data",
134
+ max_lines=15
135
+ )
136
 
137
  submit_btn.click(
138
  detect_anomalies,
139
+ inputs=[data_input, sensitivity],
140
  outputs=[plot_output, stats_output, data_output]
141
  )
142
 
143
  if __name__ == "__main__":
144
+ demo.launch(server_name="0.0.0.0", server_port=7860)