Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
| 1 |
import streamlit as st
|
|
|
|
| 2 |
import requests
|
| 3 |
import pandas as pd
|
| 4 |
import json
|
|
|
|
| 5 |
|
| 6 |
# Function to create the payload for the model
|
| 7 |
def create_tf_serving_json(data):
|
| 8 |
return {'inputs': {name: data[name].tolist() for name in data.keys()} if isinstance(data, dict) else data.tolist()}
|
| 9 |
|
| 10 |
# Function to send a request to the model endpoint
|
| 11 |
-
def score_model(dataset):
|
| 12 |
-
url = 'Add_Model_URL'
|
| 13 |
-
token = 'Add_Databricks_Token'
|
| 14 |
headers = {'Authorization': f'Bearer {token}', 'Content-Type': 'application/json'}
|
| 15 |
ds_dict = {'dataframe_split': dataset.to_dict(orient='split')} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
|
| 16 |
data_json = json.dumps(ds_dict, allow_nan=True)
|
|
@@ -19,6 +19,18 @@ def score_model(dataset):
|
|
| 19 |
raise Exception(f'Request failed with status {response.status_code}, {response.text}')
|
| 20 |
return response.json()
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
# Streamlit app UI
|
| 23 |
st.title('Employee Churn Prediction')
|
| 24 |
|
|
@@ -37,7 +49,7 @@ with st.form(key='churn_form'):
|
|
| 37 |
salaryVec_0, salaryVec_1 = 0.0, 0.0
|
| 38 |
if salary == 'Low':
|
| 39 |
salaryVec_0 = 1.0
|
| 40 |
-
elif salary
|
| 41 |
salaryVec_1 = 1.0
|
| 42 |
|
| 43 |
# Submit button
|
|
@@ -58,10 +70,20 @@ if submit_button:
|
|
| 58 |
'salaryVec_1': [salaryVec_1]
|
| 59 |
})
|
| 60 |
|
| 61 |
-
#
|
|
|
|
|
|
|
|
|
|
| 62 |
try:
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
if churn_prediction == 1:
|
| 66 |
st.write('The employee is likely to churn.')
|
| 67 |
else:
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import os
|
| 3 |
import requests
|
| 4 |
import pandas as pd
|
| 5 |
import json
|
| 6 |
+
import joblib # Import joblib to load the local model
|
| 7 |
|
| 8 |
# Function to create the payload for the model
|
| 9 |
def create_tf_serving_json(data):
|
| 10 |
return {'inputs': {name: data[name].tolist() for name in data.keys()} if isinstance(data, dict) else data.tolist()}
|
| 11 |
|
| 12 |
# Function to send a request to the model endpoint
|
| 13 |
+
def score_model(dataset, url, token):
|
|
|
|
|
|
|
| 14 |
headers = {'Authorization': f'Bearer {token}', 'Content-Type': 'application/json'}
|
| 15 |
ds_dict = {'dataframe_split': dataset.to_dict(orient='split')} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
|
| 16 |
data_json = json.dumps(ds_dict, allow_nan=True)
|
|
|
|
| 19 |
raise Exception(f'Request failed with status {response.status_code}, {response.text}')
|
| 20 |
return response.json()
|
| 21 |
|
| 22 |
+
# Load local model function
|
| 23 |
+
def load_local_model():
|
| 24 |
+
return joblib.load('Random Forest - Grid Search_model.pkl')
|
| 25 |
+
|
| 26 |
+
# Function to make predictions with local model
|
| 27 |
+
def predict_with_local_model(model, input_data):
|
| 28 |
+
input_data = input_data[['satisfaction_level', 'last_evaluation', 'number_project',
|
| 29 |
+
'average_montly_hours', 'time_spend_company', 'Work_accident',
|
| 30 |
+
'promotion_last_5years', 'salaryVec_0', 'salaryVec_1']]
|
| 31 |
+
prediction = model.predict(input_data)
|
| 32 |
+
return prediction[0]
|
| 33 |
+
|
| 34 |
# Streamlit app UI
|
| 35 |
st.title('Employee Churn Prediction')
|
| 36 |
|
|
|
|
| 49 |
salaryVec_0, salaryVec_1 = 0.0, 0.0
|
| 50 |
if salary == 'Low':
|
| 51 |
salaryVec_0 = 1.0
|
| 52 |
+
elif salary is 'Medium':
|
| 53 |
salaryVec_1 = 1.0
|
| 54 |
|
| 55 |
# Submit button
|
|
|
|
| 70 |
'salaryVec_1': [salaryVec_1]
|
| 71 |
})
|
| 72 |
|
| 73 |
+
# Check if URL and token are specified
|
| 74 |
+
url = 'Add_Model_URL'
|
| 75 |
+
token = 'Add_Databricks_Token'
|
| 76 |
+
|
| 77 |
try:
|
| 78 |
+
if url != 'Add_Model_URL' and token != 'Add_Databricks_Token':
|
| 79 |
+
# Use the API if URL and token are specified
|
| 80 |
+
prediction = score_model(input_data, url, token)
|
| 81 |
+
churn_prediction = prediction['predictions'][0]
|
| 82 |
+
else:
|
| 83 |
+
# Use the local model if URL and token are not specified
|
| 84 |
+
local_model = load_local_model()
|
| 85 |
+
churn_prediction = predict_with_local_model(local_model, input_data)
|
| 86 |
+
|
| 87 |
if churn_prediction == 1:
|
| 88 |
st.write('The employee is likely to churn.')
|
| 89 |
else:
|