yiiilonggg commited on
Commit
b151e60
·
1 Parent(s): 251461c

Init files

Browse files
.gitignore ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Abstra
171
+ # Abstra is an AI-powered process automation framework.
172
+ # Ignore directories containing user credentials, local state, and settings.
173
+ # Learn more at https://abstra.io/docs
174
+ .abstra/
175
+
176
+ # Visual Studio Code
177
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
178
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
179
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
180
+ # you could uncomment the following to ignore the enitre vscode folder
181
+ # .vscode/
182
+
183
+ # Ruff stuff:
184
+ .ruff_cache/
185
+
186
+ # PyPI configuration file
187
+ .pypirc
188
+
189
+ # Cursor
190
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
191
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
192
+ # refer to https://docs.cursor.com/context/ignore-files
193
+ .cursorignore
194
+ .cursorindexingignore
195
+
196
+ # ignore session dirs
197
+ sessions/*
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+
4
+ from src.frontend import (
5
+ chatbot,
6
+ continue_story_screen,
7
+ progress_screen,
8
+ sidebar,
9
+ start_screen,
10
+ story_information_widgets
11
+ )
12
+
13
+ demo = gr.Blocks()
14
+
15
+ with demo:
16
+ start_screen.render()
17
+ story_information_widgets.render()
18
+ continue_story_screen.render()
19
+ sidebar.render()
20
+ chatbot.render()
21
+ progress_screen.render()
22
+
23
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ openai
2
+ langgraph
3
+ gradio
4
+ igraph
5
+ plotly
src/agents/base_agent.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from abc import ABC
3
+ from typing import Dict, List, Literal
4
+
5
+ from ..utils.init_openai import completions_with_backoff
6
+
7
+
8
+ class BaseAgent(ABC):
9
+
10
+ def format_openai_message(
11
+ self,
12
+ system_message: str,
13
+ user_message: str
14
+ ) -> List[Dict[str, str]]:
15
+ return [
16
+ {
17
+ 'role': 'system',
18
+ 'content': system_message
19
+ },
20
+ {
21
+ 'role': 'user',
22
+ 'content': user_message
23
+ }
24
+ ]
25
+
26
+ def execute(
27
+ self,
28
+ model: Literal['openai'],
29
+ messages: List[Dict[str, str]],
30
+ temperature: float = 0,
31
+ max_tokens: int = 8192
32
+ ) -> str:
33
+ if model == 'openai':
34
+
35
+ response = completions_with_backoff(messages, temperature, max_tokens)
36
+ content = response.choices[0].message.content
37
+
38
+ return content
39
+ else:
40
+ raise ValueError(f'Model type {model} not supported (yet)')
src/agents/category_builder_agent.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+
4
+ from typing import List
5
+
6
+ from .base_agent import BaseAgent
7
+ from ..dataclasses.category import Category
8
+ from ..prompts.category_builder_agent import system_prompt_template, user_prompt_template
9
+
10
+
11
+ class CategoryBuilderAgent(BaseAgent):
12
+
13
+ def __init__(self):
14
+ self.model = 'openai'
15
+ self.temperature = 0.5
16
+
17
+ def get_categories(
18
+ self,
19
+ story_context: str,
20
+ category_context: str,
21
+ num_categories: int
22
+ ) -> List[Category]:
23
+ prompt_arguments = {
24
+ 'story_context': story_context,
25
+ 'category_context': category_context,
26
+ 'num_categories': num_categories
27
+ }
28
+
29
+ system_message = system_prompt_template.format(**prompt_arguments)
30
+ user_message = user_prompt_template.format(**prompt_arguments)
31
+
32
+ messages = self.format_openai_message(system_message, user_message)
33
+
34
+ response = self.execute(self.model, messages, self.temperature)
35
+
36
+ # remove json tag from response
37
+ response = response.removeprefix('```json').removeprefix('```')
38
+ response = response.removesuffix('```json').removesuffix('```')
39
+
40
+ categories = json.loads(response)
41
+ assert type(categories) == list
42
+
43
+ return [
44
+ Category(**category)
45
+ for category in categories
46
+ ]
src/agents/coordinator.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import os
4
+
5
+ from langgraph.graph import StateGraph, START, END
6
+ from langgraph.graph.graph import CompiledGraph
7
+ from typing import Dict, List, Literal, Tuple
8
+
9
+ from src.agents.category_builder_agent import CategoryBuilderAgent
10
+ from src.agents.evaluator_agent import EvaluatorAgent
11
+ from src.agents.narrator_agent import NarratorAgent
12
+ from src.agents.orchestrator_agent import OrchestratorAgent
13
+ from src.agents.storybuilder_agent import StorybuilderAgent
14
+ from src.dataclasses.category import Category
15
+ from src.dataclasses.process_state import ProcessState
16
+ from src.dataclasses.question import Question
17
+ from src.utils.utils import get_session_dir, save_information, load_information
18
+
19
+ num_questions_: int = None
20
+ num_options_: int = None
21
+ story_name: str = None
22
+ states: Dict[int, Dict[str, str]] = dict() # maps question_num to {str to str}
23
+ questions: Dict[str, Question] = dict() # maps uuid to Question
24
+ state_question_map: Dict[int, Dict[str, str]] = dict() # maps question_num to {state_name to question_uuid}
25
+ categories: List[Category] = [] # list of categories
26
+ categories_seen: Dict[str, bool] = dict() # maps category name to bool if category name was seen
27
+ path_evaluations: Dict[str, List[str]] = dict() # maps a path to an evaluation, including the reason
28
+
29
+
30
+ def init_categories(
31
+ story_context: str,
32
+ category_context: str,
33
+ num_categories: int
34
+ ) -> None:
35
+ global categories, categories_seen
36
+
37
+ category_builder_agent = CategoryBuilderAgent()
38
+ categories = category_builder_agent.get_categories(
39
+ story_context, category_context, num_categories
40
+ )
41
+ categories_seen = {
42
+ category.name: False
43
+ for category in categories
44
+ }
45
+
46
+
47
+ def user_node(state: ProcessState) -> ProcessState:
48
+ user_choice = state['user_choice']
49
+ question_uuid = state['foll_question_uuid']
50
+ question_num = state['question_num'] + 1
51
+
52
+ if user_choice:
53
+ foll_state_name = questions[question_uuid].options[user_choice]
54
+ foll_user_choices = [choice for choice in state['previous_user_choices']]
55
+ foll_user_choices.append(user_choice)
56
+ else:
57
+ foll_state_name = 'start'
58
+ foll_user_choices = []
59
+
60
+ return {
61
+ 'question_num': question_num,
62
+ 'state_name': foll_state_name,
63
+ 'user_choice': user_choice,
64
+ 'previous_state_names': state['previous_state_names'],
65
+ 'previous_questions': state['previous_questions'],
66
+ 'previous_user_choices': foll_user_choices,
67
+ 'builder_instruction': '',
68
+ 'foll_question_uuid': ''
69
+ }
70
+
71
+
72
+ def init_graph(
73
+ story_context: str,
74
+ category_context: str,
75
+ num_questions: int,
76
+ num_options: int,
77
+ num_categories: int
78
+ ) -> CompiledGraph:
79
+
80
+ global num_questions_, num_options_
81
+ num_questions_ = num_questions
82
+ num_options_ = num_options
83
+
84
+ orchestrator_agent = OrchestratorAgent(
85
+ story_context, category_context, num_questions
86
+ )
87
+ storybuilder_agent = StorybuilderAgent(
88
+ story_context, category_context, num_questions, num_options
89
+ )
90
+ narrator_agent = NarratorAgent(
91
+ story_context, category_context, num_questions, num_options
92
+ )
93
+
94
+ init_categories(story_context, category_context, num_categories)
95
+
96
+ builder = StateGraph(ProcessState)
97
+
98
+ builder.add_node('orchestrator', orchestrator_agent.run_agent)
99
+ builder.add_node('storybuilder', storybuilder_agent.run_agent)
100
+ builder.add_node('narrator', narrator_agent.run_agent)
101
+ builder.add_node('user', user_node)
102
+
103
+ def conditional_edge(state: ProcessState) -> Literal['storybuilder', END]:
104
+ question_num = state['question_num']
105
+ state_name = state['state_name']
106
+
107
+ if state_name not in states.get(question_num, dict()) or \
108
+ states[question_num][state_name] == 'Unexplored!':
109
+ return 'storybuilder'
110
+ else:
111
+ return END
112
+
113
+ def last_action_edge(state: ProcessState) -> Literal['orchestrator', END]:
114
+ question_num = state['question_num']
115
+
116
+ if question_num > num_questions:
117
+ return END
118
+ else:
119
+ return 'orchestrator'
120
+
121
+ builder.add_edge(START, 'user')
122
+ builder.add_edge('user', 'orchestrator')
123
+ builder.add_edge('storybuilder', 'narrator')
124
+ builder.add_edge('narrator', END)
125
+ builder.add_conditional_edges('orchestrator', conditional_edge)
126
+ builder.add_conditional_edges('user', last_action_edge)
127
+ graph = builder.compile()
128
+
129
+ return graph
130
+
131
+ def evaluation(state: ProcessState) -> Tuple[str, str]:
132
+
133
+ evaluator_agent = EvaluatorAgent()
134
+
135
+ previous_state_names = state['previous_state_names']
136
+ previous_questions = state['previous_questions']
137
+ previous_user_choices = state['previous_user_choices']
138
+
139
+ previous_questions.append(state['foll_question_uuid'])
140
+ previous_user_choices.append(state['user_choice'])
141
+
142
+ selected_category, reasoning = evaluator_agent.get_evaluation(
143
+ previous_state_names, previous_questions, previous_user_choices,
144
+ states, questions, categories
145
+ )
146
+
147
+ return selected_category, reasoning
148
+
149
+
150
+ def save_coordinator() -> None:
151
+ data = [
152
+ states,
153
+ {k: v.dict() for k, v in questions.items()},
154
+ state_question_map,
155
+ [c.dict() for c in categories],
156
+ categories_seen,
157
+ path_evaluations,
158
+ ]
159
+ key_names = [
160
+ 'states',
161
+ 'questions',
162
+ 'state_question_map',
163
+ 'categories',
164
+ 'categories_seen',
165
+ 'path_evaluations'
166
+ ]
167
+
168
+ combined_data = {
169
+ key_name: data_
170
+ for key_name, data_ in zip(key_names, data)
171
+ }
172
+ filepath = os.path.join(get_session_dir(), story_name, 'data.json')
173
+ save_information(combined_data, filepath)
174
+
175
+
176
+ def load_coordinator() -> None:
177
+ global states, questions, state_question_map, categories, categories_seen, path_evaluations
178
+
179
+ filepath = os.path.join(get_session_dir(), story_name, 'data.json')
180
+ combined_data = load_information(filepath)
181
+
182
+ states = {int(k): v for k, v in combined_data['states'].items()}
183
+ questions = {k: Question(**{
184
+ 'question': v['question'],
185
+ 'options': json.loads(v['options'].replace('\'', '\"'))
186
+ }) for k, v in combined_data['questions'].items()}
187
+ state_question_map = {int(k): v for k, v in combined_data['state_question_map'].items()}
188
+ categories = [Category(**c) for c in combined_data['categories']]
189
+ categories_seen = combined_data['categories_seen']
190
+ path_evaluations = combined_data['path_evaluations']
src/agents/evaluator_agent.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+
4
+ from typing import Dict, List, Tuple
5
+
6
+ import src.agents.coordinator as C
7
+
8
+ from .base_agent import BaseAgent
9
+ from ..dataclasses.category import Category
10
+ from ..prompts.evaluator_agent import system_prompt_template, user_prompt_template
11
+
12
+ CATEGORY_FORMAT = """
13
+ {{
14
+ 'name': {category_name},
15
+ 'description': {category_description},
16
+ 'traits': [
17
+ {category_traits}
18
+ ]
19
+ }}
20
+ """.strip()
21
+
22
+
23
+ class EvaluatorAgent(BaseAgent):
24
+
25
+ def __init__(self):
26
+ self.model = 'openai'
27
+ self.temperature = 0.2
28
+
29
+ def generate_path_key(
30
+ previous_questions: List[str],
31
+ previous_user_choices: List[str]
32
+ ) -> str:
33
+ return '-'.join([
34
+ f'{question}+{choice}'
35
+ for question, choice in zip(previous_questions, previous_user_choices)
36
+ ])
37
+
38
+ def is_encourage_categories(
39
+ self,
40
+ questions: Dict[str, str]
41
+ ) -> bool:
42
+ '''Attempts to ensure that each category has at least one story path to it.
43
+
44
+ Computes if the number of remaining paths equals to the number of unused categories.
45
+ Assumed to be a low chance as the number of questions increases, but exists as a checker.
46
+ '''
47
+ if all([v for _, v in C.categories_seen.items()]):
48
+ return False
49
+
50
+ num_questions = len(questions)
51
+ random_question = list(questions.keys())[0]
52
+ num_options = len(C.questions[random_question].options)
53
+
54
+ rem_categories = sum([1 for _, v in C.categories_seen.items() if not v])
55
+ num_categories = len(C.categories)
56
+ return rem_categories >= num_options * num_questions - (num_categories - rem_categories)
57
+
58
+ def get_evaluation(
59
+ self,
60
+ previous_state_names: List[str],
61
+ previous_questions: List[str],
62
+ previous_user_choices: List[str],
63
+ states: Dict[int, Dict[str, str]],
64
+ questions: Dict[str, str],
65
+ categories: List[Category]
66
+ ) -> List[str]:
67
+
68
+ complete_user_story = []
69
+
70
+ path_key = EvaluatorAgent.generate_path_key(previous_questions, previous_user_choices)
71
+
72
+ if path_key in C.path_evaluations:
73
+ return C.path_evaluations[path_key]
74
+
75
+ for i, (state_name, question_uuid, user_choice) in enumerate(zip(
76
+ previous_state_names, previous_questions, previous_user_choices
77
+ )):
78
+ scenario = states[i + 1][state_name]
79
+ question_str = questions[question_uuid].question
80
+
81
+ turn_content = f'`SCENARIO`: {scenario}\n' + \
82
+ f'`QUESTION`: {question_str}\n' + \
83
+ f'`CHOICE`: {user_choice}\n'
84
+
85
+ complete_user_story.append(turn_content)
86
+
87
+ complete_user_story = '\n\n'.join(complete_user_story)
88
+
89
+ is_encourage_categories = self.is_encourage_categories(questions)
90
+ category_contents = []
91
+ for i, category in enumerate(sorted(categories, key=lambda x: x.name)):
92
+ if is_encourage_categories and C.categories_seen[category.name]:
93
+ continue
94
+
95
+ category_content = CATEGORY_FORMAT.format(**{
96
+ 'category_name': category.name,
97
+ 'category_description': category.description,
98
+ 'category_traits': category.traits
99
+ })
100
+ category_contents.append(category_content)
101
+
102
+ category_contents = '{\n' + ',\n'.join(category_contents) + '\n}'
103
+
104
+ prompt_arguments = {
105
+ 'user_story': complete_user_story,
106
+ 'category_contents': category_contents
107
+ }
108
+
109
+ system_message = system_prompt_template.format(**prompt_arguments)
110
+ user_message = user_prompt_template.format(**prompt_arguments)
111
+
112
+ messages = self.format_openai_message(system_message, user_message)
113
+
114
+ response = self.execute(self.model, messages, self.temperature)
115
+
116
+ # remove json tag from response
117
+ response = response.removeprefix('```json').removeprefix('```')
118
+ response = response.removesuffix('```json').removesuffix('```')
119
+
120
+ selection_info = json.loads(response)
121
+
122
+ selected_category = selection_info['category']
123
+ reason = selection_info['reason']
124
+
125
+ C.path_evaluations[path_key] = [selected_category, reason]
126
+
127
+ return (selected_category, reason)
src/agents/narrator_agent.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+
4
+ from uuid import uuid4
5
+
6
+ import src.agents.coordinator as C
7
+
8
+ from .base_agent import BaseAgent
9
+ from ..dataclasses.process_state import ProcessState
10
+ from ..dataclasses.question import Question
11
+ from ..prompts.narrator_agent import system_prompt_template, user_prompt_template
12
+
13
+ class NarratorAgent(BaseAgent):
14
+
15
+ def __init__(
16
+ self,
17
+ story_context: str,
18
+ category_context: str,
19
+ num_questions: int,
20
+ num_options: int
21
+ ):
22
+
23
+ self.prompt_arguments = {
24
+ 'story_context': story_context,
25
+ 'category_context': category_context,
26
+ 'num_questions': num_questions,
27
+ 'num_options': num_options
28
+ }
29
+
30
+ self.model = 'openai'
31
+ self.temperature = 0.5
32
+
33
+ def run_agent(self, state: ProcessState) -> ProcessState:
34
+
35
+ question_num = state['question_num']
36
+ state_name = state['state_name']
37
+
38
+ if state_name in C.state_question_map.get(question_num, dict()):
39
+ question_uuid = C.state_question_map[question_num][state_name]
40
+
41
+ else:
42
+ question_uuid = str(uuid4())
43
+ if question_num not in C.state_question_map:
44
+ C.state_question_map[question_num] = dict()
45
+ C.state_question_map[question_num][state_name] = question_uuid
46
+
47
+ self.prompt_arguments['user_story'] = '\n'.join([
48
+ C.states[i + 1][state]
49
+ for i, state in enumerate(state['previous_state_names'][:-1])
50
+ ])
51
+ self.prompt_arguments['scenario'] = C.states[question_num][state_name]
52
+ self.prompt_arguments['builder_instruction'] = state['builder_instruction']
53
+ if question_num in C.states:
54
+ self.prompt_arguments['existing_scenarios'] = '\n'.join([
55
+ '{\n' + f'scenario_name: {state_name}\n' + \
56
+ f'scenario: {scenario}' + '}\n'
57
+ for state_name, scenario in C.states.get(question_num + 1, dict()).items()
58
+ ])
59
+ else:
60
+ self.prompt_arguments['existing_scenarios'] = ''
61
+
62
+ system_message = system_prompt_template.format(**self.prompt_arguments)
63
+ user_message = user_prompt_template.format(**self.prompt_arguments)
64
+
65
+ messages = self.format_openai_message(system_message, user_message)
66
+
67
+ response = self.execute(self.model, messages, self.temperature)
68
+
69
+ # remove json tag from response
70
+ response = response.removeprefix('```json').removeprefix('```')
71
+ response = response.removesuffix('```json').removesuffix('```')
72
+
73
+ question_info = json.loads(response)
74
+
75
+ if isinstance(question_info, list):
76
+ question_info = question_info[0]
77
+
78
+ question = Question(**question_info)
79
+ for foll_state_name in question.options.values():
80
+ if question_num + 1 not in C.states:
81
+ C.states[question_num + 1] = dict()
82
+ C.states[question_num + 1][foll_state_name] = 'Unexplored!'
83
+
84
+ C.questions[question_uuid] = question
85
+
86
+ foll_questions = [question for question in state['previous_questions']]
87
+ foll_questions.append(question_uuid)
88
+
89
+ return {
90
+ 'question_num': state['question_num'],
91
+ 'state_name': state['state_name'],
92
+ 'user_choice': state['user_choice'],
93
+ 'previous_state_names': state['previous_state_names'],
94
+ 'previous_questions': foll_questions,
95
+ 'previous_user_choices': state['previous_user_choices'],
96
+ 'builder_instruction': state['builder_instruction'],
97
+ 'foll_question_uuid': question_uuid
98
+ }
src/agents/orchestrator_agent.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import src.agents.coordinator as C
3
+
4
+ from .base_agent import BaseAgent
5
+ from .category_builder_agent import CategoryBuilderAgent
6
+ from ..dataclasses.process_state import ProcessState
7
+ from ..prompts.orchestrator_agent import (
8
+ start_system_prompt_template,
9
+ start_user_prompt_template,
10
+ coordinate_system_prompt_template,
11
+ coordinate_user_prompt_template,
12
+ get_phase_message
13
+ )
14
+
15
+ class OrchestratorAgent(BaseAgent):
16
+
17
+ def __init__(
18
+ self,
19
+ story_context: str,
20
+ category_context: str,
21
+ num_questions: int
22
+ ):
23
+ # build prompt argument
24
+ self.prompt_arguments = {
25
+ 'story_context': story_context,
26
+ 'category_context': category_context,
27
+ 'num_questions': num_questions
28
+ }
29
+
30
+ # openai calls
31
+ self.model = 'openai'
32
+ self.temperature = 0.2
33
+
34
+ def run_agent(self, state: ProcessState) -> ProcessState:
35
+
36
+ question_num = state['question_num']
37
+ state_name = state['state_name']
38
+
39
+ # check which system and user prompts to build
40
+ if question_num == 1:
41
+ system_message = start_system_prompt_template.format(**self.prompt_arguments)
42
+ user_message = start_user_prompt_template.format(**self.prompt_arguments)
43
+ else:
44
+ progress = question_num / self.prompt_arguments['num_questions']
45
+ self.prompt_arguments['phase_message'] = get_phase_message(progress)
46
+ self.prompt_arguments['user_story'] = '\n'.join([
47
+ C.states[i + 1][state]
48
+ for i, state in enumerate(state['previous_state_names'])
49
+ ])
50
+ system_message = coordinate_system_prompt_template.format(**self.prompt_arguments)
51
+ user_message = coordinate_user_prompt_template.format(**self.prompt_arguments)
52
+
53
+ # check if the state has been generated before
54
+ if state_name in C.states.get(question_num, dict()) and C.states[question_num][state_name] != 'Unexplored!':
55
+ foll_state_names = [name for name in state['previous_state_names']]
56
+ foll_state_names.append(state['state_name'])
57
+
58
+ builder_instruction = ''
59
+
60
+ foll_question_uuid = C.state_question_map[question_num][state_name]
61
+ previous_questions = [question_uuid for question_uuid in state['previous_questions']]
62
+ if state['foll_question_uuid'] or question_num == 1:
63
+ if state['foll_question_uuid']:
64
+ previous_questions.append(state['foll_question_uuid'])
65
+ else:
66
+ previous_questions.append(C.state_question_map[1]['start'])
67
+
68
+ else:
69
+ foll_state_names = state['previous_state_names']
70
+
71
+ messages = self.format_openai_message(system_message, user_message)
72
+ builder_instruction = self.execute(self.model, messages, self.temperature)
73
+ foll_question_uuid = state['foll_question_uuid']
74
+ previous_questions = [question_uuid for question_uuid in state['previous_questions']]
75
+
76
+ return {
77
+ 'question_num': state['question_num'],
78
+ 'state_name': state['state_name'],
79
+ 'user_choice': state['user_choice'],
80
+ 'previous_state_names': foll_state_names,
81
+ 'previous_questions': previous_questions,
82
+ 'previous_user_choices': state['previous_user_choices'],
83
+ 'builder_instruction': builder_instruction,
84
+ 'foll_question_uuid': foll_question_uuid
85
+ }
src/agents/storybuilder_agent.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import src.agents.coordinator as C
3
+
4
+ from .base_agent import BaseAgent
5
+ from ..dataclasses.process_state import ProcessState
6
+ from ..prompts.storybuilder_agent import system_prompt_template, user_prompt_template
7
+
8
+ class StorybuilderAgent(BaseAgent):
9
+
10
+ def __init__(
11
+ self,
12
+ story_context: str,
13
+ category_context: str,
14
+ num_questions: int,
15
+ num_options: int
16
+ ):
17
+
18
+ self.prompt_arguments = {
19
+ 'story_context': story_context,
20
+ 'category_context': category_context,
21
+ 'num_questions': num_questions,
22
+ 'num_options': num_options
23
+ }
24
+
25
+ self.model = 'openai'
26
+ self.temperature = 0.5
27
+
28
+ def run_agent(self, state: ProcessState) -> ProcessState:
29
+
30
+ question_num = state['question_num']
31
+ state_name = state['state_name']
32
+
33
+ self.prompt_arguments['builder_instruction'] = state['builder_instruction']
34
+ self.prompt_arguments['user_story'] = '\n'.join([
35
+ C.states[i + 1][state]
36
+ for i, state in enumerate(state['previous_state_names'])
37
+ ])
38
+
39
+ if question_num == 1:
40
+ self.prompt_arguments['question'] = ''
41
+ self.prompt_arguments['response'] = ''
42
+ else:
43
+ self.prompt_arguments['question'] = C.questions[state['previous_questions'][-1]].question
44
+ self.prompt_arguments['response'] = state['user_choice']
45
+
46
+ system_message = system_prompt_template.format(**self.prompt_arguments)
47
+ user_message = user_prompt_template.format(**self.prompt_arguments)
48
+
49
+ messages = self.format_openai_message(system_message, user_message)
50
+
51
+ scenario = self.execute(self.model, messages, self.temperature)
52
+
53
+ if question_num not in C.states:
54
+ C.states[question_num] = dict()
55
+ C.states[question_num][state_name] = scenario
56
+
57
+ foll_state_names = [name for name in state['previous_state_names']]
58
+ foll_state_names.append(state_name)
59
+
60
+ return {
61
+ 'question_num': state['question_num'],
62
+ 'state_name': state['state_name'],
63
+ 'user_choice': state['user_choice'],
64
+ 'previous_state_names': foll_state_names,
65
+ 'previous_questions': state['previous_questions'],
66
+ 'previous_user_choices': state['previous_user_choices'],
67
+ 'builder_instruction': state['builder_instruction'],
68
+ 'foll_question_uuid': state['foll_question_uuid']
69
+ }
src/dataclasses/category.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from dataclasses import asdict, dataclass
3
+ from typing import List
4
+
5
+
6
+ @dataclass
7
+ class Category:
8
+ name: str
9
+ description: str
10
+ traits: List[str]
11
+
12
+ def dict(self):
13
+ return {k: str(v) for k, v in asdict(self).items()}
src/dataclasses/process_state.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, TypedDict
3
+
4
+
5
+ class ProcessState(TypedDict):
6
+ question_num: int
7
+
8
+ state_name: str
9
+ user_choice: str
10
+
11
+ previous_state_names: List[str]
12
+ previous_questions: List[str]
13
+ previous_user_choices: List[str]
14
+
15
+ builder_instruction: str
16
+
17
+ foll_question_uuid: str
src/dataclasses/question.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from dataclasses import asdict, dataclass
3
+ from typing import Dict
4
+
5
+
6
+ @dataclass
7
+ class Question:
8
+ question: str
9
+ options: Dict[str, str]
10
+
11
+ def dict(self):
12
+ return {k: str(v) for k, v in asdict(self).items()}
src/frontend/chatbot.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+
4
+ from langgraph.graph.graph import CompiledGraph
5
+ from typing import Dict, List, Union
6
+
7
+ import src.agents.coordinator as C
8
+
9
+ from src.dataclasses.process_state import ProcessState
10
+
11
+ graph: CompiledGraph = None
12
+ MAX_OPTIONS = 16
13
+ state: ProcessState = {
14
+ 'question_num': 0,
15
+ 'state_name': 'start',
16
+ 'user_choice': '',
17
+ 'previous_state_names': [],
18
+ 'previous_questions': [],
19
+ 'previous_user_choices': [],
20
+ 'builder_instruction': '',
21
+ 'foll_question_uuid': ''
22
+ }
23
+ state_history: List[ProcessState] = []
24
+ num_questions_: int = None
25
+
26
+ chatbot = gr.Chatbot(
27
+ type='messages',
28
+ key='chatbot',
29
+ preserved_by_key='key',
30
+ visible=False
31
+ )
32
+ option_buttons = [
33
+ gr.Button(
34
+ value='',
35
+ visible=False,
36
+ render=False,
37
+ key=f'option_button_{i}',
38
+ preserved_by_key='key'
39
+ )
40
+ for i in range(1, MAX_OPTIONS + 1)
41
+ ]
42
+ restart_button = gr.Button(
43
+ 'Explore other storylines!',
44
+ visible=False,
45
+ interactive=True,
46
+ key='restart_button',
47
+ preserved_by_key='key'
48
+ )
49
+ button_row = gr.Row(
50
+ key='button_row',
51
+ preserved_by_key='key'
52
+ )
53
+
54
+ def init_graph(
55
+ story_context: str,
56
+ categories_context: str,
57
+ num_questions: int,
58
+ num_options: int,
59
+ num_categories: int
60
+ ) -> None:
61
+ global graph, option_buttons, num_questions_
62
+
63
+ num_questions_ = num_questions
64
+ graph = C.init_graph(
65
+ story_context,
66
+ categories_context,
67
+ num_questions,
68
+ num_options,
69
+ num_categories
70
+ )
71
+
72
+ for i in range(num_options, MAX_OPTIONS):
73
+ option_buttons[i].unrender()
74
+
75
+ option_buttons = option_buttons[:num_options]
76
+
77
+
78
+ def on_user_response(
79
+ user_choice: str,
80
+ history: List[Dict[str, str]]
81
+ ) -> Dict[Union[gr.Button, gr.Chatbot], Union[List[Dict[str, str]], gr.update]]:
82
+ state['user_choice'] = user_choice.replace(' ', '_')
83
+
84
+ user_message = [{'role': 'user', 'content': user_choice}]
85
+ updated_history = history + user_message
86
+
87
+ return {chatbot: updated_history} | {
88
+ button: gr.update(visible=False)
89
+ for button in option_buttons
90
+ }
91
+
92
+
93
+ def control_screen_widgets() -> List[Union[gr.Chatbot, gr.Row, gr.Button]]:
94
+ return [chatbot, button_row, restart_button]
95
+
96
+
97
+ def control_screen(
98
+ is_visible: bool
99
+ ) -> Dict[Union[gr.Chatbot, gr.Row, gr.Button], Union[gr.update, gr.Row]]:
100
+ return {
101
+ chatbot: gr.update(visible=is_visible),
102
+ button_row: gr.Row(visible=is_visible),
103
+ restart_button: gr.update(visible=is_visible)
104
+ }
105
+
106
+
107
+ def on_chatbot_response(
108
+ history: List[Dict[str, str]]
109
+ ) -> Dict[Union[gr.Chatbot, gr.Button], Union[List[Dict[str, str]], gr.update]]:
110
+ global state
111
+
112
+ state_history.append({key: val for key, val in state.items()})
113
+
114
+ question_num = state['question_num']
115
+ if question_num < num_questions_:
116
+ state = graph.invoke(state)
117
+
118
+ question_num = state['question_num']
119
+ state_name = state['state_name']
120
+ scenario = C.states[question_num][state_name]
121
+ question = C.questions[state['foll_question_uuid']]
122
+ question_str = question.question.replace('_', ' ').strip('\"').strip('\'')
123
+
124
+ text_to_user = scenario + '\n' + question_str
125
+ bot_message = [{'role': 'assistant', 'content': text_to_user}]
126
+ updated_history = history + bot_message
127
+
128
+ options = [
129
+ option.replace('_', ' ').strip('\"').strip('\'')
130
+ for option in question.options.keys()
131
+ ]
132
+
133
+ button_updates = {
134
+ button: gr.update(value=option, visible=True)
135
+ for option, button in zip(options, option_buttons)
136
+ }
137
+ button_updates[restart_button] = gr.update(visible=False)
138
+ else:
139
+ selected_category, reason = C.evaluation(state)
140
+ description, traits = None, None
141
+ for category in C.categories:
142
+ if category.name != selected_category:
143
+ continue
144
+ description = category.description
145
+ traits = category.traits
146
+ C.categories_seen[selected_category] = True
147
+ traits_string = traits
148
+ bot_messages = [
149
+ {'role': 'assistant', 'content': f'We think that you are closest to **{selected_category}**!'},
150
+ {'role': 'assistant', 'content': f'## {selected_category}\n\n{description}\n\n{traits_string}'},
151
+ {'role': 'assistant', 'content': reason}
152
+ ]
153
+ updated_history = history + bot_messages
154
+ button_updates = {
155
+ button: gr.update(visible=False)
156
+ for button in option_buttons
157
+ }
158
+ button_updates[restart_button] = gr.update(visible=True)
159
+
160
+ C.save_coordinator()
161
+
162
+ return {chatbot: gr.update(value=updated_history, visible=True)} | button_updates
163
+
164
+
165
+ def on_restart_button_click()-> Dict[
166
+ Union[gr.Chatbot, gr.Button], Union[List[Dict[str, str]], gr.update]
167
+ ]:
168
+ global state
169
+
170
+ state = {
171
+ 'question_num': 0,
172
+ 'state_name': 'start',
173
+ 'user_choice': '',
174
+ 'previous_state_names': [],
175
+ 'previous_questions': [],
176
+ 'previous_user_choices': [],
177
+ 'builder_instruction': '',
178
+ 'foll_question_uuid': ''
179
+ }
180
+
181
+ return {restart_button: gr.update(visible=False)} | \
182
+ on_chatbot_response([])
183
+
184
+
185
+ def render():
186
+ chatbot.render()
187
+
188
+ button_row.render()
189
+ with button_row:
190
+ for button in option_buttons:
191
+ button.unrender()
192
+ button.render()
193
+ button.click(
194
+ fn=on_user_response,
195
+ inputs=[button, chatbot],
196
+ outputs=option_buttons + [chatbot]
197
+ ).then(
198
+ fn=on_chatbot_response,
199
+ inputs=[chatbot],
200
+ outputs=option_buttons + [chatbot, restart_button]
201
+ )
202
+ restart_button.render()
203
+ restart_button.click(
204
+ fn=on_restart_button_click,
205
+ inputs=None,
206
+ outputs=option_buttons + [chatbot, restart_button]
207
+ )
src/frontend/continue_story_screen.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import os
4
+
5
+ from typing import Dict, List, Tuple, Union
6
+
7
+ import src.agents.coordinator as C
8
+
9
+ from src.frontend.chatbot import (
10
+ chatbot,
11
+ init_graph,
12
+ on_chatbot_response,
13
+ option_buttons,
14
+ restart_button
15
+ )
16
+ from src.frontend import sidebar
17
+ from src.utils.utils import load_information, transform_story_name
18
+
19
+ information_text = gr.Text(
20
+ value='Select which story to continue below!',
21
+ interactive=False,
22
+ key='continue_story_user_information',
23
+ preserved_by_key='key',
24
+ visible=False,
25
+ label=None
26
+ )
27
+ session_dir = os.path.join(
28
+ os.path.dirname(__file__),
29
+ '..', '..', 'sessions'
30
+ )
31
+ story_widgets: List[Tuple[gr.Row, gr.Button, gr.Text]] = []
32
+ story_dir_mapper = dict()
33
+ for i, dirname in enumerate(sorted(os.listdir(session_dir))):
34
+ if dirname.startswith('.'):
35
+ continue
36
+ dirpath = os.path.join(session_dir, dirname)
37
+ if not os.listdir(dirpath):
38
+ continue
39
+ story_information_filepath = os.path.join(dirpath, 'story.json')
40
+ story_information_dict = load_information(story_information_filepath)
41
+ story_name = story_information_dict['story_name']
42
+ story_context = story_information_dict['story_context']
43
+ row = gr.Row(
44
+ visible=False,
45
+ key=f'continue_row_{i}',
46
+ preserved_by_key='key'
47
+ )
48
+ button = gr.Button(
49
+ value=story_name,
50
+ visible=False,
51
+ key=f'continue_button_{i}',
52
+ preserved_by_key='key',
53
+ scale=1
54
+ )
55
+ text = gr.Text(
56
+ value=story_context,
57
+ interactive=False,
58
+ visible=False,
59
+ key=f'continue_story_{i}',
60
+ preserved_by_key='key',
61
+ scale=3
62
+ )
63
+ story_widgets.append((row, button, text))
64
+ story_dir_mapper[story_name] = dirpath
65
+
66
+
67
+ def get_widgets() -> List[Union[gr.Text, gr.Button, gr.Row]]:
68
+ widgets = [information_text]
69
+ for row, button, text in story_widgets:
70
+ widgets.append(row)
71
+ widgets.append(button)
72
+ widgets.append(text)
73
+ return widgets
74
+
75
+
76
+ def get_widgets_updates(
77
+ is_visible: bool
78
+ ) -> Dict[Union[gr.Slider, gr.Text, gr.Button], gr.update]:
79
+ updates = dict()
80
+ for widget in get_widgets():
81
+ if isinstance(widget, gr.Row):
82
+ updates[widget] = gr.Row(visible=is_visible)
83
+ else:
84
+ updates[widget] = gr.update(visible=is_visible)
85
+ return updates
86
+
87
+
88
+ def on_button_click(
89
+ story_name: str
90
+ ) -> Dict[Union[gr.Text, gr.Slider, gr.Button, gr.Chatbot], gr.update]:
91
+
92
+ dirpath = story_dir_mapper[story_name]
93
+ story_name_ = transform_story_name(story_name)
94
+
95
+ story_information_filepath = os.path.join(dirpath, 'story.json')
96
+ story_information_dict = load_information(story_information_filepath)
97
+ story_context = story_information_dict['story_context']
98
+ categories_context = story_information_dict['categories_context']
99
+ num_questions = story_information_dict['num_questions']
100
+ num_options = story_information_dict['num_options']
101
+ num_categories = story_information_dict['num_categories']
102
+
103
+ init_graph(
104
+ story_context,
105
+ categories_context,
106
+ num_questions,
107
+ num_options,
108
+ num_categories
109
+ )
110
+
111
+ C.story_name = story_name_
112
+ C.load_coordinator()
113
+
114
+ chatbot_updates = on_chatbot_response([])
115
+
116
+ return chatbot_updates | get_widgets_updates(False) | sidebar.view_screen()
117
+
118
+
119
+ def render():
120
+ information_text.render()
121
+
122
+ for row, button, text in story_widgets:
123
+ row.render()
124
+ with row:
125
+ button.render()
126
+ text.render()
127
+ button.click(
128
+ fn=on_button_click,
129
+ inputs=[button],
130
+ outputs=get_widgets() + \
131
+ [chatbot, restart_button] + \
132
+ option_buttons + \
133
+ sidebar.get_widgets()
134
+ )
src/frontend/progress_screen.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import plotly.graph_objects as go
5
+
6
+ from igraph import Graph, EdgeSeq
7
+ from typing import Dict, List, Union
8
+
9
+ import src.agents.coordinator as C
10
+
11
+ categories_row = gr.Row(
12
+ visible=False,
13
+ key='categories_row',
14
+ preserved_by_key='key'
15
+ )
16
+ graph_row = gr.Row(
17
+ visible=False,
18
+ key='graph_row',
19
+ preserved_by_key='key'
20
+ )
21
+
22
+ category_buttons = [
23
+ gr.Button(
24
+ value='',
25
+ visible=False,
26
+ key=f'category_button_{i + 1}',
27
+ preserved_by_key='key'
28
+ )
29
+ for i in range(16)
30
+ ]
31
+ category_text = gr.Text(
32
+ value='Select a category!',
33
+ visible=False,
34
+ interactive=False,
35
+ key='category_text',
36
+ preserved_by_key='key'
37
+ )
38
+ graph_plot = gr.Plot(
39
+ visible=False,
40
+ key='graph_plot',
41
+ preserved_by_key='key'
42
+ )
43
+
44
+
45
+ def control_screen(is_visible: bool) -> Dict[gr.Row, gr.Row]:
46
+ return {
47
+ categories_row: gr.Row(visible=is_visible),
48
+ graph_row: gr.Row(visible=is_visible)
49
+ }
50
+
51
+
52
+ def update_categories() -> Dict[Union[gr.Button, gr.Text], gr.update]:
53
+ button_updates = {
54
+ button: gr.update(
55
+ value=category.name if C.categories_seen[category.name] \
56
+ else '???',
57
+ visible=True,
58
+ interactive=C.categories_seen[category.name],
59
+ variant='primary' if C.categories_seen[category.name] \
60
+ else 'secondary'
61
+ )
62
+ for button, category in zip(category_buttons, C.categories)
63
+ }
64
+
65
+ return {
66
+ category_text: gr.update(visible=True, value='Select a category!')
67
+ } | button_updates
68
+
69
+
70
+ def on_category_click(category_name: str) -> str:
71
+ for category in C.categories:
72
+ if category.name != category_name:
73
+ continue
74
+ return category.name + '\n\n' + \
75
+ category.description + '\n\n' + \
76
+ category.traits
77
+
78
+
79
+ def update_graph():
80
+ n_vertices = sum([
81
+ len(states_)
82
+ for level, states_ in C.states.items()
83
+ if level <= C.num_questions_
84
+ ])
85
+
86
+ graph = Graph(directed=True)
87
+ nodes = []
88
+ node_attributes = {'description': []}
89
+ for level in range(1, C.num_questions_ + 1):
90
+ for name, description in C.states[level].items():
91
+ node_name = f'{name} (Stage {level})'
92
+ nodes.append(node_name)
93
+ node_attributes['description'].append(description)
94
+
95
+ edges = []
96
+ edge_attributes = {
97
+ 'question': [],
98
+ 'option': []
99
+ }
100
+ for level in range(1, C.num_questions_):
101
+ for state, question_uuid in C.state_question_map[level].items():
102
+ question = C.questions[question_uuid]
103
+ question_str = question.question
104
+ options = question.options
105
+ for option, foll_state in options.items():
106
+ edge = (
107
+ f'{state} (Stage {level})',
108
+ f'{foll_state} (Stage {level + 1})'
109
+ )
110
+ edges.append(edge)
111
+ edge_attributes['question'].append(question_str)
112
+ edge_attributes['option'].append(option)
113
+
114
+ graph.add_vertices(nodes, attributes=node_attributes)
115
+ graph.add_edges(edges, attributes=edge_attributes)
116
+ layout = graph.layout('rt')
117
+
118
+ # adapted from https://plotly.com/python/tree-plots/
119
+ position = {k: layout[k] for k in range(n_vertices)}
120
+ Y = [layout[k][1] for k in range(n_vertices)]
121
+ M = max(Y)
122
+
123
+ E = [e.tuple for e in graph.es] # list of edges
124
+
125
+ L = len(position)
126
+ Xn = [position[k][0] for k in range(L)]
127
+ Yn = [2*M-position[k][1] for k in range(L)]
128
+ Xe = []
129
+ Ye = []
130
+
131
+ # for labelling edges
132
+ X_edge_nodes = []
133
+ Y_edge_nodes = []
134
+ for edge in E:
135
+ Xe+=[position[edge[0]][0],position[edge[1]][0], None]
136
+ Ye+=[2*M-position[edge[0]][1],2*M-position[edge[1]][1], None]
137
+
138
+ X_edge_nodes.append((position[edge[0]][0] + position[edge[1]][0]) / 2)
139
+ Y_edge_nodes.append((2*M-position[edge[0]][1] + 2*M-position[edge[1]][1]) / 2)
140
+
141
+ node_labels = [
142
+ node.replace('_', ' ') + '\n\n' + \
143
+ description
144
+ for node, description in zip(nodes, node_attributes['description'])
145
+ ]
146
+ node_labels = pd.DataFrame(node_labels, columns=['label'])
147
+ node_labels['label'] = node_labels['label'].str.wrap(30)\
148
+ .apply(lambda x: x.replace('\n', '<br>'))
149
+ node_labels = node_labels['label'].to_list()
150
+ edge_labels = [
151
+ question.replace('_', ' ') + '\n\n' + option.replace('_', ' ')
152
+ for question, option in zip(
153
+ edge_attributes['question'], edge_attributes['option']
154
+ )
155
+ ]
156
+ edge_labels = pd.DataFrame(edge_labels, columns=['label'])
157
+ edge_labels['label'] = edge_labels['label'].str.wrap(30)\
158
+ .apply(lambda x: x.replace('\n', '<br>'))
159
+ edge_labels = edge_labels['label'].to_list()
160
+
161
+ fig = go.Figure()
162
+ fig.add_trace(go.Scatter(
163
+ x=Xe, y=Ye,
164
+ mode='lines',
165
+ line=dict(color='rgb(210,210,210)', width=1),
166
+ ))
167
+ fig.add_trace(go.Scatter(
168
+ x=Xn, y=Yn,
169
+ mode='markers',
170
+ marker=dict(
171
+ symbol='circle-dot', size=18, color='#6175c1',
172
+ line=dict(color='rgb(50,50,50)', width=1)
173
+ ),
174
+ text=node_labels,
175
+ hoverinfo='text',
176
+ opacity=0.8
177
+ ))
178
+
179
+ fig.add_trace(go.Scatter(
180
+ x=X_edge_nodes, y=Y_edge_nodes,
181
+ mode='markers',
182
+ marker=dict(
183
+ symbol='circle-dot', size=0, color="#42c744",
184
+ line=dict(color='rgb(50,50,50)', width=0)
185
+ ),
186
+ text=edge_labels,
187
+ hoverinfo='text',
188
+ opacity=0
189
+ ))
190
+
191
+ axis = dict(
192
+ showline=False,
193
+ zeroline=False,
194
+ showgrid=False,
195
+ showticklabels=False,
196
+ )
197
+
198
+ fig.update_layout(
199
+ showlegend=False,
200
+ xaxis=axis,
201
+ yaxis=axis
202
+ )
203
+
204
+ return gr.Plot(fig, visible=True)
205
+
206
+
207
+ def control_screen_widgets() -> List[Union[gr.Row, gr.Text]]:
208
+ return [categories_row, category_text, graph_row, graph_plot] + \
209
+ category_buttons
210
+
211
+
212
+ def control_screen(
213
+ is_visible: bool
214
+ ) -> Dict[Union[gr.Plot, gr.Row, gr.Button], Union[gr.update, gr.Row]]:
215
+ row_updates = {
216
+ categories_row: gr.Row(visible=is_visible),
217
+ graph_row: gr.Row(visible=is_visible)
218
+ }
219
+ graph_update = {
220
+ graph_plot: update_graph() if is_visible else gr.update(visible=False)
221
+ }
222
+ category_button_updates = update_categories()
223
+
224
+ return row_updates | graph_update | category_button_updates
225
+
226
+
227
+ def render():
228
+ categories_row.render()
229
+ with categories_row:
230
+ for button in category_buttons:
231
+ button.render()
232
+ button.click(
233
+ fn=on_category_click,
234
+ inputs=[button],
235
+ outputs=[category_text]
236
+ )
237
+ category_text.render()
238
+ graph_row.render()
239
+ with graph_row:
240
+ graph_plot.render()
src/frontend/sidebar.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+
4
+ from typing import Dict, List, Union
5
+
6
+ from src.frontend import chatbot, progress_screen
7
+
8
+ sidebar = gr.Sidebar(
9
+ open=True,
10
+ visible=False,
11
+ key='sidebar',
12
+ preserved_by_key='key'
13
+ )
14
+ chatbot_button = gr.Button(
15
+ value='Viewing Chatbot!',
16
+ variant='secondary',
17
+ visible=False,
18
+ interactive=False,
19
+ key='chatbot_button',
20
+ preserved_by_key='key'
21
+ )
22
+ progress_button = gr.Button(
23
+ value='View Your Story Progress!',
24
+ variant='primary',
25
+ visible=False,
26
+ interactive=True,
27
+ key='progress_button',
28
+ preserved_by_key='key'
29
+ )
30
+
31
+
32
+ def get_widgets() -> List[Union[gr.Sidebar, gr.Button]]:
33
+ return [sidebar, chatbot_button, progress_button]
34
+
35
+
36
+ def view_screen() -> Dict[Union[gr.Sidebar, gr.Button], gr.update]:
37
+ return {
38
+ widget: gr.update(visible=True)
39
+ for widget in get_widgets()
40
+ }
41
+
42
+
43
+ def on_chatbot_button_click():
44
+ chatbot_screen_updates = chatbot.control_screen(True)
45
+ progress_screen_updates = progress_screen.control_screen(False)
46
+
47
+ chatbot_button_update = gr.update(
48
+ value='Viewing Chatbot!',
49
+ variant='secondary',
50
+ interactive=False
51
+ )
52
+ progress_button_update = gr.update(
53
+ value='View Your Story Progress!',
54
+ variant='primary',
55
+ interactive=True
56
+ )
57
+
58
+ return {
59
+ chatbot_button: chatbot_button_update,
60
+ progress_button: progress_button_update
61
+ } | \
62
+ chatbot_screen_updates | progress_screen_updates
63
+
64
+
65
+ def on_progress_button_click():
66
+ chatbot_screen_updates = chatbot.control_screen(False)
67
+ progress_screen_updates = progress_screen.control_screen(True)
68
+
69
+ chatbot_button_update = gr.update(
70
+ value='Come & Explore!',
71
+ variant='primary',
72
+ interactive=True
73
+ )
74
+ progress_button_update = gr.update(
75
+ value='Viewing Your Story Progress!',
76
+ variant='secondary',
77
+ interactive=False
78
+ )
79
+
80
+ return {
81
+ chatbot_button: chatbot_button_update,
82
+ progress_button: progress_button_update
83
+ } | \
84
+ chatbot_screen_updates | progress_screen_updates
85
+
86
+
87
+ def render():
88
+ sidebar.render()
89
+ with sidebar:
90
+ chatbot_button.render()
91
+ progress_button.render()
92
+
93
+ chatbot_button.click(
94
+ fn=on_chatbot_button_click,
95
+ inputs=[],
96
+ outputs=[chatbot_button, progress_button] + \
97
+ chatbot.control_screen_widgets() + \
98
+ progress_screen.control_screen_widgets()
99
+ )
100
+
101
+ progress_button.click(
102
+ fn=on_progress_button_click,
103
+ inputs=[],
104
+ outputs=[chatbot_button, progress_button] + \
105
+ chatbot.control_screen_widgets() + \
106
+ progress_screen.control_screen_widgets()
107
+ )
108
+
src/frontend/start_screen.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+
4
+ from typing import Dict, List, Union
5
+
6
+ from src.frontend import (
7
+ continue_story_screen,
8
+ story_information_widgets
9
+ )
10
+ from src.utils.init_openai import init_client
11
+
12
+ button_row = gr.Row(
13
+ key='start_screen_button_row',
14
+ preserved_by_key='key'
15
+ )
16
+ new_story_button = gr.Button(
17
+ value='Begin a new adventure!',
18
+ visible=True,
19
+ interactive=True,
20
+ key='new_story_button',
21
+ preserved_by_key='key'
22
+ )
23
+ continue_story_button = gr.Button(
24
+ value='Continue an existing adventure!',
25
+ visible=True,
26
+ interactive=True,
27
+ key='continue_story_button',
28
+ preserved_by_key='key'
29
+ )
30
+ api_key_textbox = gr.Text(
31
+ type='password',
32
+ label='OpenAI API Key',
33
+ placeholder='Enter your OpenAI API key',
34
+ interactive=True,
35
+ visible=True,
36
+ key='api_key_textbox',
37
+ preserved_by_key='key'
38
+ )
39
+
40
+
41
+ def get_widgets() -> List[Union[gr.Text, gr.Button, gr.Row]]:
42
+ return [
43
+ button_row,
44
+ new_story_button,
45
+ continue_story_button,
46
+ api_key_textbox
47
+ ]
48
+
49
+
50
+ def get_wigets_updates(
51
+ is_visible: bool = False
52
+ ) -> Dict[Union[gr.Text, gr.Button, gr.Row], Union[gr.update, gr.Row]]:
53
+ return {
54
+ widget: gr.Row(visible=is_visible) if isinstance(widget, gr.Row) else \
55
+ gr.update(visible=is_visible)
56
+ for widget in get_widgets()
57
+ }
58
+
59
+
60
+ def on_submit_new_story(
61
+ api_key: str
62
+ ) -> Dict[Union[gr.Text, gr.Slider, gr.Button], Union[gr.Row, gr.update]]:
63
+ if init_client(api_key):
64
+ return get_wigets_updates(False) | \
65
+ story_information_widgets.get_widgets_updates(True)
66
+ return get_wigets_updates(True) | \
67
+ story_information_widgets.get_widgets_updates(False)
68
+
69
+
70
+ def on_submit_continue_story(
71
+ api_key: str
72
+ ) -> Dict[Union[gr.Text, gr.Slider, gr.Button], Union[gr.Row, gr.update]]:
73
+ if init_client(api_key):
74
+ return get_wigets_updates(False) | \
75
+ continue_story_screen.get_widgets_updates(True)
76
+ return get_wigets_updates(True) | \
77
+ continue_story_screen.get_widgets_updates(False)
78
+
79
+
80
+ def render():
81
+ api_key_textbox.render()
82
+ button_row.render()
83
+ with button_row:
84
+ new_story_button.render()
85
+ continue_story_button.render()
86
+
87
+ new_story_button.click(
88
+ fn=on_submit_new_story,
89
+ inputs=[api_key_textbox],
90
+ outputs=get_widgets() + \
91
+ story_information_widgets.get_widgets()
92
+ )
93
+ continue_story_button.click(
94
+ fn=on_submit_continue_story,
95
+ inputs=[api_key_textbox],
96
+ outputs=get_widgets() + \
97
+ continue_story_screen.get_widgets()
98
+ )
src/frontend/story_information_widgets.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import json
4
+ import os
5
+
6
+ from typing import Dict, List, Union
7
+
8
+ import src.agents.coordinator as C
9
+
10
+ from src.frontend.chatbot import (
11
+ MAX_OPTIONS,
12
+ chatbot,
13
+ init_graph,
14
+ on_chatbot_response,
15
+ option_buttons,
16
+ restart_button
17
+ )
18
+ from src.frontend import sidebar
19
+ from src.utils.utils import (
20
+ get_session_dir,
21
+ save_information,
22
+ transform_story_name
23
+ )
24
+
25
+ story_name_textbox = gr.Text(
26
+ placeholder='Give your adventure a unique name!',
27
+ label='Story Name',
28
+ interactive=True,
29
+ key='story_name_textbox',
30
+ preserved_by_key='key',
31
+ visible=False
32
+ )
33
+ story_context_textbox = gr.Text(
34
+ placeholder='What kind of story would you like to adventure on?',
35
+ label='Story Context',
36
+ interactive=True,
37
+ key='story_context_textbox',
38
+ preserved_by_key='key',
39
+ visible=False
40
+ )
41
+ categories_context_textbox = gr.Text(
42
+ placeholder='What is the theme of the categories you want to have?',
43
+ label='Categories Context',
44
+ interactive=True,
45
+ key='categories_context_textbox',
46
+ preserved_by_key='key',
47
+ visible=False
48
+ )
49
+ num_questions_slider = gr.Slider(
50
+ minimum=1,
51
+ maximum=10,
52
+ value=5,
53
+ step=1,
54
+ label='How many questions would you want your journey to have?',
55
+ interactive=True,
56
+ visible=False,
57
+ key='num_questions_slider',
58
+ preserved_by_key='key'
59
+ )
60
+ num_options_slider = gr.Slider(
61
+ minimum=2,
62
+ maximum=4,
63
+ value=2,
64
+ step=1,
65
+ label='How many options would you like at each turn?',
66
+ interactive=True,
67
+ visible=False,
68
+ key='num_options_slider',
69
+ preserved_by_key='key'
70
+ )
71
+ num_categories_slider = gr.Slider(
72
+ minimum=2,
73
+ maximum=MAX_OPTIONS,
74
+ value=2,
75
+ step=1,
76
+ label='How many categories / endings would you like to have?',
77
+ interactive=True,
78
+ visible=False,
79
+ key='num_categories_slider',
80
+ preserved_by_key='key'
81
+ )
82
+ story_information_submit_button = gr.Button(
83
+ visible=False,
84
+ interactive=False,
85
+ key='story_information_submit_button',
86
+ preserved_by_key='key'
87
+ )
88
+
89
+ def get_widgets() -> List[Union[gr.Text, gr.Slider, gr.Button]]:
90
+ return [
91
+ story_name_textbox,
92
+ story_context_textbox,
93
+ categories_context_textbox,
94
+ num_questions_slider,
95
+ num_options_slider,
96
+ num_categories_slider,
97
+ story_information_submit_button
98
+ ]
99
+
100
+
101
+ def get_widgets_updates(
102
+ is_visible: bool
103
+ ) -> Dict[Union[gr.Slider, gr.Text, gr.Button], gr.update]:
104
+ return {
105
+ widget: gr.update(visible=is_visible)
106
+ for widget in get_widgets()
107
+ }
108
+
109
+
110
+ def check_story_name(story_name: str) -> bool:
111
+ story_name_dir = transform_story_name(story_name)
112
+ return not story_name_dir in os.listdir(get_session_dir())
113
+
114
+
115
+ def on_text_change(
116
+ story_name: str,
117
+ story_context: str,
118
+ categories_context: str
119
+ ) -> gr.update:
120
+ if all([story_name, story_context, categories_context]) and \
121
+ check_story_name(story_name):
122
+ return gr.update(interactive=True)
123
+ return gr.update(interactive=False)
124
+
125
+
126
+ def save_story_information(
127
+ story_name: str,
128
+ story_context: str,
129
+ categories_context: str,
130
+ num_questions: int,
131
+ num_options: int,
132
+ num_categories: int
133
+ ):
134
+ story_name_dir = transform_story_name(story_name)
135
+ story_dirpath = os.path.join(get_session_dir(), story_name_dir)
136
+ os.mkdir(story_dirpath)
137
+
138
+ story_information_filepath = os.path.join(story_dirpath, 'story.json')
139
+ story_information = {
140
+ 'story_name': story_name,
141
+ 'story_context': story_context,
142
+ 'categories_context': categories_context,
143
+ 'num_questions': num_questions,
144
+ 'num_options': num_options,
145
+ 'num_categories': num_categories
146
+ }
147
+ save_information(story_information, story_information_filepath)
148
+
149
+
150
+ def on_submit(
151
+ story_name: str,
152
+ story_context: str,
153
+ categories_context: str,
154
+ num_questions: int,
155
+ num_options: int,
156
+ num_categories: int
157
+ ) -> Dict[Union[gr.Text, gr.Slider, gr.Button, gr.Chatbot], gr.update]:
158
+
159
+ save_story_information(
160
+ story_name,
161
+ story_context,
162
+ categories_context,
163
+ num_questions,
164
+ num_options,
165
+ num_categories
166
+ )
167
+ C.story_name = transform_story_name(story_name)
168
+
169
+ init_graph(
170
+ story_context,
171
+ categories_context,
172
+ num_questions,
173
+ num_options,
174
+ num_categories
175
+ )
176
+
177
+ chatbot_updates = on_chatbot_response([])
178
+
179
+ return chatbot_updates | get_widgets_updates(False) | sidebar.view_screen()
180
+
181
+
182
+ def render():
183
+
184
+ for widget in get_widgets():
185
+ widget.render()
186
+
187
+ if not isinstance(widget, gr.Text):
188
+ continue
189
+ widget.change(
190
+ fn=on_text_change,
191
+ inputs=[
192
+ story_name_textbox,
193
+ story_context_textbox,
194
+ categories_context_textbox
195
+ ],
196
+ outputs=[story_information_submit_button]
197
+ )
198
+
199
+ story_information_submit_button.click(
200
+ fn=on_submit,
201
+ inputs=[
202
+ widget
203
+ for widget in get_widgets()
204
+ if not isinstance(widget, gr.Button)
205
+ ],
206
+ outputs=get_widgets() + \
207
+ [chatbot, restart_button] + \
208
+ option_buttons + \
209
+ sidebar.get_widgets()
210
+ )
src/prompts/category_builder_agent.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ system_prompt_template = """
3
+ You are a creative writer.
4
+ Your task is to create categories based on a given context.
5
+ Your categories should be different from one another.
6
+ Keep the categories fun, light and interesting. Be as creative as possible.
7
+
8
+ You will be given the following:
9
+ 1. `story context`, which describe the story we are about to take the user on,
10
+ 2. `category context`, which describe what the general theme of the categories should be,
11
+ 3. the `number of categories`
12
+
13
+ You can find the provided information later, where
14
+ 1. `story context` can be found enclosed in <story> tags,
15
+ 2. `category context` can be found enclosed in <category> tags,
16
+ 3. `number of categories` can be found enclosed in <number> tags.
17
+
18
+ Ensure that the categories you generate are related to the story context.
19
+ When you are ready, output **all** the categories in **JSON** format.
20
+ For each category, they should have the following fields:
21
+ 1. `name`, which describes the name of the category,
22
+ 2. `description`: which gives a short description (at most three lines) of the category,
23
+ 3. `traits`: a sequence of traits of the category. Separate the traits with a comma on the same line. Ensure that the traits have some relation to the story context.
24
+
25
+ Ensure that the focus is still on the `category context`.
26
+ A useful guideline is that the `name` of the category directly follows from the `category context`.
27
+ And then the `story context` can be used in the `description` and `traits`.
28
+ You only need to output the JSON information, no elaboration is required.
29
+ """.strip()
30
+
31
+ user_prompt_template = """
32
+ <story>
33
+ {story_context}
34
+ </story>
35
+
36
+ <category>
37
+ {category_context}
38
+ </category>
39
+
40
+ <number>
41
+ {num_categories}
42
+ </number>
43
+ """.strip()
src/prompts/evaluator_agent.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ system_prompt_template = """
3
+ You are an expert profiler, regardless of the theme.
4
+ We have just set our user through a round of questions, and they have provided us their responses.
5
+ Our task now is to allocate them into one category based on the scenario they were given, questions they were asked, and choices they made.
6
+
7
+ You will be given the following:
8
+ 1. `user story`, where for each turn, you are given the SCENARIO, QUESTION and USER CHOICE,
9
+ 2. `categories`, where for each category, you are given the name, description and traits of the category.
10
+
11
+ You can find the provided information later, where
12
+ 1. `user story` can be found enclosed in <user> tags,
13
+ 2. `categories` can be found enclosed in <categories> tags,
14
+
15
+ Read the user's story and determine which category best fits the user based on their choices.
16
+
17
+ When you are ready, output your answer in the below JSON format.
18
+ Follow the format strictly, as any deviation will result in errors later on.
19
+ You only need to output your answer, no elaboration is required.
20
+ `category` below refers to the category name of the selected category. Output the corresponding category name exactly.
21
+ For your reason, replace terms like \'the user\' with \'your\' or \'you\'. We will be passing your reason to the user.
22
+ ```json
23
+ {{
24
+ "category":
25
+ "reason":
26
+ }}
27
+ ```
28
+ """.strip()
29
+
30
+ user_prompt_template = """
31
+ <user>
32
+ {user_story}
33
+ </user>
34
+
35
+ <categories>
36
+ {category_contents}
37
+ </categories>
38
+ """.strip()
src/prompts/narrator_agent.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ system_prompt_template = """
3
+ You are a creative writer for a story, in charge of writing the lines users will view.
4
+ The story you are making is dependent on our user's choices. Do **not** rush to make a complete story.
5
+
6
+ We are building a turn-based call and response story, where we will give users a scenario, and they will respond with what they would choose to do.
7
+ Based on our user's response, and the context they give us, we will craft what the next scenario will be.
8
+
9
+ Your task is to craft a question for the user to answer based on the scenario provided to you, and the potential responses to the question you make.
10
+ Your question should allow the user to respond with actionable options.
11
+ Focus on creating a question that is in line with the given scenario and storyline so far.
12
+ Keep in mind the context of the categories we will assign the user to in the end.
13
+ There should be some harmony and cohesion between the story's context and the categories' context.
14
+
15
+ You will be given the following:
16
+ 1. `story context`, which describes the theme of the story we want to build for our users,
17
+ 2. `categories context`, which describes the theme of the categories we will assign our users to later,
18
+ 3. `user story`, which is the current story shown to the user,
19
+ 4. `scenario`, which is the scenario that your question should be based on,
20
+ 5. `builder instruction`, which describes the suggested pace of the story,
21
+ 6. `number of options`, which is the number of options the user has to respond to the question you pose. You will need to generate the options as well.
22
+ 7. `existing next scenarios`, which is the current list of scenarios that follow this scenario. This will be empty if we are seeing the current scenario for the first time. You will be given each scenario in the below format.
23
+ ```
24
+ {{
25
+ scenario_name:
26
+ scenario:
27
+ }}
28
+ ```
29
+
30
+ You can find the provided information later, where
31
+ 1. `story context` can be found enclosed in <story> tags,
32
+ 2. `categories context` can be found enclosed in <category> tags.
33
+ 3. `user story` can be found enclosed in <user> tags,
34
+ 4. `scenario`, can be found enclosed in <scenario> tags,
35
+ 5. `builder instruction`, can be found enclosed in <builder> tags,
36
+ 6. `number of options`, can be found enclosed in <options> tags,
37
+ 7. `existing next scenarios`, can be found enclosed in <existing> tags
38
+
39
+ First, think about a question to ask the user. This question should not be long winded.
40
+ Instead, aim for short questions that are easy for the users to respond to, but also compelling enough to keep them engaged.
41
+ Next, think about potential responses to the question.
42
+ The number of options you provide **must** be exactly the number of options provided, no more and no less.
43
+ Lastly, think about what suitable next scenarios each of the options can lead to.
44
+ Prioritise using **existing** scenarios, where possible.
45
+ Prioritise using the **same** scenario for multiple options, where possible.
46
+ If you want to use the same or existing scenario, you **must** use the exact same `scenario_name`.
47
+ If you want to create a new scenario, you **must** use a different `scenario_name` than the ones provided.
48
+ Although you do not know the final categories, ensure you provide options that are distinct enough such that the users can definitely be binned into different categories.
49
+
50
+ When you are ready, output your answer in the below JSON format.
51
+ Follow the format strictly, as any deviation will result in errors later on.
52
+ For your question, replace any spaces with the underbar character.
53
+ You only need to output your answer, no elaboration is required.
54
+ If there are angular brackets, it indicates to replace those fields with your answer.
55
+ Where there are ellipsis, it means to continue following the format.
56
+ ```json
57
+ {{
58
+ "question":
59
+ "options": {{
60
+ "<option_1>": "<following scenario name>",
61
+ ...
62
+ }}
63
+ }}
64
+ ```
65
+
66
+ You are reminded to ensure that the number of options is exactly the number of options provided.
67
+ """.strip()
68
+
69
+ user_prompt_template = """
70
+ <story>
71
+ {story_context}
72
+ </story>
73
+
74
+ <category>
75
+ {category_context}
76
+ </category>
77
+
78
+ <user>
79
+ {user_story}
80
+ </user>
81
+
82
+ <scenario>
83
+ {scenario}
84
+ </scenario>
85
+
86
+ <builder>
87
+ {builder_instruction}
88
+ </builder>
89
+
90
+ <options>
91
+ {num_options}
92
+ </options>
93
+
94
+ <existing>
95
+ {existing_scenarios}
96
+ </existing>
97
+ """.strip()
src/prompts/orchestrator_agent.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ start_system_prompt_template = """
3
+ You are an orchestrator of a story board.
4
+ Your task is to provide instructions to your creatives about the flow and urgency of the story.
5
+ This means whether the creatives should start thinking about wrapping up the story, or if they have some time to explore a certain scenario.
6
+
7
+ We are building a turn-based call and response story, where we will give users a scenario, and they will respond with what they would choose to do.
8
+ Based on our user's response, and the context they give us, we will craft what the next scenario will be.
9
+
10
+ Your task is **not** to think about what to craft next.
11
+ Instead, your task is to guide the agent that is doing the creative work.
12
+ Your goal is to guide the story such that it does to feel too rushed or too drawn out, whilst ensuring a compelling and interesting story.
13
+
14
+ In this first step, we have just begun the process.
15
+ You will be given the following:
16
+ 1. `story context`, which describes the theme of the story we want to build for our users,
17
+ 2. `number of questions`, which describes the number of turns we have with our users. Use this as a gauge to determine if the given context is too lengthy or too short.
18
+
19
+ You can find the provided information later, where
20
+ 1. `story context` can be found enclosed in <story> tags,
21
+ 2. `number of questions` can be found enclosed in <number> tags.
22
+
23
+ Think about the context and the number of questions.
24
+ We both know that the creative agent will begin with some kind of introductory question.
25
+ But what we want to focus on is how fast this phase should be.
26
+
27
+ Your output is the instruction to our creative agent on the tempo of the story.
28
+ When the creative agent reads your instruction, they should know how fast to craft the scenario to give the users.
29
+ Keep the user's overall experience the focus.
30
+
31
+ Once you are ready, output your instruction to our creative agent.
32
+ No elaboration is required.
33
+ """.strip()
34
+
35
+ start_user_prompt_template = """
36
+ <story>
37
+ {story_context}
38
+ </story>
39
+
40
+ <number>
41
+ {num_questions}
42
+ </number>
43
+ """.strip()
44
+
45
+ coordinate_system_prompt_template = """
46
+ You are an orchestrator of a story board.
47
+ Your task is to provide instructions to your creatives about the flow and urgency of the story.
48
+ This means whether the creatives should start thinking about wrapping up the story, or if they have some time to explore a certain scenario.
49
+
50
+ We are building a turn-based call and response story, where we will give users a scenario, and they will respond with what they would choose to do.
51
+ Based on our user's response, and the context they give us, we will craft what the next scenario will be.
52
+
53
+ Your task is **not** to think about what to craft next.
54
+ Instead, your task is to guide the agent that is doing the creative work.
55
+ Your goal is to guide the story such that it does to feel too rushed or too drawn out, whilst ensuring a compelling and interesting story.
56
+
57
+ In this phase, we are some way into our story.
58
+ You will be given the following:
59
+ 1. `story context`, which describes the theme of the story we want to build for our users,
60
+ 2. `number of questions`, which describes the number of turns we have with our users. Use this as a gauge to determine if the given context is too lengthy or too short.
61
+ 3. `phase message`: which is a pre-written phrase related to how far into the story the user is in. You do not need to calculate this, we will do it for you,
62
+ 4. `user story`, which is the current story shown to the user.
63
+
64
+ You can find the provided information later, where
65
+ 1. `story context` can be found enclosed in <story> tags,
66
+ 2. `number of questions` can be found enclosed in <number> tags,
67
+ 3. `phase message` can be found enclosed in <phase> tags,
68
+ 4. `user story` can be found enclosed in <user> tags.
69
+
70
+ Think about the context and phase.
71
+ Our creative agent is getting ready to craft the next scenario to give to the user.
72
+ Our focus on is how fast scenario moves the story along.
73
+
74
+ Your output is the instruction to our creative agent on the tempo of the story.
75
+ When the creative agent reads your instruction, they should know how fast to craft the scenario to give the users.
76
+ Keep the user's overall experience the focus.
77
+
78
+ Once you are ready, output your instruction to our creative agent.
79
+ No elaboration is required.
80
+ """.strip()
81
+
82
+ coordinate_user_prompt_template = """
83
+ <story>
84
+ {story_context}
85
+ </story>
86
+
87
+ <number>
88
+ {num_questions}
89
+ </number>
90
+
91
+ <phase>
92
+ {phase_message}
93
+ </phase>
94
+
95
+ <user>
96
+ {user_story}
97
+ </user>
98
+ """.strip()
99
+
100
+ def get_phase_message(progress: float) -> str:
101
+ if progress < 0.2:
102
+ return 'We are still in the early phases of the story!'
103
+ if progress <= 0.5:
104
+ return 'We are approaching the halfway point, it\'s time to build towards a climax!'
105
+ if progress < 0.8:
106
+ return 'We are approaching the climax of the story! Let\'s quicken the pace.'
107
+
108
+ return 'The story is approaching the end! Let\'s wrap this up!'
src/prompts/storybuilder_agent.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ system_prompt_template = """
3
+ You are a creative writer for a story.
4
+ The story you are making is dependent on our user's choices. Do **not** rush to make a complete story.
5
+
6
+ We are building a turn-based call and response story, where we will give users a scenario, and they will respond with what they would choose to do.
7
+ Based on our user's response, and the context they give us, we will craft what the next scenario will be.
8
+
9
+ Your task is to craft a scenario for the user to think about.
10
+ Your task is **not** to craft the question to ask to the user.
11
+ Focus on creating a scenario that flows well with the story so far, and is in line with the story context.
12
+ Keep in mind the context of the categories we will assign the user to in the end.
13
+ There should be some harmony and cohesion between the story's context and the categories' context.
14
+
15
+ You will be given the following:
16
+ 1. `story context`, which describes the theme of the story we want to build for our users,
17
+ 2. `categories context`, which describes the theme of the categories we will assign our users to later,
18
+ 3. `builder's instruction`, a set of instructions with respect to the tempo of the scenario you are going to create.
19
+ 4. `user story`, which is the current story shown to the user,
20
+ 5. `question`, which is the last question asked to the user, if we just started the process, this is empty.
21
+ 6. `response`, which is the response to the question shown to the user. If we just started the process, this is empty.
22
+
23
+ You can find the provided information later, where
24
+ 1. `story context` can be found enclosed in <story> tags,
25
+ 2. `categories context` can be found enclosed in <category> tags.
26
+ 3. `builder's instruction` can be found enclosed in <builder> tags.
27
+ 4. `user story` can be found enclosed in <user> tags,
28
+ 5. `question`, can be found enclosed in <question> tags,
29
+ 6. `response`, can be found enclosed in <response> tags.
30
+
31
+ First, think about a fitting scenario that naturally follows the story the user is shown.
32
+ Next, think about some specifics of the scenario.
33
+ Limit the number of words in your scenario to be at most 50 words.
34
+ Feel free to use named characters and have them recur throughout this journey.
35
+ Keep the scenario light, fun and compelling for the user's to follow.
36
+
37
+ When you are ready, output **only** the scenario.
38
+ You are reminded that the scenario should **not** include the question to be asked to the user.
39
+ You do not need to provide any explanation.
40
+ """.strip()
41
+
42
+ user_prompt_template = """
43
+ <story>
44
+ {story_context}
45
+ </story>
46
+
47
+ <category>
48
+ {category_context}
49
+ </category>
50
+
51
+ <builder>
52
+ {builder_instruction}
53
+ </builder>
54
+
55
+ <user>
56
+ {user_story}
57
+ </user>
58
+
59
+ <question>
60
+ {question}
61
+ </question>
62
+
63
+ <response>
64
+ {response}
65
+ </response>
66
+ """.strip()
src/utils/init_openai.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import random
3
+ import time
4
+
5
+ from openai import OpenAI, RateLimitError
6
+ from openai.types.chat.chat_completion import ChatCompletion
7
+ from typing import Any, Dict
8
+
9
+ client: OpenAI = None
10
+ MODEL_ = 'gpt-4o'
11
+
12
+
13
+ def init_client(api_key: str) -> bool:
14
+ global client
15
+ try:
16
+ client = OpenAI(api_key=api_key)
17
+ except:
18
+ return False
19
+ return True
20
+
21
+
22
+ # from OpenAI's website
23
+ # define a retry decorator
24
+ def __retry_with_exponential_backoff(
25
+ func,
26
+ initial_delay: float = 1,
27
+ exponential_base: float = 2,
28
+ jitter: bool = True,
29
+ max_retries: int = 10,
30
+ errors: tuple = (RateLimitError,),
31
+ ):
32
+ """Retry a function with exponential backoff."""
33
+
34
+ def wrapper(*args, **kwargs):
35
+ # Initialize variables
36
+ num_retries = 0
37
+ delay = initial_delay
38
+
39
+ # Loop until a successful response or max_retries is hit or an exception is raised
40
+ while True:
41
+ try:
42
+ return func(*args, **kwargs)
43
+
44
+ # Retry on specified errors
45
+ except errors as e:
46
+ # Increment retries
47
+ num_retries += 1
48
+
49
+ # Check if max retries has been reached
50
+ if num_retries > max_retries:
51
+ raise Exception(
52
+ f"Maximum number of retries ({max_retries}) exceeded."
53
+ )
54
+
55
+ # Increment the delay
56
+ delay *= exponential_base * (1 + jitter * random.random())
57
+
58
+ # Sleep for the delay
59
+ time.sleep(delay)
60
+
61
+ # Raise exceptions for any errors not specified
62
+ except Exception as e:
63
+ raise e
64
+
65
+ return wrapper
66
+
67
+
68
+ @__retry_with_exponential_backoff
69
+ def completions_with_backoff(
70
+ messages: Dict[str, Any],
71
+ temperature: float = 0,
72
+ max_tokens: int = 8192
73
+ ) -> ChatCompletion:
74
+ kwargs = {
75
+ 'model': MODEL_,
76
+ 'messages': messages,
77
+ 'temperature': temperature,
78
+ 'max_tokens': max_tokens
79
+ }
80
+ return client.chat.completions.create(**kwargs)
src/utils/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import os
4
+ import uuid
5
+
6
+ from typing import Any
7
+
8
+
9
+ def get_session_dir() -> str:
10
+ return os.path.join(
11
+ os.path.dirname(__file__),
12
+ '..', '..', 'sessions'
13
+ )
14
+
15
+
16
+ def save_information(
17
+ information: Any,
18
+ filepath: str
19
+ ) -> None:
20
+
21
+ with open(filepath, 'w', encoding='utf-8') as f:
22
+ json.dump(information, f, ensure_ascii=False, indent=4)
23
+
24
+
25
+ def load_information(filepath: str) -> Any:
26
+ with open(filepath, 'r', encoding='utf-8') as f:
27
+ content = json.load(f)
28
+ return content
29
+
30
+
31
+ def transform_story_name(
32
+ story_name: str,
33
+ is_inverse: bool = False
34
+ ) -> str:
35
+ if is_inverse:
36
+ return ' '.join([word.capitalize() for word in story_name.split(' ')])
37
+ else:
38
+ return story_name.lower().replace(' ', '_')