File size: 6,009 Bytes
4e02702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import streamlit as st
import pandas as pd
from datasets import load_dataset, Dataset
from random import sample
from utils.pairwise_comparison import one_regard_computation
import matplotlib.pyplot as plt
import os

# Set up the Streamlit interface
st.title('Gender Bias Analysis in Text Generation')


def check_password():
    def password_entered():
        if password_input == os.getenv('PASSWORD'):
        # if password_input == "  ":
            st.session_state['password_correct'] = True
        else:
            st.error("Incorrect Password, please try again.")

    password_input = st.text_input("Enter Password:", type="password")
    submit_button = st.button("Submit", on_click=password_entered)

    if submit_button and not st.session_state.get('password_correct', False):
        st.error("Please enter a valid password to access the demo.")


if not st.session_state.get('password_correct', False):
    check_password()
else:
    st.sidebar.success("Password Verified. Proceed with the demo.")

    if 'data_size' not in st.session_state:
        st.session_state['data_size'] = 10
    if 'bold' not in st.session_state:
        bold = pd.DataFrame({})
        bold_raw = pd.DataFrame(load_dataset("AlexaAI/bold", split="train"))
        for index, row in bold_raw.iterrows():
            bold_raw_prompts = list(row['prompts'])
            bold_raw_wikipedia = list(row['wikipedia'])
            bold_expansion = zip(bold_raw_prompts, bold_raw_wikipedia)
            for bold_prompt, bold_wikipedia in bold_expansion:
                bold = bold._append(
                    {'domain': row['domain'], 'name': row['name'], 'category': row['category'], 'prompts': bold_prompt,
                     'wikipedia': bold_wikipedia}, ignore_index=True)
        st.session_state['bold'] = Dataset.from_pandas(bold)

    domain = st.selectbox(
        "Select the domain",
        pd.DataFrame(st.session_state['bold'])['domain'].unique())
    domain_limited = [p for p in st.session_state['bold'] if p['domain'] == domain]

    st.session_state['sample_size'] = st.slider('Select number of samples per category:', min_value=1, max_value=50,
                                                value=st.session_state['data_size'])

    if st.button('Compute'):
        answer_dict = {}
        category_list = pd.DataFrame(domain_limited)['category'].unique().tolist()
        unique_pairs = []
        ref_list = {}
        no_ref_list = {}
        for i in range(len(category_list)):
            o_one = category_list[i]
            with st.spinner(f'Computing regard results for {o_one.replace("_", " ")}'):
                st.session_state['rmr'] = one_regard_computation(o_one, st.session_state['bold'],
                                                                 st.session_state['sample_size'])
                answer_dict[o_one] = (st.session_state['rmr'])
                st.write(f'Regard results for {o_one.replace("_", " ")} computed successfully.')
                # st.json(answer_dict[o_one])
                ref_list[o_one] = st.session_state['rmr']['ref_diff_mean']['positive'] \
                           - st.session_state['rmr']['ref_diff_mean']['negative']
                no_ref_list[o_one] = st.session_state['rmr']['no_ref_diff_mean']['positive'] \
                           - st.session_state['rmr']['no_ref_diff_mean']['negative']

                # Plotting
                categories = ['GPT2', 'Wiki']
                mp_gpt = st.session_state['rmr']['no_ref_diff_mean']['positive']
                mn_gpt = st.session_state['rmr']['no_ref_diff_mean']['negative']
                mo_gpt = 1 - (mp_gpt + mn_gpt)

                mp_wiki = mp_gpt - st.session_state['rmr']['ref_diff_mean']['positive']
                mn_wiki = mn_gpt - st.session_state['rmr']['ref_diff_mean']['negative']
                mo_wiki = 1 - (mn_wiki + mp_wiki)

                positive_m = [mp_gpt, mp_wiki]
                other_m = [mo_gpt, mo_wiki]
                negative_m = [mn_gpt, mn_wiki]


                fig_a, ax_a = plt.subplots()
                ax_a.bar(categories, negative_m, label='Negative', color='blue')
                ax_a.bar(categories, other_m, bottom=negative_m, label='Other', color='orange')
                ax_a.bar(categories, positive_m, bottom=[negative_m[i] + other_m[i] for i in range(len(negative_m))],
                         label='Positive', color='green')

                plt.ylabel('Proportion')
                plt.title(f'GPT2 vs Wiki on {o_one.replace("_", " ")} regard')
                plt.legend()

                st.pyplot(fig_a)


        st.subheader(f'The comparison of absolute regard value in {domain.replace("_", " ")} by GPT2')
        st.bar_chart(no_ref_list)
        st.write(f'***Max difference of absolute regard values in the {domain.replace("_", " ")}:***')
        keys_with_max_value_no_ref = [key for key, value in no_ref_list.items() if value == max(no_ref_list.values())][0]
        keys_with_min_value_no_ref = [key for key, value in no_ref_list.items() if value == min(no_ref_list.values())][0]
        st.write(f'     {keys_with_max_value_no_ref.replace("_", " ")} regard - {keys_with_min_value_no_ref.replace("_", " ")} regard ='
                 f'{max(ref_list.values()) - min(ref_list.values())}')

        st.subheader(f'The comparison of regard value in {domain.replace("_", " ")} with references to Wikipedia by GPT2')
        st.bar_chart(ref_list)
        st.write(f'***Max difference of regard values in the {domain.replace("_", " ")} with references to Wikipedia:***')
        keys_with_max_value_ref = [key for key, value in ref_list.items() if value == max(ref_list.values())][0]
        keys_with_min_value_ref = [key for key, value in ref_list.items() if value == min(ref_list.values())][0]
        st.write(f'     {keys_with_max_value_ref.replace("_", " ")} regard - {keys_with_min_value_ref.replace("_", " ")} regard  = '
                 f'{max(ref_list.values()) - min(ref_list.values())}')