Spaces:
Sleeping
Sleeping
initial commit
Browse files
app.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Robotic Grasp Predictor - Gradio Demo
|
4 |
+
This script loads (or trains, if necessary) a model to predict the robustness of a robotic grasp.
|
5 |
+
It provides a Gradio interface to generate random sensor values, compute a prediction,
|
6 |
+
and visualize the grasp with a simple robotic arm drawing.
|
7 |
+
|
8 |
+
Dependencies:
|
9 |
+
gradio
|
10 |
+
pandas
|
11 |
+
numpy
|
12 |
+
matplotlib
|
13 |
+
scikit-learn
|
14 |
+
joblib
|
15 |
+
|
16 |
+
IMPORTANT:
|
17 |
+
Ensure that the file 'shadow_robot_dataset.csv' is included in your repository.
|
18 |
+
If the pickle files (final_model.pkl and feature_names.pkl) do not exist, the app will train the model,
|
19 |
+
save them (using compression to reduce file size), and use them for subsequent inference.
|
20 |
+
"""
|
21 |
+
|
22 |
+
import warnings
|
23 |
+
warnings.filterwarnings("ignore")
|
24 |
+
import os
|
25 |
+
import numpy as np
|
26 |
+
import pandas as pd
|
27 |
+
import matplotlib.pyplot as plt
|
28 |
+
from joblib import dump, load
|
29 |
+
from sklearn.model_selection import train_test_split
|
30 |
+
from sklearn.preprocessing import StandardScaler
|
31 |
+
from sklearn.ensemble import ExtraTreesRegressor
|
32 |
+
from sklearn.pipeline import Pipeline
|
33 |
+
import gradio as gr
|
34 |
+
|
35 |
+
# Define paths for the pre-trained model, feature names, and dataset
|
36 |
+
MODEL_PATH = "final_model.pkl"
|
37 |
+
FEATURE_NAMES_PATH = "feature_names.pkl"
|
38 |
+
DATASET_PATH = "shadow_robot_dataset.csv"
|
39 |
+
|
40 |
+
def load_dataset(dataset_path):
|
41 |
+
"""
|
42 |
+
Attempts to load the dataset CSV file.
|
43 |
+
Raises an error if the file does not exist or cannot be read.
|
44 |
+
"""
|
45 |
+
if not os.path.exists(dataset_path):
|
46 |
+
raise FileNotFoundError(
|
47 |
+
f"Dataset file '{dataset_path}' not found. "
|
48 |
+
"Please ensure it is included in the repository with the correct permissions."
|
49 |
+
)
|
50 |
+
try:
|
51 |
+
df = pd.read_csv(dataset_path)
|
52 |
+
except Exception as e:
|
53 |
+
raise IOError(
|
54 |
+
f"An error occurred while reading '{dataset_path}': {e}\n"
|
55 |
+
"Please check file permissions and format."
|
56 |
+
)
|
57 |
+
return df
|
58 |
+
|
59 |
+
def load_or_train_model():
|
60 |
+
"""
|
61 |
+
Loads a pre-trained model and corresponding feature names if available.
|
62 |
+
Otherwise, trains the model from the dataset, saves the model and feature names as compressed pickle files,
|
63 |
+
and returns them.
|
64 |
+
"""
|
65 |
+
if os.path.exists(MODEL_PATH) and os.path.exists(FEATURE_NAMES_PATH):
|
66 |
+
final_model = load(MODEL_PATH)
|
67 |
+
feature_names = load(FEATURE_NAMES_PATH)
|
68 |
+
print("Loaded pre-trained model and feature names.")
|
69 |
+
else:
|
70 |
+
print("Pre-trained model not found. Training model; please wait...")
|
71 |
+
# Load and preprocess data
|
72 |
+
df = load_dataset(DATASET_PATH)
|
73 |
+
# Clean up column names and drop identifier columns if present
|
74 |
+
df.columns = df.columns.str.replace(r'\s+', '', regex=True)
|
75 |
+
for col in ['experiment_number', 'measurement_number']:
|
76 |
+
if col in df.columns:
|
77 |
+
df.drop(col, axis=1, inplace=True)
|
78 |
+
|
79 |
+
# Define target and features
|
80 |
+
y = df['robustness']
|
81 |
+
X = df.drop('robustness', axis=1)
|
82 |
+
|
83 |
+
# Split the data (use only the training split for model fitting)
|
84 |
+
X_train, _, y_train, _ = train_test_split(X, y, test_size=0.2, random_state=42)
|
85 |
+
|
86 |
+
# Create a pipeline without caching to avoid large file size.
|
87 |
+
final_model = Pipeline([
|
88 |
+
('scaler', StandardScaler()),
|
89 |
+
('model', ExtraTreesRegressor(random_state=42, n_jobs=-1, max_depth=None, n_estimators=100))
|
90 |
+
])
|
91 |
+
|
92 |
+
# Train the model on the training data
|
93 |
+
final_model.fit(X_train, y_train)
|
94 |
+
|
95 |
+
# Save the trained model and feature names with compression
|
96 |
+
feature_names = sorted(X_train.columns) # Sorted to ensure consistent feature order
|
97 |
+
dump(final_model, MODEL_PATH, compress=3)
|
98 |
+
dump(feature_names, FEATURE_NAMES_PATH, compress=3)
|
99 |
+
print("Model training complete and saved.")
|
100 |
+
return final_model, feature_names
|
101 |
+
|
102 |
+
# Load or train the model at startup
|
103 |
+
final_model, feature_names = load_or_train_model()
|
104 |
+
|
105 |
+
############################################
|
106 |
+
# Utility Functions for the Gradio Demo
|
107 |
+
############################################
|
108 |
+
|
109 |
+
def generate_random_sensors(feature_names, lower_bound, upper_bound):
|
110 |
+
"""Generates random sensor values within [lower_bound, upper_bound] for each feature."""
|
111 |
+
sensor_values = {}
|
112 |
+
for feature in feature_names:
|
113 |
+
sensor_values[feature] = round(np.random.uniform(lower_bound, upper_bound), 4)
|
114 |
+
return sensor_values
|
115 |
+
|
116 |
+
def draw_robotic_arm(stability):
|
117 |
+
"""
|
118 |
+
Draws a simple two-link robotic arm with three fingers and a ball.
|
119 |
+
The drawing changes based on the 'stability' value.
|
120 |
+
"""
|
121 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
122 |
+
|
123 |
+
# Draw Link 1 (Blue)
|
124 |
+
shoulder = (0, 0)
|
125 |
+
elbow = (0.5, 0.5)
|
126 |
+
ax.plot([shoulder[0], elbow[0]], [shoulder[1], elbow[1]], color='blue', linewidth=3, label='Link 1')
|
127 |
+
ax.plot(shoulder[0], shoulder[1], 'o', color='blue', markersize=8)
|
128 |
+
ax.plot(elbow[0], elbow[1], 'o', color='blue', markersize=8)
|
129 |
+
|
130 |
+
# Draw Link 2 (Green)
|
131 |
+
if stability == "Stable":
|
132 |
+
end_effector = (0.8, 1.2)
|
133 |
+
else:
|
134 |
+
end_effector = (0.7, 0.8)
|
135 |
+
ax.plot([elbow[0], end_effector[0]], [elbow[1], end_effector[1]], color='green', linewidth=3, label='Link 2')
|
136 |
+
ax.plot(end_effector[0], end_effector[1], 'o', color='green', markersize=8)
|
137 |
+
|
138 |
+
# Draw Ball (Red Circle)
|
139 |
+
if stability == "Stable":
|
140 |
+
ball_center = (end_effector[0], end_effector[1] + 0.4)
|
141 |
+
else:
|
142 |
+
ball_center = (end_effector[0] + 0.5, end_effector[1])
|
143 |
+
ball = plt.Circle(ball_center, 0.1, color='red', alpha=0.6)
|
144 |
+
ax.add_artist(ball)
|
145 |
+
ax.plot(ball_center[0], ball_center[1], 'o', color='red')
|
146 |
+
|
147 |
+
# Draw Fingers (Black)
|
148 |
+
finger_base = end_effector
|
149 |
+
if stability == "Stable":
|
150 |
+
offsets = [(-0.05, 0.3), (0.0, 0.3), (0.05, 0.3)]
|
151 |
+
else:
|
152 |
+
offsets = [(-0.2, 0.2), (0.0, 0.3), (0.2, 0.2)]
|
153 |
+
for dx, dy in offsets:
|
154 |
+
tip_x = end_effector[0] + dx
|
155 |
+
tip_y = end_effector[1] + dy
|
156 |
+
ax.plot([finger_base[0], tip_x], [finger_base[1], tip_y], color='black', linewidth=3)
|
157 |
+
|
158 |
+
ax.set_xlim(-0.5, 2.0)
|
159 |
+
ax.set_ylim(-0.2, 2.5)
|
160 |
+
ax.set_aspect('equal')
|
161 |
+
ax.grid(True)
|
162 |
+
|
163 |
+
# Create a custom legend
|
164 |
+
from matplotlib.lines import Line2D
|
165 |
+
from matplotlib.patches import Patch
|
166 |
+
legend_elements = [
|
167 |
+
Line2D([0], [0], color='blue', lw=3, label='Link 1'),
|
168 |
+
Line2D([0], [0], color='green', lw=3, label='Link 2'),
|
169 |
+
Patch(facecolor='red', alpha=0.6, label='Ball'),
|
170 |
+
Line2D([0], [0], color='black', lw=3, label='Fingers')
|
171 |
+
]
|
172 |
+
ax.legend(handles=legend_elements, loc='upper left')
|
173 |
+
ax.set_title("2-Link Robot Arm with 3-Finger Grasp")
|
174 |
+
return fig
|
175 |
+
|
176 |
+
def predict_robustness_range(lower_bound, upper_bound):
|
177 |
+
"""
|
178 |
+
1) Generates random sensor values within [lower_bound, upper_bound].
|
179 |
+
2) Predicts the grasp's robustness score using the pre-trained model.
|
180 |
+
3) Categorizes the grasp as 'Stable' or 'Unstable' (using a threshold of 100).
|
181 |
+
4) Returns the sensor values, a summary, and a visualization of the robotic arm.
|
182 |
+
"""
|
183 |
+
if lower_bound >= upper_bound:
|
184 |
+
return {}, "Error: Lower bound must be less than upper bound.", None
|
185 |
+
|
186 |
+
# Generate sensor values
|
187 |
+
sensor_values = generate_random_sensors(feature_names, lower_bound, upper_bound)
|
188 |
+
|
189 |
+
# Prepare model input in the same order as stored feature names
|
190 |
+
X_new = np.array([sensor_values[f] for f in feature_names]).reshape(1, -1)
|
191 |
+
pred_score = round(final_model.predict(X_new)[0], 2)
|
192 |
+
|
193 |
+
threshold = 100
|
194 |
+
stability = "Stable" if pred_score >= threshold else "Unstable"
|
195 |
+
result_text = f"Predicted Robustness Score: {pred_score}\nStability Category: {stability}"
|
196 |
+
|
197 |
+
# Create visualization based on prediction
|
198 |
+
fig = draw_robotic_arm(stability)
|
199 |
+
|
200 |
+
return sensor_values, result_text, fig
|
201 |
+
|
202 |
+
############################################
|
203 |
+
# Gradio Interface
|
204 |
+
############################################
|
205 |
+
|
206 |
+
with gr.Blocks() as demo:
|
207 |
+
gr.Markdown("## Robotic Grasp Predictor")
|
208 |
+
gr.Markdown(
|
209 |
+
"Enter lower and upper bounds for random sensor values.\n"
|
210 |
+
"The system will predict the grasp's robustness and categorize it as **Stable** "
|
211 |
+
"(if the predicted score >= 100) or **Unstable** (if the predicted score < 100)."
|
212 |
+
)
|
213 |
+
|
214 |
+
with gr.Row():
|
215 |
+
with gr.Column():
|
216 |
+
lower_bound = gr.Number(label="Sensor Values Lower Bound", value=0.0)
|
217 |
+
upper_bound = gr.Number(label="Sensor Values Upper Bound", value=1.0)
|
218 |
+
button = gr.Button("Generate & Predict")
|
219 |
+
sensor_json = gr.JSON(label="Generated Sensor Values")
|
220 |
+
with gr.Column():
|
221 |
+
prediction_box = gr.Textbox(label="Prediction Result")
|
222 |
+
plot_output = gr.Plot(label="Arm Visualization")
|
223 |
+
|
224 |
+
button.click(
|
225 |
+
fn=predict_robustness_range,
|
226 |
+
inputs=[lower_bound, upper_bound],
|
227 |
+
outputs=[sensor_json, prediction_box, plot_output]
|
228 |
+
)
|
229 |
+
|
230 |
+
if __name__ == "__main__":
|
231 |
+
demo.launch(share=True)
|