pankajmathur commited on
Commit
406e6c0
·
verified ·
1 Parent(s): 6219f36

Upload Orca_Mini_Chatbot_1B.ipynb

Browse files
Files changed (1) hide show
  1. Orca_Mini_Chatbot_1B.ipynb +186 -0
Orca_Mini_Chatbot_1B.ipynb ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {
23
+ "id": "TFu_ibC1eYrz"
24
+ },
25
+ "outputs": [],
26
+ "source": [
27
+ "!pip install torch transformers -q"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "source": [
33
+ "import torch\n",
34
+ "from transformers import pipeline\n",
35
+ "from IPython.display import clear_output\n",
36
+ "from google.colab import output"
37
+ ],
38
+ "metadata": {
39
+ "id": "Zs7QNs0Tet6r"
40
+ },
41
+ "execution_count": null,
42
+ "outputs": []
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "source": [
47
+ "class ChatBot:\n",
48
+ " _instance = None\n",
49
+ " _current_model = None\n",
50
+ "\n",
51
+ " def __init__(self, model_slug=None):\n",
52
+ " if model_slug and model_slug != ChatBot._current_model:\n",
53
+ " self.load_model(model_slug)\n",
54
+ " ChatBot._current_model = model_slug\n",
55
+ "\n",
56
+ " self.messages = []\n",
57
+ " self.max_tokens = 2048\n",
58
+ " self.temperature = 0.01\n",
59
+ " self.top_k = 100\n",
60
+ " self.top_p = 0.95\n",
61
+ "\n",
62
+ " @classmethod\n",
63
+ " def get_instance(cls, model_slug=None):\n",
64
+ " if not cls._instance or (model_slug and model_slug != cls._current_model):\n",
65
+ " cls._instance = cls(model_slug)\n",
66
+ " return cls._instance\n",
67
+ "\n",
68
+ " def load_model(self, model_slug):\n",
69
+ " print(f\"Loading model {model_slug}...\")\n",
70
+ " self.pipeline = pipeline(\n",
71
+ " \"text-generation\",\n",
72
+ " model=model_slug,\n",
73
+ " device_map=\"auto\",\n",
74
+ " )\n",
75
+ " clear_output()\n",
76
+ " print(\"Model loaded successfully!\")\n",
77
+ "\n",
78
+ " def reset_conversation(self, system_message):\n",
79
+ " \"\"\"Reset the conversation with a new system message\"\"\"\n",
80
+ " self.messages = [{\"role\": \"system\", \"content\": system_message}]\n",
81
+ "\n",
82
+ " def get_response(self, user_input):\n",
83
+ " \"\"\"Get response with current parameters\"\"\"\n",
84
+ " self.messages.append({\"role\": \"user\", \"content\": user_input})\n",
85
+ " outputs = self.pipeline(\n",
86
+ " self.messages,\n",
87
+ " max_new_tokens=self.max_tokens,\n",
88
+ " do_sample=True,\n",
89
+ " temperature=self.temperature,\n",
90
+ " top_k=self.top_k,\n",
91
+ " top_p=self.top_p\n",
92
+ " )\n",
93
+ " response = outputs[0][\"generated_text\"][-1]\n",
94
+ " content = response.get('content', 'No content available')\n",
95
+ " self.messages.append({\"role\": \"assistant\", \"content\": content})\n",
96
+ " return content\n",
97
+ "\n",
98
+ " def update_params(self, max_tokens=None, temperature=None, top_k=None, top_p=None):\n",
99
+ " \"\"\"Update generation parameters\"\"\"\n",
100
+ " if max_tokens is not None:\n",
101
+ " self.max_tokens = max_tokens\n",
102
+ " if temperature is not None:\n",
103
+ " self.temperature = temperature\n",
104
+ " if top_k is not None:\n",
105
+ " self.top_k = top_k\n",
106
+ " if top_p is not None:\n",
107
+ " self.top_p = top_p"
108
+ ],
109
+ "metadata": {
110
+ "id": "v4uIN6uIeyl3"
111
+ },
112
+ "execution_count": null,
113
+ "outputs": []
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "source": [
118
+ "def run_chatbot(\n",
119
+ " model=None,\n",
120
+ " system_message=\"You are Orca Mini, You are expert in following given instructions, Think step by step before coming up with final answer\",\n",
121
+ " max_tokens=None,\n",
122
+ " temperature=None,\n",
123
+ " top_k=None,\n",
124
+ " top_p=None,\n",
125
+ "):\n",
126
+ " try:\n",
127
+ " # Get or create chatbot instance\n",
128
+ " chatbot = ChatBot.get_instance(model)\n",
129
+ "\n",
130
+ " # Update parameters if provided\n",
131
+ " chatbot.update_params(max_tokens, temperature, top_k, top_p)\n",
132
+ "\n",
133
+ " # Reset conversation with new system message\n",
134
+ " chatbot.reset_conversation(system_message)\n",
135
+ "\n",
136
+ " print(\"Chatbot: Hi! Type 'quit' to exit.\")\n",
137
+ "\n",
138
+ " while True:\n",
139
+ " user_input = input(\"You: \").strip()\n",
140
+ " if user_input.lower() == 'quit':\n",
141
+ " break\n",
142
+ " try:\n",
143
+ " response = chatbot.get_response(user_input)\n",
144
+ " print(\"Chatbot:\", response)\n",
145
+ " except Exception as e:\n",
146
+ " print(f\"Chatbot: An error occurred: {str(e)}\")\n",
147
+ " print(\"Please try again.\")\n",
148
+ "\n",
149
+ " except Exception as e:\n",
150
+ " print(f\"Error in chatbot: {str(e)}\")"
151
+ ],
152
+ "metadata": {
153
+ "id": "H2n_6Xcue3Vn"
154
+ },
155
+ "execution_count": null,
156
+ "outputs": []
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "source": [
161
+ "run_chatbot(model=\"pankajmathur/orca_mini_v9_6_1B-Instruct\")"
162
+ ],
163
+ "metadata": {
164
+ "id": "JEqgoAH2fC6h"
165
+ },
166
+ "execution_count": null,
167
+ "outputs": []
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "source": [
172
+ "# # change system message\n",
173
+ "# run_chatbot(\n",
174
+ "# system_message=\"You are Orca Mini, You are expert in logic, Think step by step before coming up with final answer\",\n",
175
+ "# max_tokens=1024,\n",
176
+ "# temperature=0.3\n",
177
+ "# )"
178
+ ],
179
+ "metadata": {
180
+ "id": "tGW8wsfAfHDf"
181
+ },
182
+ "execution_count": null,
183
+ "outputs": []
184
+ }
185
+ ]
186
+ }