geethareddy commited on
Commit
76f8605
·
verified ·
1 Parent(s): bfd0001

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +295 -83
app.py CHANGED
@@ -1,45 +1,150 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- import datetime
 
 
5
 
6
- # Initialize model and tokenizer (preloading them for quicker response)
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  model_name = "distilgpt2"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
- # Set pad_token_id to eos_token_id to avoid warnings
12
- tokenizer.pad_token = tokenizer.eos_token
13
- model.config.pad_token_id = tokenizer.eos_token_id
14
-
15
- # Define a more contextual prompt template
16
- PROMPT_TEMPLATE = """You are an AI coach for construction supervisors. Based on the following inputs, generate a daily checklist, focus suggestions, and a motivational quote. Format your response with clear labels as follows:
17
 
 
 
 
 
 
 
 
 
18
  Checklist:
19
  - {milestones_list}
20
-
21
  Suggestions:
22
  - {suggestions_list}
23
-
24
  Quote:
25
  - Your motivational quote here
26
-
27
- Inputs:
28
- Role: {role}
29
- Project: {project_id}
30
- Milestones: {milestones}
31
- Reflection: {reflection}
32
  """
33
 
34
- # Function to generate outputs based on inputs
35
- def generate_outputs(role, project_id, milestones, reflection):
36
- # Validate inputs to ensure no missing fields
37
- if not all([role, project_id, milestones, reflection]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  return "Error: All fields are required.", "", ""
39
-
40
- # Create prompt from template
41
  milestones_list = "\n- ".join([m.strip() for m in milestones.split(",")])
42
-
43
  suggestions_list = ""
44
  if "delays" in reflection.lower():
45
  suggestions_list = "- Consider adjusting timelines to accommodate delays.\n- Communicate delays to all relevant stakeholders."
@@ -48,7 +153,7 @@ def generate_outputs(role, project_id, milestones, reflection):
48
  elif "equipment" in reflection.lower():
49
  suggestions_list = "- Inspect all equipment to ensure no malfunctions.\n- Schedule maintenance if necessary."
50
 
51
- # Create final prompt
52
  prompt = PROMPT_TEMPLATE.format(
53
  role=role,
54
  project_id=project_id,
@@ -58,81 +163,188 @@ def generate_outputs(role, project_id, milestones, reflection):
58
  suggestions_list=suggestions_list
59
  )
60
 
61
- # Tokenize inputs for model processing
62
  inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True, padding=True)
63
-
64
- # Generate response from the model
65
- with torch.no_grad():
66
- outputs = model.generate(
67
- inputs['input_ids'],
68
- max_length=512,
69
- num_return_sequences=1,
70
- no_repeat_ngram_size=2,
71
- do_sample=True,
72
- top_p=0.9,
73
- temperature=0.8,
74
- pad_token_id=tokenizer.eos_token_id
75
- )
76
-
77
- # Decode the response
78
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
-
80
- # Parse the output and ensure it is structured
81
- checklist = "No checklist generated."
82
- suggestions = "No suggestions generated."
83
- quote = "No quote generated."
84
-
85
- if "Checklist:" in generated_text:
86
- checklist_start = generated_text.find("Checklist:") + len("Checklist:")
87
- suggestions_start = generated_text.find("Suggestions:")
88
- checklist = generated_text[checklist_start:suggestions_start].strip()
89
-
90
- if "Suggestions:" in generated_text:
91
- suggestions_start = generated_text.find("Suggestions:") + len("Suggestions:")
92
- quote_start = generated_text.find("Quote:")
93
- suggestions = generated_text[suggestions_start:quote_start].strip()
94
-
95
- if "Quote:" in generated_text:
96
- quote_start = generated_text.find("Quote:") + len("Quote:")
97
- quote = generated_text[quote_start:].strip()
98
-
99
- # Return structured outputs
100
  return checklist, suggestions, quote
101
 
102
- # Gradio interface for fast user interaction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def create_interface():
104
- with gr.Blocks() as demo:
105
- gr.Markdown("# Construction Supervisor AI Coach")
 
 
 
 
106
  gr.Markdown("Enter details to generate a daily checklist, focus suggestions, and a motivational quote.")
107
-
108
  with gr.Row():
109
- role = gr.Dropdown(choices=["Supervisor", "Foreman", "Project Manager"], label="Role")
110
- project_id = gr.Textbox(label="Project ID")
111
-
 
112
  milestones = gr.Textbox(label="Milestones (comma-separated KPIs)")
113
- reflection = gr.Textbox(label="Reflection Log", lines=5)
114
-
115
  with gr.Row():
116
- submit = gr.Button("Generate")
117
  clear = gr.Button("Clear")
118
-
119
- checklist_output = gr.Textbox(label="Daily Checklist")
120
- suggestions_output = gr.Textbox(label="Focus Suggestions")
121
- quote_output = gr.Textbox(label="Motivational Quote")
122
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  submit.click(
124
  fn=generate_outputs,
125
- inputs=[role, project_id, milestones, reflection],
126
  outputs=[checklist_output, suggestions_output, quote_output]
127
  )
 
128
  clear.click(
129
- fn=lambda: ("", "", "", ""),
130
  inputs=None,
131
- outputs=[role, project_id, milestones, reflection]
 
 
 
 
 
132
  )
133
-
134
  return demo
135
 
 
136
  if __name__ == "__main__":
137
- demo = create_interface()
138
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from simple_salesforce import Salesforce
5
+ import os
6
+ from dotenv import load_dotenv
7
 
8
+ # Load environment variables
9
+ load_dotenv()
10
+
11
+ # Check if required environment variables are set
12
+ required_env_vars = ['SF_USERNAME', 'SF_PASSWORD', 'SF_SECURITY_TOKEN']
13
+ missing_vars = [var for var in required_env_vars if not os.getenv(var)]
14
+ if missing_vars:
15
+ raise EnvironmentError(f"Missing required environment variables: {missing_vars}")
16
+
17
+ # Get configurable values for KPI_Flag__c and Engagement_Score__c
18
+ KPI_FLAG_DEFAULT = os.getenv('KPI_FLAG', 'True') == 'True' # Default to True if not set
19
+ ENGAGEMENT_SCORE_DEFAULT = float(os.getenv('ENGAGEMENT_SCORE', '85.0')) # Default to 85.0
20
+
21
+ # Initialize model and tokenizer
22
  model_name = "distilgpt2"
23
  tokenizer = AutoTokenizer.from_pretrained(model_name)
24
  model = AutoModelForCausalLM.from_pretrained(model_name)
25
 
26
+ # Avoid warnings by setting pad token
27
+ if tokenizer.pad_token is None:
28
+ tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "[PAD]"
29
+ tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
30
+ model.config.pad_token_id = tokenizer.pad_token_id
 
31
 
32
+ # Prompt template for generating structured output
33
+ PROMPT_TEMPLATE = """You are an AI coach for construction supervisors. Based on the following inputs, generate a daily checklist, focus suggestions, and a motivational quote.
34
+ Inputs:
35
+ Role: {role}
36
+ Project: {project_id}
37
+ Milestones: {milestones}
38
+ Reflection: {reflection}
39
+ Format your response clearly like this:
40
  Checklist:
41
  - {milestones_list}
 
42
  Suggestions:
43
  - {suggestions_list}
 
44
  Quote:
45
  - Your motivational quote here
 
 
 
 
 
 
46
  """
47
 
48
+ # Function to get all roles from Salesforce
49
+ def get_roles_from_salesforce():
50
+ try:
51
+ sf = Salesforce(
52
+ username=os.getenv('SF_USERNAME'),
53
+ password=os.getenv('SF_PASSWORD'),
54
+ security_token=os.getenv('SF_SECURITY_TOKEN'),
55
+ domain=os.getenv('SF_DOMAIN', 'login')
56
+ )
57
+
58
+ # Query distinct Role__c values
59
+ result = sf.query("SELECT Role__c FROM Supervisor__c WHERE Role__c != NULL")
60
+
61
+ # Extract roles and remove duplicates
62
+ roles = list(set(record['Role__c'] for record in result.get('records', [])))
63
+
64
+ print(f"✅ Fetched {len(roles)} unique roles from Salesforce")
65
+ return roles
66
+
67
+ except Exception as e:
68
+ print(f"⚠️ Error fetching roles from Salesforce: {e}")
69
+ print("Using fallback roles...")
70
+ return ["Site Manager", "Safety Officer", "Project Lead"] # Match actual active roles
71
+
72
+
73
+ # Function to get supervisor's Name (Auto Number) by role
74
+ def get_supervisor_name_by_role(role):
75
+ try:
76
+ sf = Salesforce(
77
+ username=os.getenv('SF_USERNAME'),
78
+ password=os.getenv('SF_PASSWORD'),
79
+ security_token=os.getenv('SF_SECURITY_TOKEN'),
80
+ domain=os.getenv('SF_DOMAIN', 'login')
81
+ )
82
+
83
+ # Escape single quotes in the role to prevent SOQL injection
84
+ role = role.replace("'", "\\'")
85
+
86
+ # Query all supervisors for the selected role
87
+ result = sf.query(f"SELECT Name FROM Supervisor__c WHERE Role__c = '{role}'")
88
+ if result['totalSize'] == 0:
89
+ print("❌ No matching supervisors found.")
90
+ return []
91
+
92
+ # Extract all supervisor names
93
+ supervisor_names = [record['Name'] for record in result['records']]
94
+ print(f"✅ Found supervisors: {supervisor_names} for role: {role}")
95
+ return supervisor_names
96
+
97
+ except Exception as e:
98
+ print(f"⚠️ Error fetching supervisor names: {e}")
99
+ return []
100
+
101
+
102
+ # Function to get project IDs and names assigned to selected supervisor
103
+ def get_projects_for_supervisor(supervisor_name):
104
+ try:
105
+ # Use the selected supervisor name to fetch the associated project
106
+ sf = Salesforce(
107
+ username=os.getenv('SF_USERNAME'),
108
+ password=os.getenv('SF_PASSWORD'),
109
+ security_token=os.getenv('SF_SECURITY_TOKEN'),
110
+ domain=os.getenv('SF_DOMAIN', 'login')
111
+ )
112
+
113
+ # Escape single quotes in the supervisor_name
114
+ supervisor_name = supervisor_name.replace("'", "\\'")
115
+
116
+ # Step 1: Get the Salesforce record ID of the supervisor based on the Name
117
+ supervisor_result = sf.query(f"SELECT Id FROM Supervisor__c WHERE Name = '{supervisor_name}' LIMIT 1")
118
+ if supervisor_result['totalSize'] == 0:
119
+ print("❌ No supervisor found with the given name.")
120
+ return ""
121
+
122
+ supervisor_id = supervisor_result['records'][0]['Id']
123
+
124
+ # Step 2: Query Project__c records where Supervisor_ID__c matches the supervisor's record ID
125
+ project_result = sf.query(f"SELECT Name FROM Project__c WHERE Supervisor_ID__c = '{supervisor_id}' LIMIT 1")
126
+
127
+ if project_result['totalSize'] == 0:
128
+ print("❌ No project found for supervisor.")
129
+ return ""
130
+
131
+ project_name = project_result['records'][0]['Name']
132
+ print(f"✅ Found project: {project_name} for supervisor: {supervisor_name}")
133
+ return project_name
134
+
135
+ except Exception as e:
136
+ print(f"⚠️ Error fetching project for supervisor: {e}")
137
+ return ""
138
+
139
+
140
+ # Function to generate AI-based coaching output
141
+ def generate_outputs(role, supervisor_name, project_id, milestones, reflection):
142
+ if not all([role, supervisor_name, project_id, milestones, reflection]):
143
  return "Error: All fields are required.", "", ""
144
+
145
+ # Format the prompt
146
  milestones_list = "\n- ".join([m.strip() for m in milestones.split(",")])
147
+
148
  suggestions_list = ""
149
  if "delays" in reflection.lower():
150
  suggestions_list = "- Consider adjusting timelines to accommodate delays.\n- Communicate delays to all relevant stakeholders."
 
153
  elif "equipment" in reflection.lower():
154
  suggestions_list = "- Inspect all equipment to ensure no malfunctions.\n- Schedule maintenance if necessary."
155
 
156
+ # Fill in the prompt template
157
  prompt = PROMPT_TEMPLATE.format(
158
  role=role,
159
  project_id=project_id,
 
163
  suggestions_list=suggestions_list
164
  )
165
 
166
+ # Tokenize input
167
  inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True, padding=True)
168
+
169
+ # Generate response
170
+ try:
171
+ with torch.no_grad():
172
+ outputs = model.generate(
173
+ inputs['input_ids'],
174
+ max_length=1024, # Increased to allow for longer outputs
175
+ num_return_sequences=1,
176
+ no_repeat_ngram_size=2,
177
+ do_sample=True,
178
+ top_p=0.9,
179
+ temperature=0.8,
180
+ pad_token_id=tokenizer.pad_token_id
181
+ )
182
+
183
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
184
+
185
+ except Exception as e:
186
+ print(f"⚠️ Error during model generation: {e}")
187
+ return "Error: Failed to generate outputs.", "", ""
188
+
189
+ # Parse sections
190
+ def extract_section(text, start_marker, end_marker):
191
+ start = text.find(start_marker)
192
+ if start == -1:
193
+ return "Not found"
194
+ start += len(start_marker)
195
+ end = text.find(end_marker, start) if end_marker else len(text)
196
+ return text[start:end].strip()
197
+
198
+ checklist = extract_section(generated_text, "Checklist:\n", "Suggestions:")
199
+ suggestions = extract_section(generated_text, "Suggestions:\n", "Quote:")
200
+ quote = extract_section(generated_text, "Quote:\n", None)
201
+
202
+ # Save to Salesforce
203
+ save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, quote, supervisor_name)
204
+
205
  return checklist, suggestions, quote
206
 
207
+
208
+ # Function to check if a field exists in a Salesforce object
209
+ def field_exists(sf, object_name, field_name):
210
+ try:
211
+ # Describe the object to get its fields
212
+ obj_desc = getattr(sf, object_name).describe()
213
+ fields = [field['name'] for field in obj_desc['fields']]
214
+ return field_name in fields
215
+ except Exception as e:
216
+ print(f"⚠️ Error checking if field {field_name} exists in {object_name}: {e}")
217
+ return False
218
+
219
+
220
+ # Function to create a record in Salesforce
221
+ def save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, quote, supervisor_name):
222
+ try:
223
+ sf = Salesforce(
224
+ username=os.getenv('SF_USERNAME'),
225
+ password=os.getenv('SF_PASSWORD'),
226
+ security_token=os.getenv('SF_SECURITY_TOKEN'),
227
+ domain=os.getenv('SF_DOMAIN', 'login')
228
+ )
229
+
230
+ # Escape single quotes in supervisor_name and project_id
231
+ supervisor_name = supervisor_name.replace("'", "\\'")
232
+ project_id = project_id.replace("'", "\\'")
233
+
234
+ # Step 1: Get the Salesforce record ID for the supervisor
235
+ supervisor_result = sf.query(f"SELECT Id FROM Supervisor__c WHERE Name = '{supervisor_name}' LIMIT 1")
236
+ if supervisor_result['totalSize'] == 0:
237
+ print(f"❌ No supervisor found with Name: {supervisor_name}")
238
+ return
239
+
240
+ supervisor_id = supervisor_result['records'][0]['Id']
241
+
242
+ # Step 2: Get the Salesforce record ID for the project
243
+ project_result = sf.query(f"SELECT Id FROM Project__c WHERE Name = '{project_id}' LIMIT 1")
244
+ if project_result['totalSize'] == 0:
245
+ print(f"❌ No project found with Name: {project_id}")
246
+ return
247
+
248
+ project_record_id = project_result['records'][0]['Id']
249
+
250
+ # Truncate text fields to avoid exceeding Salesforce field length limits (assuming 255 characters for simplicity)
251
+ MAX_TEXT_LENGTH = 255
252
+ checklist = checklist[:MAX_TEXT_LENGTH] if checklist else ""
253
+ suggestions = suggestions[:MAX_TEXT_LENGTH] if suggestions else ""
254
+ reflection = reflection[:MAX_TEXT_LENGTH] if reflection else ""
255
+
256
+ # Prepare data for Salesforce with explicit mapping
257
+ data = {
258
+ 'Supervisor_ID__c': supervisor_id, # Lookup field expects the record ID of Supervisor__c
259
+ 'Project_ID__c': project_record_id, # Lookup field expects the record ID of Project__c
260
+ 'Daily_Checklist__c': checklist, # Maps to the generated Daily Checklist
261
+ 'Suggested_Tips__c': suggestions, # Maps to the generated Focus Suggestions
262
+ 'Reflection_Log__c': reflection, # Maps to the Reflection Log input
263
+ 'KPI_Flag__c': KPI_FLAG_DEFAULT, # Configurable via .env
264
+ 'Engagement_Score__c': ENGAGEMENT_SCORE_DEFAULT # Configurable via .env
265
+ }
266
+
267
+ # Check if Milestones_KPIs__c field exists before mapping
268
+ if field_exists(sf, 'Supervisor_AI_Coaching__c', 'Milestones_KPIs__c'):
269
+ # Truncate milestones as well if the field exists
270
+ milestones = milestones[:MAX_TEXT_LENGTH] if milestones else ""
271
+ data['Milestones_KPIs__c'] = milestones
272
+ else:
273
+ print("⚠️ Milestones_KPIs__c field does not exist in Supervisor_AI_Coaching__c. Skipping mapping.")
274
+
275
+ # Create record
276
+ response = sf.Supervisor_AI_Coaching__c.create(data)
277
+ print("✅ Record created successfully in Salesforce.")
278
+ print("Record ID:", response['id'])
279
+
280
+ except Exception as e:
281
+ print(f"❌ Error saving to Salesforce: {e}")
282
+ print("Data being sent:", data)
283
+ if hasattr(e, 'content'):
284
+ print("Salesforce API response:", e.content)
285
+
286
+
287
+ # Gradio Interface
288
  def create_interface():
289
+ # Fetch roles from Salesforce
290
+ roles = get_roles_from_salesforce()
291
+ print(f"Fetched Roles: {roles}")
292
+
293
+ with gr.Blocks(theme="soft") as demo:
294
+ gr.Markdown("# 🏗️ Construction Supervisor AI Coach")
295
  gr.Markdown("Enter details to generate a daily checklist, focus suggestions, and a motivational quote.")
296
+
297
  with gr.Row():
298
+ role = gr.Dropdown(choices=roles, label="Role")
299
+ supervisor_name = gr.Dropdown(choices=[], label="Supervisor Name")
300
+ project_id = gr.Textbox(label="Project ID", interactive=False)
301
+
302
  milestones = gr.Textbox(label="Milestones (comma-separated KPIs)")
303
+ reflection = gr.Textbox(label="Reflection Log", lines=4)
304
+
305
  with gr.Row():
306
+ submit = gr.Button("Generate", variant="primary")
307
  clear = gr.Button("Clear")
308
+ refresh_btn = gr.Button("🔄 Refresh Roles")
309
+
310
+ checklist_output = gr.Textbox(label=" Daily Checklist")
311
+ suggestions_output = gr.Textbox(label="💡 Focus Suggestions")
312
+ quote_output = gr.Textbox(label="✨ Motivational Quote")
313
+
314
+ # Event: When role changes, update supervisor name dropdown
315
+ role.change(
316
+ fn=lambda r: gr.update(choices=get_supervisor_name_by_role(r)),
317
+ inputs=[role],
318
+ outputs=[supervisor_name]
319
+ )
320
+
321
+ # Event: When supervisor name changes, update project ID
322
+ supervisor_name.change(
323
+ fn=get_projects_for_supervisor,
324
+ inputs=[supervisor_name],
325
+ outputs=[project_id]
326
+ )
327
+
328
  submit.click(
329
  fn=generate_outputs,
330
+ inputs=[role, supervisor_name, project_id, milestones, reflection],
331
  outputs=[checklist_output, suggestions_output, quote_output]
332
  )
333
+
334
  clear.click(
335
+ fn=lambda: ("", "", "", "", ""),
336
  inputs=None,
337
+ outputs=[role, supervisor_name, project_id, milestones, reflection]
338
+ )
339
+
340
+ refresh_btn.click(
341
+ fn=lambda: gr.update(choices=get_roles_from_salesforce()),
342
+ outputs=role
343
  )
344
+
345
  return demo
346
 
347
+
348
  if __name__ == "__main__":
349
+ app = create_interface()
350
+ app.launch()