JoeArmani
commited on
Commit
·
c7c1b4e
1
Parent(s):
64e7c31
chat refinements
Browse files- chatbot_config.py +8 -4
- chatbot_model.py +28 -118
- cross_encoder_reranker.py +2 -1
- run_chatbot_chat.py +47 -23
- run_chatbot_validation.py +7 -16
- tf_data_pipeline.py +28 -38
chatbot_config.py
CHANGED
|
@@ -4,19 +4,23 @@ from typing import Dict
|
|
| 4 |
|
| 5 |
@dataclass
|
| 6 |
class ChatbotConfig:
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
learning_rate: float = 0.0005
|
| 11 |
min_text_length: int = 3
|
| 12 |
-
max_context_turns: int =
|
| 13 |
pretrained_model: str = 'sentence-transformers/all-MiniLM-L6-v2'
|
| 14 |
cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
|
| 15 |
summarizer_model: str = 't5-small'
|
| 16 |
embedding_batch_size: int = 64
|
| 17 |
search_batch_size: int = 64
|
| 18 |
max_batch_size: int = 64
|
|
|
|
| 19 |
max_retries: int = 3
|
|
|
|
| 20 |
|
| 21 |
def to_dict(self) -> Dict:
|
| 22 |
"""Convert config to dictionary."""
|
|
|
|
| 4 |
|
| 5 |
@dataclass
|
| 6 |
class ChatbotConfig:
|
| 7 |
+
"""
|
| 8 |
+
All config params for the chatbot
|
| 9 |
+
"""
|
| 10 |
+
max_context_length: int = 512
|
| 11 |
+
embedding_dim: int = 384 # Sentence Transformer dim
|
| 12 |
learning_rate: float = 0.0005
|
| 13 |
min_text_length: int = 3
|
| 14 |
+
max_context_turns: int = 24
|
| 15 |
pretrained_model: str = 'sentence-transformers/all-MiniLM-L6-v2'
|
| 16 |
cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
|
| 17 |
summarizer_model: str = 't5-small'
|
| 18 |
embedding_batch_size: int = 64
|
| 19 |
search_batch_size: int = 64
|
| 20 |
max_batch_size: int = 64
|
| 21 |
+
neg_samples: int = 10
|
| 22 |
max_retries: int = 3
|
| 23 |
+
nlist: int = 100
|
| 24 |
|
| 25 |
def to_dict(self) -> Dict:
|
| 26 |
"""Convert config to dictionary."""
|
chatbot_model.py
CHANGED
|
@@ -22,6 +22,9 @@ from tqdm.auto import tqdm
|
|
| 22 |
|
| 23 |
absl.logging.set_verbosity(absl.logging.WARNING)
|
| 24 |
logger = config_logger(__name__)
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
class RetrievalChatbot(DeviceAwareModel):
|
| 27 |
"""
|
|
@@ -59,7 +62,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 59 |
tokenizer=self.tokenizer,
|
| 60 |
encoder=self.encoder,
|
| 61 |
response_pool=[],
|
| 62 |
-
max_length=self.config.max_context_token_limit,
|
| 63 |
query_embeddings_cache={},
|
| 64 |
)
|
| 65 |
|
|
@@ -96,7 +98,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 96 |
return Summarizer(
|
| 97 |
tokenizer=self.tokenizer,
|
| 98 |
model_name=self.config.summarizer_model,
|
| 99 |
-
max_summary_length=self.config.
|
| 100 |
device=self.device,
|
| 101 |
max_summary_rounds=2
|
| 102 |
)
|
|
@@ -218,7 +220,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 218 |
) -> List[Tuple[str, float]]:
|
| 219 |
"""
|
| 220 |
Retrieve top-k responses using FAISS and cross-encoder re-ranking.
|
| 221 |
-
|
| 222 |
Args:
|
| 223 |
query: The user's input text.
|
| 224 |
top_k: Number of responses to return.
|
|
@@ -226,7 +227,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 226 |
summarizer: Optional summarizer for long queries.
|
| 227 |
summarize_threshold: Threshold to summarize long queries.
|
| 228 |
boost_factor: Factor to boost scores for keyword matches.
|
| 229 |
-
|
| 230 |
Returns:
|
| 231 |
List of (response_text, final_score).
|
| 232 |
"""
|
|
@@ -241,18 +241,27 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 241 |
|
| 242 |
# Detect domain for query
|
| 243 |
detected_domain = self.detect_domain_from_query(query)
|
|
|
|
| 244 |
|
| 245 |
-
#
|
| 246 |
-
logger.info("Retrieving initial candidates from FAISS...")
|
| 247 |
faiss_candidates = self.data_pipeline.retrieve_responses(query, top_k=top_k * 10)
|
| 248 |
|
| 249 |
if not faiss_candidates:
|
| 250 |
logger.warning("No candidates retrieved from FAISS.")
|
| 251 |
return []
|
| 252 |
|
| 253 |
-
#
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
faiss_scores = [item[1] for item in faiss_candidates]
|
| 257 |
|
| 258 |
if reranker is None:
|
|
@@ -277,9 +286,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 277 |
|
| 278 |
final_candidates.append((resp_text, length_adjusted_score))
|
| 279 |
|
| 280 |
-
#
|
| 281 |
final_candidates.sort(key=lambda x: x[1], reverse=True)
|
| 282 |
-
logger.info(f"Returning top-{top_k} re-ranked responses.")
|
|
|
|
| 283 |
return final_candidates[:top_k]
|
| 284 |
|
| 285 |
def extract_keywords(self, query: str) -> List[str]:
|
|
@@ -323,7 +333,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 323 |
|
| 324 |
def detect_domain_from_query(self, query: str) -> str:
|
| 325 |
"""
|
| 326 |
-
Detect the domain of the query based on keywords. Used for
|
| 327 |
"""
|
| 328 |
domain_patterns = {
|
| 329 |
'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b',
|
|
@@ -348,85 +358,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 348 |
pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
|
| 349 |
return bool(re.match(pattern, text.strip()))
|
| 350 |
|
| 351 |
-
def faiss_search(
|
| 352 |
-
self,
|
| 353 |
-
query: str,
|
| 354 |
-
domain: str = 'other',
|
| 355 |
-
top_k: int = 10,
|
| 356 |
-
boost_factor: float = 1.15
|
| 357 |
-
) -> List[Tuple[str, float]]:
|
| 358 |
-
"""
|
| 359 |
-
Retrieve top-k responses from the FAISS index (IndexFlatIP) given a user query.
|
| 360 |
-
Args:
|
| 361 |
-
query (str): The user input text.
|
| 362 |
-
domain (str): The detected domain from possible domains: ['restaurant', 'movie', 'ride_share', 'coffee', 'pizza', 'auto', 'other']
|
| 363 |
-
top_k (int): Number of top results to return.
|
| 364 |
-
boost_factor (float, optional): Factor to boost scores for keyword matches.
|
| 365 |
-
Returns:
|
| 366 |
-
List[Tuple[str, float]]: List of (response_text, similarity) sorted by descending similarity.
|
| 367 |
-
"""
|
| 368 |
-
# Encode the query
|
| 369 |
-
q_emb = self.data_pipeline.encode_query(query)
|
| 370 |
-
q_emb_np = q_emb.reshape(1, -1).astype('float32')
|
| 371 |
-
|
| 372 |
-
# Search the index
|
| 373 |
-
distances, indices = self.data_pipeline.index.search(q_emb_np, top_k * 10)
|
| 374 |
-
|
| 375 |
-
# IndexFlatIP: 'distances' are inner products (cosine similarities for normalized vectors).
|
| 376 |
-
candidates = []
|
| 377 |
-
for rank, idx in enumerate(indices[0]):
|
| 378 |
-
if idx < 0:
|
| 379 |
-
continue
|
| 380 |
-
text_dict = self.data_pipeline.response_pool[idx]
|
| 381 |
-
text = text_dict.get('text', '').strip()
|
| 382 |
-
cand_domain = text_dict.get('domain', 'other')
|
| 383 |
-
score = distances[0][rank]
|
| 384 |
-
|
| 385 |
-
# Skip purely numeric or extremely short text (fewer than 3 words):
|
| 386 |
-
words = text.split()
|
| 387 |
-
if len(words) < 4:
|
| 388 |
-
continue
|
| 389 |
-
if self.is_numeric_response(text):
|
| 390 |
-
continue
|
| 391 |
-
|
| 392 |
-
candidates.append((text, cand_domain, score))
|
| 393 |
-
|
| 394 |
-
if not candidates:
|
| 395 |
-
logger.warning("No valid candidates found after initial numeric/length filtering.")
|
| 396 |
-
return []
|
| 397 |
-
|
| 398 |
-
# Sort candidates by score descending
|
| 399 |
-
candidates.sort(key=lambda x: x[2], reverse=True)
|
| 400 |
-
|
| 401 |
-
# Filter in-domain responses
|
| 402 |
-
in_domain = [c for c in candidates if c[1] == domain]
|
| 403 |
-
if not in_domain:
|
| 404 |
-
logger.info(f"No in-domain responses found for '{domain}'. Using all candidates.")
|
| 405 |
-
in_domain = candidates
|
| 406 |
-
|
| 407 |
-
# Boost responses containing query keywords
|
| 408 |
-
query_keywords = self.extract_keywords(query)
|
| 409 |
-
boosted = []
|
| 410 |
-
for (resp_text, resp_domain, score) in in_domain:
|
| 411 |
-
new_score = score
|
| 412 |
-
# If the domain is known AND the response text shares any query keywords, boost it
|
| 413 |
-
if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
|
| 414 |
-
new_score *= boost_factor
|
| 415 |
-
|
| 416 |
-
# Apply length penalty/bonus
|
| 417 |
-
new_score = self.length_adjust_score(resp_text, new_score)
|
| 418 |
-
|
| 419 |
-
boosted.append((resp_text, new_score))
|
| 420 |
-
|
| 421 |
-
# Sort boosted responses
|
| 422 |
-
boosted.sort(key=lambda x: x[1], reverse=True)
|
| 423 |
-
|
| 424 |
-
# Debug logging (see FAISS responses)
|
| 425 |
-
# for resp, score in boosted[:100]:
|
| 426 |
-
# logger.debug(f"Candidate: '{resp}' with score {score}")
|
| 427 |
-
|
| 428 |
-
return boosted[:top_k]
|
| 429 |
-
|
| 430 |
def introduction_message(self) -> None:
|
| 431 |
"""Print an introduction message to introduce the chatbot."""
|
| 432 |
print(
|
|
@@ -453,7 +384,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 453 |
print("\nAssistant: Goodbye!")
|
| 454 |
break
|
| 455 |
|
| 456 |
-
response, candidates, metrics = self.chat(
|
| 457 |
query=user_input,
|
| 458 |
conversation_history=None,
|
| 459 |
quality_checker=quality_checker,
|
|
@@ -466,7 +397,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 466 |
print("\n Alternative responses:")
|
| 467 |
for resp, score in candidates[1:4]:
|
| 468 |
print(f" Score: {score:.4f} - {resp}")
|
| 469 |
-
|
| 470 |
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
|
| 471 |
|
| 472 |
def chat(
|
|
@@ -504,10 +435,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 504 |
|
| 505 |
# if uncertain, ask for clarification
|
| 506 |
if not is_confident or top_response_score < 0.5:
|
| 507 |
-
return ("I need more information to provide a good answer. Could you please clarify?", responses, metrics)
|
| 508 |
|
| 509 |
# Return the top response
|
| 510 |
-
return responses[0][0], responses, metrics
|
| 511 |
|
| 512 |
return get_response(self, query)
|
| 513 |
|
|
@@ -535,27 +466,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 535 |
conversation_parts.append(f"{USER_TOKEN} {query}")
|
| 536 |
return "\n".join(conversation_parts)
|
| 537 |
|
| 538 |
-
# def _build_conversation_context(
|
| 539 |
-
# self,
|
| 540 |
-
# query: str,
|
| 541 |
-
# conversation_history: Optional[List[Tuple[str, str]]]
|
| 542 |
-
# ) -> str:
|
| 543 |
-
# """
|
| 544 |
-
# Build conversation context string from conversation history.
|
| 545 |
-
# """
|
| 546 |
-
# if not conversation_history:
|
| 547 |
-
# return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
| 548 |
-
|
| 549 |
-
# conversation_parts = []
|
| 550 |
-
# for user_txt, assistant_txt in conversation_history:
|
| 551 |
-
# conversation_parts.extend([
|
| 552 |
-
# f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {user_txt}",
|
| 553 |
-
# f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {assistant_txt}"
|
| 554 |
-
# ])
|
| 555 |
-
|
| 556 |
-
# conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}")
|
| 557 |
-
# return "\n".join(conversation_parts)
|
| 558 |
-
|
| 559 |
def train_model(
|
| 560 |
self,
|
| 561 |
tfrecord_file_path: str,
|
|
@@ -633,7 +543,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 633 |
logger.info("Using fixed learning rate.")
|
| 634 |
|
| 635 |
# Dummy step to force initialization
|
| 636 |
-
dummy_input = tf.zeros((1, self.config.
|
| 637 |
with tf.GradientTape() as tape:
|
| 638 |
dummy_output = self.encoder(dummy_input)
|
| 639 |
dummy_loss = tf.cast(tf.reduce_mean(dummy_output), tf.float32)
|
|
@@ -747,7 +657,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 747 |
logger.info(f"New validation pairs: {val_size}")
|
| 748 |
|
| 749 |
dataset = dataset.map(
|
| 750 |
-
lambda x: parse_tfrecord_fn(x, self.config.
|
| 751 |
num_parallel_calls=tf.data.AUTOTUNE
|
| 752 |
)
|
| 753 |
|
|
|
|
| 22 |
|
| 23 |
absl.logging.set_verbosity(absl.logging.WARNING)
|
| 24 |
logger = config_logger(__name__)
|
| 25 |
+
logger.setLevel("WARNING")
|
| 26 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 27 |
+
tqdm(disable=True)
|
| 28 |
|
| 29 |
class RetrievalChatbot(DeviceAwareModel):
|
| 30 |
"""
|
|
|
|
| 62 |
tokenizer=self.tokenizer,
|
| 63 |
encoder=self.encoder,
|
| 64 |
response_pool=[],
|
|
|
|
| 65 |
query_embeddings_cache={},
|
| 66 |
)
|
| 67 |
|
|
|
|
| 98 |
return Summarizer(
|
| 99 |
tokenizer=self.tokenizer,
|
| 100 |
model_name=self.config.summarizer_model,
|
| 101 |
+
max_summary_length=self.config.max_context_length // 4,
|
| 102 |
device=self.device,
|
| 103 |
max_summary_rounds=2
|
| 104 |
)
|
|
|
|
| 220 |
) -> List[Tuple[str, float]]:
|
| 221 |
"""
|
| 222 |
Retrieve top-k responses using FAISS and cross-encoder re-ranking.
|
|
|
|
| 223 |
Args:
|
| 224 |
query: The user's input text.
|
| 225 |
top_k: Number of responses to return.
|
|
|
|
| 227 |
summarizer: Optional summarizer for long queries.
|
| 228 |
summarize_threshold: Threshold to summarize long queries.
|
| 229 |
boost_factor: Factor to boost scores for keyword matches.
|
|
|
|
| 230 |
Returns:
|
| 231 |
List of (response_text, final_score).
|
| 232 |
"""
|
|
|
|
| 241 |
|
| 242 |
# Detect domain for query
|
| 243 |
detected_domain = self.detect_domain_from_query(query)
|
| 244 |
+
#logger.info(f"Detected domain: {detected_domain}")
|
| 245 |
|
| 246 |
+
# Retrieve candidates from FAISS
|
| 247 |
+
#logger.info("Retrieving initial candidates from FAISS...")
|
| 248 |
faiss_candidates = self.data_pipeline.retrieve_responses(query, top_k=top_k * 10)
|
| 249 |
|
| 250 |
if not faiss_candidates:
|
| 251 |
logger.warning("No candidates retrieved from FAISS.")
|
| 252 |
return []
|
| 253 |
|
| 254 |
+
# Filter out-of-domain responses
|
| 255 |
+
if detected_domain != 'other':
|
| 256 |
+
in_domain_candidates = [c for c in faiss_candidates if c[0]["domain"] == detected_domain]
|
| 257 |
+
if in_domain_candidates:
|
| 258 |
+
faiss_candidates = in_domain_candidates
|
| 259 |
+
else:
|
| 260 |
+
logger.info(f"No in-domain responses found for '{query}'. Using all candidates.")
|
| 261 |
+
|
| 262 |
+
# Re-rank candidates using Cross-Encoder
|
| 263 |
+
#logger.info("Re-ranking candidates using Cross-Encoder...")
|
| 264 |
+
texts = [item[0]["text"] for item in faiss_candidates] # Extract response texts
|
| 265 |
faiss_scores = [item[1] for item in faiss_candidates]
|
| 266 |
|
| 267 |
if reranker is None:
|
|
|
|
| 286 |
|
| 287 |
final_candidates.append((resp_text, length_adjusted_score))
|
| 288 |
|
| 289 |
+
# Sort and return top-k results
|
| 290 |
final_candidates.sort(key=lambda x: x[1], reverse=True)
|
| 291 |
+
#logger.info(f"Returning top-{top_k} re-ranked responses.")
|
| 292 |
+
|
| 293 |
return final_candidates[:top_k]
|
| 294 |
|
| 295 |
def extract_keywords(self, query: str) -> List[str]:
|
|
|
|
| 333 |
|
| 334 |
def detect_domain_from_query(self, query: str) -> str:
|
| 335 |
"""
|
| 336 |
+
Detect the domain of the query based on keywords. Used for filtering FAISS search.
|
| 337 |
"""
|
| 338 |
domain_patterns = {
|
| 339 |
'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b',
|
|
|
|
| 358 |
pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
|
| 359 |
return bool(re.match(pattern, text.strip()))
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
def introduction_message(self) -> None:
|
| 362 |
"""Print an introduction message to introduce the chatbot."""
|
| 363 |
print(
|
|
|
|
| 384 |
print("\nAssistant: Goodbye!")
|
| 385 |
break
|
| 386 |
|
| 387 |
+
response, candidates, metrics, top_response_score = self.chat(
|
| 388 |
query=user_input,
|
| 389 |
conversation_history=None,
|
| 390 |
quality_checker=quality_checker,
|
|
|
|
| 397 |
print("\n Alternative responses:")
|
| 398 |
for resp, score in candidates[1:4]:
|
| 399 |
print(f" Score: {score:.4f} - {resp}")
|
| 400 |
+
elif top_response_score < 0.7:
|
| 401 |
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
|
| 402 |
|
| 403 |
def chat(
|
|
|
|
| 435 |
|
| 436 |
# if uncertain, ask for clarification
|
| 437 |
if not is_confident or top_response_score < 0.5:
|
| 438 |
+
return ("I need more information to provide a good answer. Could you please clarify?", responses, metrics, top_response_score)
|
| 439 |
|
| 440 |
# Return the top response
|
| 441 |
+
return responses[0][0], responses, metrics, top_response_score
|
| 442 |
|
| 443 |
return get_response(self, query)
|
| 444 |
|
|
|
|
| 466 |
conversation_parts.append(f"{USER_TOKEN} {query}")
|
| 467 |
return "\n".join(conversation_parts)
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
def train_model(
|
| 470 |
self,
|
| 471 |
tfrecord_file_path: str,
|
|
|
|
| 543 |
logger.info("Using fixed learning rate.")
|
| 544 |
|
| 545 |
# Dummy step to force initialization
|
| 546 |
+
dummy_input = tf.zeros((1, self.config.max_context_length), dtype=tf.int32)
|
| 547 |
with tf.GradientTape() as tape:
|
| 548 |
dummy_output = self.encoder(dummy_input)
|
| 549 |
dummy_loss = tf.cast(tf.reduce_mean(dummy_output), tf.float32)
|
|
|
|
| 657 |
logger.info(f"New validation pairs: {val_size}")
|
| 658 |
|
| 659 |
dataset = dataset.map(
|
| 660 |
+
lambda x: parse_tfrecord_fn(x, self.config.max_context_length, self.data_pipeline.neg_samples),
|
| 661 |
num_parallel_calls=tf.data.AUTOTUNE
|
| 662 |
)
|
| 663 |
|
cross_encoder_reranker.py
CHANGED
|
@@ -42,7 +42,8 @@ class CrossEncoderReranker:
|
|
| 42 |
padding=True,
|
| 43 |
truncation=True,
|
| 44 |
max_length=max_length,
|
| 45 |
-
return_tensors="tf"
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
# Forward pass, logits shape [batch_size, 1]
|
|
|
|
| 42 |
padding=True,
|
| 43 |
truncation=True,
|
| 44 |
max_length=max_length,
|
| 45 |
+
return_tensors="tf",
|
| 46 |
+
verbose=False
|
| 47 |
)
|
| 48 |
|
| 49 |
# Forward pass, logits shape [batch_size, 1]
|
run_chatbot_chat.py
CHANGED
|
@@ -1,12 +1,19 @@
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
-
from
|
| 4 |
from chatbot_config import ChatbotConfig
|
|
|
|
|
|
|
|
|
|
| 5 |
from response_quality_checker import ResponseQualityChecker
|
| 6 |
from environment_setup import EnvironmentSetup
|
| 7 |
from logger_config import config_logger
|
| 8 |
|
| 9 |
logger = config_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def run_chatbot_chat():
|
| 12 |
env = EnvironmentSetup()
|
|
@@ -37,38 +44,55 @@ def run_chatbot_chat():
|
|
| 37 |
config = ChatbotConfig()
|
| 38 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
| 39 |
|
| 40 |
-
#
|
| 41 |
try:
|
| 42 |
-
|
|
|
|
| 43 |
except Exception as e:
|
| 44 |
-
logger.error(f"Failed to load
|
| 45 |
-
return
|
| 46 |
-
|
| 47 |
-
# Confirm FAISS index & response pool exist
|
| 48 |
-
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
|
| 49 |
-
logger.error("FAISS index or response pool file is missing.")
|
| 50 |
return
|
| 51 |
-
|
| 52 |
# Load FAISS index and response pool
|
| 53 |
try:
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
# Validate dimension consistency
|
| 60 |
-
|
| 61 |
-
|
| 62 |
except Exception as e:
|
| 63 |
logger.error(f"Failed to load or validate FAISS index: {e}")
|
| 64 |
return
|
| 65 |
-
|
| 66 |
-
#
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
if __name__ == "__main__":
|
| 74 |
run_chatbot_chat()
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
+
from tqdm.auto import tqdm
|
| 4 |
from chatbot_config import ChatbotConfig
|
| 5 |
+
from chatbot_model import RetrievalChatbot
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
from tf_data_pipeline import TFDataPipeline
|
| 8 |
from response_quality_checker import ResponseQualityChecker
|
| 9 |
from environment_setup import EnvironmentSetup
|
| 10 |
from logger_config import config_logger
|
| 11 |
|
| 12 |
logger = config_logger(__name__)
|
| 13 |
+
logger.setLevel("WARNING")
|
| 14 |
+
|
| 15 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 16 |
+
tqdm(disable=True)
|
| 17 |
|
| 18 |
def run_chatbot_chat():
|
| 19 |
env = EnvironmentSetup()
|
|
|
|
| 44 |
config = ChatbotConfig()
|
| 45 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
| 46 |
|
| 47 |
+
# Init SentenceTransformer
|
| 48 |
try:
|
| 49 |
+
encoder = SentenceTransformer(config.pretrained_model)
|
| 50 |
+
logger.info(f"Loaded SentenceTransformer model: {config.pretrained_model}")
|
| 51 |
except Exception as e:
|
| 52 |
+
logger.error(f"Failed to load SentenceTransformer: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
return
|
| 54 |
+
|
| 55 |
# Load FAISS index and response pool
|
| 56 |
try:
|
| 57 |
+
# Initialize TFDataPipeline
|
| 58 |
+
data_pipeline = TFDataPipeline(
|
| 59 |
+
config=config,
|
| 60 |
+
tokenizer=encoder.tokenizer,
|
| 61 |
+
encoder=encoder,
|
| 62 |
+
response_pool=[],
|
| 63 |
+
query_embeddings_cache={},
|
| 64 |
+
index_type='IndexFlatIP',
|
| 65 |
+
faiss_index_file_path=FAISS_INDEX_PATH
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
|
| 69 |
+
logger.error("FAISS index or response pool file is missing.")
|
| 70 |
+
return
|
| 71 |
+
|
| 72 |
+
data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
|
| 73 |
+
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
|
| 74 |
+
|
| 75 |
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
| 76 |
+
data_pipeline.response_pool = json.load(f)
|
| 77 |
+
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
|
| 78 |
+
logger.info(f"Total responses in pool: {len(data_pipeline.response_pool)}")
|
| 79 |
+
|
| 80 |
# Validate dimension consistency
|
| 81 |
+
data_pipeline.validate_faiss_index()
|
| 82 |
+
logger.info("FAISS index and response pool validated successfully.")
|
| 83 |
except Exception as e:
|
| 84 |
logger.error(f"Failed to load or validate FAISS index: {e}")
|
| 85 |
return
|
| 86 |
+
|
| 87 |
+
# Run interactive chat
|
| 88 |
+
try:
|
| 89 |
+
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
|
| 90 |
+
quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)
|
| 91 |
+
|
| 92 |
+
logger.info("\nStarting interactive chat session...")
|
| 93 |
+
chatbot.run_interactive_chat(quality_checker=quality_checker, show_alternatives=False)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"Interactive chat session failed: {e}")
|
| 96 |
|
| 97 |
if __name__ == "__main__":
|
| 98 |
run_chatbot_chat()
|
run_chatbot_validation.py
CHANGED
|
@@ -44,9 +44,8 @@ def run_chatbot_validation():
|
|
| 44 |
|
| 45 |
# Init SentenceTransformer
|
| 46 |
try:
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
logger.info(f"Loaded SentenceTransformer model: {model_name}")
|
| 50 |
except Exception as e:
|
| 51 |
logger.error(f"Failed to load SentenceTransformer: {e}")
|
| 52 |
return
|
|
@@ -108,18 +107,10 @@ def run_chatbot_validation():
|
|
| 108 |
# Run interactive chat loop
|
| 109 |
try:
|
| 110 |
logger.info("\nStarting interactive chat session...")
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
responses = data_pipeline.retrieve_responses(user_input, top_k=3)
|
| 118 |
-
print("Top Responses:")
|
| 119 |
-
for i, (response, score) in enumerate(responses, start=1):
|
| 120 |
-
print(f"{i}. {response} (Score: {score:.4f})")
|
| 121 |
-
except KeyboardInterrupt:
|
| 122 |
-
logger.info("Interactive chat session interrupted by user.")
|
| 123 |
-
|
| 124 |
if __name__ == "__main__":
|
| 125 |
run_chatbot_validation()
|
|
|
|
| 44 |
|
| 45 |
# Init SentenceTransformer
|
| 46 |
try:
|
| 47 |
+
encoder = SentenceTransformer(config.pretrained_model)
|
| 48 |
+
logger.info(f"Loaded SentenceTransformer model: {config.pretrained_model}")
|
|
|
|
| 49 |
except Exception as e:
|
| 50 |
logger.error(f"Failed to load SentenceTransformer: {e}")
|
| 51 |
return
|
|
|
|
| 107 |
# Run interactive chat loop
|
| 108 |
try:
|
| 109 |
logger.info("\nStarting interactive chat session...")
|
| 110 |
+
chatbot.run_interactive_chat(quality_checker=quality_checker, show_alternatives=True)
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.error(f"Interactive chat session failed: {e}")
|
| 113 |
+
|
| 114 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
if __name__ == "__main__":
|
| 116 |
run_chatbot_validation()
|
tf_data_pipeline.py
CHANGED
|
@@ -6,7 +6,7 @@ import h5py
|
|
| 6 |
import math
|
| 7 |
import random
|
| 8 |
import gc
|
| 9 |
-
from tqdm import tqdm
|
| 10 |
import json
|
| 11 |
from pathlib import Path
|
| 12 |
from typing import Union, Optional, Dict, List, Tuple, Generator
|
|
@@ -28,31 +28,25 @@ class TFDataPipeline:
|
|
| 28 |
encoder: SentenceTransformer,
|
| 29 |
response_pool: List[str],
|
| 30 |
query_embeddings_cache: dict,
|
| 31 |
-
model_name: str = 'sentence-transformers/all-MiniLM-L6-v2',
|
| 32 |
-
max_length: int = 512,
|
| 33 |
-
neg_samples: int = 10,
|
| 34 |
index_type: str = 'IndexFlatIP',
|
| 35 |
faiss_index_file_path: str = 'models/faiss_indices/faiss_index_production.index',
|
| 36 |
-
dimension: int = 384,
|
| 37 |
-
nlist: int = 100,
|
| 38 |
-
max_retries: int = 3
|
| 39 |
):
|
| 40 |
self.config = config
|
| 41 |
self.tokenizer = tokenizer
|
| 42 |
self.encoder = encoder
|
| 43 |
-
self.model = SentenceTransformer(
|
| 44 |
self.faiss_index_file_path = faiss_index_file_path
|
| 45 |
self.response_pool = response_pool
|
| 46 |
-
self.max_length = max_length
|
| 47 |
-
self.neg_samples = neg_samples
|
| 48 |
self.query_embeddings_cache = query_embeddings_cache # In-memory cache for embeddings
|
| 49 |
-
self.dimension = config.embedding_dim
|
| 50 |
self.index_type = index_type
|
| 51 |
-
self.
|
| 52 |
-
self.
|
| 53 |
-
self.
|
| 54 |
-
self.
|
| 55 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
# Build text -> domain map for O(1) domain lookups (hard negative sampling)
|
| 58 |
self._text_domain_map = {}
|
|
@@ -159,7 +153,7 @@ class TFDataPipeline:
|
|
| 159 |
speaker = turn.get('speaker')
|
| 160 |
text = turn.get('text', '').strip()
|
| 161 |
if speaker == 'assistant' and text:
|
| 162 |
-
if len(text) <= self.
|
| 163 |
# Use tuple as set key to ensure uniqueness
|
| 164 |
key = (domain, text)
|
| 165 |
if key not in response_set:
|
|
@@ -388,7 +382,7 @@ class TFDataPipeline:
|
|
| 388 |
# f"Collision detected: text '{stripped_text}' found with domains "
|
| 389 |
# f"'{existing_domain}' and '{domain}'. Keeping the first."
|
| 390 |
# )
|
| 391 |
-
# By default, keep the first domain or overwrite.
|
| 392 |
continue
|
| 393 |
else:
|
| 394 |
# Insert into the dict
|
|
@@ -434,7 +428,7 @@ class TFDataPipeline:
|
|
| 434 |
prepared,
|
| 435 |
padding='max_length',
|
| 436 |
truncation=True,
|
| 437 |
-
max_length=self.
|
| 438 |
return_tensors='np'
|
| 439 |
)
|
| 440 |
input_ids = encodings['input_ids']
|
|
@@ -454,23 +448,19 @@ class TFDataPipeline:
|
|
| 454 |
def retrieve_responses(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
|
| 455 |
"""
|
| 456 |
Retrieve top-k responses for a query using FAISS.
|
| 457 |
-
|
| 458 |
-
Args:
|
| 459 |
-
query: User's query text.
|
| 460 |
-
top_k: Number of responses to return.
|
| 461 |
-
|
| 462 |
-
Returns:
|
| 463 |
-
List of tuples (response text, similarity score).
|
| 464 |
"""
|
| 465 |
query_embedding = self.encode_query(query).reshape(1, -1).astype("float32")
|
| 466 |
distances, indices = self.index.search(query_embedding, top_k)
|
| 467 |
|
| 468 |
results = []
|
| 469 |
-
for idx, dist in
|
|
|
|
|
|
|
|
|
|
| 470 |
if idx < 0:
|
| 471 |
continue
|
| 472 |
response = self.response_pool[idx]
|
| 473 |
-
results.append((response
|
| 474 |
|
| 475 |
return results
|
| 476 |
|
|
@@ -496,7 +486,7 @@ class TFDataPipeline:
|
|
| 496 |
for dialogue in batch_dialogues:
|
| 497 |
pairs = self._extract_pairs_from_dialogue(dialogue)
|
| 498 |
for query, positive in pairs:
|
| 499 |
-
if len(query) <= self.
|
| 500 |
queries.append(query)
|
| 501 |
positives.append(positive)
|
| 502 |
|
|
@@ -524,14 +514,14 @@ class TFDataPipeline:
|
|
| 524 |
try:
|
| 525 |
encoded_queries = self.tokenizer.batch_encode_plus(
|
| 526 |
queries,
|
| 527 |
-
max_length=self.config.
|
| 528 |
truncation=True,
|
| 529 |
padding='max_length',
|
| 530 |
return_tensors='tf'
|
| 531 |
)
|
| 532 |
encoded_positives = self.tokenizer.batch_encode_plus(
|
| 533 |
positives,
|
| 534 |
-
max_length=self.config.
|
| 535 |
truncation=True,
|
| 536 |
padding='max_length',
|
| 537 |
return_tensors='tf'
|
|
@@ -547,7 +537,7 @@ class TFDataPipeline:
|
|
| 547 |
flattened_negatives = [neg for sublist in hard_negatives for neg in sublist]
|
| 548 |
encoded_negatives = self.tokenizer.batch_encode_plus(
|
| 549 |
flattened_negatives,
|
| 550 |
-
max_length=self.config.
|
| 551 |
truncation=True,
|
| 552 |
padding='max_length',
|
| 553 |
return_tensors='tf'
|
|
@@ -555,7 +545,7 @@ class TFDataPipeline:
|
|
| 555 |
|
| 556 |
# Reshape to [num_queries, num_negatives, max_length]
|
| 557 |
num_negatives = self.config.neg_samples
|
| 558 |
-
reshaped_negatives = encoded_negatives['input_ids'].numpy().reshape(-1, num_negatives, self.config.
|
| 559 |
except Exception as e:
|
| 560 |
logger.error(f"Error during negatives tokenization: {e}")
|
| 561 |
pbar.update(1)
|
|
@@ -600,7 +590,7 @@ class TFDataPipeline:
|
|
| 600 |
batch_queries,
|
| 601 |
padding=True,
|
| 602 |
truncation=True,
|
| 603 |
-
max_length=self.
|
| 604 |
return_tensors='tf'
|
| 605 |
)
|
| 606 |
batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
|
|
@@ -667,14 +657,14 @@ class TFDataPipeline:
|
|
| 667 |
# Use tf.py_function, limit parallelism
|
| 668 |
q_ids, p_ids, n_ids = tf.py_function(
|
| 669 |
func=self._tokenize_triple_py,
|
| 670 |
-
inp=[q, p, n, tf.constant(self.
|
| 671 |
Tout=[tf.int32, tf.int32, tf.int32]
|
| 672 |
)
|
| 673 |
|
| 674 |
# Set shape info for the output tensors
|
| 675 |
-
q_ids.set_shape([None, self.
|
| 676 |
-
p_ids.set_shape([None, self.
|
| 677 |
-
n_ids.set_shape([None, self.neg_samples, self.
|
| 678 |
|
| 679 |
return q_ids, p_ids, n_ids
|
| 680 |
|
|
|
|
| 6 |
import math
|
| 7 |
import random
|
| 8 |
import gc
|
| 9 |
+
from tqdm.auto import tqdm
|
| 10 |
import json
|
| 11 |
from pathlib import Path
|
| 12 |
from typing import Union, Optional, Dict, List, Tuple, Generator
|
|
|
|
| 28 |
encoder: SentenceTransformer,
|
| 29 |
response_pool: List[str],
|
| 30 |
query_embeddings_cache: dict,
|
|
|
|
|
|
|
|
|
|
| 31 |
index_type: str = 'IndexFlatIP',
|
| 32 |
faiss_index_file_path: str = 'models/faiss_indices/faiss_index_production.index',
|
|
|
|
|
|
|
|
|
|
| 33 |
):
|
| 34 |
self.config = config
|
| 35 |
self.tokenizer = tokenizer
|
| 36 |
self.encoder = encoder
|
| 37 |
+
self.model = SentenceTransformer(config.pretrained_model)
|
| 38 |
self.faiss_index_file_path = faiss_index_file_path
|
| 39 |
self.response_pool = response_pool
|
|
|
|
|
|
|
| 40 |
self.query_embeddings_cache = query_embeddings_cache # In-memory cache for embeddings
|
|
|
|
| 41 |
self.index_type = index_type
|
| 42 |
+
self.neg_samples = config.neg_samples
|
| 43 |
+
self.nlist = config.nlist
|
| 44 |
+
self.dimension = config.embedding_dim
|
| 45 |
+
self.max_context_length = config.max_context_length
|
| 46 |
+
self.embedding_batch_size = config.embedding_batch_size
|
| 47 |
+
self.search_batch_size = config.search_batch_size
|
| 48 |
+
self.max_batch_size = config.max_batch_size
|
| 49 |
+
self.max_retries = config.max_retries
|
| 50 |
|
| 51 |
# Build text -> domain map for O(1) domain lookups (hard negative sampling)
|
| 52 |
self._text_domain_map = {}
|
|
|
|
| 153 |
speaker = turn.get('speaker')
|
| 154 |
text = turn.get('text', '').strip()
|
| 155 |
if speaker == 'assistant' and text:
|
| 156 |
+
if len(text) <= self.max_context_length:
|
| 157 |
# Use tuple as set key to ensure uniqueness
|
| 158 |
key = (domain, text)
|
| 159 |
if key not in response_set:
|
|
|
|
| 382 |
# f"Collision detected: text '{stripped_text}' found with domains "
|
| 383 |
# f"'{existing_domain}' and '{domain}'. Keeping the first."
|
| 384 |
# )
|
| 385 |
+
# By default, keep the first domain or overwrite. Skip overwriting:
|
| 386 |
continue
|
| 387 |
else:
|
| 388 |
# Insert into the dict
|
|
|
|
| 428 |
prepared,
|
| 429 |
padding='max_length',
|
| 430 |
truncation=True,
|
| 431 |
+
max_length=self.max_context_length,
|
| 432 |
return_tensors='np'
|
| 433 |
)
|
| 434 |
input_ids = encodings['input_ids']
|
|
|
|
| 448 |
def retrieve_responses(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
|
| 449 |
"""
|
| 450 |
Retrieve top-k responses for a query using FAISS.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
"""
|
| 452 |
query_embedding = self.encode_query(query).reshape(1, -1).astype("float32")
|
| 453 |
distances, indices = self.index.search(query_embedding, top_k)
|
| 454 |
|
| 455 |
results = []
|
| 456 |
+
for idx, dist in tqdm(
|
| 457 |
+
zip(indices[0], distances[0]),
|
| 458 |
+
disable=True # Silence tqdm
|
| 459 |
+
):
|
| 460 |
if idx < 0:
|
| 461 |
continue
|
| 462 |
response = self.response_pool[idx]
|
| 463 |
+
results.append((response, dist))
|
| 464 |
|
| 465 |
return results
|
| 466 |
|
|
|
|
| 486 |
for dialogue in batch_dialogues:
|
| 487 |
pairs = self._extract_pairs_from_dialogue(dialogue)
|
| 488 |
for query, positive in pairs:
|
| 489 |
+
if len(query) <= self.max_context_length and len(positive) <= self.max_context_length:
|
| 490 |
queries.append(query)
|
| 491 |
positives.append(positive)
|
| 492 |
|
|
|
|
| 514 |
try:
|
| 515 |
encoded_queries = self.tokenizer.batch_encode_plus(
|
| 516 |
queries,
|
| 517 |
+
max_length=self.config.max_context_length,
|
| 518 |
truncation=True,
|
| 519 |
padding='max_length',
|
| 520 |
return_tensors='tf'
|
| 521 |
)
|
| 522 |
encoded_positives = self.tokenizer.batch_encode_plus(
|
| 523 |
positives,
|
| 524 |
+
max_length=self.config.max_context_length,
|
| 525 |
truncation=True,
|
| 526 |
padding='max_length',
|
| 527 |
return_tensors='tf'
|
|
|
|
| 537 |
flattened_negatives = [neg for sublist in hard_negatives for neg in sublist]
|
| 538 |
encoded_negatives = self.tokenizer.batch_encode_plus(
|
| 539 |
flattened_negatives,
|
| 540 |
+
max_length=self.config.max_context_length,
|
| 541 |
truncation=True,
|
| 542 |
padding='max_length',
|
| 543 |
return_tensors='tf'
|
|
|
|
| 545 |
|
| 546 |
# Reshape to [num_queries, num_negatives, max_length]
|
| 547 |
num_negatives = self.config.neg_samples
|
| 548 |
+
reshaped_negatives = encoded_negatives['input_ids'].numpy().reshape(-1, num_negatives, self.config.max_context_length)
|
| 549 |
except Exception as e:
|
| 550 |
logger.error(f"Error during negatives tokenization: {e}")
|
| 551 |
pbar.update(1)
|
|
|
|
| 590 |
batch_queries,
|
| 591 |
padding=True,
|
| 592 |
truncation=True,
|
| 593 |
+
max_length=self.max_context_length,
|
| 594 |
return_tensors='tf'
|
| 595 |
)
|
| 596 |
batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
|
|
|
|
| 657 |
# Use tf.py_function, limit parallelism
|
| 658 |
q_ids, p_ids, n_ids = tf.py_function(
|
| 659 |
func=self._tokenize_triple_py,
|
| 660 |
+
inp=[q, p, n, tf.constant(self.max_context_length), tf.constant(self.neg_samples)],
|
| 661 |
Tout=[tf.int32, tf.int32, tf.int32]
|
| 662 |
)
|
| 663 |
|
| 664 |
# Set shape info for the output tensors
|
| 665 |
+
q_ids.set_shape([None, self.max_context_length]) # [batch_size, max_length]
|
| 666 |
+
p_ids.set_shape([None, self.max_context_length]) # [batch_size, max_length]
|
| 667 |
+
n_ids.set_shape([None, self.neg_samples, self.max_context_length]) # [batch_size, neg_samples, max_length]
|
| 668 |
|
| 669 |
return q_ids, p_ids, n_ids
|
| 670 |
|