IITJ_Project / app.py
VeekshanArroju's picture
initial commit
a5d1a00 verified
raw
history blame
8.98 kB
# -*- coding: utf-8 -*-
"""
Robotic Grasp Predictor - Gradio Demo
This script loads (or trains, if necessary) a model to predict the robustness of a robotic grasp.
It provides a Gradio interface to generate random sensor values, compute a prediction,
and visualize the grasp with a simple robotic arm drawing.
Dependencies:
gradio
pandas
numpy
matplotlib
scikit-learn
joblib
IMPORTANT:
Ensure that the file 'shadow_robot_dataset.csv' is included in your repository.
If the pickle files (final_model.pkl and feature_names.pkl) do not exist, the app will train the model,
save them (using compression to reduce file size), and use them for subsequent inference.
"""
import warnings
warnings.filterwarnings("ignore")
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from joblib import dump, load
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.pipeline import Pipeline
import gradio as gr
# Define paths for the pre-trained model, feature names, and dataset
MODEL_PATH = "final_model.pkl"
FEATURE_NAMES_PATH = "feature_names.pkl"
DATASET_PATH = "shadow_robot_dataset.csv"
def load_dataset(dataset_path):
"""
Attempts to load the dataset CSV file.
Raises an error if the file does not exist or cannot be read.
"""
if not os.path.exists(dataset_path):
raise FileNotFoundError(
f"Dataset file '{dataset_path}' not found. "
"Please ensure it is included in the repository with the correct permissions."
)
try:
df = pd.read_csv(dataset_path)
except Exception as e:
raise IOError(
f"An error occurred while reading '{dataset_path}': {e}\n"
"Please check file permissions and format."
)
return df
def load_or_train_model():
"""
Loads a pre-trained model and corresponding feature names if available.
Otherwise, trains the model from the dataset, saves the model and feature names as compressed pickle files,
and returns them.
"""
if os.path.exists(MODEL_PATH) and os.path.exists(FEATURE_NAMES_PATH):
final_model = load(MODEL_PATH)
feature_names = load(FEATURE_NAMES_PATH)
print("Loaded pre-trained model and feature names.")
else:
print("Pre-trained model not found. Training model; please wait...")
# Load and preprocess data
df = load_dataset(DATASET_PATH)
# Clean up column names and drop identifier columns if present
df.columns = df.columns.str.replace(r'\s+', '', regex=True)
for col in ['experiment_number', 'measurement_number']:
if col in df.columns:
df.drop(col, axis=1, inplace=True)
# Define target and features
y = df['robustness']
X = df.drop('robustness', axis=1)
# Split the data (use only the training split for model fitting)
X_train, _, y_train, _ = train_test_split(X, y, test_size=0.2, random_state=42)
# Create a pipeline without caching to avoid large file size.
final_model = Pipeline([
('scaler', StandardScaler()),
('model', ExtraTreesRegressor(random_state=42, n_jobs=-1, max_depth=None, n_estimators=100))
])
# Train the model on the training data
final_model.fit(X_train, y_train)
# Save the trained model and feature names with compression
feature_names = sorted(X_train.columns) # Sorted to ensure consistent feature order
dump(final_model, MODEL_PATH, compress=3)
dump(feature_names, FEATURE_NAMES_PATH, compress=3)
print("Model training complete and saved.")
return final_model, feature_names
# Load or train the model at startup
final_model, feature_names = load_or_train_model()
############################################
# Utility Functions for the Gradio Demo
############################################
def generate_random_sensors(feature_names, lower_bound, upper_bound):
"""Generates random sensor values within [lower_bound, upper_bound] for each feature."""
sensor_values = {}
for feature in feature_names:
sensor_values[feature] = round(np.random.uniform(lower_bound, upper_bound), 4)
return sensor_values
def draw_robotic_arm(stability):
"""
Draws a simple two-link robotic arm with three fingers and a ball.
The drawing changes based on the 'stability' value.
"""
fig, ax = plt.subplots(figsize=(6, 6))
# Draw Link 1 (Blue)
shoulder = (0, 0)
elbow = (0.5, 0.5)
ax.plot([shoulder[0], elbow[0]], [shoulder[1], elbow[1]], color='blue', linewidth=3, label='Link 1')
ax.plot(shoulder[0], shoulder[1], 'o', color='blue', markersize=8)
ax.plot(elbow[0], elbow[1], 'o', color='blue', markersize=8)
# Draw Link 2 (Green)
if stability == "Stable":
end_effector = (0.8, 1.2)
else:
end_effector = (0.7, 0.8)
ax.plot([elbow[0], end_effector[0]], [elbow[1], end_effector[1]], color='green', linewidth=3, label='Link 2')
ax.plot(end_effector[0], end_effector[1], 'o', color='green', markersize=8)
# Draw Ball (Red Circle)
if stability == "Stable":
ball_center = (end_effector[0], end_effector[1] + 0.4)
else:
ball_center = (end_effector[0] + 0.5, end_effector[1])
ball = plt.Circle(ball_center, 0.1, color='red', alpha=0.6)
ax.add_artist(ball)
ax.plot(ball_center[0], ball_center[1], 'o', color='red')
# Draw Fingers (Black)
finger_base = end_effector
if stability == "Stable":
offsets = [(-0.05, 0.3), (0.0, 0.3), (0.05, 0.3)]
else:
offsets = [(-0.2, 0.2), (0.0, 0.3), (0.2, 0.2)]
for dx, dy in offsets:
tip_x = end_effector[0] + dx
tip_y = end_effector[1] + dy
ax.plot([finger_base[0], tip_x], [finger_base[1], tip_y], color='black', linewidth=3)
ax.set_xlim(-0.5, 2.0)
ax.set_ylim(-0.2, 2.5)
ax.set_aspect('equal')
ax.grid(True)
# Create a custom legend
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
legend_elements = [
Line2D([0], [0], color='blue', lw=3, label='Link 1'),
Line2D([0], [0], color='green', lw=3, label='Link 2'),
Patch(facecolor='red', alpha=0.6, label='Ball'),
Line2D([0], [0], color='black', lw=3, label='Fingers')
]
ax.legend(handles=legend_elements, loc='upper left')
ax.set_title("2-Link Robot Arm with 3-Finger Grasp")
return fig
def predict_robustness_range(lower_bound, upper_bound):
"""
1) Generates random sensor values within [lower_bound, upper_bound].
2) Predicts the grasp's robustness score using the pre-trained model.
3) Categorizes the grasp as 'Stable' or 'Unstable' (using a threshold of 100).
4) Returns the sensor values, a summary, and a visualization of the robotic arm.
"""
if lower_bound >= upper_bound:
return {}, "Error: Lower bound must be less than upper bound.", None
# Generate sensor values
sensor_values = generate_random_sensors(feature_names, lower_bound, upper_bound)
# Prepare model input in the same order as stored feature names
X_new = np.array([sensor_values[f] for f in feature_names]).reshape(1, -1)
pred_score = round(final_model.predict(X_new)[0], 2)
threshold = 100
stability = "Stable" if pred_score >= threshold else "Unstable"
result_text = f"Predicted Robustness Score: {pred_score}\nStability Category: {stability}"
# Create visualization based on prediction
fig = draw_robotic_arm(stability)
return sensor_values, result_text, fig
############################################
# Gradio Interface
############################################
with gr.Blocks() as demo:
gr.Markdown("## Robotic Grasp Predictor")
gr.Markdown(
"Enter lower and upper bounds for random sensor values.\n"
"The system will predict the grasp's robustness and categorize it as **Stable** "
"(if the predicted score >= 100) or **Unstable** (if the predicted score < 100)."
)
with gr.Row():
with gr.Column():
lower_bound = gr.Number(label="Sensor Values Lower Bound", value=0.0)
upper_bound = gr.Number(label="Sensor Values Upper Bound", value=1.0)
button = gr.Button("Generate & Predict")
sensor_json = gr.JSON(label="Generated Sensor Values")
with gr.Column():
prediction_box = gr.Textbox(label="Prediction Result")
plot_output = gr.Plot(label="Arm Visualization")
button.click(
fn=predict_robustness_range,
inputs=[lower_bound, upper_bound],
outputs=[sensor_json, prediction_box, plot_output]
)
if __name__ == "__main__":
demo.launch(share=True)