Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
""" | |
Robotic Grasp Predictor - Gradio Demo | |
This script loads (or trains, if necessary) a model to predict the robustness of a robotic grasp. | |
It provides a Gradio interface to generate random sensor values, compute a prediction, | |
and visualize the grasp with a simple robotic arm drawing. | |
Dependencies: | |
gradio | |
pandas | |
numpy | |
matplotlib | |
scikit-learn | |
joblib | |
IMPORTANT: | |
Ensure that the file 'shadow_robot_dataset.csv' is included in your repository. | |
If the pickle files (final_model.pkl and feature_names.pkl) do not exist, the app will train the model, | |
save them (using compression to reduce file size), and use them for subsequent inference. | |
""" | |
import warnings | |
warnings.filterwarnings("ignore") | |
import os | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from joblib import dump, load | |
from sklearn.model_selection import train_test_split | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.ensemble import ExtraTreesRegressor | |
from sklearn.pipeline import Pipeline | |
import gradio as gr | |
# Define paths for the pre-trained model, feature names, and dataset | |
MODEL_PATH = "final_model.pkl" | |
FEATURE_NAMES_PATH = "feature_names.pkl" | |
DATASET_PATH = "shadow_robot_dataset.csv" | |
def load_dataset(dataset_path): | |
""" | |
Attempts to load the dataset CSV file. | |
Raises an error if the file does not exist or cannot be read. | |
""" | |
if not os.path.exists(dataset_path): | |
raise FileNotFoundError( | |
f"Dataset file '{dataset_path}' not found. " | |
"Please ensure it is included in the repository with the correct permissions." | |
) | |
try: | |
df = pd.read_csv(dataset_path) | |
except Exception as e: | |
raise IOError( | |
f"An error occurred while reading '{dataset_path}': {e}\n" | |
"Please check file permissions and format." | |
) | |
return df | |
def load_or_train_model(): | |
""" | |
Loads a pre-trained model and corresponding feature names if available. | |
Otherwise, trains the model from the dataset, saves the model and feature names as compressed pickle files, | |
and returns them. | |
""" | |
if os.path.exists(MODEL_PATH) and os.path.exists(FEATURE_NAMES_PATH): | |
final_model = load(MODEL_PATH) | |
feature_names = load(FEATURE_NAMES_PATH) | |
print("Loaded pre-trained model and feature names.") | |
else: | |
print("Pre-trained model not found. Training model; please wait...") | |
# Load and preprocess data | |
df = load_dataset(DATASET_PATH) | |
# Clean up column names and drop identifier columns if present | |
df.columns = df.columns.str.replace(r'\s+', '', regex=True) | |
for col in ['experiment_number', 'measurement_number']: | |
if col in df.columns: | |
df.drop(col, axis=1, inplace=True) | |
# Define target and features | |
y = df['robustness'] | |
X = df.drop('robustness', axis=1) | |
# Split the data (use only the training split for model fitting) | |
X_train, _, y_train, _ = train_test_split(X, y, test_size=0.2, random_state=42) | |
# Create a pipeline without caching to avoid large file size. | |
final_model = Pipeline([ | |
('scaler', StandardScaler()), | |
('model', ExtraTreesRegressor(random_state=42, n_jobs=-1, max_depth=None, n_estimators=100)) | |
]) | |
# Train the model on the training data | |
final_model.fit(X_train, y_train) | |
# Save the trained model and feature names with compression | |
feature_names = sorted(X_train.columns) # Sorted to ensure consistent feature order | |
dump(final_model, MODEL_PATH, compress=3) | |
dump(feature_names, FEATURE_NAMES_PATH, compress=3) | |
print("Model training complete and saved.") | |
return final_model, feature_names | |
# Load or train the model at startup | |
final_model, feature_names = load_or_train_model() | |
############################################ | |
# Utility Functions for the Gradio Demo | |
############################################ | |
def generate_random_sensors(feature_names, lower_bound, upper_bound): | |
"""Generates random sensor values within [lower_bound, upper_bound] for each feature.""" | |
sensor_values = {} | |
for feature in feature_names: | |
sensor_values[feature] = round(np.random.uniform(lower_bound, upper_bound), 4) | |
return sensor_values | |
def draw_robotic_arm(stability): | |
""" | |
Draws a simple two-link robotic arm with three fingers and a ball. | |
The drawing changes based on the 'stability' value. | |
""" | |
fig, ax = plt.subplots(figsize=(6, 6)) | |
# Draw Link 1 (Blue) | |
shoulder = (0, 0) | |
elbow = (0.5, 0.5) | |
ax.plot([shoulder[0], elbow[0]], [shoulder[1], elbow[1]], color='blue', linewidth=3, label='Link 1') | |
ax.plot(shoulder[0], shoulder[1], 'o', color='blue', markersize=8) | |
ax.plot(elbow[0], elbow[1], 'o', color='blue', markersize=8) | |
# Draw Link 2 (Green) | |
if stability == "Stable": | |
end_effector = (0.8, 1.2) | |
else: | |
end_effector = (0.7, 0.8) | |
ax.plot([elbow[0], end_effector[0]], [elbow[1], end_effector[1]], color='green', linewidth=3, label='Link 2') | |
ax.plot(end_effector[0], end_effector[1], 'o', color='green', markersize=8) | |
# Draw Ball (Red Circle) | |
if stability == "Stable": | |
ball_center = (end_effector[0], end_effector[1] + 0.4) | |
else: | |
ball_center = (end_effector[0] + 0.5, end_effector[1]) | |
ball = plt.Circle(ball_center, 0.1, color='red', alpha=0.6) | |
ax.add_artist(ball) | |
ax.plot(ball_center[0], ball_center[1], 'o', color='red') | |
# Draw Fingers (Black) | |
finger_base = end_effector | |
if stability == "Stable": | |
offsets = [(-0.05, 0.3), (0.0, 0.3), (0.05, 0.3)] | |
else: | |
offsets = [(-0.2, 0.2), (0.0, 0.3), (0.2, 0.2)] | |
for dx, dy in offsets: | |
tip_x = end_effector[0] + dx | |
tip_y = end_effector[1] + dy | |
ax.plot([finger_base[0], tip_x], [finger_base[1], tip_y], color='black', linewidth=3) | |
ax.set_xlim(-0.5, 2.0) | |
ax.set_ylim(-0.2, 2.5) | |
ax.set_aspect('equal') | |
ax.grid(True) | |
# Create a custom legend | |
from matplotlib.lines import Line2D | |
from matplotlib.patches import Patch | |
legend_elements = [ | |
Line2D([0], [0], color='blue', lw=3, label='Link 1'), | |
Line2D([0], [0], color='green', lw=3, label='Link 2'), | |
Patch(facecolor='red', alpha=0.6, label='Ball'), | |
Line2D([0], [0], color='black', lw=3, label='Fingers') | |
] | |
ax.legend(handles=legend_elements, loc='upper left') | |
ax.set_title("2-Link Robot Arm with 3-Finger Grasp") | |
return fig | |
def predict_robustness_range(lower_bound, upper_bound): | |
""" | |
1) Generates random sensor values within [lower_bound, upper_bound]. | |
2) Predicts the grasp's robustness score using the pre-trained model. | |
3) Categorizes the grasp as 'Stable' or 'Unstable' (using a threshold of 100). | |
4) Returns the sensor values, a summary, and a visualization of the robotic arm. | |
""" | |
if lower_bound >= upper_bound: | |
return {}, "Error: Lower bound must be less than upper bound.", None | |
# Generate sensor values | |
sensor_values = generate_random_sensors(feature_names, lower_bound, upper_bound) | |
# Prepare model input in the same order as stored feature names | |
X_new = np.array([sensor_values[f] for f in feature_names]).reshape(1, -1) | |
pred_score = round(final_model.predict(X_new)[0], 2) | |
threshold = 100 | |
stability = "Stable" if pred_score >= threshold else "Unstable" | |
result_text = f"Predicted Robustness Score: {pred_score}\nStability Category: {stability}" | |
# Create visualization based on prediction | |
fig = draw_robotic_arm(stability) | |
return sensor_values, result_text, fig | |
############################################ | |
# Gradio Interface | |
############################################ | |
with gr.Blocks() as demo: | |
gr.Markdown("## Robotic Grasp Predictor") | |
gr.Markdown( | |
"Enter lower and upper bounds for random sensor values.\n" | |
"The system will predict the grasp's robustness and categorize it as **Stable** " | |
"(if the predicted score >= 100) or **Unstable** (if the predicted score < 100)." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
lower_bound = gr.Number(label="Sensor Values Lower Bound", value=0.0) | |
upper_bound = gr.Number(label="Sensor Values Upper Bound", value=1.0) | |
button = gr.Button("Generate & Predict") | |
sensor_json = gr.JSON(label="Generated Sensor Values") | |
with gr.Column(): | |
prediction_box = gr.Textbox(label="Prediction Result") | |
plot_output = gr.Plot(label="Arm Visualization") | |
button.click( | |
fn=predict_robustness_range, | |
inputs=[lower_bound, upper_bound], | |
outputs=[sensor_json, prediction_box, plot_output] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |