File size: 3,840 Bytes
a524d54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c8fa27a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gradio as gr\n",
    "import json\n",
    "import re\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4466940f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"unsloth/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"auto\",\n",
    "    low_cpu_mem_usage=True\n",
    ")\n",
    "\n",
    "chat = pipeline(\n",
    "    \"text-generation\",\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    max_length=512,\n",
    "    temperature=0.7,\n",
    "    do_sample=True,\n",
    "    device=0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeb02955",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_prompt = \"\"\"You are a helpful assistant guiding a user through the Boston Public Schools registration process.\n",
    "You are given:\n",
    "1. The user's most recent message\n",
    "2. The current known registration info (`info`) — provided as a JSON object\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c801c862",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_response_and_update(text):\n",
    "    think = re.search(r\"<think>(.*?)</think>\", text, re.DOTALL)\n",
    "    resp  = re.search(r\"<response>(.*?)</response>\", text, re.DOTALL)\n",
    "    upd   = re.search(r\"<update>(.*?)</update>\", text, re.DOTALL)\n",
    "\n",
    "    out_text = resp.group(1).strip() if resp else text\n",
    "    try:\n",
    "        update = json.loads(upd.group(1)) if upd else {}\n",
    "    except json.JSONDecodeError:\n",
    "        update = {}\n",
    "    return out_text, update\n",
    "\n",
    "info = {\n",
    "    \"location\": None,\n",
    "    \"school\": None,\n",
    "    \"child\": {\n",
    "        \"name\": None,\n",
    "        \"age\": None,\n",
    "        \"grade\": None,\n",
    "        \"special_needs\": None,\n",
    "        \"transferring\": None\n",
    "    },\n",
    "    \"residency_docs\": []\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3dfb0c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat_fn(user_message, chat_history):\n",
    "    full = system_prompt + \"\\n<|user|>\\n\" + user_message + \"\\n<|assistant|>\"\n",
    "    raw = chat(full)[0][\"generated_text\"].strip()\n",
    "    resp_text, update = extract_response_and_update(raw)\n",
    "\n",
    "    def merge(existing, upd):\n",
    "        for k, v in upd.items():\n",
    "            if isinstance(v, dict) and k in existing:\n",
    "                merge(existing[k], v)\n",
    "            else:\n",
    "                existing[k] = v\n",
    "    merge(info, update)\n",
    "\n",
    "    chat_history = chat_history or []\n",
    "    chat_history.append((user_message, resp_text))\n",
    "    return chat_history, chat_history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b676221",
   "metadata": {},
   "outputs": [],
   "source": [
    "demo = gr.ChatInterface(\n",
    "    fn=chat_fn,\n",
    "    title=\"Boston School Choice\",\n",
    "    description=\"Ask me anything about Boston Public Schools registration\",\n",
    ")\n",
    "\n",
    "demo.launch(inline=True)"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 5
}