Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import mysql.connector | |
| import os | |
| # Use a pipeline as a high-level helper | |
| from transformers import pipeline | |
| classifier_model = pipeline( | |
| "zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1" | |
| ) | |
| # get db info from env vars | |
| db_host = os.environ.get("DB_HOST") | |
| db_user = os.environ.get("DB_USER") | |
| db_pass = os.environ.get("DB_PASS") | |
| db_name = os.environ.get("DB_NAME") | |
| db_connection = mysql.connector.connect( | |
| host=db_host, | |
| user=db_user, | |
| password=db_pass, | |
| database=db_name, | |
| ) | |
| db_cursor = db_connection.cursor() | |
| ORG_ID = 731 | |
| potential_labels = [] | |
| def get_potential_labels(): | |
| # get potential labels from db | |
| global potential_labels | |
| potential_labels = db_cursor.execute( | |
| "SELECT message_category_name FROM radmap_frog12.message_categorys" | |
| ) | |
| potential_labels = db_cursor.fetchall() | |
| potential_labels = [label[0] for label in potential_labels] | |
| return potential_labels | |
| potential_labels = get_potential_labels() | |
| # Function to handle the classification | |
| def classify_email(constituent_email): | |
| potential_labels = get_potential_labels() | |
| print("classifying email") | |
| model_out = classifier_model(constituent_email, potential_labels, multi_label=True) | |
| print("classification complete") | |
| top_labels = [ | |
| label | |
| for label, score in zip(model_out["labels"], model_out["scores"]) | |
| if score > 0.95 | |
| ] | |
| if top_labels == []: | |
| # Find the index of the highest score | |
| max_score_index = model_out["scores"].index(max(model_out["scores"])) | |
| # Return the label with the highest score | |
| return model_out["labels"][max_score_index] | |
| return ", ".join(top_labels) | |
| def remove_spaces_after_comma(s): | |
| parts = s.split(",") | |
| parts = [part.strip() for part in parts] | |
| return ",".join(parts) | |
| # Function to handle saving data | |
| def save_data(orig_user_email, constituent_email, labels, user_response, current_user): | |
| # save the data to the database | |
| # orig_user_email should have volley 0 | |
| # constituent_email should have volley 1 | |
| # user_response should have volley 2 | |
| # app_id, org_id, and person_id should be 0 | |
| # subject should be "Email Classification and Response Tracking" | |
| # body should be the original email | |
| db_connection = mysql.connector.connect( | |
| host=db_host, | |
| user=db_user, | |
| password=db_pass, | |
| database=db_name, | |
| ) | |
| db_cursor = db_connection.cursor() | |
| if current_user == "Sheryl Springer": | |
| person_id = 11021 | |
| elif current_user == "Diane Taylor": | |
| person_id = 11023 | |
| elif current_user == "Ann E. Belyea": | |
| person_id = 11025 | |
| elif current_user == "Marcelo Mejia": | |
| person_id = 11027 | |
| elif current_user == "Rishi Vasudeva": | |
| person_id = 11029 | |
| try: | |
| message_id = 0 | |
| if orig_user_email != "": | |
| db_cursor.execute( | |
| "INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, previous_message_id) VALUES (345678, %s, %s, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)", | |
| (ORG_ID, person_id, orig_user_email, message_id), | |
| ) | |
| # | |
| message_id = db_cursor.lastrowid | |
| db_cursor.execute( | |
| "INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, previous_message_id) VALUES (345678, %s, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)", | |
| (ORG_ID, constituent_email, message_id), | |
| ) | |
| message_id = db_cursor.lastrowid | |
| db_cursor.execute( | |
| "INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, previous_message_id) VALUES (345678, %s, %s, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)", | |
| (ORG_ID, person_id, user_response, message_id), | |
| ) | |
| # insert a row into the message_categorys_associations table for each valid label in labels with the message_id of the constituent_email | |
| # if there is a comma, remove all spaces after the comma | |
| labels = remove_spaces_after_comma(labels) | |
| labels = labels.split(",") | |
| for label in labels: | |
| label_exists = db_cursor.execute( | |
| "SELECT * FROM radmap_frog12.message_categorys WHERE message_category_name = %s", | |
| (label,), | |
| ) | |
| label_exists = db_cursor.fetchall() | |
| if label_exists: | |
| message_id = db_cursor.execute( | |
| "SELECT id FROM radmap_frog12.messages WHERE body = %s", | |
| (constituent_email,), | |
| ) | |
| message_id = db_cursor.fetchall() | |
| db_cursor.execute( | |
| "INSERT INTO radmap_frog12.message_category_associations (message_id, message_category_id) VALUES (%s, %s)", | |
| (message_id[0][0], label_exists[0][0]), | |
| ) | |
| db_connection.commit() | |
| return "Response successfully saved to database" | |
| except Exception as e: | |
| print(e) | |
| db_connection.rollback() | |
| return "Error saving data to database" | |
| # read auth from env vars | |
| auth_username = os.environ.get("AUTH_USERNAME") | |
| auth_password = os.environ.get("AUTH_PASSWORD") | |
| # Define your username and password pairs | |
| auth = [(auth_username, auth_password)] | |
| # Start building the Gradio interface | |
| # Start building the Gradio interface with two columns | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| with gr.Row(): | |
| gr.Markdown("## Campaign Messaging Assistant") | |
| with gr.Row(): | |
| with gr.Column(): | |
| current_user = gr.Dropdown( | |
| label="Current User", | |
| choices=[ | |
| "Sheryl Springer", | |
| "Ann E. Belyea", | |
| "Marcelo Mejia", | |
| "Rishi Vasudeva", | |
| "Diane Taylor", | |
| ], | |
| ) | |
| email_labels_input = gr.Markdown( | |
| "## Message Category Library\n ### " + ", ".join(potential_labels), | |
| ) | |
| original_email_input = gr.TextArea( | |
| placeholder="Enter the original email sent by you", | |
| label="Your Original Email (if any)", | |
| ) | |
| spacer1 = gr.Label(visible=False) | |
| constituent_response_input = gr.TextArea( | |
| placeholder="Enter the incoming message", | |
| label="Incoming Message (may be a response to original email)", | |
| lines=15, | |
| ) | |
| classify_button = gr.Button("Process Message", variant="primary") | |
| with gr.Column(): | |
| classification_output = gr.TextArea( | |
| label="Suggested Message Categories (modify as needed). Separate categories with commas", | |
| lines=1, | |
| interactive=True, | |
| ) | |
| spacer2 = gr.Label(visible=False) | |
| user_response_input = gr.TextArea( | |
| placeholder="Enter your response to the constituent", | |
| label="Suggested Response (modify as needed)", | |
| lines=25, | |
| ) | |
| save_button = gr.Button("Save Response", variant="primary") | |
| save_output = gr.Label(label="Backend Response") | |
| # Define button actions | |
| classify_button.click( | |
| fn=classify_email, | |
| inputs=constituent_response_input, | |
| outputs=classification_output, | |
| ) | |
| save_button.click( | |
| fn=save_data, | |
| inputs=[ | |
| original_email_input, | |
| constituent_response_input, | |
| classification_output, | |
| user_response_input, | |
| current_user, | |
| ], | |
| outputs=save_output, | |
| ) | |
| # Launch the app | |
| app.launch(auth=auth, debug=True) | |