Spaces:
Running
Running
Commit
·
b151e60
1
Parent(s):
251461c
Init files
Browse files- .gitignore +197 -0
- app.py +23 -0
- requirements.txt +5 -0
- src/agents/base_agent.py +40 -0
- src/agents/category_builder_agent.py +46 -0
- src/agents/coordinator.py +190 -0
- src/agents/evaluator_agent.py +127 -0
- src/agents/narrator_agent.py +98 -0
- src/agents/orchestrator_agent.py +85 -0
- src/agents/storybuilder_agent.py +69 -0
- src/dataclasses/category.py +13 -0
- src/dataclasses/process_state.py +17 -0
- src/dataclasses/question.py +12 -0
- src/frontend/chatbot.py +207 -0
- src/frontend/continue_story_screen.py +134 -0
- src/frontend/progress_screen.py +240 -0
- src/frontend/sidebar.py +108 -0
- src/frontend/start_screen.py +98 -0
- src/frontend/story_information_widgets.py +210 -0
- src/prompts/category_builder_agent.py +43 -0
- src/prompts/evaluator_agent.py +38 -0
- src/prompts/narrator_agent.py +97 -0
- src/prompts/orchestrator_agent.py +108 -0
- src/prompts/storybuilder_agent.py +66 -0
- src/utils/init_openai.py +80 -0
- src/utils/utils.py +38 -0
.gitignore
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
169 |
+
|
170 |
+
# Abstra
|
171 |
+
# Abstra is an AI-powered process automation framework.
|
172 |
+
# Ignore directories containing user credentials, local state, and settings.
|
173 |
+
# Learn more at https://abstra.io/docs
|
174 |
+
.abstra/
|
175 |
+
|
176 |
+
# Visual Studio Code
|
177 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
178 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
179 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
180 |
+
# you could uncomment the following to ignore the enitre vscode folder
|
181 |
+
# .vscode/
|
182 |
+
|
183 |
+
# Ruff stuff:
|
184 |
+
.ruff_cache/
|
185 |
+
|
186 |
+
# PyPI configuration file
|
187 |
+
.pypirc
|
188 |
+
|
189 |
+
# Cursor
|
190 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
191 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
192 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
193 |
+
.cursorignore
|
194 |
+
.cursorindexingignore
|
195 |
+
|
196 |
+
# ignore session dirs
|
197 |
+
sessions/*
|
app.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from src.frontend import (
|
5 |
+
chatbot,
|
6 |
+
continue_story_screen,
|
7 |
+
progress_screen,
|
8 |
+
sidebar,
|
9 |
+
start_screen,
|
10 |
+
story_information_widgets
|
11 |
+
)
|
12 |
+
|
13 |
+
demo = gr.Blocks()
|
14 |
+
|
15 |
+
with demo:
|
16 |
+
start_screen.render()
|
17 |
+
story_information_widgets.render()
|
18 |
+
continue_story_screen.render()
|
19 |
+
sidebar.render()
|
20 |
+
chatbot.render()
|
21 |
+
progress_screen.render()
|
22 |
+
|
23 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai
|
2 |
+
langgraph
|
3 |
+
gradio
|
4 |
+
igraph
|
5 |
+
plotly
|
src/agents/base_agent.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from abc import ABC
|
3 |
+
from typing import Dict, List, Literal
|
4 |
+
|
5 |
+
from ..utils.init_openai import completions_with_backoff
|
6 |
+
|
7 |
+
|
8 |
+
class BaseAgent(ABC):
|
9 |
+
|
10 |
+
def format_openai_message(
|
11 |
+
self,
|
12 |
+
system_message: str,
|
13 |
+
user_message: str
|
14 |
+
) -> List[Dict[str, str]]:
|
15 |
+
return [
|
16 |
+
{
|
17 |
+
'role': 'system',
|
18 |
+
'content': system_message
|
19 |
+
},
|
20 |
+
{
|
21 |
+
'role': 'user',
|
22 |
+
'content': user_message
|
23 |
+
}
|
24 |
+
]
|
25 |
+
|
26 |
+
def execute(
|
27 |
+
self,
|
28 |
+
model: Literal['openai'],
|
29 |
+
messages: List[Dict[str, str]],
|
30 |
+
temperature: float = 0,
|
31 |
+
max_tokens: int = 8192
|
32 |
+
) -> str:
|
33 |
+
if model == 'openai':
|
34 |
+
|
35 |
+
response = completions_with_backoff(messages, temperature, max_tokens)
|
36 |
+
content = response.choices[0].message.content
|
37 |
+
|
38 |
+
return content
|
39 |
+
else:
|
40 |
+
raise ValueError(f'Model type {model} not supported (yet)')
|
src/agents/category_builder_agent.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import json
|
3 |
+
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from .base_agent import BaseAgent
|
7 |
+
from ..dataclasses.category import Category
|
8 |
+
from ..prompts.category_builder_agent import system_prompt_template, user_prompt_template
|
9 |
+
|
10 |
+
|
11 |
+
class CategoryBuilderAgent(BaseAgent):
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
self.model = 'openai'
|
15 |
+
self.temperature = 0.5
|
16 |
+
|
17 |
+
def get_categories(
|
18 |
+
self,
|
19 |
+
story_context: str,
|
20 |
+
category_context: str,
|
21 |
+
num_categories: int
|
22 |
+
) -> List[Category]:
|
23 |
+
prompt_arguments = {
|
24 |
+
'story_context': story_context,
|
25 |
+
'category_context': category_context,
|
26 |
+
'num_categories': num_categories
|
27 |
+
}
|
28 |
+
|
29 |
+
system_message = system_prompt_template.format(**prompt_arguments)
|
30 |
+
user_message = user_prompt_template.format(**prompt_arguments)
|
31 |
+
|
32 |
+
messages = self.format_openai_message(system_message, user_message)
|
33 |
+
|
34 |
+
response = self.execute(self.model, messages, self.temperature)
|
35 |
+
|
36 |
+
# remove json tag from response
|
37 |
+
response = response.removeprefix('```json').removeprefix('```')
|
38 |
+
response = response.removesuffix('```json').removesuffix('```')
|
39 |
+
|
40 |
+
categories = json.loads(response)
|
41 |
+
assert type(categories) == list
|
42 |
+
|
43 |
+
return [
|
44 |
+
Category(**category)
|
45 |
+
for category in categories
|
46 |
+
]
|
src/agents/coordinator.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
from langgraph.graph import StateGraph, START, END
|
6 |
+
from langgraph.graph.graph import CompiledGraph
|
7 |
+
from typing import Dict, List, Literal, Tuple
|
8 |
+
|
9 |
+
from src.agents.category_builder_agent import CategoryBuilderAgent
|
10 |
+
from src.agents.evaluator_agent import EvaluatorAgent
|
11 |
+
from src.agents.narrator_agent import NarratorAgent
|
12 |
+
from src.agents.orchestrator_agent import OrchestratorAgent
|
13 |
+
from src.agents.storybuilder_agent import StorybuilderAgent
|
14 |
+
from src.dataclasses.category import Category
|
15 |
+
from src.dataclasses.process_state import ProcessState
|
16 |
+
from src.dataclasses.question import Question
|
17 |
+
from src.utils.utils import get_session_dir, save_information, load_information
|
18 |
+
|
19 |
+
num_questions_: int = None
|
20 |
+
num_options_: int = None
|
21 |
+
story_name: str = None
|
22 |
+
states: Dict[int, Dict[str, str]] = dict() # maps question_num to {str to str}
|
23 |
+
questions: Dict[str, Question] = dict() # maps uuid to Question
|
24 |
+
state_question_map: Dict[int, Dict[str, str]] = dict() # maps question_num to {state_name to question_uuid}
|
25 |
+
categories: List[Category] = [] # list of categories
|
26 |
+
categories_seen: Dict[str, bool] = dict() # maps category name to bool if category name was seen
|
27 |
+
path_evaluations: Dict[str, List[str]] = dict() # maps a path to an evaluation, including the reason
|
28 |
+
|
29 |
+
|
30 |
+
def init_categories(
|
31 |
+
story_context: str,
|
32 |
+
category_context: str,
|
33 |
+
num_categories: int
|
34 |
+
) -> None:
|
35 |
+
global categories, categories_seen
|
36 |
+
|
37 |
+
category_builder_agent = CategoryBuilderAgent()
|
38 |
+
categories = category_builder_agent.get_categories(
|
39 |
+
story_context, category_context, num_categories
|
40 |
+
)
|
41 |
+
categories_seen = {
|
42 |
+
category.name: False
|
43 |
+
for category in categories
|
44 |
+
}
|
45 |
+
|
46 |
+
|
47 |
+
def user_node(state: ProcessState) -> ProcessState:
|
48 |
+
user_choice = state['user_choice']
|
49 |
+
question_uuid = state['foll_question_uuid']
|
50 |
+
question_num = state['question_num'] + 1
|
51 |
+
|
52 |
+
if user_choice:
|
53 |
+
foll_state_name = questions[question_uuid].options[user_choice]
|
54 |
+
foll_user_choices = [choice for choice in state['previous_user_choices']]
|
55 |
+
foll_user_choices.append(user_choice)
|
56 |
+
else:
|
57 |
+
foll_state_name = 'start'
|
58 |
+
foll_user_choices = []
|
59 |
+
|
60 |
+
return {
|
61 |
+
'question_num': question_num,
|
62 |
+
'state_name': foll_state_name,
|
63 |
+
'user_choice': user_choice,
|
64 |
+
'previous_state_names': state['previous_state_names'],
|
65 |
+
'previous_questions': state['previous_questions'],
|
66 |
+
'previous_user_choices': foll_user_choices,
|
67 |
+
'builder_instruction': '',
|
68 |
+
'foll_question_uuid': ''
|
69 |
+
}
|
70 |
+
|
71 |
+
|
72 |
+
def init_graph(
|
73 |
+
story_context: str,
|
74 |
+
category_context: str,
|
75 |
+
num_questions: int,
|
76 |
+
num_options: int,
|
77 |
+
num_categories: int
|
78 |
+
) -> CompiledGraph:
|
79 |
+
|
80 |
+
global num_questions_, num_options_
|
81 |
+
num_questions_ = num_questions
|
82 |
+
num_options_ = num_options
|
83 |
+
|
84 |
+
orchestrator_agent = OrchestratorAgent(
|
85 |
+
story_context, category_context, num_questions
|
86 |
+
)
|
87 |
+
storybuilder_agent = StorybuilderAgent(
|
88 |
+
story_context, category_context, num_questions, num_options
|
89 |
+
)
|
90 |
+
narrator_agent = NarratorAgent(
|
91 |
+
story_context, category_context, num_questions, num_options
|
92 |
+
)
|
93 |
+
|
94 |
+
init_categories(story_context, category_context, num_categories)
|
95 |
+
|
96 |
+
builder = StateGraph(ProcessState)
|
97 |
+
|
98 |
+
builder.add_node('orchestrator', orchestrator_agent.run_agent)
|
99 |
+
builder.add_node('storybuilder', storybuilder_agent.run_agent)
|
100 |
+
builder.add_node('narrator', narrator_agent.run_agent)
|
101 |
+
builder.add_node('user', user_node)
|
102 |
+
|
103 |
+
def conditional_edge(state: ProcessState) -> Literal['storybuilder', END]:
|
104 |
+
question_num = state['question_num']
|
105 |
+
state_name = state['state_name']
|
106 |
+
|
107 |
+
if state_name not in states.get(question_num, dict()) or \
|
108 |
+
states[question_num][state_name] == 'Unexplored!':
|
109 |
+
return 'storybuilder'
|
110 |
+
else:
|
111 |
+
return END
|
112 |
+
|
113 |
+
def last_action_edge(state: ProcessState) -> Literal['orchestrator', END]:
|
114 |
+
question_num = state['question_num']
|
115 |
+
|
116 |
+
if question_num > num_questions:
|
117 |
+
return END
|
118 |
+
else:
|
119 |
+
return 'orchestrator'
|
120 |
+
|
121 |
+
builder.add_edge(START, 'user')
|
122 |
+
builder.add_edge('user', 'orchestrator')
|
123 |
+
builder.add_edge('storybuilder', 'narrator')
|
124 |
+
builder.add_edge('narrator', END)
|
125 |
+
builder.add_conditional_edges('orchestrator', conditional_edge)
|
126 |
+
builder.add_conditional_edges('user', last_action_edge)
|
127 |
+
graph = builder.compile()
|
128 |
+
|
129 |
+
return graph
|
130 |
+
|
131 |
+
def evaluation(state: ProcessState) -> Tuple[str, str]:
|
132 |
+
|
133 |
+
evaluator_agent = EvaluatorAgent()
|
134 |
+
|
135 |
+
previous_state_names = state['previous_state_names']
|
136 |
+
previous_questions = state['previous_questions']
|
137 |
+
previous_user_choices = state['previous_user_choices']
|
138 |
+
|
139 |
+
previous_questions.append(state['foll_question_uuid'])
|
140 |
+
previous_user_choices.append(state['user_choice'])
|
141 |
+
|
142 |
+
selected_category, reasoning = evaluator_agent.get_evaluation(
|
143 |
+
previous_state_names, previous_questions, previous_user_choices,
|
144 |
+
states, questions, categories
|
145 |
+
)
|
146 |
+
|
147 |
+
return selected_category, reasoning
|
148 |
+
|
149 |
+
|
150 |
+
def save_coordinator() -> None:
|
151 |
+
data = [
|
152 |
+
states,
|
153 |
+
{k: v.dict() for k, v in questions.items()},
|
154 |
+
state_question_map,
|
155 |
+
[c.dict() for c in categories],
|
156 |
+
categories_seen,
|
157 |
+
path_evaluations,
|
158 |
+
]
|
159 |
+
key_names = [
|
160 |
+
'states',
|
161 |
+
'questions',
|
162 |
+
'state_question_map',
|
163 |
+
'categories',
|
164 |
+
'categories_seen',
|
165 |
+
'path_evaluations'
|
166 |
+
]
|
167 |
+
|
168 |
+
combined_data = {
|
169 |
+
key_name: data_
|
170 |
+
for key_name, data_ in zip(key_names, data)
|
171 |
+
}
|
172 |
+
filepath = os.path.join(get_session_dir(), story_name, 'data.json')
|
173 |
+
save_information(combined_data, filepath)
|
174 |
+
|
175 |
+
|
176 |
+
def load_coordinator() -> None:
|
177 |
+
global states, questions, state_question_map, categories, categories_seen, path_evaluations
|
178 |
+
|
179 |
+
filepath = os.path.join(get_session_dir(), story_name, 'data.json')
|
180 |
+
combined_data = load_information(filepath)
|
181 |
+
|
182 |
+
states = {int(k): v for k, v in combined_data['states'].items()}
|
183 |
+
questions = {k: Question(**{
|
184 |
+
'question': v['question'],
|
185 |
+
'options': json.loads(v['options'].replace('\'', '\"'))
|
186 |
+
}) for k, v in combined_data['questions'].items()}
|
187 |
+
state_question_map = {int(k): v for k, v in combined_data['state_question_map'].items()}
|
188 |
+
categories = [Category(**c) for c in combined_data['categories']]
|
189 |
+
categories_seen = combined_data['categories_seen']
|
190 |
+
path_evaluations = combined_data['path_evaluations']
|
src/agents/evaluator_agent.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import json
|
3 |
+
|
4 |
+
from typing import Dict, List, Tuple
|
5 |
+
|
6 |
+
import src.agents.coordinator as C
|
7 |
+
|
8 |
+
from .base_agent import BaseAgent
|
9 |
+
from ..dataclasses.category import Category
|
10 |
+
from ..prompts.evaluator_agent import system_prompt_template, user_prompt_template
|
11 |
+
|
12 |
+
CATEGORY_FORMAT = """
|
13 |
+
{{
|
14 |
+
'name': {category_name},
|
15 |
+
'description': {category_description},
|
16 |
+
'traits': [
|
17 |
+
{category_traits}
|
18 |
+
]
|
19 |
+
}}
|
20 |
+
""".strip()
|
21 |
+
|
22 |
+
|
23 |
+
class EvaluatorAgent(BaseAgent):
|
24 |
+
|
25 |
+
def __init__(self):
|
26 |
+
self.model = 'openai'
|
27 |
+
self.temperature = 0.2
|
28 |
+
|
29 |
+
def generate_path_key(
|
30 |
+
previous_questions: List[str],
|
31 |
+
previous_user_choices: List[str]
|
32 |
+
) -> str:
|
33 |
+
return '-'.join([
|
34 |
+
f'{question}+{choice}'
|
35 |
+
for question, choice in zip(previous_questions, previous_user_choices)
|
36 |
+
])
|
37 |
+
|
38 |
+
def is_encourage_categories(
|
39 |
+
self,
|
40 |
+
questions: Dict[str, str]
|
41 |
+
) -> bool:
|
42 |
+
'''Attempts to ensure that each category has at least one story path to it.
|
43 |
+
|
44 |
+
Computes if the number of remaining paths equals to the number of unused categories.
|
45 |
+
Assumed to be a low chance as the number of questions increases, but exists as a checker.
|
46 |
+
'''
|
47 |
+
if all([v for _, v in C.categories_seen.items()]):
|
48 |
+
return False
|
49 |
+
|
50 |
+
num_questions = len(questions)
|
51 |
+
random_question = list(questions.keys())[0]
|
52 |
+
num_options = len(C.questions[random_question].options)
|
53 |
+
|
54 |
+
rem_categories = sum([1 for _, v in C.categories_seen.items() if not v])
|
55 |
+
num_categories = len(C.categories)
|
56 |
+
return rem_categories >= num_options * num_questions - (num_categories - rem_categories)
|
57 |
+
|
58 |
+
def get_evaluation(
|
59 |
+
self,
|
60 |
+
previous_state_names: List[str],
|
61 |
+
previous_questions: List[str],
|
62 |
+
previous_user_choices: List[str],
|
63 |
+
states: Dict[int, Dict[str, str]],
|
64 |
+
questions: Dict[str, str],
|
65 |
+
categories: List[Category]
|
66 |
+
) -> List[str]:
|
67 |
+
|
68 |
+
complete_user_story = []
|
69 |
+
|
70 |
+
path_key = EvaluatorAgent.generate_path_key(previous_questions, previous_user_choices)
|
71 |
+
|
72 |
+
if path_key in C.path_evaluations:
|
73 |
+
return C.path_evaluations[path_key]
|
74 |
+
|
75 |
+
for i, (state_name, question_uuid, user_choice) in enumerate(zip(
|
76 |
+
previous_state_names, previous_questions, previous_user_choices
|
77 |
+
)):
|
78 |
+
scenario = states[i + 1][state_name]
|
79 |
+
question_str = questions[question_uuid].question
|
80 |
+
|
81 |
+
turn_content = f'`SCENARIO`: {scenario}\n' + \
|
82 |
+
f'`QUESTION`: {question_str}\n' + \
|
83 |
+
f'`CHOICE`: {user_choice}\n'
|
84 |
+
|
85 |
+
complete_user_story.append(turn_content)
|
86 |
+
|
87 |
+
complete_user_story = '\n\n'.join(complete_user_story)
|
88 |
+
|
89 |
+
is_encourage_categories = self.is_encourage_categories(questions)
|
90 |
+
category_contents = []
|
91 |
+
for i, category in enumerate(sorted(categories, key=lambda x: x.name)):
|
92 |
+
if is_encourage_categories and C.categories_seen[category.name]:
|
93 |
+
continue
|
94 |
+
|
95 |
+
category_content = CATEGORY_FORMAT.format(**{
|
96 |
+
'category_name': category.name,
|
97 |
+
'category_description': category.description,
|
98 |
+
'category_traits': category.traits
|
99 |
+
})
|
100 |
+
category_contents.append(category_content)
|
101 |
+
|
102 |
+
category_contents = '{\n' + ',\n'.join(category_contents) + '\n}'
|
103 |
+
|
104 |
+
prompt_arguments = {
|
105 |
+
'user_story': complete_user_story,
|
106 |
+
'category_contents': category_contents
|
107 |
+
}
|
108 |
+
|
109 |
+
system_message = system_prompt_template.format(**prompt_arguments)
|
110 |
+
user_message = user_prompt_template.format(**prompt_arguments)
|
111 |
+
|
112 |
+
messages = self.format_openai_message(system_message, user_message)
|
113 |
+
|
114 |
+
response = self.execute(self.model, messages, self.temperature)
|
115 |
+
|
116 |
+
# remove json tag from response
|
117 |
+
response = response.removeprefix('```json').removeprefix('```')
|
118 |
+
response = response.removesuffix('```json').removesuffix('```')
|
119 |
+
|
120 |
+
selection_info = json.loads(response)
|
121 |
+
|
122 |
+
selected_category = selection_info['category']
|
123 |
+
reason = selection_info['reason']
|
124 |
+
|
125 |
+
C.path_evaluations[path_key] = [selected_category, reason]
|
126 |
+
|
127 |
+
return (selected_category, reason)
|
src/agents/narrator_agent.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import json
|
3 |
+
|
4 |
+
from uuid import uuid4
|
5 |
+
|
6 |
+
import src.agents.coordinator as C
|
7 |
+
|
8 |
+
from .base_agent import BaseAgent
|
9 |
+
from ..dataclasses.process_state import ProcessState
|
10 |
+
from ..dataclasses.question import Question
|
11 |
+
from ..prompts.narrator_agent import system_prompt_template, user_prompt_template
|
12 |
+
|
13 |
+
class NarratorAgent(BaseAgent):
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
story_context: str,
|
18 |
+
category_context: str,
|
19 |
+
num_questions: int,
|
20 |
+
num_options: int
|
21 |
+
):
|
22 |
+
|
23 |
+
self.prompt_arguments = {
|
24 |
+
'story_context': story_context,
|
25 |
+
'category_context': category_context,
|
26 |
+
'num_questions': num_questions,
|
27 |
+
'num_options': num_options
|
28 |
+
}
|
29 |
+
|
30 |
+
self.model = 'openai'
|
31 |
+
self.temperature = 0.5
|
32 |
+
|
33 |
+
def run_agent(self, state: ProcessState) -> ProcessState:
|
34 |
+
|
35 |
+
question_num = state['question_num']
|
36 |
+
state_name = state['state_name']
|
37 |
+
|
38 |
+
if state_name in C.state_question_map.get(question_num, dict()):
|
39 |
+
question_uuid = C.state_question_map[question_num][state_name]
|
40 |
+
|
41 |
+
else:
|
42 |
+
question_uuid = str(uuid4())
|
43 |
+
if question_num not in C.state_question_map:
|
44 |
+
C.state_question_map[question_num] = dict()
|
45 |
+
C.state_question_map[question_num][state_name] = question_uuid
|
46 |
+
|
47 |
+
self.prompt_arguments['user_story'] = '\n'.join([
|
48 |
+
C.states[i + 1][state]
|
49 |
+
for i, state in enumerate(state['previous_state_names'][:-1])
|
50 |
+
])
|
51 |
+
self.prompt_arguments['scenario'] = C.states[question_num][state_name]
|
52 |
+
self.prompt_arguments['builder_instruction'] = state['builder_instruction']
|
53 |
+
if question_num in C.states:
|
54 |
+
self.prompt_arguments['existing_scenarios'] = '\n'.join([
|
55 |
+
'{\n' + f'scenario_name: {state_name}\n' + \
|
56 |
+
f'scenario: {scenario}' + '}\n'
|
57 |
+
for state_name, scenario in C.states.get(question_num + 1, dict()).items()
|
58 |
+
])
|
59 |
+
else:
|
60 |
+
self.prompt_arguments['existing_scenarios'] = ''
|
61 |
+
|
62 |
+
system_message = system_prompt_template.format(**self.prompt_arguments)
|
63 |
+
user_message = user_prompt_template.format(**self.prompt_arguments)
|
64 |
+
|
65 |
+
messages = self.format_openai_message(system_message, user_message)
|
66 |
+
|
67 |
+
response = self.execute(self.model, messages, self.temperature)
|
68 |
+
|
69 |
+
# remove json tag from response
|
70 |
+
response = response.removeprefix('```json').removeprefix('```')
|
71 |
+
response = response.removesuffix('```json').removesuffix('```')
|
72 |
+
|
73 |
+
question_info = json.loads(response)
|
74 |
+
|
75 |
+
if isinstance(question_info, list):
|
76 |
+
question_info = question_info[0]
|
77 |
+
|
78 |
+
question = Question(**question_info)
|
79 |
+
for foll_state_name in question.options.values():
|
80 |
+
if question_num + 1 not in C.states:
|
81 |
+
C.states[question_num + 1] = dict()
|
82 |
+
C.states[question_num + 1][foll_state_name] = 'Unexplored!'
|
83 |
+
|
84 |
+
C.questions[question_uuid] = question
|
85 |
+
|
86 |
+
foll_questions = [question for question in state['previous_questions']]
|
87 |
+
foll_questions.append(question_uuid)
|
88 |
+
|
89 |
+
return {
|
90 |
+
'question_num': state['question_num'],
|
91 |
+
'state_name': state['state_name'],
|
92 |
+
'user_choice': state['user_choice'],
|
93 |
+
'previous_state_names': state['previous_state_names'],
|
94 |
+
'previous_questions': foll_questions,
|
95 |
+
'previous_user_choices': state['previous_user_choices'],
|
96 |
+
'builder_instruction': state['builder_instruction'],
|
97 |
+
'foll_question_uuid': question_uuid
|
98 |
+
}
|
src/agents/orchestrator_agent.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import src.agents.coordinator as C
|
3 |
+
|
4 |
+
from .base_agent import BaseAgent
|
5 |
+
from .category_builder_agent import CategoryBuilderAgent
|
6 |
+
from ..dataclasses.process_state import ProcessState
|
7 |
+
from ..prompts.orchestrator_agent import (
|
8 |
+
start_system_prompt_template,
|
9 |
+
start_user_prompt_template,
|
10 |
+
coordinate_system_prompt_template,
|
11 |
+
coordinate_user_prompt_template,
|
12 |
+
get_phase_message
|
13 |
+
)
|
14 |
+
|
15 |
+
class OrchestratorAgent(BaseAgent):
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
story_context: str,
|
20 |
+
category_context: str,
|
21 |
+
num_questions: int
|
22 |
+
):
|
23 |
+
# build prompt argument
|
24 |
+
self.prompt_arguments = {
|
25 |
+
'story_context': story_context,
|
26 |
+
'category_context': category_context,
|
27 |
+
'num_questions': num_questions
|
28 |
+
}
|
29 |
+
|
30 |
+
# openai calls
|
31 |
+
self.model = 'openai'
|
32 |
+
self.temperature = 0.2
|
33 |
+
|
34 |
+
def run_agent(self, state: ProcessState) -> ProcessState:
|
35 |
+
|
36 |
+
question_num = state['question_num']
|
37 |
+
state_name = state['state_name']
|
38 |
+
|
39 |
+
# check which system and user prompts to build
|
40 |
+
if question_num == 1:
|
41 |
+
system_message = start_system_prompt_template.format(**self.prompt_arguments)
|
42 |
+
user_message = start_user_prompt_template.format(**self.prompt_arguments)
|
43 |
+
else:
|
44 |
+
progress = question_num / self.prompt_arguments['num_questions']
|
45 |
+
self.prompt_arguments['phase_message'] = get_phase_message(progress)
|
46 |
+
self.prompt_arguments['user_story'] = '\n'.join([
|
47 |
+
C.states[i + 1][state]
|
48 |
+
for i, state in enumerate(state['previous_state_names'])
|
49 |
+
])
|
50 |
+
system_message = coordinate_system_prompt_template.format(**self.prompt_arguments)
|
51 |
+
user_message = coordinate_user_prompt_template.format(**self.prompt_arguments)
|
52 |
+
|
53 |
+
# check if the state has been generated before
|
54 |
+
if state_name in C.states.get(question_num, dict()) and C.states[question_num][state_name] != 'Unexplored!':
|
55 |
+
foll_state_names = [name for name in state['previous_state_names']]
|
56 |
+
foll_state_names.append(state['state_name'])
|
57 |
+
|
58 |
+
builder_instruction = ''
|
59 |
+
|
60 |
+
foll_question_uuid = C.state_question_map[question_num][state_name]
|
61 |
+
previous_questions = [question_uuid for question_uuid in state['previous_questions']]
|
62 |
+
if state['foll_question_uuid'] or question_num == 1:
|
63 |
+
if state['foll_question_uuid']:
|
64 |
+
previous_questions.append(state['foll_question_uuid'])
|
65 |
+
else:
|
66 |
+
previous_questions.append(C.state_question_map[1]['start'])
|
67 |
+
|
68 |
+
else:
|
69 |
+
foll_state_names = state['previous_state_names']
|
70 |
+
|
71 |
+
messages = self.format_openai_message(system_message, user_message)
|
72 |
+
builder_instruction = self.execute(self.model, messages, self.temperature)
|
73 |
+
foll_question_uuid = state['foll_question_uuid']
|
74 |
+
previous_questions = [question_uuid for question_uuid in state['previous_questions']]
|
75 |
+
|
76 |
+
return {
|
77 |
+
'question_num': state['question_num'],
|
78 |
+
'state_name': state['state_name'],
|
79 |
+
'user_choice': state['user_choice'],
|
80 |
+
'previous_state_names': foll_state_names,
|
81 |
+
'previous_questions': previous_questions,
|
82 |
+
'previous_user_choices': state['previous_user_choices'],
|
83 |
+
'builder_instruction': builder_instruction,
|
84 |
+
'foll_question_uuid': foll_question_uuid
|
85 |
+
}
|
src/agents/storybuilder_agent.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import src.agents.coordinator as C
|
3 |
+
|
4 |
+
from .base_agent import BaseAgent
|
5 |
+
from ..dataclasses.process_state import ProcessState
|
6 |
+
from ..prompts.storybuilder_agent import system_prompt_template, user_prompt_template
|
7 |
+
|
8 |
+
class StorybuilderAgent(BaseAgent):
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
story_context: str,
|
13 |
+
category_context: str,
|
14 |
+
num_questions: int,
|
15 |
+
num_options: int
|
16 |
+
):
|
17 |
+
|
18 |
+
self.prompt_arguments = {
|
19 |
+
'story_context': story_context,
|
20 |
+
'category_context': category_context,
|
21 |
+
'num_questions': num_questions,
|
22 |
+
'num_options': num_options
|
23 |
+
}
|
24 |
+
|
25 |
+
self.model = 'openai'
|
26 |
+
self.temperature = 0.5
|
27 |
+
|
28 |
+
def run_agent(self, state: ProcessState) -> ProcessState:
|
29 |
+
|
30 |
+
question_num = state['question_num']
|
31 |
+
state_name = state['state_name']
|
32 |
+
|
33 |
+
self.prompt_arguments['builder_instruction'] = state['builder_instruction']
|
34 |
+
self.prompt_arguments['user_story'] = '\n'.join([
|
35 |
+
C.states[i + 1][state]
|
36 |
+
for i, state in enumerate(state['previous_state_names'])
|
37 |
+
])
|
38 |
+
|
39 |
+
if question_num == 1:
|
40 |
+
self.prompt_arguments['question'] = ''
|
41 |
+
self.prompt_arguments['response'] = ''
|
42 |
+
else:
|
43 |
+
self.prompt_arguments['question'] = C.questions[state['previous_questions'][-1]].question
|
44 |
+
self.prompt_arguments['response'] = state['user_choice']
|
45 |
+
|
46 |
+
system_message = system_prompt_template.format(**self.prompt_arguments)
|
47 |
+
user_message = user_prompt_template.format(**self.prompt_arguments)
|
48 |
+
|
49 |
+
messages = self.format_openai_message(system_message, user_message)
|
50 |
+
|
51 |
+
scenario = self.execute(self.model, messages, self.temperature)
|
52 |
+
|
53 |
+
if question_num not in C.states:
|
54 |
+
C.states[question_num] = dict()
|
55 |
+
C.states[question_num][state_name] = scenario
|
56 |
+
|
57 |
+
foll_state_names = [name for name in state['previous_state_names']]
|
58 |
+
foll_state_names.append(state_name)
|
59 |
+
|
60 |
+
return {
|
61 |
+
'question_num': state['question_num'],
|
62 |
+
'state_name': state['state_name'],
|
63 |
+
'user_choice': state['user_choice'],
|
64 |
+
'previous_state_names': foll_state_names,
|
65 |
+
'previous_questions': state['previous_questions'],
|
66 |
+
'previous_user_choices': state['previous_user_choices'],
|
67 |
+
'builder_instruction': state['builder_instruction'],
|
68 |
+
'foll_question_uuid': state['foll_question_uuid']
|
69 |
+
}
|
src/dataclasses/category.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from dataclasses import asdict, dataclass
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class Category:
|
8 |
+
name: str
|
9 |
+
description: str
|
10 |
+
traits: List[str]
|
11 |
+
|
12 |
+
def dict(self):
|
13 |
+
return {k: str(v) for k, v in asdict(self).items()}
|
src/dataclasses/process_state.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import List, TypedDict
|
3 |
+
|
4 |
+
|
5 |
+
class ProcessState(TypedDict):
|
6 |
+
question_num: int
|
7 |
+
|
8 |
+
state_name: str
|
9 |
+
user_choice: str
|
10 |
+
|
11 |
+
previous_state_names: List[str]
|
12 |
+
previous_questions: List[str]
|
13 |
+
previous_user_choices: List[str]
|
14 |
+
|
15 |
+
builder_instruction: str
|
16 |
+
|
17 |
+
foll_question_uuid: str
|
src/dataclasses/question.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from dataclasses import asdict, dataclass
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class Question:
|
8 |
+
question: str
|
9 |
+
options: Dict[str, str]
|
10 |
+
|
11 |
+
def dict(self):
|
12 |
+
return {k: str(v) for k, v in asdict(self).items()}
|
src/frontend/chatbot.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from langgraph.graph.graph import CompiledGraph
|
5 |
+
from typing import Dict, List, Union
|
6 |
+
|
7 |
+
import src.agents.coordinator as C
|
8 |
+
|
9 |
+
from src.dataclasses.process_state import ProcessState
|
10 |
+
|
11 |
+
graph: CompiledGraph = None
|
12 |
+
MAX_OPTIONS = 16
|
13 |
+
state: ProcessState = {
|
14 |
+
'question_num': 0,
|
15 |
+
'state_name': 'start',
|
16 |
+
'user_choice': '',
|
17 |
+
'previous_state_names': [],
|
18 |
+
'previous_questions': [],
|
19 |
+
'previous_user_choices': [],
|
20 |
+
'builder_instruction': '',
|
21 |
+
'foll_question_uuid': ''
|
22 |
+
}
|
23 |
+
state_history: List[ProcessState] = []
|
24 |
+
num_questions_: int = None
|
25 |
+
|
26 |
+
chatbot = gr.Chatbot(
|
27 |
+
type='messages',
|
28 |
+
key='chatbot',
|
29 |
+
preserved_by_key='key',
|
30 |
+
visible=False
|
31 |
+
)
|
32 |
+
option_buttons = [
|
33 |
+
gr.Button(
|
34 |
+
value='',
|
35 |
+
visible=False,
|
36 |
+
render=False,
|
37 |
+
key=f'option_button_{i}',
|
38 |
+
preserved_by_key='key'
|
39 |
+
)
|
40 |
+
for i in range(1, MAX_OPTIONS + 1)
|
41 |
+
]
|
42 |
+
restart_button = gr.Button(
|
43 |
+
'Explore other storylines!',
|
44 |
+
visible=False,
|
45 |
+
interactive=True,
|
46 |
+
key='restart_button',
|
47 |
+
preserved_by_key='key'
|
48 |
+
)
|
49 |
+
button_row = gr.Row(
|
50 |
+
key='button_row',
|
51 |
+
preserved_by_key='key'
|
52 |
+
)
|
53 |
+
|
54 |
+
def init_graph(
|
55 |
+
story_context: str,
|
56 |
+
categories_context: str,
|
57 |
+
num_questions: int,
|
58 |
+
num_options: int,
|
59 |
+
num_categories: int
|
60 |
+
) -> None:
|
61 |
+
global graph, option_buttons, num_questions_
|
62 |
+
|
63 |
+
num_questions_ = num_questions
|
64 |
+
graph = C.init_graph(
|
65 |
+
story_context,
|
66 |
+
categories_context,
|
67 |
+
num_questions,
|
68 |
+
num_options,
|
69 |
+
num_categories
|
70 |
+
)
|
71 |
+
|
72 |
+
for i in range(num_options, MAX_OPTIONS):
|
73 |
+
option_buttons[i].unrender()
|
74 |
+
|
75 |
+
option_buttons = option_buttons[:num_options]
|
76 |
+
|
77 |
+
|
78 |
+
def on_user_response(
|
79 |
+
user_choice: str,
|
80 |
+
history: List[Dict[str, str]]
|
81 |
+
) -> Dict[Union[gr.Button, gr.Chatbot], Union[List[Dict[str, str]], gr.update]]:
|
82 |
+
state['user_choice'] = user_choice.replace(' ', '_')
|
83 |
+
|
84 |
+
user_message = [{'role': 'user', 'content': user_choice}]
|
85 |
+
updated_history = history + user_message
|
86 |
+
|
87 |
+
return {chatbot: updated_history} | {
|
88 |
+
button: gr.update(visible=False)
|
89 |
+
for button in option_buttons
|
90 |
+
}
|
91 |
+
|
92 |
+
|
93 |
+
def control_screen_widgets() -> List[Union[gr.Chatbot, gr.Row, gr.Button]]:
|
94 |
+
return [chatbot, button_row, restart_button]
|
95 |
+
|
96 |
+
|
97 |
+
def control_screen(
|
98 |
+
is_visible: bool
|
99 |
+
) -> Dict[Union[gr.Chatbot, gr.Row, gr.Button], Union[gr.update, gr.Row]]:
|
100 |
+
return {
|
101 |
+
chatbot: gr.update(visible=is_visible),
|
102 |
+
button_row: gr.Row(visible=is_visible),
|
103 |
+
restart_button: gr.update(visible=is_visible)
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
def on_chatbot_response(
|
108 |
+
history: List[Dict[str, str]]
|
109 |
+
) -> Dict[Union[gr.Chatbot, gr.Button], Union[List[Dict[str, str]], gr.update]]:
|
110 |
+
global state
|
111 |
+
|
112 |
+
state_history.append({key: val for key, val in state.items()})
|
113 |
+
|
114 |
+
question_num = state['question_num']
|
115 |
+
if question_num < num_questions_:
|
116 |
+
state = graph.invoke(state)
|
117 |
+
|
118 |
+
question_num = state['question_num']
|
119 |
+
state_name = state['state_name']
|
120 |
+
scenario = C.states[question_num][state_name]
|
121 |
+
question = C.questions[state['foll_question_uuid']]
|
122 |
+
question_str = question.question.replace('_', ' ').strip('\"').strip('\'')
|
123 |
+
|
124 |
+
text_to_user = scenario + '\n' + question_str
|
125 |
+
bot_message = [{'role': 'assistant', 'content': text_to_user}]
|
126 |
+
updated_history = history + bot_message
|
127 |
+
|
128 |
+
options = [
|
129 |
+
option.replace('_', ' ').strip('\"').strip('\'')
|
130 |
+
for option in question.options.keys()
|
131 |
+
]
|
132 |
+
|
133 |
+
button_updates = {
|
134 |
+
button: gr.update(value=option, visible=True)
|
135 |
+
for option, button in zip(options, option_buttons)
|
136 |
+
}
|
137 |
+
button_updates[restart_button] = gr.update(visible=False)
|
138 |
+
else:
|
139 |
+
selected_category, reason = C.evaluation(state)
|
140 |
+
description, traits = None, None
|
141 |
+
for category in C.categories:
|
142 |
+
if category.name != selected_category:
|
143 |
+
continue
|
144 |
+
description = category.description
|
145 |
+
traits = category.traits
|
146 |
+
C.categories_seen[selected_category] = True
|
147 |
+
traits_string = traits
|
148 |
+
bot_messages = [
|
149 |
+
{'role': 'assistant', 'content': f'We think that you are closest to **{selected_category}**!'},
|
150 |
+
{'role': 'assistant', 'content': f'## {selected_category}\n\n{description}\n\n{traits_string}'},
|
151 |
+
{'role': 'assistant', 'content': reason}
|
152 |
+
]
|
153 |
+
updated_history = history + bot_messages
|
154 |
+
button_updates = {
|
155 |
+
button: gr.update(visible=False)
|
156 |
+
for button in option_buttons
|
157 |
+
}
|
158 |
+
button_updates[restart_button] = gr.update(visible=True)
|
159 |
+
|
160 |
+
C.save_coordinator()
|
161 |
+
|
162 |
+
return {chatbot: gr.update(value=updated_history, visible=True)} | button_updates
|
163 |
+
|
164 |
+
|
165 |
+
def on_restart_button_click()-> Dict[
|
166 |
+
Union[gr.Chatbot, gr.Button], Union[List[Dict[str, str]], gr.update]
|
167 |
+
]:
|
168 |
+
global state
|
169 |
+
|
170 |
+
state = {
|
171 |
+
'question_num': 0,
|
172 |
+
'state_name': 'start',
|
173 |
+
'user_choice': '',
|
174 |
+
'previous_state_names': [],
|
175 |
+
'previous_questions': [],
|
176 |
+
'previous_user_choices': [],
|
177 |
+
'builder_instruction': '',
|
178 |
+
'foll_question_uuid': ''
|
179 |
+
}
|
180 |
+
|
181 |
+
return {restart_button: gr.update(visible=False)} | \
|
182 |
+
on_chatbot_response([])
|
183 |
+
|
184 |
+
|
185 |
+
def render():
|
186 |
+
chatbot.render()
|
187 |
+
|
188 |
+
button_row.render()
|
189 |
+
with button_row:
|
190 |
+
for button in option_buttons:
|
191 |
+
button.unrender()
|
192 |
+
button.render()
|
193 |
+
button.click(
|
194 |
+
fn=on_user_response,
|
195 |
+
inputs=[button, chatbot],
|
196 |
+
outputs=option_buttons + [chatbot]
|
197 |
+
).then(
|
198 |
+
fn=on_chatbot_response,
|
199 |
+
inputs=[chatbot],
|
200 |
+
outputs=option_buttons + [chatbot, restart_button]
|
201 |
+
)
|
202 |
+
restart_button.render()
|
203 |
+
restart_button.click(
|
204 |
+
fn=on_restart_button_click,
|
205 |
+
inputs=None,
|
206 |
+
outputs=option_buttons + [chatbot, restart_button]
|
207 |
+
)
|
src/frontend/continue_story_screen.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
|
5 |
+
from typing import Dict, List, Tuple, Union
|
6 |
+
|
7 |
+
import src.agents.coordinator as C
|
8 |
+
|
9 |
+
from src.frontend.chatbot import (
|
10 |
+
chatbot,
|
11 |
+
init_graph,
|
12 |
+
on_chatbot_response,
|
13 |
+
option_buttons,
|
14 |
+
restart_button
|
15 |
+
)
|
16 |
+
from src.frontend import sidebar
|
17 |
+
from src.utils.utils import load_information, transform_story_name
|
18 |
+
|
19 |
+
information_text = gr.Text(
|
20 |
+
value='Select which story to continue below!',
|
21 |
+
interactive=False,
|
22 |
+
key='continue_story_user_information',
|
23 |
+
preserved_by_key='key',
|
24 |
+
visible=False,
|
25 |
+
label=None
|
26 |
+
)
|
27 |
+
session_dir = os.path.join(
|
28 |
+
os.path.dirname(__file__),
|
29 |
+
'..', '..', 'sessions'
|
30 |
+
)
|
31 |
+
story_widgets: List[Tuple[gr.Row, gr.Button, gr.Text]] = []
|
32 |
+
story_dir_mapper = dict()
|
33 |
+
for i, dirname in enumerate(sorted(os.listdir(session_dir))):
|
34 |
+
if dirname.startswith('.'):
|
35 |
+
continue
|
36 |
+
dirpath = os.path.join(session_dir, dirname)
|
37 |
+
if not os.listdir(dirpath):
|
38 |
+
continue
|
39 |
+
story_information_filepath = os.path.join(dirpath, 'story.json')
|
40 |
+
story_information_dict = load_information(story_information_filepath)
|
41 |
+
story_name = story_information_dict['story_name']
|
42 |
+
story_context = story_information_dict['story_context']
|
43 |
+
row = gr.Row(
|
44 |
+
visible=False,
|
45 |
+
key=f'continue_row_{i}',
|
46 |
+
preserved_by_key='key'
|
47 |
+
)
|
48 |
+
button = gr.Button(
|
49 |
+
value=story_name,
|
50 |
+
visible=False,
|
51 |
+
key=f'continue_button_{i}',
|
52 |
+
preserved_by_key='key',
|
53 |
+
scale=1
|
54 |
+
)
|
55 |
+
text = gr.Text(
|
56 |
+
value=story_context,
|
57 |
+
interactive=False,
|
58 |
+
visible=False,
|
59 |
+
key=f'continue_story_{i}',
|
60 |
+
preserved_by_key='key',
|
61 |
+
scale=3
|
62 |
+
)
|
63 |
+
story_widgets.append((row, button, text))
|
64 |
+
story_dir_mapper[story_name] = dirpath
|
65 |
+
|
66 |
+
|
67 |
+
def get_widgets() -> List[Union[gr.Text, gr.Button, gr.Row]]:
|
68 |
+
widgets = [information_text]
|
69 |
+
for row, button, text in story_widgets:
|
70 |
+
widgets.append(row)
|
71 |
+
widgets.append(button)
|
72 |
+
widgets.append(text)
|
73 |
+
return widgets
|
74 |
+
|
75 |
+
|
76 |
+
def get_widgets_updates(
|
77 |
+
is_visible: bool
|
78 |
+
) -> Dict[Union[gr.Slider, gr.Text, gr.Button], gr.update]:
|
79 |
+
updates = dict()
|
80 |
+
for widget in get_widgets():
|
81 |
+
if isinstance(widget, gr.Row):
|
82 |
+
updates[widget] = gr.Row(visible=is_visible)
|
83 |
+
else:
|
84 |
+
updates[widget] = gr.update(visible=is_visible)
|
85 |
+
return updates
|
86 |
+
|
87 |
+
|
88 |
+
def on_button_click(
|
89 |
+
story_name: str
|
90 |
+
) -> Dict[Union[gr.Text, gr.Slider, gr.Button, gr.Chatbot], gr.update]:
|
91 |
+
|
92 |
+
dirpath = story_dir_mapper[story_name]
|
93 |
+
story_name_ = transform_story_name(story_name)
|
94 |
+
|
95 |
+
story_information_filepath = os.path.join(dirpath, 'story.json')
|
96 |
+
story_information_dict = load_information(story_information_filepath)
|
97 |
+
story_context = story_information_dict['story_context']
|
98 |
+
categories_context = story_information_dict['categories_context']
|
99 |
+
num_questions = story_information_dict['num_questions']
|
100 |
+
num_options = story_information_dict['num_options']
|
101 |
+
num_categories = story_information_dict['num_categories']
|
102 |
+
|
103 |
+
init_graph(
|
104 |
+
story_context,
|
105 |
+
categories_context,
|
106 |
+
num_questions,
|
107 |
+
num_options,
|
108 |
+
num_categories
|
109 |
+
)
|
110 |
+
|
111 |
+
C.story_name = story_name_
|
112 |
+
C.load_coordinator()
|
113 |
+
|
114 |
+
chatbot_updates = on_chatbot_response([])
|
115 |
+
|
116 |
+
return chatbot_updates | get_widgets_updates(False) | sidebar.view_screen()
|
117 |
+
|
118 |
+
|
119 |
+
def render():
|
120 |
+
information_text.render()
|
121 |
+
|
122 |
+
for row, button, text in story_widgets:
|
123 |
+
row.render()
|
124 |
+
with row:
|
125 |
+
button.render()
|
126 |
+
text.render()
|
127 |
+
button.click(
|
128 |
+
fn=on_button_click,
|
129 |
+
inputs=[button],
|
130 |
+
outputs=get_widgets() + \
|
131 |
+
[chatbot, restart_button] + \
|
132 |
+
option_buttons + \
|
133 |
+
sidebar.get_widgets()
|
134 |
+
)
|
src/frontend/progress_screen.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import pandas as pd
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
|
6 |
+
from igraph import Graph, EdgeSeq
|
7 |
+
from typing import Dict, List, Union
|
8 |
+
|
9 |
+
import src.agents.coordinator as C
|
10 |
+
|
11 |
+
categories_row = gr.Row(
|
12 |
+
visible=False,
|
13 |
+
key='categories_row',
|
14 |
+
preserved_by_key='key'
|
15 |
+
)
|
16 |
+
graph_row = gr.Row(
|
17 |
+
visible=False,
|
18 |
+
key='graph_row',
|
19 |
+
preserved_by_key='key'
|
20 |
+
)
|
21 |
+
|
22 |
+
category_buttons = [
|
23 |
+
gr.Button(
|
24 |
+
value='',
|
25 |
+
visible=False,
|
26 |
+
key=f'category_button_{i + 1}',
|
27 |
+
preserved_by_key='key'
|
28 |
+
)
|
29 |
+
for i in range(16)
|
30 |
+
]
|
31 |
+
category_text = gr.Text(
|
32 |
+
value='Select a category!',
|
33 |
+
visible=False,
|
34 |
+
interactive=False,
|
35 |
+
key='category_text',
|
36 |
+
preserved_by_key='key'
|
37 |
+
)
|
38 |
+
graph_plot = gr.Plot(
|
39 |
+
visible=False,
|
40 |
+
key='graph_plot',
|
41 |
+
preserved_by_key='key'
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def control_screen(is_visible: bool) -> Dict[gr.Row, gr.Row]:
|
46 |
+
return {
|
47 |
+
categories_row: gr.Row(visible=is_visible),
|
48 |
+
graph_row: gr.Row(visible=is_visible)
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
def update_categories() -> Dict[Union[gr.Button, gr.Text], gr.update]:
|
53 |
+
button_updates = {
|
54 |
+
button: gr.update(
|
55 |
+
value=category.name if C.categories_seen[category.name] \
|
56 |
+
else '???',
|
57 |
+
visible=True,
|
58 |
+
interactive=C.categories_seen[category.name],
|
59 |
+
variant='primary' if C.categories_seen[category.name] \
|
60 |
+
else 'secondary'
|
61 |
+
)
|
62 |
+
for button, category in zip(category_buttons, C.categories)
|
63 |
+
}
|
64 |
+
|
65 |
+
return {
|
66 |
+
category_text: gr.update(visible=True, value='Select a category!')
|
67 |
+
} | button_updates
|
68 |
+
|
69 |
+
|
70 |
+
def on_category_click(category_name: str) -> str:
|
71 |
+
for category in C.categories:
|
72 |
+
if category.name != category_name:
|
73 |
+
continue
|
74 |
+
return category.name + '\n\n' + \
|
75 |
+
category.description + '\n\n' + \
|
76 |
+
category.traits
|
77 |
+
|
78 |
+
|
79 |
+
def update_graph():
|
80 |
+
n_vertices = sum([
|
81 |
+
len(states_)
|
82 |
+
for level, states_ in C.states.items()
|
83 |
+
if level <= C.num_questions_
|
84 |
+
])
|
85 |
+
|
86 |
+
graph = Graph(directed=True)
|
87 |
+
nodes = []
|
88 |
+
node_attributes = {'description': []}
|
89 |
+
for level in range(1, C.num_questions_ + 1):
|
90 |
+
for name, description in C.states[level].items():
|
91 |
+
node_name = f'{name} (Stage {level})'
|
92 |
+
nodes.append(node_name)
|
93 |
+
node_attributes['description'].append(description)
|
94 |
+
|
95 |
+
edges = []
|
96 |
+
edge_attributes = {
|
97 |
+
'question': [],
|
98 |
+
'option': []
|
99 |
+
}
|
100 |
+
for level in range(1, C.num_questions_):
|
101 |
+
for state, question_uuid in C.state_question_map[level].items():
|
102 |
+
question = C.questions[question_uuid]
|
103 |
+
question_str = question.question
|
104 |
+
options = question.options
|
105 |
+
for option, foll_state in options.items():
|
106 |
+
edge = (
|
107 |
+
f'{state} (Stage {level})',
|
108 |
+
f'{foll_state} (Stage {level + 1})'
|
109 |
+
)
|
110 |
+
edges.append(edge)
|
111 |
+
edge_attributes['question'].append(question_str)
|
112 |
+
edge_attributes['option'].append(option)
|
113 |
+
|
114 |
+
graph.add_vertices(nodes, attributes=node_attributes)
|
115 |
+
graph.add_edges(edges, attributes=edge_attributes)
|
116 |
+
layout = graph.layout('rt')
|
117 |
+
|
118 |
+
# adapted from https://plotly.com/python/tree-plots/
|
119 |
+
position = {k: layout[k] for k in range(n_vertices)}
|
120 |
+
Y = [layout[k][1] for k in range(n_vertices)]
|
121 |
+
M = max(Y)
|
122 |
+
|
123 |
+
E = [e.tuple for e in graph.es] # list of edges
|
124 |
+
|
125 |
+
L = len(position)
|
126 |
+
Xn = [position[k][0] for k in range(L)]
|
127 |
+
Yn = [2*M-position[k][1] for k in range(L)]
|
128 |
+
Xe = []
|
129 |
+
Ye = []
|
130 |
+
|
131 |
+
# for labelling edges
|
132 |
+
X_edge_nodes = []
|
133 |
+
Y_edge_nodes = []
|
134 |
+
for edge in E:
|
135 |
+
Xe+=[position[edge[0]][0],position[edge[1]][0], None]
|
136 |
+
Ye+=[2*M-position[edge[0]][1],2*M-position[edge[1]][1], None]
|
137 |
+
|
138 |
+
X_edge_nodes.append((position[edge[0]][0] + position[edge[1]][0]) / 2)
|
139 |
+
Y_edge_nodes.append((2*M-position[edge[0]][1] + 2*M-position[edge[1]][1]) / 2)
|
140 |
+
|
141 |
+
node_labels = [
|
142 |
+
node.replace('_', ' ') + '\n\n' + \
|
143 |
+
description
|
144 |
+
for node, description in zip(nodes, node_attributes['description'])
|
145 |
+
]
|
146 |
+
node_labels = pd.DataFrame(node_labels, columns=['label'])
|
147 |
+
node_labels['label'] = node_labels['label'].str.wrap(30)\
|
148 |
+
.apply(lambda x: x.replace('\n', '<br>'))
|
149 |
+
node_labels = node_labels['label'].to_list()
|
150 |
+
edge_labels = [
|
151 |
+
question.replace('_', ' ') + '\n\n' + option.replace('_', ' ')
|
152 |
+
for question, option in zip(
|
153 |
+
edge_attributes['question'], edge_attributes['option']
|
154 |
+
)
|
155 |
+
]
|
156 |
+
edge_labels = pd.DataFrame(edge_labels, columns=['label'])
|
157 |
+
edge_labels['label'] = edge_labels['label'].str.wrap(30)\
|
158 |
+
.apply(lambda x: x.replace('\n', '<br>'))
|
159 |
+
edge_labels = edge_labels['label'].to_list()
|
160 |
+
|
161 |
+
fig = go.Figure()
|
162 |
+
fig.add_trace(go.Scatter(
|
163 |
+
x=Xe, y=Ye,
|
164 |
+
mode='lines',
|
165 |
+
line=dict(color='rgb(210,210,210)', width=1),
|
166 |
+
))
|
167 |
+
fig.add_trace(go.Scatter(
|
168 |
+
x=Xn, y=Yn,
|
169 |
+
mode='markers',
|
170 |
+
marker=dict(
|
171 |
+
symbol='circle-dot', size=18, color='#6175c1',
|
172 |
+
line=dict(color='rgb(50,50,50)', width=1)
|
173 |
+
),
|
174 |
+
text=node_labels,
|
175 |
+
hoverinfo='text',
|
176 |
+
opacity=0.8
|
177 |
+
))
|
178 |
+
|
179 |
+
fig.add_trace(go.Scatter(
|
180 |
+
x=X_edge_nodes, y=Y_edge_nodes,
|
181 |
+
mode='markers',
|
182 |
+
marker=dict(
|
183 |
+
symbol='circle-dot', size=0, color="#42c744",
|
184 |
+
line=dict(color='rgb(50,50,50)', width=0)
|
185 |
+
),
|
186 |
+
text=edge_labels,
|
187 |
+
hoverinfo='text',
|
188 |
+
opacity=0
|
189 |
+
))
|
190 |
+
|
191 |
+
axis = dict(
|
192 |
+
showline=False,
|
193 |
+
zeroline=False,
|
194 |
+
showgrid=False,
|
195 |
+
showticklabels=False,
|
196 |
+
)
|
197 |
+
|
198 |
+
fig.update_layout(
|
199 |
+
showlegend=False,
|
200 |
+
xaxis=axis,
|
201 |
+
yaxis=axis
|
202 |
+
)
|
203 |
+
|
204 |
+
return gr.Plot(fig, visible=True)
|
205 |
+
|
206 |
+
|
207 |
+
def control_screen_widgets() -> List[Union[gr.Row, gr.Text]]:
|
208 |
+
return [categories_row, category_text, graph_row, graph_plot] + \
|
209 |
+
category_buttons
|
210 |
+
|
211 |
+
|
212 |
+
def control_screen(
|
213 |
+
is_visible: bool
|
214 |
+
) -> Dict[Union[gr.Plot, gr.Row, gr.Button], Union[gr.update, gr.Row]]:
|
215 |
+
row_updates = {
|
216 |
+
categories_row: gr.Row(visible=is_visible),
|
217 |
+
graph_row: gr.Row(visible=is_visible)
|
218 |
+
}
|
219 |
+
graph_update = {
|
220 |
+
graph_plot: update_graph() if is_visible else gr.update(visible=False)
|
221 |
+
}
|
222 |
+
category_button_updates = update_categories()
|
223 |
+
|
224 |
+
return row_updates | graph_update | category_button_updates
|
225 |
+
|
226 |
+
|
227 |
+
def render():
|
228 |
+
categories_row.render()
|
229 |
+
with categories_row:
|
230 |
+
for button in category_buttons:
|
231 |
+
button.render()
|
232 |
+
button.click(
|
233 |
+
fn=on_category_click,
|
234 |
+
inputs=[button],
|
235 |
+
outputs=[category_text]
|
236 |
+
)
|
237 |
+
category_text.render()
|
238 |
+
graph_row.render()
|
239 |
+
with graph_row:
|
240 |
+
graph_plot.render()
|
src/frontend/sidebar.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from typing import Dict, List, Union
|
5 |
+
|
6 |
+
from src.frontend import chatbot, progress_screen
|
7 |
+
|
8 |
+
sidebar = gr.Sidebar(
|
9 |
+
open=True,
|
10 |
+
visible=False,
|
11 |
+
key='sidebar',
|
12 |
+
preserved_by_key='key'
|
13 |
+
)
|
14 |
+
chatbot_button = gr.Button(
|
15 |
+
value='Viewing Chatbot!',
|
16 |
+
variant='secondary',
|
17 |
+
visible=False,
|
18 |
+
interactive=False,
|
19 |
+
key='chatbot_button',
|
20 |
+
preserved_by_key='key'
|
21 |
+
)
|
22 |
+
progress_button = gr.Button(
|
23 |
+
value='View Your Story Progress!',
|
24 |
+
variant='primary',
|
25 |
+
visible=False,
|
26 |
+
interactive=True,
|
27 |
+
key='progress_button',
|
28 |
+
preserved_by_key='key'
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def get_widgets() -> List[Union[gr.Sidebar, gr.Button]]:
|
33 |
+
return [sidebar, chatbot_button, progress_button]
|
34 |
+
|
35 |
+
|
36 |
+
def view_screen() -> Dict[Union[gr.Sidebar, gr.Button], gr.update]:
|
37 |
+
return {
|
38 |
+
widget: gr.update(visible=True)
|
39 |
+
for widget in get_widgets()
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def on_chatbot_button_click():
|
44 |
+
chatbot_screen_updates = chatbot.control_screen(True)
|
45 |
+
progress_screen_updates = progress_screen.control_screen(False)
|
46 |
+
|
47 |
+
chatbot_button_update = gr.update(
|
48 |
+
value='Viewing Chatbot!',
|
49 |
+
variant='secondary',
|
50 |
+
interactive=False
|
51 |
+
)
|
52 |
+
progress_button_update = gr.update(
|
53 |
+
value='View Your Story Progress!',
|
54 |
+
variant='primary',
|
55 |
+
interactive=True
|
56 |
+
)
|
57 |
+
|
58 |
+
return {
|
59 |
+
chatbot_button: chatbot_button_update,
|
60 |
+
progress_button: progress_button_update
|
61 |
+
} | \
|
62 |
+
chatbot_screen_updates | progress_screen_updates
|
63 |
+
|
64 |
+
|
65 |
+
def on_progress_button_click():
|
66 |
+
chatbot_screen_updates = chatbot.control_screen(False)
|
67 |
+
progress_screen_updates = progress_screen.control_screen(True)
|
68 |
+
|
69 |
+
chatbot_button_update = gr.update(
|
70 |
+
value='Come & Explore!',
|
71 |
+
variant='primary',
|
72 |
+
interactive=True
|
73 |
+
)
|
74 |
+
progress_button_update = gr.update(
|
75 |
+
value='Viewing Your Story Progress!',
|
76 |
+
variant='secondary',
|
77 |
+
interactive=False
|
78 |
+
)
|
79 |
+
|
80 |
+
return {
|
81 |
+
chatbot_button: chatbot_button_update,
|
82 |
+
progress_button: progress_button_update
|
83 |
+
} | \
|
84 |
+
chatbot_screen_updates | progress_screen_updates
|
85 |
+
|
86 |
+
|
87 |
+
def render():
|
88 |
+
sidebar.render()
|
89 |
+
with sidebar:
|
90 |
+
chatbot_button.render()
|
91 |
+
progress_button.render()
|
92 |
+
|
93 |
+
chatbot_button.click(
|
94 |
+
fn=on_chatbot_button_click,
|
95 |
+
inputs=[],
|
96 |
+
outputs=[chatbot_button, progress_button] + \
|
97 |
+
chatbot.control_screen_widgets() + \
|
98 |
+
progress_screen.control_screen_widgets()
|
99 |
+
)
|
100 |
+
|
101 |
+
progress_button.click(
|
102 |
+
fn=on_progress_button_click,
|
103 |
+
inputs=[],
|
104 |
+
outputs=[chatbot_button, progress_button] + \
|
105 |
+
chatbot.control_screen_widgets() + \
|
106 |
+
progress_screen.control_screen_widgets()
|
107 |
+
)
|
108 |
+
|
src/frontend/start_screen.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from typing import Dict, List, Union
|
5 |
+
|
6 |
+
from src.frontend import (
|
7 |
+
continue_story_screen,
|
8 |
+
story_information_widgets
|
9 |
+
)
|
10 |
+
from src.utils.init_openai import init_client
|
11 |
+
|
12 |
+
button_row = gr.Row(
|
13 |
+
key='start_screen_button_row',
|
14 |
+
preserved_by_key='key'
|
15 |
+
)
|
16 |
+
new_story_button = gr.Button(
|
17 |
+
value='Begin a new adventure!',
|
18 |
+
visible=True,
|
19 |
+
interactive=True,
|
20 |
+
key='new_story_button',
|
21 |
+
preserved_by_key='key'
|
22 |
+
)
|
23 |
+
continue_story_button = gr.Button(
|
24 |
+
value='Continue an existing adventure!',
|
25 |
+
visible=True,
|
26 |
+
interactive=True,
|
27 |
+
key='continue_story_button',
|
28 |
+
preserved_by_key='key'
|
29 |
+
)
|
30 |
+
api_key_textbox = gr.Text(
|
31 |
+
type='password',
|
32 |
+
label='OpenAI API Key',
|
33 |
+
placeholder='Enter your OpenAI API key',
|
34 |
+
interactive=True,
|
35 |
+
visible=True,
|
36 |
+
key='api_key_textbox',
|
37 |
+
preserved_by_key='key'
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def get_widgets() -> List[Union[gr.Text, gr.Button, gr.Row]]:
|
42 |
+
return [
|
43 |
+
button_row,
|
44 |
+
new_story_button,
|
45 |
+
continue_story_button,
|
46 |
+
api_key_textbox
|
47 |
+
]
|
48 |
+
|
49 |
+
|
50 |
+
def get_wigets_updates(
|
51 |
+
is_visible: bool = False
|
52 |
+
) -> Dict[Union[gr.Text, gr.Button, gr.Row], Union[gr.update, gr.Row]]:
|
53 |
+
return {
|
54 |
+
widget: gr.Row(visible=is_visible) if isinstance(widget, gr.Row) else \
|
55 |
+
gr.update(visible=is_visible)
|
56 |
+
for widget in get_widgets()
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
def on_submit_new_story(
|
61 |
+
api_key: str
|
62 |
+
) -> Dict[Union[gr.Text, gr.Slider, gr.Button], Union[gr.Row, gr.update]]:
|
63 |
+
if init_client(api_key):
|
64 |
+
return get_wigets_updates(False) | \
|
65 |
+
story_information_widgets.get_widgets_updates(True)
|
66 |
+
return get_wigets_updates(True) | \
|
67 |
+
story_information_widgets.get_widgets_updates(False)
|
68 |
+
|
69 |
+
|
70 |
+
def on_submit_continue_story(
|
71 |
+
api_key: str
|
72 |
+
) -> Dict[Union[gr.Text, gr.Slider, gr.Button], Union[gr.Row, gr.update]]:
|
73 |
+
if init_client(api_key):
|
74 |
+
return get_wigets_updates(False) | \
|
75 |
+
continue_story_screen.get_widgets_updates(True)
|
76 |
+
return get_wigets_updates(True) | \
|
77 |
+
continue_story_screen.get_widgets_updates(False)
|
78 |
+
|
79 |
+
|
80 |
+
def render():
|
81 |
+
api_key_textbox.render()
|
82 |
+
button_row.render()
|
83 |
+
with button_row:
|
84 |
+
new_story_button.render()
|
85 |
+
continue_story_button.render()
|
86 |
+
|
87 |
+
new_story_button.click(
|
88 |
+
fn=on_submit_new_story,
|
89 |
+
inputs=[api_key_textbox],
|
90 |
+
outputs=get_widgets() + \
|
91 |
+
story_information_widgets.get_widgets()
|
92 |
+
)
|
93 |
+
continue_story_button.click(
|
94 |
+
fn=on_submit_continue_story,
|
95 |
+
inputs=[api_key_textbox],
|
96 |
+
outputs=get_widgets() + \
|
97 |
+
continue_story_screen.get_widgets()
|
98 |
+
)
|
src/frontend/story_information_widgets.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
|
6 |
+
from typing import Dict, List, Union
|
7 |
+
|
8 |
+
import src.agents.coordinator as C
|
9 |
+
|
10 |
+
from src.frontend.chatbot import (
|
11 |
+
MAX_OPTIONS,
|
12 |
+
chatbot,
|
13 |
+
init_graph,
|
14 |
+
on_chatbot_response,
|
15 |
+
option_buttons,
|
16 |
+
restart_button
|
17 |
+
)
|
18 |
+
from src.frontend import sidebar
|
19 |
+
from src.utils.utils import (
|
20 |
+
get_session_dir,
|
21 |
+
save_information,
|
22 |
+
transform_story_name
|
23 |
+
)
|
24 |
+
|
25 |
+
story_name_textbox = gr.Text(
|
26 |
+
placeholder='Give your adventure a unique name!',
|
27 |
+
label='Story Name',
|
28 |
+
interactive=True,
|
29 |
+
key='story_name_textbox',
|
30 |
+
preserved_by_key='key',
|
31 |
+
visible=False
|
32 |
+
)
|
33 |
+
story_context_textbox = gr.Text(
|
34 |
+
placeholder='What kind of story would you like to adventure on?',
|
35 |
+
label='Story Context',
|
36 |
+
interactive=True,
|
37 |
+
key='story_context_textbox',
|
38 |
+
preserved_by_key='key',
|
39 |
+
visible=False
|
40 |
+
)
|
41 |
+
categories_context_textbox = gr.Text(
|
42 |
+
placeholder='What is the theme of the categories you want to have?',
|
43 |
+
label='Categories Context',
|
44 |
+
interactive=True,
|
45 |
+
key='categories_context_textbox',
|
46 |
+
preserved_by_key='key',
|
47 |
+
visible=False
|
48 |
+
)
|
49 |
+
num_questions_slider = gr.Slider(
|
50 |
+
minimum=1,
|
51 |
+
maximum=10,
|
52 |
+
value=5,
|
53 |
+
step=1,
|
54 |
+
label='How many questions would you want your journey to have?',
|
55 |
+
interactive=True,
|
56 |
+
visible=False,
|
57 |
+
key='num_questions_slider',
|
58 |
+
preserved_by_key='key'
|
59 |
+
)
|
60 |
+
num_options_slider = gr.Slider(
|
61 |
+
minimum=2,
|
62 |
+
maximum=4,
|
63 |
+
value=2,
|
64 |
+
step=1,
|
65 |
+
label='How many options would you like at each turn?',
|
66 |
+
interactive=True,
|
67 |
+
visible=False,
|
68 |
+
key='num_options_slider',
|
69 |
+
preserved_by_key='key'
|
70 |
+
)
|
71 |
+
num_categories_slider = gr.Slider(
|
72 |
+
minimum=2,
|
73 |
+
maximum=MAX_OPTIONS,
|
74 |
+
value=2,
|
75 |
+
step=1,
|
76 |
+
label='How many categories / endings would you like to have?',
|
77 |
+
interactive=True,
|
78 |
+
visible=False,
|
79 |
+
key='num_categories_slider',
|
80 |
+
preserved_by_key='key'
|
81 |
+
)
|
82 |
+
story_information_submit_button = gr.Button(
|
83 |
+
visible=False,
|
84 |
+
interactive=False,
|
85 |
+
key='story_information_submit_button',
|
86 |
+
preserved_by_key='key'
|
87 |
+
)
|
88 |
+
|
89 |
+
def get_widgets() -> List[Union[gr.Text, gr.Slider, gr.Button]]:
|
90 |
+
return [
|
91 |
+
story_name_textbox,
|
92 |
+
story_context_textbox,
|
93 |
+
categories_context_textbox,
|
94 |
+
num_questions_slider,
|
95 |
+
num_options_slider,
|
96 |
+
num_categories_slider,
|
97 |
+
story_information_submit_button
|
98 |
+
]
|
99 |
+
|
100 |
+
|
101 |
+
def get_widgets_updates(
|
102 |
+
is_visible: bool
|
103 |
+
) -> Dict[Union[gr.Slider, gr.Text, gr.Button], gr.update]:
|
104 |
+
return {
|
105 |
+
widget: gr.update(visible=is_visible)
|
106 |
+
for widget in get_widgets()
|
107 |
+
}
|
108 |
+
|
109 |
+
|
110 |
+
def check_story_name(story_name: str) -> bool:
|
111 |
+
story_name_dir = transform_story_name(story_name)
|
112 |
+
return not story_name_dir in os.listdir(get_session_dir())
|
113 |
+
|
114 |
+
|
115 |
+
def on_text_change(
|
116 |
+
story_name: str,
|
117 |
+
story_context: str,
|
118 |
+
categories_context: str
|
119 |
+
) -> gr.update:
|
120 |
+
if all([story_name, story_context, categories_context]) and \
|
121 |
+
check_story_name(story_name):
|
122 |
+
return gr.update(interactive=True)
|
123 |
+
return gr.update(interactive=False)
|
124 |
+
|
125 |
+
|
126 |
+
def save_story_information(
|
127 |
+
story_name: str,
|
128 |
+
story_context: str,
|
129 |
+
categories_context: str,
|
130 |
+
num_questions: int,
|
131 |
+
num_options: int,
|
132 |
+
num_categories: int
|
133 |
+
):
|
134 |
+
story_name_dir = transform_story_name(story_name)
|
135 |
+
story_dirpath = os.path.join(get_session_dir(), story_name_dir)
|
136 |
+
os.mkdir(story_dirpath)
|
137 |
+
|
138 |
+
story_information_filepath = os.path.join(story_dirpath, 'story.json')
|
139 |
+
story_information = {
|
140 |
+
'story_name': story_name,
|
141 |
+
'story_context': story_context,
|
142 |
+
'categories_context': categories_context,
|
143 |
+
'num_questions': num_questions,
|
144 |
+
'num_options': num_options,
|
145 |
+
'num_categories': num_categories
|
146 |
+
}
|
147 |
+
save_information(story_information, story_information_filepath)
|
148 |
+
|
149 |
+
|
150 |
+
def on_submit(
|
151 |
+
story_name: str,
|
152 |
+
story_context: str,
|
153 |
+
categories_context: str,
|
154 |
+
num_questions: int,
|
155 |
+
num_options: int,
|
156 |
+
num_categories: int
|
157 |
+
) -> Dict[Union[gr.Text, gr.Slider, gr.Button, gr.Chatbot], gr.update]:
|
158 |
+
|
159 |
+
save_story_information(
|
160 |
+
story_name,
|
161 |
+
story_context,
|
162 |
+
categories_context,
|
163 |
+
num_questions,
|
164 |
+
num_options,
|
165 |
+
num_categories
|
166 |
+
)
|
167 |
+
C.story_name = transform_story_name(story_name)
|
168 |
+
|
169 |
+
init_graph(
|
170 |
+
story_context,
|
171 |
+
categories_context,
|
172 |
+
num_questions,
|
173 |
+
num_options,
|
174 |
+
num_categories
|
175 |
+
)
|
176 |
+
|
177 |
+
chatbot_updates = on_chatbot_response([])
|
178 |
+
|
179 |
+
return chatbot_updates | get_widgets_updates(False) | sidebar.view_screen()
|
180 |
+
|
181 |
+
|
182 |
+
def render():
|
183 |
+
|
184 |
+
for widget in get_widgets():
|
185 |
+
widget.render()
|
186 |
+
|
187 |
+
if not isinstance(widget, gr.Text):
|
188 |
+
continue
|
189 |
+
widget.change(
|
190 |
+
fn=on_text_change,
|
191 |
+
inputs=[
|
192 |
+
story_name_textbox,
|
193 |
+
story_context_textbox,
|
194 |
+
categories_context_textbox
|
195 |
+
],
|
196 |
+
outputs=[story_information_submit_button]
|
197 |
+
)
|
198 |
+
|
199 |
+
story_information_submit_button.click(
|
200 |
+
fn=on_submit,
|
201 |
+
inputs=[
|
202 |
+
widget
|
203 |
+
for widget in get_widgets()
|
204 |
+
if not isinstance(widget, gr.Button)
|
205 |
+
],
|
206 |
+
outputs=get_widgets() + \
|
207 |
+
[chatbot, restart_button] + \
|
208 |
+
option_buttons + \
|
209 |
+
sidebar.get_widgets()
|
210 |
+
)
|
src/prompts/category_builder_agent.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
system_prompt_template = """
|
3 |
+
You are a creative writer.
|
4 |
+
Your task is to create categories based on a given context.
|
5 |
+
Your categories should be different from one another.
|
6 |
+
Keep the categories fun, light and interesting. Be as creative as possible.
|
7 |
+
|
8 |
+
You will be given the following:
|
9 |
+
1. `story context`, which describe the story we are about to take the user on,
|
10 |
+
2. `category context`, which describe what the general theme of the categories should be,
|
11 |
+
3. the `number of categories`
|
12 |
+
|
13 |
+
You can find the provided information later, where
|
14 |
+
1. `story context` can be found enclosed in <story> tags,
|
15 |
+
2. `category context` can be found enclosed in <category> tags,
|
16 |
+
3. `number of categories` can be found enclosed in <number> tags.
|
17 |
+
|
18 |
+
Ensure that the categories you generate are related to the story context.
|
19 |
+
When you are ready, output **all** the categories in **JSON** format.
|
20 |
+
For each category, they should have the following fields:
|
21 |
+
1. `name`, which describes the name of the category,
|
22 |
+
2. `description`: which gives a short description (at most three lines) of the category,
|
23 |
+
3. `traits`: a sequence of traits of the category. Separate the traits with a comma on the same line. Ensure that the traits have some relation to the story context.
|
24 |
+
|
25 |
+
Ensure that the focus is still on the `category context`.
|
26 |
+
A useful guideline is that the `name` of the category directly follows from the `category context`.
|
27 |
+
And then the `story context` can be used in the `description` and `traits`.
|
28 |
+
You only need to output the JSON information, no elaboration is required.
|
29 |
+
""".strip()
|
30 |
+
|
31 |
+
user_prompt_template = """
|
32 |
+
<story>
|
33 |
+
{story_context}
|
34 |
+
</story>
|
35 |
+
|
36 |
+
<category>
|
37 |
+
{category_context}
|
38 |
+
</category>
|
39 |
+
|
40 |
+
<number>
|
41 |
+
{num_categories}
|
42 |
+
</number>
|
43 |
+
""".strip()
|
src/prompts/evaluator_agent.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
system_prompt_template = """
|
3 |
+
You are an expert profiler, regardless of the theme.
|
4 |
+
We have just set our user through a round of questions, and they have provided us their responses.
|
5 |
+
Our task now is to allocate them into one category based on the scenario they were given, questions they were asked, and choices they made.
|
6 |
+
|
7 |
+
You will be given the following:
|
8 |
+
1. `user story`, where for each turn, you are given the SCENARIO, QUESTION and USER CHOICE,
|
9 |
+
2. `categories`, where for each category, you are given the name, description and traits of the category.
|
10 |
+
|
11 |
+
You can find the provided information later, where
|
12 |
+
1. `user story` can be found enclosed in <user> tags,
|
13 |
+
2. `categories` can be found enclosed in <categories> tags,
|
14 |
+
|
15 |
+
Read the user's story and determine which category best fits the user based on their choices.
|
16 |
+
|
17 |
+
When you are ready, output your answer in the below JSON format.
|
18 |
+
Follow the format strictly, as any deviation will result in errors later on.
|
19 |
+
You only need to output your answer, no elaboration is required.
|
20 |
+
`category` below refers to the category name of the selected category. Output the corresponding category name exactly.
|
21 |
+
For your reason, replace terms like \'the user\' with \'your\' or \'you\'. We will be passing your reason to the user.
|
22 |
+
```json
|
23 |
+
{{
|
24 |
+
"category":
|
25 |
+
"reason":
|
26 |
+
}}
|
27 |
+
```
|
28 |
+
""".strip()
|
29 |
+
|
30 |
+
user_prompt_template = """
|
31 |
+
<user>
|
32 |
+
{user_story}
|
33 |
+
</user>
|
34 |
+
|
35 |
+
<categories>
|
36 |
+
{category_contents}
|
37 |
+
</categories>
|
38 |
+
""".strip()
|
src/prompts/narrator_agent.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
system_prompt_template = """
|
3 |
+
You are a creative writer for a story, in charge of writing the lines users will view.
|
4 |
+
The story you are making is dependent on our user's choices. Do **not** rush to make a complete story.
|
5 |
+
|
6 |
+
We are building a turn-based call and response story, where we will give users a scenario, and they will respond with what they would choose to do.
|
7 |
+
Based on our user's response, and the context they give us, we will craft what the next scenario will be.
|
8 |
+
|
9 |
+
Your task is to craft a question for the user to answer based on the scenario provided to you, and the potential responses to the question you make.
|
10 |
+
Your question should allow the user to respond with actionable options.
|
11 |
+
Focus on creating a question that is in line with the given scenario and storyline so far.
|
12 |
+
Keep in mind the context of the categories we will assign the user to in the end.
|
13 |
+
There should be some harmony and cohesion between the story's context and the categories' context.
|
14 |
+
|
15 |
+
You will be given the following:
|
16 |
+
1. `story context`, which describes the theme of the story we want to build for our users,
|
17 |
+
2. `categories context`, which describes the theme of the categories we will assign our users to later,
|
18 |
+
3. `user story`, which is the current story shown to the user,
|
19 |
+
4. `scenario`, which is the scenario that your question should be based on,
|
20 |
+
5. `builder instruction`, which describes the suggested pace of the story,
|
21 |
+
6. `number of options`, which is the number of options the user has to respond to the question you pose. You will need to generate the options as well.
|
22 |
+
7. `existing next scenarios`, which is the current list of scenarios that follow this scenario. This will be empty if we are seeing the current scenario for the first time. You will be given each scenario in the below format.
|
23 |
+
```
|
24 |
+
{{
|
25 |
+
scenario_name:
|
26 |
+
scenario:
|
27 |
+
}}
|
28 |
+
```
|
29 |
+
|
30 |
+
You can find the provided information later, where
|
31 |
+
1. `story context` can be found enclosed in <story> tags,
|
32 |
+
2. `categories context` can be found enclosed in <category> tags.
|
33 |
+
3. `user story` can be found enclosed in <user> tags,
|
34 |
+
4. `scenario`, can be found enclosed in <scenario> tags,
|
35 |
+
5. `builder instruction`, can be found enclosed in <builder> tags,
|
36 |
+
6. `number of options`, can be found enclosed in <options> tags,
|
37 |
+
7. `existing next scenarios`, can be found enclosed in <existing> tags
|
38 |
+
|
39 |
+
First, think about a question to ask the user. This question should not be long winded.
|
40 |
+
Instead, aim for short questions that are easy for the users to respond to, but also compelling enough to keep them engaged.
|
41 |
+
Next, think about potential responses to the question.
|
42 |
+
The number of options you provide **must** be exactly the number of options provided, no more and no less.
|
43 |
+
Lastly, think about what suitable next scenarios each of the options can lead to.
|
44 |
+
Prioritise using **existing** scenarios, where possible.
|
45 |
+
Prioritise using the **same** scenario for multiple options, where possible.
|
46 |
+
If you want to use the same or existing scenario, you **must** use the exact same `scenario_name`.
|
47 |
+
If you want to create a new scenario, you **must** use a different `scenario_name` than the ones provided.
|
48 |
+
Although you do not know the final categories, ensure you provide options that are distinct enough such that the users can definitely be binned into different categories.
|
49 |
+
|
50 |
+
When you are ready, output your answer in the below JSON format.
|
51 |
+
Follow the format strictly, as any deviation will result in errors later on.
|
52 |
+
For your question, replace any spaces with the underbar character.
|
53 |
+
You only need to output your answer, no elaboration is required.
|
54 |
+
If there are angular brackets, it indicates to replace those fields with your answer.
|
55 |
+
Where there are ellipsis, it means to continue following the format.
|
56 |
+
```json
|
57 |
+
{{
|
58 |
+
"question":
|
59 |
+
"options": {{
|
60 |
+
"<option_1>": "<following scenario name>",
|
61 |
+
...
|
62 |
+
}}
|
63 |
+
}}
|
64 |
+
```
|
65 |
+
|
66 |
+
You are reminded to ensure that the number of options is exactly the number of options provided.
|
67 |
+
""".strip()
|
68 |
+
|
69 |
+
user_prompt_template = """
|
70 |
+
<story>
|
71 |
+
{story_context}
|
72 |
+
</story>
|
73 |
+
|
74 |
+
<category>
|
75 |
+
{category_context}
|
76 |
+
</category>
|
77 |
+
|
78 |
+
<user>
|
79 |
+
{user_story}
|
80 |
+
</user>
|
81 |
+
|
82 |
+
<scenario>
|
83 |
+
{scenario}
|
84 |
+
</scenario>
|
85 |
+
|
86 |
+
<builder>
|
87 |
+
{builder_instruction}
|
88 |
+
</builder>
|
89 |
+
|
90 |
+
<options>
|
91 |
+
{num_options}
|
92 |
+
</options>
|
93 |
+
|
94 |
+
<existing>
|
95 |
+
{existing_scenarios}
|
96 |
+
</existing>
|
97 |
+
""".strip()
|
src/prompts/orchestrator_agent.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
start_system_prompt_template = """
|
3 |
+
You are an orchestrator of a story board.
|
4 |
+
Your task is to provide instructions to your creatives about the flow and urgency of the story.
|
5 |
+
This means whether the creatives should start thinking about wrapping up the story, or if they have some time to explore a certain scenario.
|
6 |
+
|
7 |
+
We are building a turn-based call and response story, where we will give users a scenario, and they will respond with what they would choose to do.
|
8 |
+
Based on our user's response, and the context they give us, we will craft what the next scenario will be.
|
9 |
+
|
10 |
+
Your task is **not** to think about what to craft next.
|
11 |
+
Instead, your task is to guide the agent that is doing the creative work.
|
12 |
+
Your goal is to guide the story such that it does to feel too rushed or too drawn out, whilst ensuring a compelling and interesting story.
|
13 |
+
|
14 |
+
In this first step, we have just begun the process.
|
15 |
+
You will be given the following:
|
16 |
+
1. `story context`, which describes the theme of the story we want to build for our users,
|
17 |
+
2. `number of questions`, which describes the number of turns we have with our users. Use this as a gauge to determine if the given context is too lengthy or too short.
|
18 |
+
|
19 |
+
You can find the provided information later, where
|
20 |
+
1. `story context` can be found enclosed in <story> tags,
|
21 |
+
2. `number of questions` can be found enclosed in <number> tags.
|
22 |
+
|
23 |
+
Think about the context and the number of questions.
|
24 |
+
We both know that the creative agent will begin with some kind of introductory question.
|
25 |
+
But what we want to focus on is how fast this phase should be.
|
26 |
+
|
27 |
+
Your output is the instruction to our creative agent on the tempo of the story.
|
28 |
+
When the creative agent reads your instruction, they should know how fast to craft the scenario to give the users.
|
29 |
+
Keep the user's overall experience the focus.
|
30 |
+
|
31 |
+
Once you are ready, output your instruction to our creative agent.
|
32 |
+
No elaboration is required.
|
33 |
+
""".strip()
|
34 |
+
|
35 |
+
start_user_prompt_template = """
|
36 |
+
<story>
|
37 |
+
{story_context}
|
38 |
+
</story>
|
39 |
+
|
40 |
+
<number>
|
41 |
+
{num_questions}
|
42 |
+
</number>
|
43 |
+
""".strip()
|
44 |
+
|
45 |
+
coordinate_system_prompt_template = """
|
46 |
+
You are an orchestrator of a story board.
|
47 |
+
Your task is to provide instructions to your creatives about the flow and urgency of the story.
|
48 |
+
This means whether the creatives should start thinking about wrapping up the story, or if they have some time to explore a certain scenario.
|
49 |
+
|
50 |
+
We are building a turn-based call and response story, where we will give users a scenario, and they will respond with what they would choose to do.
|
51 |
+
Based on our user's response, and the context they give us, we will craft what the next scenario will be.
|
52 |
+
|
53 |
+
Your task is **not** to think about what to craft next.
|
54 |
+
Instead, your task is to guide the agent that is doing the creative work.
|
55 |
+
Your goal is to guide the story such that it does to feel too rushed or too drawn out, whilst ensuring a compelling and interesting story.
|
56 |
+
|
57 |
+
In this phase, we are some way into our story.
|
58 |
+
You will be given the following:
|
59 |
+
1. `story context`, which describes the theme of the story we want to build for our users,
|
60 |
+
2. `number of questions`, which describes the number of turns we have with our users. Use this as a gauge to determine if the given context is too lengthy or too short.
|
61 |
+
3. `phase message`: which is a pre-written phrase related to how far into the story the user is in. You do not need to calculate this, we will do it for you,
|
62 |
+
4. `user story`, which is the current story shown to the user.
|
63 |
+
|
64 |
+
You can find the provided information later, where
|
65 |
+
1. `story context` can be found enclosed in <story> tags,
|
66 |
+
2. `number of questions` can be found enclosed in <number> tags,
|
67 |
+
3. `phase message` can be found enclosed in <phase> tags,
|
68 |
+
4. `user story` can be found enclosed in <user> tags.
|
69 |
+
|
70 |
+
Think about the context and phase.
|
71 |
+
Our creative agent is getting ready to craft the next scenario to give to the user.
|
72 |
+
Our focus on is how fast scenario moves the story along.
|
73 |
+
|
74 |
+
Your output is the instruction to our creative agent on the tempo of the story.
|
75 |
+
When the creative agent reads your instruction, they should know how fast to craft the scenario to give the users.
|
76 |
+
Keep the user's overall experience the focus.
|
77 |
+
|
78 |
+
Once you are ready, output your instruction to our creative agent.
|
79 |
+
No elaboration is required.
|
80 |
+
""".strip()
|
81 |
+
|
82 |
+
coordinate_user_prompt_template = """
|
83 |
+
<story>
|
84 |
+
{story_context}
|
85 |
+
</story>
|
86 |
+
|
87 |
+
<number>
|
88 |
+
{num_questions}
|
89 |
+
</number>
|
90 |
+
|
91 |
+
<phase>
|
92 |
+
{phase_message}
|
93 |
+
</phase>
|
94 |
+
|
95 |
+
<user>
|
96 |
+
{user_story}
|
97 |
+
</user>
|
98 |
+
""".strip()
|
99 |
+
|
100 |
+
def get_phase_message(progress: float) -> str:
|
101 |
+
if progress < 0.2:
|
102 |
+
return 'We are still in the early phases of the story!'
|
103 |
+
if progress <= 0.5:
|
104 |
+
return 'We are approaching the halfway point, it\'s time to build towards a climax!'
|
105 |
+
if progress < 0.8:
|
106 |
+
return 'We are approaching the climax of the story! Let\'s quicken the pace.'
|
107 |
+
|
108 |
+
return 'The story is approaching the end! Let\'s wrap this up!'
|
src/prompts/storybuilder_agent.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
system_prompt_template = """
|
3 |
+
You are a creative writer for a story.
|
4 |
+
The story you are making is dependent on our user's choices. Do **not** rush to make a complete story.
|
5 |
+
|
6 |
+
We are building a turn-based call and response story, where we will give users a scenario, and they will respond with what they would choose to do.
|
7 |
+
Based on our user's response, and the context they give us, we will craft what the next scenario will be.
|
8 |
+
|
9 |
+
Your task is to craft a scenario for the user to think about.
|
10 |
+
Your task is **not** to craft the question to ask to the user.
|
11 |
+
Focus on creating a scenario that flows well with the story so far, and is in line with the story context.
|
12 |
+
Keep in mind the context of the categories we will assign the user to in the end.
|
13 |
+
There should be some harmony and cohesion between the story's context and the categories' context.
|
14 |
+
|
15 |
+
You will be given the following:
|
16 |
+
1. `story context`, which describes the theme of the story we want to build for our users,
|
17 |
+
2. `categories context`, which describes the theme of the categories we will assign our users to later,
|
18 |
+
3. `builder's instruction`, a set of instructions with respect to the tempo of the scenario you are going to create.
|
19 |
+
4. `user story`, which is the current story shown to the user,
|
20 |
+
5. `question`, which is the last question asked to the user, if we just started the process, this is empty.
|
21 |
+
6. `response`, which is the response to the question shown to the user. If we just started the process, this is empty.
|
22 |
+
|
23 |
+
You can find the provided information later, where
|
24 |
+
1. `story context` can be found enclosed in <story> tags,
|
25 |
+
2. `categories context` can be found enclosed in <category> tags.
|
26 |
+
3. `builder's instruction` can be found enclosed in <builder> tags.
|
27 |
+
4. `user story` can be found enclosed in <user> tags,
|
28 |
+
5. `question`, can be found enclosed in <question> tags,
|
29 |
+
6. `response`, can be found enclosed in <response> tags.
|
30 |
+
|
31 |
+
First, think about a fitting scenario that naturally follows the story the user is shown.
|
32 |
+
Next, think about some specifics of the scenario.
|
33 |
+
Limit the number of words in your scenario to be at most 50 words.
|
34 |
+
Feel free to use named characters and have them recur throughout this journey.
|
35 |
+
Keep the scenario light, fun and compelling for the user's to follow.
|
36 |
+
|
37 |
+
When you are ready, output **only** the scenario.
|
38 |
+
You are reminded that the scenario should **not** include the question to be asked to the user.
|
39 |
+
You do not need to provide any explanation.
|
40 |
+
""".strip()
|
41 |
+
|
42 |
+
user_prompt_template = """
|
43 |
+
<story>
|
44 |
+
{story_context}
|
45 |
+
</story>
|
46 |
+
|
47 |
+
<category>
|
48 |
+
{category_context}
|
49 |
+
</category>
|
50 |
+
|
51 |
+
<builder>
|
52 |
+
{builder_instruction}
|
53 |
+
</builder>
|
54 |
+
|
55 |
+
<user>
|
56 |
+
{user_story}
|
57 |
+
</user>
|
58 |
+
|
59 |
+
<question>
|
60 |
+
{question}
|
61 |
+
</question>
|
62 |
+
|
63 |
+
<response>
|
64 |
+
{response}
|
65 |
+
</response>
|
66 |
+
""".strip()
|
src/utils/init_openai.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
|
5 |
+
from openai import OpenAI, RateLimitError
|
6 |
+
from openai.types.chat.chat_completion import ChatCompletion
|
7 |
+
from typing import Any, Dict
|
8 |
+
|
9 |
+
client: OpenAI = None
|
10 |
+
MODEL_ = 'gpt-4o'
|
11 |
+
|
12 |
+
|
13 |
+
def init_client(api_key: str) -> bool:
|
14 |
+
global client
|
15 |
+
try:
|
16 |
+
client = OpenAI(api_key=api_key)
|
17 |
+
except:
|
18 |
+
return False
|
19 |
+
return True
|
20 |
+
|
21 |
+
|
22 |
+
# from OpenAI's website
|
23 |
+
# define a retry decorator
|
24 |
+
def __retry_with_exponential_backoff(
|
25 |
+
func,
|
26 |
+
initial_delay: float = 1,
|
27 |
+
exponential_base: float = 2,
|
28 |
+
jitter: bool = True,
|
29 |
+
max_retries: int = 10,
|
30 |
+
errors: tuple = (RateLimitError,),
|
31 |
+
):
|
32 |
+
"""Retry a function with exponential backoff."""
|
33 |
+
|
34 |
+
def wrapper(*args, **kwargs):
|
35 |
+
# Initialize variables
|
36 |
+
num_retries = 0
|
37 |
+
delay = initial_delay
|
38 |
+
|
39 |
+
# Loop until a successful response or max_retries is hit or an exception is raised
|
40 |
+
while True:
|
41 |
+
try:
|
42 |
+
return func(*args, **kwargs)
|
43 |
+
|
44 |
+
# Retry on specified errors
|
45 |
+
except errors as e:
|
46 |
+
# Increment retries
|
47 |
+
num_retries += 1
|
48 |
+
|
49 |
+
# Check if max retries has been reached
|
50 |
+
if num_retries > max_retries:
|
51 |
+
raise Exception(
|
52 |
+
f"Maximum number of retries ({max_retries}) exceeded."
|
53 |
+
)
|
54 |
+
|
55 |
+
# Increment the delay
|
56 |
+
delay *= exponential_base * (1 + jitter * random.random())
|
57 |
+
|
58 |
+
# Sleep for the delay
|
59 |
+
time.sleep(delay)
|
60 |
+
|
61 |
+
# Raise exceptions for any errors not specified
|
62 |
+
except Exception as e:
|
63 |
+
raise e
|
64 |
+
|
65 |
+
return wrapper
|
66 |
+
|
67 |
+
|
68 |
+
@__retry_with_exponential_backoff
|
69 |
+
def completions_with_backoff(
|
70 |
+
messages: Dict[str, Any],
|
71 |
+
temperature: float = 0,
|
72 |
+
max_tokens: int = 8192
|
73 |
+
) -> ChatCompletion:
|
74 |
+
kwargs = {
|
75 |
+
'model': MODEL_,
|
76 |
+
'messages': messages,
|
77 |
+
'temperature': temperature,
|
78 |
+
'max_tokens': max_tokens
|
79 |
+
}
|
80 |
+
return client.chat.completions.create(**kwargs)
|
src/utils/utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import uuid
|
5 |
+
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
|
9 |
+
def get_session_dir() -> str:
|
10 |
+
return os.path.join(
|
11 |
+
os.path.dirname(__file__),
|
12 |
+
'..', '..', 'sessions'
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def save_information(
|
17 |
+
information: Any,
|
18 |
+
filepath: str
|
19 |
+
) -> None:
|
20 |
+
|
21 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
22 |
+
json.dump(information, f, ensure_ascii=False, indent=4)
|
23 |
+
|
24 |
+
|
25 |
+
def load_information(filepath: str) -> Any:
|
26 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
27 |
+
content = json.load(f)
|
28 |
+
return content
|
29 |
+
|
30 |
+
|
31 |
+
def transform_story_name(
|
32 |
+
story_name: str,
|
33 |
+
is_inverse: bool = False
|
34 |
+
) -> str:
|
35 |
+
if is_inverse:
|
36 |
+
return ' '.join([word.capitalize() for word in story_name.split(' ')])
|
37 |
+
else:
|
38 |
+
return story_name.lower().replace(' ', '_')
|