JoeArmani
commited on
Commit
·
9b268d0
1
Parent(s):
c7c1b4e
finalize Gradio updates
Browse files- app.py +145 -0
- readme.md +11 -35
- requirements.txt +28 -26
app.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
from chatbot_config import ChatbotConfig
|
| 7 |
+
from chatbot_model import RetrievalChatbot
|
| 8 |
+
from tf_data_pipeline import TFDataPipeline
|
| 9 |
+
from response_quality_checker import ResponseQualityChecker
|
| 10 |
+
from environment_setup import EnvironmentSetup
|
| 11 |
+
from sentence_transformers import SentenceTransformer
|
| 12 |
+
from logger_config import config_logger
|
| 13 |
+
|
| 14 |
+
logger = config_logger(__name__)
|
| 15 |
+
|
| 16 |
+
def load_pipeline():
|
| 17 |
+
"""
|
| 18 |
+
Loads config, FAISS index, response pool, SentenceTransformer, TFDataPipeline, and sets up the chatbot.
|
| 19 |
+
"""
|
| 20 |
+
MODEL_DIR = "models"
|
| 21 |
+
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
|
| 22 |
+
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
|
| 23 |
+
RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
|
| 24 |
+
|
| 25 |
+
config_path = Path(MODEL_DIR) / "config.json"
|
| 26 |
+
if config_path.exists():
|
| 27 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 28 |
+
config_dict = json.load(f)
|
| 29 |
+
config = ChatbotConfig.from_dict(config_dict)
|
| 30 |
+
else:
|
| 31 |
+
config = ChatbotConfig()
|
| 32 |
+
|
| 33 |
+
# Initialize environment
|
| 34 |
+
env = EnvironmentSetup()
|
| 35 |
+
env.initialize()
|
| 36 |
+
|
| 37 |
+
# Load models and data
|
| 38 |
+
encoder = SentenceTransformer(config.pretrained_model)
|
| 39 |
+
|
| 40 |
+
data_pipeline = TFDataPipeline(
|
| 41 |
+
config=config,
|
| 42 |
+
tokenizer=encoder.tokenizer,
|
| 43 |
+
encoder=encoder,
|
| 44 |
+
response_pool=[],
|
| 45 |
+
query_embeddings_cache={},
|
| 46 |
+
index_type='IndexFlatIP',
|
| 47 |
+
faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Load FAISS index and response pool
|
| 51 |
+
if os.path.exists(FAISS_INDEX_PRODUCTION_PATH) and os.path.exists(RESPONSE_POOL_PATH):
|
| 52 |
+
data_pipeline.load_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
|
| 53 |
+
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
| 54 |
+
data_pipeline.response_pool = json.load(f)
|
| 55 |
+
data_pipeline.validate_faiss_index()
|
| 56 |
+
else:
|
| 57 |
+
logger.warning("FAISS index or responses are missing. The chatbot may not work properly.")
|
| 58 |
+
|
| 59 |
+
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
|
| 60 |
+
quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)
|
| 61 |
+
|
| 62 |
+
return chatbot, quality_checker
|
| 63 |
+
|
| 64 |
+
# Load the chatbot and quality checker globally
|
| 65 |
+
chatbot, quality_checker = load_pipeline()
|
| 66 |
+
|
| 67 |
+
def respond(message: str, history: List[List[str]]) -> Tuple[str, List[List[str]]]:
|
| 68 |
+
"""Generate chatbot response using internal context handling."""
|
| 69 |
+
if not message.strip(): # Skip
|
| 70 |
+
return "", history
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
response, _, metrics, confidence = chatbot.chat(
|
| 74 |
+
query=message,
|
| 75 |
+
conversation_history=None, # Handled internally
|
| 76 |
+
quality_checker=quality_checker,
|
| 77 |
+
top_k=10
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
history.append((message, response))
|
| 81 |
+
return "", history
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"Error generating response: {e}")
|
| 84 |
+
error_message = "I apologize, but I encountered an error processing your request."
|
| 85 |
+
history.append((message, error_message))
|
| 86 |
+
return "", history
|
| 87 |
+
|
| 88 |
+
def main():
|
| 89 |
+
"""Initialize and launch Gradio interface."""
|
| 90 |
+
with gr.Blocks(
|
| 91 |
+
title="Chatbot Demo",
|
| 92 |
+
css="""
|
| 93 |
+
.message-wrap { max-height: 800px !important; }
|
| 94 |
+
.chatbot { min-height: 600px; }
|
| 95 |
+
"""
|
| 96 |
+
) as demo:
|
| 97 |
+
gr.Markdown(
|
| 98 |
+
"""
|
| 99 |
+
# Retrieval-Based Chatbot Demo using Sentence Transformers + FAISS
|
| 100 |
+
Knowledge areas: restaurants, movie tickets, rideshare, coffee, pizza, and auto repair.
|
| 101 |
+
"""
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Chat interface with custom height
|
| 105 |
+
chatbot = gr.Chatbot(
|
| 106 |
+
label="Conversation",
|
| 107 |
+
container=True,
|
| 108 |
+
height=600,
|
| 109 |
+
show_label=True,
|
| 110 |
+
elem_classes="chatbot"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Input area with send button
|
| 114 |
+
with gr.Row():
|
| 115 |
+
msg = gr.Textbox(
|
| 116 |
+
show_label=False,
|
| 117 |
+
placeholder="Type your message here...",
|
| 118 |
+
container=False,
|
| 119 |
+
scale=8
|
| 120 |
+
)
|
| 121 |
+
send = gr.Button(
|
| 122 |
+
"Send",
|
| 123 |
+
variant="primary",
|
| 124 |
+
scale=1,
|
| 125 |
+
min_width=100
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
clear = gr.Button("Clear Conversation", variant="secondary")
|
| 129 |
+
|
| 130 |
+
# Event handlers
|
| 131 |
+
msg.submit(respond, [msg, chatbot], [msg, chatbot], queue=False)
|
| 132 |
+
send.click(respond, [msg, chatbot], [msg, chatbot], queue=False)
|
| 133 |
+
clear.click(lambda: ([], []), outputs=[chatbot, msg], queue=False)
|
| 134 |
+
|
| 135 |
+
# Responsive interface
|
| 136 |
+
msg.change(lambda: None, None, None, queue=False)
|
| 137 |
+
|
| 138 |
+
return demo
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
demo = main()
|
| 142 |
+
demo.launch(
|
| 143 |
+
server_name="0.0.0.0",
|
| 144 |
+
server_port=7860,
|
| 145 |
+
)
|
readme.md
CHANGED
|
@@ -1,42 +1,18 @@
|
|
| 1 |
-
# Retrieval
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
##
|
| 6 |
-
|
| 7 |
-
A Python tool to generate high-quality dialog variations.
|
| 8 |
-
|
| 9 |
-
This package automatically downloads the following models during installation:
|
| 10 |
-
|
| 11 |
-
- Universal Sentence Encoder v4 (TensorFlow Hub)
|
| 12 |
-
- ChatGPT Paraphraser T5-base
|
| 13 |
-
- Helsinki-NLP translation models (en-de, de-es, es-en)
|
| 14 |
-
- spaCy en_core_web_sm, eng_core_web_md
|
| 15 |
-
- nltk wordnet and averaged_perceptron_tagger_eng models
|
| 16 |
-
|
| 17 |
-
## Install package
|
| 18 |
|
| 19 |
-
|
| 20 |
|
| 21 |
-
##
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
Two approaches are used for text augmentation: paraphrasing and back-translation. The pipeline also includes quality metrics for evaluating the augmented text.
|
| 26 |
-
Special handling is implemented for very short text such as greetings and farewells, which are predefined and filtered for quality.
|
| 27 |
-
The pipeline is designed to process a dataset of dialogues and generate multiple high-quality augmented versions of each dialogue.
|
| 28 |
-
The pipeline ensures duplicate dialogues are not generated and that the output meets quality thresholds for semantic similarity, grammar, fluency, diversity, and content preservation.
|
| 29 |
|
| 30 |
-
##
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
Helsinki-NLP. (2024). Opus-MT [Computer software]. GitHub. <https://github.com/Helsinki-NLP/Opus-MT>
|
| 36 |
-
Hugging Face. (n.d.). Transformers. Hugging Face. <https://huggingface.co/docs/transformers/en/index>
|
| 37 |
-
Humarin. (2023). ChatGPT paraphraser on T5-base [Computer software]. Hugging Face. <https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base>
|
| 38 |
-
Keita, Z. (2022). Data augmentation in NLP using back-translation with MarianMT. Towards Data Science. <https://towardsdatascience.com/data-augmentation-in-nlp-using-back-translation-with-marianmt-a8939dfea50a>
|
| 39 |
-
Memgraph. (2023). Cosine similarity in Python with scikit-learn. Memgraph. <https://memgraph.com/blog/cosine-similarity-python-scikit-learn>
|
| 40 |
-
Morris, J. (n.d.). language-tool-python (Version 2.8.1) [Computer software]. PyPI. <https://pypi.org/project/language-tool-python/>
|
| 41 |
-
TensorFlow. (n.d.). Universal sentence encoder. TensorFlow Hub. <https://www.tensorflow.org/hub/tutorials/semantic_similarity_with_tf_hub_universal_encoder>
|
| 42 |
-
Waheed, A. (2023). How to calculate ROUGE score in Python. Python Code. <https://thepythoncode.com/article/calculate-rouge-score-in-python>
|
|
|
|
| 1 |
+
# CSC252 Retrieval Chatbot
|
| 2 |
|
| 3 |
+
This is a retrieval-based chatbot using Sentence Transformers and FAISS for efficient similarity search.
|
| 4 |
|
| 5 |
+
## Description
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
The chatbot uses a pre-trained Sentence Transformer model to encode queries and a FAISS index to retrieve relevant responses from a curated response pool (Taskmaster-1 dataset)
|
| 8 |
|
| 9 |
+
## Usage
|
| 10 |
|
| 11 |
+
Simply type your question in the chat interface and the bot will retrieve the most relevant response from its knowledge base.
|
| 12 |
+
Features
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
## Semantic search using Sentence Transformers
|
| 15 |
|
| 16 |
+
Efficient retrieval using FAISS indexing
|
| 17 |
+
Context-aware responses
|
| 18 |
+
Quality checking of responses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,27 +1,29 @@
|
|
| 1 |
-
faiss-cpu>=1.7.0
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
# Dev dependencies
|
| 24 |
-
black>=22.0.0
|
| 25 |
-
isort>=5.10.0
|
| 26 |
-
mypy>=1.0.0
|
| 27 |
-
pytest>=7.0.0
|
|
|
|
| 1 |
+
faiss-cpu>=1.7.0 # Facebook AI Similarity Search
|
| 2 |
+
gradio>=3.30.0 # Web app creation
|
| 3 |
+
h5py>=3.1.0 # For saving and loading models
|
| 4 |
+
ipython>=8.0.0 # Interactive Python
|
| 5 |
+
loguru>=0.7.0 # Enhanced logging (optional but recommended)
|
| 6 |
+
matplotlib>=3.5.0 # Validation plotting
|
| 7 |
+
nlpaug>=1.1.0 # Data augmentation for NLP
|
| 8 |
+
nltk>=3.6.0 # Natural language toolkit
|
| 9 |
+
numpy>=1.19.0 # Numerical computation
|
| 10 |
+
pandas>=1.5.0 # Data handling
|
| 11 |
+
pyyaml>=6.0.0 # Config management
|
| 12 |
+
scikit-learn>=1.0.0 # ML tools
|
| 13 |
+
sacremoses>=0.0.53 # Required for some HuggingFace pipelines
|
| 14 |
+
sentencepiece>=0.1.99 # Required for Transformers
|
| 15 |
+
sentence-transformers>=2.2.2 # Sentence embeddings
|
| 16 |
+
spacy>=3.0.0 # Text processing, tokenization
|
| 17 |
+
tensorflow>=2.13.0 # TensorFlow
|
| 18 |
+
tensorflow-hub>=0.12.0 # Pretrained model hub
|
| 19 |
+
tokenizers>=0.13.0 # HuggingFace tokenizers
|
| 20 |
+
torch>=2.0.0 # PyTorch
|
| 21 |
+
tqdm>=4.64.0 # Progress bars
|
| 22 |
+
transformers>=4.30.0 # Hugging Face Transformers
|
| 23 |
+
typing-extensions>=4.0.0
|
| 24 |
|
| 25 |
+
# Dev dependencies:
|
| 26 |
+
black>=22.0.0 # Code formatting
|
| 27 |
+
isort>=5.10.0 # Import sorting
|
| 28 |
+
mypy>=1.0.0 # Type checking
|
| 29 |
+
pytest>=7.0.0 # Testing
|