Spaces:
Running
Running
Chandima Prabhath
commited on
Commit
·
db2f80b
1
Parent(s):
cc969ee
Refactor intent handling to use Pydantic models for strict parsing; update GenerateImageIntent to inherit from BaseModel and improve route_intent function for better error handling and data extraction.
Browse files
app.py
CHANGED
|
@@ -330,7 +330,7 @@ class PollResultsIntent(BaseIntent):
|
|
| 330 |
class PollEndIntent(BaseIntent):
|
| 331 |
action: Literal["poll_end"]
|
| 332 |
|
| 333 |
-
class GenerateImageIntent(
|
| 334 |
action: Literal["generate_image"]
|
| 335 |
prompt: str
|
| 336 |
count: int = Field(default=1, ge=1)
|
|
@@ -341,7 +341,8 @@ class SendTextIntent(BaseIntent):
|
|
| 341 |
action: Literal["send_text"]
|
| 342 |
message: str
|
| 343 |
|
| 344 |
-
|
|
|
|
| 345 |
SummarizeIntent, TranslateIntent, JokeIntent, WeatherIntent,
|
| 346 |
InspireIntent, MemeIntent, PollCreateIntent, PollVoteIntent,
|
| 347 |
PollResultsIntent, PollEndIntent, GenerateImageIntent, SendTextIntent
|
|
@@ -364,7 +365,7 @@ ACTION_HANDLERS = {
|
|
| 364 |
|
| 365 |
# --- Intent Routing with Fallback ------------------------------------------
|
| 366 |
|
| 367 |
-
def route_intent(user_input: str, chat_id: str, sender: str)
|
| 368 |
history_text = get_history_text(chat_id, sender)
|
| 369 |
sys_prompt = (
|
| 370 |
"You are Eve. You can either chat or call one of these functions:\n"
|
|
@@ -384,53 +385,56 @@ def route_intent(user_input: str, chat_id: str, sender: str) -> IntentUnion:
|
|
| 384 |
" {\"action\":\"generate_image\",\"prompt\":\"a red fox\",\"count\":3,\"width\":512,\"height\":512}\n"
|
| 385 |
"Otherwise, use send_text to reply with plain chat.\n"
|
| 386 |
)
|
| 387 |
-
prompt =
|
| 388 |
-
f"{sys_prompt}\n"
|
| 389 |
-
f"Conversation so far:\n{history_text}\n\n"
|
| 390 |
-
f"User: {user_input}"
|
| 391 |
-
)
|
| 392 |
raw = generate_llm(prompt)
|
| 393 |
|
| 394 |
-
# 1)
|
| 395 |
try:
|
| 396 |
parsed = json.loads(raw)
|
| 397 |
-
intent = IntentUnion.parse_obj(parsed)
|
| 398 |
-
return intent
|
| 399 |
-
except (json.JSONDecodeError, ValidationError) as e:
|
| 400 |
-
logger.warning(f"Strict parse failed: {e}. Falling back to lenient.")
|
| 401 |
-
|
| 402 |
-
# 2) Lenient: basic JSON get + defaults
|
| 403 |
-
try:
|
| 404 |
-
data = json.loads(raw)
|
| 405 |
except json.JSONDecodeError:
|
| 406 |
return SendTextIntent(action="send_text", message=raw)
|
| 407 |
|
| 408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
if action in ACTION_HANDLERS:
|
|
|
|
| 410 |
kwargs = {}
|
| 411 |
if action == "generate_image":
|
| 412 |
-
kwargs["prompt"] = data.get("prompt",
|
| 413 |
kwargs["count"] = int(data.get("count", BotConfig.DEFAULT_IMAGE_COUNT))
|
| 414 |
kwargs["width"] = data.get("width")
|
| 415 |
kwargs["height"] = data.get("height")
|
| 416 |
elif action == "send_text":
|
| 417 |
-
kwargs["message"] = data.get("message",
|
| 418 |
elif action == "translate":
|
| 419 |
-
kwargs["lang"] = data.get("lang",
|
| 420 |
-
kwargs["text"] = data.get("text",
|
| 421 |
elif action == "summarize":
|
| 422 |
-
kwargs["text"] = data.get("text",
|
| 423 |
elif action == "weather":
|
| 424 |
-
kwargs["location"] = data.get("location",
|
| 425 |
elif action == "meme":
|
| 426 |
-
kwargs["text"] = data.get("text",
|
| 427 |
elif action == "poll_create":
|
| 428 |
-
kwargs["question"] = data.get("question",
|
| 429 |
-
kwargs["options"] = data.get("options",
|
| 430 |
elif action == "poll_vote":
|
| 431 |
kwargs["voter"] = sender
|
| 432 |
-
kwargs["choice"] = int(data.get("choice",
|
| 433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
|
| 435 |
return SendTextIntent(action="send_text", message=raw)
|
| 436 |
|
|
@@ -531,9 +535,7 @@ async def whatsapp_webhook(request: Request):
|
|
| 531 |
if tmd.get("contextInfo", {}).get("mentionedJidList"):
|
| 532 |
return {"success": True}
|
| 533 |
|
| 534 |
-
# Handle quoted replies if needed...
|
| 535 |
effective = body
|
| 536 |
-
|
| 537 |
intent = route_intent(effective, chat_id, sender)
|
| 538 |
handler = ACTION_HANDLERS.get(intent.action)
|
| 539 |
if handler:
|
|
|
|
| 330 |
class PollEndIntent(BaseIntent):
|
| 331 |
action: Literal["poll_end"]
|
| 332 |
|
| 333 |
+
class GenerateImageIntent(BaseModel):
|
| 334 |
action: Literal["generate_image"]
|
| 335 |
prompt: str
|
| 336 |
count: int = Field(default=1, ge=1)
|
|
|
|
| 341 |
action: Literal["send_text"]
|
| 342 |
message: str
|
| 343 |
|
| 344 |
+
# list of all intent models
|
| 345 |
+
INTENT_MODELS = [
|
| 346 |
SummarizeIntent, TranslateIntent, JokeIntent, WeatherIntent,
|
| 347 |
InspireIntent, MemeIntent, PollCreateIntent, PollVoteIntent,
|
| 348 |
PollResultsIntent, PollEndIntent, GenerateImageIntent, SendTextIntent
|
|
|
|
| 365 |
|
| 366 |
# --- Intent Routing with Fallback ------------------------------------------
|
| 367 |
|
| 368 |
+
def route_intent(user_input: str, chat_id: str, sender: str):
|
| 369 |
history_text = get_history_text(chat_id, sender)
|
| 370 |
sys_prompt = (
|
| 371 |
"You are Eve. You can either chat or call one of these functions:\n"
|
|
|
|
| 385 |
" {\"action\":\"generate_image\",\"prompt\":\"a red fox\",\"count\":3,\"width\":512,\"height\":512}\n"
|
| 386 |
"Otherwise, use send_text to reply with plain chat.\n"
|
| 387 |
)
|
| 388 |
+
prompt = f"{sys_prompt}\nConversation so far:\n{history_text}\n\nUser: {user_input}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
raw = generate_llm(prompt)
|
| 390 |
|
| 391 |
+
# 1) Strict: try each Pydantic model
|
| 392 |
try:
|
| 393 |
parsed = json.loads(raw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
except json.JSONDecodeError:
|
| 395 |
return SendTextIntent(action="send_text", message=raw)
|
| 396 |
|
| 397 |
+
for M in INTENT_MODELS:
|
| 398 |
+
try:
|
| 399 |
+
intent = M.parse_obj(parsed)
|
| 400 |
+
return intent
|
| 401 |
+
except ValidationError:
|
| 402 |
+
continue
|
| 403 |
+
|
| 404 |
+
logger.warning("Strict parse failed for all models, falling back to lenient")
|
| 405 |
+
|
| 406 |
+
# 2) Lenient JSON get
|
| 407 |
+
action = parsed.get("action")
|
| 408 |
if action in ACTION_HANDLERS:
|
| 409 |
+
data = parsed
|
| 410 |
kwargs = {}
|
| 411 |
if action == "generate_image":
|
| 412 |
+
kwargs["prompt"] = data.get("prompt","")
|
| 413 |
kwargs["count"] = int(data.get("count", BotConfig.DEFAULT_IMAGE_COUNT))
|
| 414 |
kwargs["width"] = data.get("width")
|
| 415 |
kwargs["height"] = data.get("height")
|
| 416 |
elif action == "send_text":
|
| 417 |
+
kwargs["message"] = data.get("message","")
|
| 418 |
elif action == "translate":
|
| 419 |
+
kwargs["lang"] = data.get("lang","")
|
| 420 |
+
kwargs["text"] = data.get("text","")
|
| 421 |
elif action == "summarize":
|
| 422 |
+
kwargs["text"] = data.get("text","")
|
| 423 |
elif action == "weather":
|
| 424 |
+
kwargs["location"] = data.get("location","")
|
| 425 |
elif action == "meme":
|
| 426 |
+
kwargs["text"] = data.get("text","")
|
| 427 |
elif action == "poll_create":
|
| 428 |
+
kwargs["question"] = data.get("question","")
|
| 429 |
+
kwargs["options"] = data.get("options",[])
|
| 430 |
elif action == "poll_vote":
|
| 431 |
kwargs["voter"] = sender
|
| 432 |
+
kwargs["choice"] = int(data.get("choice",0))
|
| 433 |
+
# parse into Pydantic for uniformity
|
| 434 |
+
try:
|
| 435 |
+
return next(M for M in INTENT_MODELS if getattr(M, "__fields__", {}).get("action").default == action).parse_obj({"action":action,**kwargs})
|
| 436 |
+
except Exception:
|
| 437 |
+
return SendTextIntent(action="send_text", message=raw)
|
| 438 |
|
| 439 |
return SendTextIntent(action="send_text", message=raw)
|
| 440 |
|
|
|
|
| 535 |
if tmd.get("contextInfo", {}).get("mentionedJidList"):
|
| 536 |
return {"success": True}
|
| 537 |
|
|
|
|
| 538 |
effective = body
|
|
|
|
| 539 |
intent = route_intent(effective, chat_id, sender)
|
| 540 |
handler = ACTION_HANDLERS.get(intent.action)
|
| 541 |
if handler:
|