VeekshanArroju commited on
Commit
a5d1a00
·
verified ·
1 Parent(s): c1f1440

initial commit

Browse files
Files changed (1) hide show
  1. app.py +231 -0
app.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Robotic Grasp Predictor - Gradio Demo
4
+ This script loads (or trains, if necessary) a model to predict the robustness of a robotic grasp.
5
+ It provides a Gradio interface to generate random sensor values, compute a prediction,
6
+ and visualize the grasp with a simple robotic arm drawing.
7
+
8
+ Dependencies:
9
+ gradio
10
+ pandas
11
+ numpy
12
+ matplotlib
13
+ scikit-learn
14
+ joblib
15
+
16
+ IMPORTANT:
17
+ Ensure that the file 'shadow_robot_dataset.csv' is included in your repository.
18
+ If the pickle files (final_model.pkl and feature_names.pkl) do not exist, the app will train the model,
19
+ save them (using compression to reduce file size), and use them for subsequent inference.
20
+ """
21
+
22
+ import warnings
23
+ warnings.filterwarnings("ignore")
24
+ import os
25
+ import numpy as np
26
+ import pandas as pd
27
+ import matplotlib.pyplot as plt
28
+ from joblib import dump, load
29
+ from sklearn.model_selection import train_test_split
30
+ from sklearn.preprocessing import StandardScaler
31
+ from sklearn.ensemble import ExtraTreesRegressor
32
+ from sklearn.pipeline import Pipeline
33
+ import gradio as gr
34
+
35
+ # Define paths for the pre-trained model, feature names, and dataset
36
+ MODEL_PATH = "final_model.pkl"
37
+ FEATURE_NAMES_PATH = "feature_names.pkl"
38
+ DATASET_PATH = "shadow_robot_dataset.csv"
39
+
40
+ def load_dataset(dataset_path):
41
+ """
42
+ Attempts to load the dataset CSV file.
43
+ Raises an error if the file does not exist or cannot be read.
44
+ """
45
+ if not os.path.exists(dataset_path):
46
+ raise FileNotFoundError(
47
+ f"Dataset file '{dataset_path}' not found. "
48
+ "Please ensure it is included in the repository with the correct permissions."
49
+ )
50
+ try:
51
+ df = pd.read_csv(dataset_path)
52
+ except Exception as e:
53
+ raise IOError(
54
+ f"An error occurred while reading '{dataset_path}': {e}\n"
55
+ "Please check file permissions and format."
56
+ )
57
+ return df
58
+
59
+ def load_or_train_model():
60
+ """
61
+ Loads a pre-trained model and corresponding feature names if available.
62
+ Otherwise, trains the model from the dataset, saves the model and feature names as compressed pickle files,
63
+ and returns them.
64
+ """
65
+ if os.path.exists(MODEL_PATH) and os.path.exists(FEATURE_NAMES_PATH):
66
+ final_model = load(MODEL_PATH)
67
+ feature_names = load(FEATURE_NAMES_PATH)
68
+ print("Loaded pre-trained model and feature names.")
69
+ else:
70
+ print("Pre-trained model not found. Training model; please wait...")
71
+ # Load and preprocess data
72
+ df = load_dataset(DATASET_PATH)
73
+ # Clean up column names and drop identifier columns if present
74
+ df.columns = df.columns.str.replace(r'\s+', '', regex=True)
75
+ for col in ['experiment_number', 'measurement_number']:
76
+ if col in df.columns:
77
+ df.drop(col, axis=1, inplace=True)
78
+
79
+ # Define target and features
80
+ y = df['robustness']
81
+ X = df.drop('robustness', axis=1)
82
+
83
+ # Split the data (use only the training split for model fitting)
84
+ X_train, _, y_train, _ = train_test_split(X, y, test_size=0.2, random_state=42)
85
+
86
+ # Create a pipeline without caching to avoid large file size.
87
+ final_model = Pipeline([
88
+ ('scaler', StandardScaler()),
89
+ ('model', ExtraTreesRegressor(random_state=42, n_jobs=-1, max_depth=None, n_estimators=100))
90
+ ])
91
+
92
+ # Train the model on the training data
93
+ final_model.fit(X_train, y_train)
94
+
95
+ # Save the trained model and feature names with compression
96
+ feature_names = sorted(X_train.columns) # Sorted to ensure consistent feature order
97
+ dump(final_model, MODEL_PATH, compress=3)
98
+ dump(feature_names, FEATURE_NAMES_PATH, compress=3)
99
+ print("Model training complete and saved.")
100
+ return final_model, feature_names
101
+
102
+ # Load or train the model at startup
103
+ final_model, feature_names = load_or_train_model()
104
+
105
+ ############################################
106
+ # Utility Functions for the Gradio Demo
107
+ ############################################
108
+
109
+ def generate_random_sensors(feature_names, lower_bound, upper_bound):
110
+ """Generates random sensor values within [lower_bound, upper_bound] for each feature."""
111
+ sensor_values = {}
112
+ for feature in feature_names:
113
+ sensor_values[feature] = round(np.random.uniform(lower_bound, upper_bound), 4)
114
+ return sensor_values
115
+
116
+ def draw_robotic_arm(stability):
117
+ """
118
+ Draws a simple two-link robotic arm with three fingers and a ball.
119
+ The drawing changes based on the 'stability' value.
120
+ """
121
+ fig, ax = plt.subplots(figsize=(6, 6))
122
+
123
+ # Draw Link 1 (Blue)
124
+ shoulder = (0, 0)
125
+ elbow = (0.5, 0.5)
126
+ ax.plot([shoulder[0], elbow[0]], [shoulder[1], elbow[1]], color='blue', linewidth=3, label='Link 1')
127
+ ax.plot(shoulder[0], shoulder[1], 'o', color='blue', markersize=8)
128
+ ax.plot(elbow[0], elbow[1], 'o', color='blue', markersize=8)
129
+
130
+ # Draw Link 2 (Green)
131
+ if stability == "Stable":
132
+ end_effector = (0.8, 1.2)
133
+ else:
134
+ end_effector = (0.7, 0.8)
135
+ ax.plot([elbow[0], end_effector[0]], [elbow[1], end_effector[1]], color='green', linewidth=3, label='Link 2')
136
+ ax.plot(end_effector[0], end_effector[1], 'o', color='green', markersize=8)
137
+
138
+ # Draw Ball (Red Circle)
139
+ if stability == "Stable":
140
+ ball_center = (end_effector[0], end_effector[1] + 0.4)
141
+ else:
142
+ ball_center = (end_effector[0] + 0.5, end_effector[1])
143
+ ball = plt.Circle(ball_center, 0.1, color='red', alpha=0.6)
144
+ ax.add_artist(ball)
145
+ ax.plot(ball_center[0], ball_center[1], 'o', color='red')
146
+
147
+ # Draw Fingers (Black)
148
+ finger_base = end_effector
149
+ if stability == "Stable":
150
+ offsets = [(-0.05, 0.3), (0.0, 0.3), (0.05, 0.3)]
151
+ else:
152
+ offsets = [(-0.2, 0.2), (0.0, 0.3), (0.2, 0.2)]
153
+ for dx, dy in offsets:
154
+ tip_x = end_effector[0] + dx
155
+ tip_y = end_effector[1] + dy
156
+ ax.plot([finger_base[0], tip_x], [finger_base[1], tip_y], color='black', linewidth=3)
157
+
158
+ ax.set_xlim(-0.5, 2.0)
159
+ ax.set_ylim(-0.2, 2.5)
160
+ ax.set_aspect('equal')
161
+ ax.grid(True)
162
+
163
+ # Create a custom legend
164
+ from matplotlib.lines import Line2D
165
+ from matplotlib.patches import Patch
166
+ legend_elements = [
167
+ Line2D([0], [0], color='blue', lw=3, label='Link 1'),
168
+ Line2D([0], [0], color='green', lw=3, label='Link 2'),
169
+ Patch(facecolor='red', alpha=0.6, label='Ball'),
170
+ Line2D([0], [0], color='black', lw=3, label='Fingers')
171
+ ]
172
+ ax.legend(handles=legend_elements, loc='upper left')
173
+ ax.set_title("2-Link Robot Arm with 3-Finger Grasp")
174
+ return fig
175
+
176
+ def predict_robustness_range(lower_bound, upper_bound):
177
+ """
178
+ 1) Generates random sensor values within [lower_bound, upper_bound].
179
+ 2) Predicts the grasp's robustness score using the pre-trained model.
180
+ 3) Categorizes the grasp as 'Stable' or 'Unstable' (using a threshold of 100).
181
+ 4) Returns the sensor values, a summary, and a visualization of the robotic arm.
182
+ """
183
+ if lower_bound >= upper_bound:
184
+ return {}, "Error: Lower bound must be less than upper bound.", None
185
+
186
+ # Generate sensor values
187
+ sensor_values = generate_random_sensors(feature_names, lower_bound, upper_bound)
188
+
189
+ # Prepare model input in the same order as stored feature names
190
+ X_new = np.array([sensor_values[f] for f in feature_names]).reshape(1, -1)
191
+ pred_score = round(final_model.predict(X_new)[0], 2)
192
+
193
+ threshold = 100
194
+ stability = "Stable" if pred_score >= threshold else "Unstable"
195
+ result_text = f"Predicted Robustness Score: {pred_score}\nStability Category: {stability}"
196
+
197
+ # Create visualization based on prediction
198
+ fig = draw_robotic_arm(stability)
199
+
200
+ return sensor_values, result_text, fig
201
+
202
+ ############################################
203
+ # Gradio Interface
204
+ ############################################
205
+
206
+ with gr.Blocks() as demo:
207
+ gr.Markdown("## Robotic Grasp Predictor")
208
+ gr.Markdown(
209
+ "Enter lower and upper bounds for random sensor values.\n"
210
+ "The system will predict the grasp's robustness and categorize it as **Stable** "
211
+ "(if the predicted score >= 100) or **Unstable** (if the predicted score < 100)."
212
+ )
213
+
214
+ with gr.Row():
215
+ with gr.Column():
216
+ lower_bound = gr.Number(label="Sensor Values Lower Bound", value=0.0)
217
+ upper_bound = gr.Number(label="Sensor Values Upper Bound", value=1.0)
218
+ button = gr.Button("Generate & Predict")
219
+ sensor_json = gr.JSON(label="Generated Sensor Values")
220
+ with gr.Column():
221
+ prediction_box = gr.Textbox(label="Prediction Result")
222
+ plot_output = gr.Plot(label="Arm Visualization")
223
+
224
+ button.click(
225
+ fn=predict_robustness_range,
226
+ inputs=[lower_bound, upper_bound],
227
+ outputs=[sensor_json, prediction_box, plot_output]
228
+ )
229
+
230
+ if __name__ == "__main__":
231
+ demo.launch(share=True)