{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "2c8fa27a", "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "import json\n", "import re\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline\n", "import torch" ] }, { "cell_type": "code", "execution_count": null, "id": "4466940f", "metadata": {}, "outputs": [], "source": [ "model_name = \"unsloth/DeepSeek-R1-Distill-Qwen-1.5B\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " torch_dtype=torch.float16,\n", " device_map=\"auto\",\n", " low_cpu_mem_usage=True\n", ")\n", "\n", "chat = pipeline(\n", " \"text-generation\",\n", " model=model,\n", " tokenizer=tokenizer,\n", " max_length=512,\n", " temperature=0.7,\n", " do_sample=True,\n", " device=0\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "eeb02955", "metadata": {}, "outputs": [], "source": [ "system_prompt = \"\"\"You are a helpful assistant guiding a user through the Boston Public Schools registration process.\n", "You are given:\n", "1. The user's most recent message\n", "2. The current known registration info (`info`) — provided as a JSON object\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "c801c862", "metadata": {}, "outputs": [], "source": [ "def extract_response_and_update(text):\n", " think = re.search(r\"(.*?)\", text, re.DOTALL)\n", " resp = re.search(r\"(.*?)\", text, re.DOTALL)\n", " upd = re.search(r\"(.*?)\", text, re.DOTALL)\n", "\n", " out_text = resp.group(1).strip() if resp else text\n", " try:\n", " update = json.loads(upd.group(1)) if upd else {}\n", " except json.JSONDecodeError:\n", " update = {}\n", " return out_text, update\n", "\n", "info = {\n", " \"location\": None,\n", " \"school\": None,\n", " \"child\": {\n", " \"name\": None,\n", " \"age\": None,\n", " \"grade\": None,\n", " \"special_needs\": None,\n", " \"transferring\": None\n", " },\n", " \"residency_docs\": []\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "c3dfb0c4", "metadata": {}, "outputs": [], "source": [ "def chat_fn(user_message, chat_history):\n", " full = system_prompt + \"\\n<|user|>\\n\" + user_message + \"\\n<|assistant|>\"\n", " raw = chat(full)[0][\"generated_text\"].strip()\n", " resp_text, update = extract_response_and_update(raw)\n", "\n", " def merge(existing, upd):\n", " for k, v in upd.items():\n", " if isinstance(v, dict) and k in existing:\n", " merge(existing[k], v)\n", " else:\n", " existing[k] = v\n", " merge(info, update)\n", "\n", " chat_history = chat_history or []\n", " chat_history.append((user_message, resp_text))\n", " return chat_history, chat_history" ] }, { "cell_type": "code", "execution_count": null, "id": "8b676221", "metadata": {}, "outputs": [], "source": [ "demo = gr.ChatInterface(\n", " fn=chat_fn,\n", " title=\"Boston School Choice\",\n", " description=\"Ask me anything about Boston Public Schools registration\",\n", ")\n", "\n", "demo.launch(inline=True)" ] } ], "metadata": {}, "nbformat": 4, "nbformat_minor": 5 }