|
from transformers import TFAutoModel, AutoTokenizer |
|
import tensorflow as tf |
|
import numpy as np |
|
from typing import List, Tuple, Dict, Optional, Union, Any |
|
import math |
|
from dataclasses import dataclass |
|
import json |
|
from tqdm import tqdm |
|
from pathlib import Path |
|
import datetime |
|
import faiss |
|
from response_quality_checker import ResponseQualityChecker |
|
from cross_encoder_reranker import CrossEncoderReranker |
|
from conversation_summarizer import DeviceAwareModel, Summarizer |
|
from logger_config import config_logger |
|
logger = config_logger(__name__) |
|
|
|
@dataclass |
|
class ChatbotConfig: |
|
"""Configuration for the RetrievalChatbot.""" |
|
vocab_size: int = 30526 |
|
max_context_token_limit: int = 512 |
|
embedding_dim: int = 512 |
|
encoder_units: int = 256 |
|
num_attention_heads: int = 8 |
|
dropout_rate: float = 0.2 |
|
l2_reg_weight: float = 0.001 |
|
margin: float = 0.3 |
|
learning_rate: float = 0.001 |
|
min_text_length: int = 3 |
|
max_context_turns: int = 5 |
|
warmup_steps: int = 200 |
|
pretrained_model: str = 'distilbert-base-uncased' |
|
dtype: str = 'float32' |
|
freeze_embeddings: bool = False |
|
|
|
|
|
def to_dict(self) -> dict: |
|
"""Convert config to dictionary.""" |
|
return {k: str(v) if isinstance(v, Path) else v |
|
for k, v in self.__dict__.items()} |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: dict) -> 'ChatbotConfig': |
|
"""Create config from dictionary.""" |
|
return cls(**{k: v for k, v in config_dict.items() |
|
if k in cls.__dataclass_fields__}) |
|
|
|
class EncoderModel(tf.keras.Model): |
|
"""Dual encoder model with pretrained embeddings.""" |
|
def __init__( |
|
self, |
|
config: ChatbotConfig, |
|
name: str = "encoder", |
|
shared_weights: bool = False, |
|
**kwargs |
|
): |
|
super().__init__(name=name, **kwargs) |
|
self.config = config |
|
self.shared_weights = shared_weights |
|
|
|
|
|
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model) |
|
|
|
|
|
self.pretrained.distilbert.embeddings.trainable = False |
|
for i, layer_module in enumerate(self.pretrained.distilbert.transformer.layer): |
|
if i < 1: |
|
layer_module.trainable = False |
|
else: |
|
layer_module.trainable = True |
|
|
|
|
|
self.pooler = tf.keras.layers.GlobalAveragePooling1D() |
|
|
|
|
|
self.projection = tf.keras.layers.Dense( |
|
config.embedding_dim, |
|
activation='tanh', |
|
name="projection" |
|
) |
|
|
|
|
|
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) |
|
self.normalize = tf.keras.layers.Lambda( |
|
lambda x: tf.nn.l2_normalize(x, axis=1) |
|
) |
|
|
|
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor: |
|
"""Forward pass.""" |
|
|
|
pretrained_outputs = self.pretrained(inputs, training=training) |
|
x = pretrained_outputs.last_hidden_state |
|
|
|
|
|
x = self.pooler(x) |
|
x = self.projection(x) |
|
x = self.dropout(x, training=training) |
|
x = self.normalize(x) |
|
|
|
return x |
|
|
|
def get_config(self) -> dict: |
|
"""Return the config of the model.""" |
|
config = super().get_config() |
|
config.update({ |
|
"config": self.config.to_dict(), |
|
"shared_weights": self.shared_weights, |
|
"name": self.name |
|
}) |
|
return config |
|
|
|
class RetrievalChatbot(DeviceAwareModel): |
|
"""Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search.""" |
|
def __init__(self, config: ChatbotConfig, dialogues: List[dict] = [], device: str = None, strategy=None, reranker: Optional[CrossEncoderReranker] = None, summarizer: Optional[Summarizer] = None): |
|
self.config = config |
|
self.strategy = strategy |
|
self.setup_device(device) |
|
|
|
if reranker is None: |
|
logger.info("Creating default CrossEncoderReranker...") |
|
reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2") |
|
self.reranker = reranker |
|
|
|
if summarizer is None: |
|
logger.info("Creating default Summarizer...") |
|
summarizer = Summarizer(device=self.device) |
|
self.summarizer = summarizer |
|
|
|
|
|
if self.device in ["GPU", "TPU"]: |
|
tf.config.optimizer.set_jit(True) |
|
logger.info(f"XLA compilation enabled for {self.device}") |
|
|
|
|
|
if self.device != "CPU": |
|
policy = tf.keras.mixed_precision.Policy('mixed_float16') |
|
tf.keras.mixed_precision.set_global_policy(policy) |
|
logger.info("Mixed precision training enabled (float16)") |
|
|
|
|
|
self.special_tokens = { |
|
"user": "<USER>", |
|
"assistant": "<ASSISTANT>", |
|
"context": "<CONTEXT>", |
|
"sep": "<SEP>" |
|
} |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model) |
|
self.tokenizer.add_special_tokens( |
|
{'additional_special_tokens': list(self.special_tokens.values())} |
|
) |
|
|
|
|
|
if self.strategy: |
|
with self.strategy.scope(): |
|
self._build_models() |
|
else: |
|
self._build_models() |
|
|
|
|
|
self._initialize_faiss() |
|
|
|
|
|
self._precompute_and_index_responses(dialogues) |
|
|
|
|
|
self.history = { |
|
"train_loss": [], |
|
"val_loss": [], |
|
"train_metrics": {}, |
|
"val_metrics": {} |
|
} |
|
|
|
def _build_models(self): |
|
"""Initialize the shared encoder.""" |
|
logger.info("Building encoder model...") |
|
|
|
|
|
self.encoder = EncoderModel( |
|
self.config, |
|
name="shared_encoder", |
|
) |
|
|
|
|
|
new_vocab_size = len(self.tokenizer) |
|
self.encoder.pretrained.resize_token_embeddings(new_vocab_size) |
|
logger.info(f"Token embeddings resized to: {new_vocab_size}") |
|
|
|
|
|
logger.info("Inspecting embeddings attributes:") |
|
for attr in dir(self.encoder.pretrained.distilbert.embeddings): |
|
if not attr.startswith('_'): |
|
logger.info(f" {attr}") |
|
|
|
|
|
try: |
|
|
|
embedding_dim = self.encoder.pretrained.config.dim |
|
logger.info("Got embedding dim from config") |
|
except AttributeError: |
|
try: |
|
|
|
embedding_dim = self.encoder.pretrained.distilbert.embeddings.word_embeddings.embedding_dim |
|
logger.info("Got embedding dim from word embeddings") |
|
except AttributeError: |
|
try: |
|
|
|
embedding_dim = self.encoder.pretrained.distilbert.embeddings.embedding_dim |
|
logger.info("Got embedding dim from embeddings module") |
|
except AttributeError: |
|
|
|
embedding_dim = self.config.embedding_dim |
|
logger.info("Using config embedding dim") |
|
|
|
vocab_size = len(self.tokenizer) |
|
|
|
logger.info(f"Encoder Embedding Dimension: {embedding_dim}") |
|
logger.info(f"Encoder Embedding Vocabulary Size: {vocab_size}") |
|
if vocab_size >= embedding_dim: |
|
logger.info("Encoder model built and embeddings resized successfully.") |
|
else: |
|
logger.error("Vocabulary size is less than embedding dimension.") |
|
raise ValueError("Vocabulary size is less than embedding dimension.") |
|
|
|
def _initialize_faiss(self): |
|
"""Initialize FAISS index based on available resources.""" |
|
logger.info("Initializing FAISS index...") |
|
|
|
try: |
|
res = faiss.StandardGpuResources() |
|
self.faiss_gpu = True |
|
logger.info("FAISS GPU resources initialized.") |
|
except Exception as e: |
|
self.faiss_gpu = False |
|
logger.info("FAISS GPU resources not available. Using FAISS CPU.") |
|
|
|
|
|
if self.faiss_gpu: |
|
self.index = faiss.IndexFlatIP(self.config.embedding_dim) |
|
self.index = faiss.index_cpu_to_gpu(res, 0, self.index) |
|
else: |
|
self.index = faiss.IndexFlatIP(self.config.embedding_dim) |
|
logger.info("FAISS index initialized.") |
|
|
|
def verify_faiss_index(self): |
|
"""Verify that FAISS index matches the response pool.""" |
|
indexed_size = self.index.ntotal |
|
pool_size = len(self.response_pool) |
|
logger.info(f"FAISS index size: {indexed_size}") |
|
logger.info(f"Response pool size: {pool_size}") |
|
if indexed_size != pool_size: |
|
logger.warning("Mismatch between FAISS index size and response pool size.") |
|
else: |
|
logger.info("FAISS index correctly matches the response pool.") |
|
|
|
|
|
def _precompute_and_index_responses(self, dialogues: List[dict]): |
|
"""Precompute embeddings for all responses and index them using FAISS.""" |
|
logger.info("Precomputing response embeddings and indexing with FAISS...") |
|
|
|
|
|
responses = [] |
|
for dialogue in tqdm(dialogues, desc="Collecting assistant responses"): |
|
turns = dialogue.get('turns', []) |
|
for turn in turns: |
|
if turn.get('speaker') == 'assistant' and 'text' in turn: |
|
responses.append(turn['text'].strip()) |
|
|
|
|
|
unique_responses = list(set(responses)) |
|
logger.info(f"Found {len(unique_responses)} unique responses.") |
|
|
|
|
|
logger.info("Encoding unique responses") |
|
response_embeddings = self.encode_responses(unique_responses) |
|
response_embeddings = response_embeddings.numpy() |
|
|
|
|
|
if response_embeddings.dtype != np.float32: |
|
response_embeddings = response_embeddings.astype('float32') |
|
|
|
|
|
if not response_embeddings.flags['C_CONTIGUOUS']: |
|
logger.info("Making embeddings contiguous in memory.") |
|
response_embeddings = np.ascontiguousarray(response_embeddings) |
|
|
|
|
|
logger.info("Normalizing embeddings with FAISS.") |
|
faiss.normalize_L2(response_embeddings) |
|
|
|
|
|
logger.info("Adding embeddings to FAISS index...") |
|
self.index.add(response_embeddings) |
|
logger.info(f"Indexed {self.index.ntotal} responses.") |
|
|
|
|
|
self.response_pool = unique_responses |
|
self.response_embeddings = response_embeddings |
|
logger.info("Precomputation and indexing completed.") |
|
|
|
def encode_responses( |
|
self, |
|
responses: List[str], |
|
batch_size: int = 64 |
|
) -> tf.Tensor: |
|
""" |
|
Encodes a list of responses into embeddings, using chunked/batched processing |
|
to avoid running out of memory when there are many responses. |
|
|
|
Args: |
|
responses (List[str]): The list of response texts to encode. |
|
batch_size (int): How many responses to encode per chunk. |
|
Adjust based on available GPU/CPU memory. |
|
|
|
Returns: |
|
tf.Tensor: Tensor of shape (N, emb_dim) with all response embeddings. |
|
""" |
|
|
|
all_embeddings = [] |
|
|
|
|
|
for start_idx in range(0, len(responses), batch_size): |
|
end_idx = start_idx + batch_size |
|
batch_texts = responses[start_idx:end_idx] |
|
|
|
|
|
encodings = self.tokenizer( |
|
batch_texts, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.config.max_context_token_limit, |
|
return_tensors='tf', |
|
) |
|
|
|
|
|
input_ids = encodings['input_ids'] |
|
embeddings_batch = self.encoder(input_ids, training=False) |
|
|
|
|
|
if embeddings_batch.dtype != tf.float32: |
|
embeddings_batch = tf.cast(embeddings_batch, tf.float32) |
|
|
|
|
|
all_embeddings.append(embeddings_batch) |
|
|
|
|
|
if len(all_embeddings) == 1: |
|
|
|
final_embeddings = all_embeddings[0] |
|
else: |
|
|
|
final_embeddings = tf.concat(all_embeddings, axis=0) |
|
|
|
return final_embeddings |
|
|
|
def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor: |
|
"""Encode a query with optional conversation context.""" |
|
|
|
if context: |
|
context_str = ' '.join([ |
|
f"{self.special_tokens['user']} {q} " |
|
f"{self.special_tokens['assistant']} {r}" |
|
for q, r in context[-self.config.max_context_turns:] |
|
]) |
|
query = f"{context_str} {self.special_tokens['user']} {query}" |
|
else: |
|
query = f"{self.special_tokens['user']} {query}" |
|
|
|
|
|
encodings = self.tokenizer( |
|
[query], |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.config.max_context_token_limit, |
|
return_tensors='tf' |
|
) |
|
input_ids = encodings['input_ids'] |
|
|
|
|
|
max_id = tf.reduce_max(input_ids).numpy() |
|
new_vocab_size = len(self.tokenizer) |
|
|
|
if max_id >= new_vocab_size: |
|
logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.") |
|
raise ValueError("Token ID exceeds vocabulary size.") |
|
|
|
|
|
return self.encoder(input_ids, training=False) |
|
|
|
def retrieve_responses_cross_encoder( |
|
self, |
|
query: str, |
|
top_k: int, |
|
reranker: Optional[CrossEncoderReranker] = None, |
|
summarizer: Optional[Summarizer] = None, |
|
summarize_threshold: int = 512 |
|
) -> List[Tuple[str, float]]: |
|
""" |
|
Retrieve top-k from FAISS, then re-rank them with a cross-encoder. |
|
Optionally summarize the user query if it's too long. |
|
""" |
|
if reranker is None: |
|
reranker = self.reranker |
|
if summarizer is None: |
|
summarizer = self.summarizer |
|
|
|
|
|
if summarizer and len(query.split()) > summarize_threshold: |
|
logger.info(f"Query is long. Summarizing before cross-encoder. Original length: {len(query.split())}") |
|
query = summarizer.summarize_text(query) |
|
logger.info(f"Summarized query: {query}") |
|
|
|
|
|
dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) |
|
|
|
if not dense_topk: |
|
return [] |
|
|
|
|
|
candidate_texts = [pair[0] for pair in dense_topk] |
|
cross_scores = reranker.rerank(query, candidate_texts, max_length=256) |
|
|
|
|
|
combined = [(text, score) for (text, _), score in zip(dense_topk, cross_scores)] |
|
|
|
combined.sort(key=lambda x: x[1], reverse=True) |
|
|
|
return combined |
|
|
|
def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: |
|
"""Retrieve top-k responses using FAISS.""" |
|
|
|
q_emb = self.encode_query(query) |
|
q_emb_np = q_emb.numpy().astype('float32') |
|
|
|
|
|
faiss.normalize_L2(q_emb_np) |
|
|
|
|
|
distances, indices = self.index.search(q_emb_np, top_k) |
|
|
|
|
|
top_responses = [] |
|
for i, idx in enumerate(indices[0]): |
|
if idx < len(self.response_pool): |
|
top_responses.append((self.response_pool[idx], float(distances[0][i]))) |
|
else: |
|
logger.warning(f"FAISS returned invalid index {idx}. Skipping.") |
|
|
|
return top_responses |
|
|
|
def save_models(self, save_dir: Union[str, Path]): |
|
"""Save models and configuration.""" |
|
save_dir = Path(save_dir) |
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with open(save_dir / "config.json", "w") as f: |
|
json.dump(self.config.to_dict(), f, indent=2) |
|
|
|
|
|
self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder") |
|
|
|
|
|
self.tokenizer.save_pretrained(save_dir / "tokenizer") |
|
|
|
logger.info(f"Models and tokenizer saved to {save_dir}.") |
|
|
|
@classmethod |
|
def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot': |
|
"""Load saved models and configuration.""" |
|
load_dir = Path(load_dir) |
|
|
|
|
|
with open(load_dir / "config.json", "r") as f: |
|
config = ChatbotConfig.from_dict(json.load(f)) |
|
|
|
|
|
chatbot = cls(config) |
|
|
|
|
|
chatbot.encoder.pretrained = TFAutoModel.from_pretrained( |
|
load_dir / "shared_encoder", |
|
config=config |
|
) |
|
|
|
|
|
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer") |
|
|
|
logger.info(f"Models and tokenizer loaded from {load_dir}.") |
|
return chatbot |
|
|
|
@staticmethod |
|
def load_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]: |
|
""" |
|
Load training data from a JSON file. |
|
|
|
Args: |
|
data_path (Union[str, Path]): Path to the JSON file containing dialogues. |
|
debug_samples (Optional[int]): Number of samples to load for debugging. |
|
|
|
Returns: |
|
List[dict]: List of dialogue dictionaries. |
|
""" |
|
logger.info(f"Loading training data from {data_path}...") |
|
data_path = Path(data_path) |
|
if not data_path.exists(): |
|
logger.error(f"Data file {data_path} does not exist.") |
|
return [] |
|
|
|
with open(data_path, 'r', encoding='utf-8') as f: |
|
dialogues = json.load(f) |
|
|
|
if debug_samples is not None: |
|
dialogues = dialogues[:debug_samples] |
|
logger.info(f"Debug mode: Limited to {debug_samples} dialogues") |
|
|
|
logger.info(f"Loaded {len(dialogues)} dialogues.") |
|
return dialogues |
|
|
|
def prepare_dataset( |
|
self, |
|
dialogues: List[dict], |
|
neg_samples: int = 1, |
|
debug_samples: int = None |
|
) -> Tuple[tf.Tensor, tf.Tensor]: |
|
""" |
|
Prepares dataset for multiple-negatives ranking, |
|
but also appends 'hard negative' pairs for each query. |
|
|
|
We'll generate: |
|
- (query, positive) as usual |
|
- (query, negative) for each query, using FAISS top-1 approx. negative. |
|
Then, in-batch training sees them as 'two different positives' |
|
for the same query, forcing the model to discriminate them. |
|
""" |
|
|
|
logger.info("Preparing in-batch dataset with hard negatives...") |
|
|
|
queries, positives = [], [] |
|
|
|
|
|
for dialogue in dialogues: |
|
turns = dialogue.get('turns', []) |
|
for i in range(len(turns) - 1): |
|
current_turn = turns[i] |
|
next_turn = turns[i+1] |
|
|
|
if (current_turn.get('speaker') == 'user' |
|
and next_turn.get('speaker') == 'assistant' |
|
and 'text' in current_turn |
|
and 'text' in next_turn): |
|
|
|
query_text = current_turn['text'].strip() |
|
pos_text = next_turn['text'].strip() |
|
|
|
queries.append(query_text) |
|
positives.append(pos_text) |
|
|
|
|
|
if debug_samples is not None: |
|
queries = queries[:debug_samples] |
|
positives = positives[:debug_samples] |
|
logger.info(f"Debug mode: limited to {debug_samples} pairs.") |
|
|
|
logger.info(f"Prepared {len(queries)} (query, positive) pairs initially.") |
|
|
|
|
|
|
|
augmented_queries = [] |
|
augmented_positives = [] |
|
|
|
for q_text, p_text in zip(queries, positives): |
|
neg_texts = self._find_hard_negative(q_text, p_text, top_k=5, neg_samples=neg_samples) |
|
for neg_text in neg_texts: |
|
augmented_queries.append(q_text) |
|
augmented_positives.append(neg_text) |
|
|
|
logger.info(f"Found hard negatives for {len(augmented_queries)} queries.") |
|
|
|
|
|
final_queries = queries + augmented_queries |
|
final_positives = positives + augmented_positives |
|
logger.info(f"Total dataset size after adding hard neg: {len(final_queries)}") |
|
|
|
|
|
encoded_queries = self.tokenizer( |
|
final_queries, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.config.max_context_token_limit, |
|
return_tensors='tf' |
|
) |
|
encoded_positives = self.tokenizer( |
|
final_positives, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.config.max_context_token_limit, |
|
return_tensors='tf' |
|
) |
|
|
|
q_tensor = encoded_queries['input_ids'] |
|
p_tensor = encoded_positives['input_ids'] |
|
|
|
logger.info("Tokenized and padded sequences for in-batch training + hard negatives.") |
|
return q_tensor, p_tensor |
|
|
|
def _find_hard_negative( |
|
self, |
|
query_text: str, |
|
positive_text: str, |
|
top_k: int = 5, |
|
neg_samples: int = 1 |
|
) -> List[str]: |
|
""" |
|
Return up to `neg_samples` unique negatives from top_k FAISS results, |
|
excluding the known positive_text. |
|
""" |
|
|
|
query_emb = self.encode_query(query_text) |
|
q_emb_np = query_emb.numpy().astype('float32') |
|
|
|
|
|
faiss.normalize_L2(q_emb_np) |
|
|
|
|
|
distances, indices = self.index.search(q_emb_np, top_k) |
|
|
|
|
|
hard_negatives = [] |
|
for idx in indices[0]: |
|
if idx < len(self.response_pool): |
|
candidate = self.response_pool[idx].strip() |
|
if candidate != positive_text.strip(): |
|
hard_negatives.append(candidate) |
|
if len(hard_negatives) == neg_samples: |
|
break |
|
|
|
return hard_negatives |
|
|
|
def train( |
|
self, |
|
q_pad: tf.Tensor, |
|
p_pad: tf.Tensor, |
|
epochs: int = 20, |
|
batch_size: int = 16, |
|
validation_split: float = 0.2, |
|
checkpoint_dir: str = "checkpoints/", |
|
use_lr_schedule: bool = True, |
|
peak_lr: float = 2e-5, |
|
warmup_steps_ratio: float = 0.1, |
|
early_stopping_patience: int = 3, |
|
min_delta: float = 1e-4, |
|
accum_steps: int = 2 |
|
): |
|
dataset_size = tf.shape(q_pad)[0].numpy() |
|
val_size = int(dataset_size * validation_split) |
|
train_size = dataset_size - val_size |
|
|
|
logger.info(f"Total samples: {dataset_size}") |
|
logger.info(f"Training samples: {train_size}") |
|
logger.info(f"Validation samples: {val_size}") |
|
|
|
steps_per_epoch = train_size // batch_size |
|
if train_size % batch_size != 0: |
|
steps_per_epoch += 1 |
|
total_steps = steps_per_epoch * epochs |
|
logger.info(f"Total training steps (approx): {total_steps}") |
|
|
|
|
|
if use_lr_schedule: |
|
warmup_steps = int(total_steps * warmup_steps_ratio) |
|
lr_schedule = self._get_lr_schedule( |
|
total_steps=total_steps, |
|
peak_lr=peak_lr, |
|
warmup_steps=warmup_steps |
|
) |
|
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) |
|
logger.info("Using custom learning rate schedule.") |
|
else: |
|
self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr) |
|
logger.info("Using fixed learning rate.") |
|
|
|
|
|
train_q = q_pad[:train_size] |
|
train_p = p_pad[:train_size] |
|
val_q = q_pad[train_size:] |
|
val_p = p_pad[train_size:] |
|
|
|
train_dataset = (tf.data.Dataset.from_tensor_slices((train_q, train_p)) |
|
.shuffle(4096) |
|
.batch(batch_size) |
|
.prefetch(tf.data.AUTOTUNE)) |
|
|
|
val_dataset = (tf.data.Dataset.from_tensor_slices((val_q, val_p)) |
|
.batch(batch_size) |
|
.prefetch(tf.data.AUTOTUNE)) |
|
|
|
|
|
checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder) |
|
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3) |
|
|
|
|
|
log_dir = Path(checkpoint_dir) / "tensorboard_logs" |
|
log_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
|
train_log_dir = str(log_dir / f"train_{current_time}") |
|
val_log_dir = str(log_dir / f"val_{current_time}") |
|
|
|
train_summary_writer = tf.summary.create_file_writer(train_log_dir) |
|
val_summary_writer = tf.summary.create_file_writer(val_log_dir) |
|
|
|
logger.info(f"TensorBoard logs will be saved in {log_dir}") |
|
|
|
|
|
best_val_loss = float("inf") |
|
epochs_no_improve = 0 |
|
|
|
logger.info("Beginning training loop...") |
|
global_step = 0 |
|
|
|
|
|
|
|
train_vars = self.encoder.pretrained.trainable_variables |
|
accum_grads = [tf.zeros_like(var, dtype=tf.float32) for var in train_vars] |
|
|
|
from tqdm import tqdm |
|
for epoch in range(1, epochs + 1): |
|
logger.info(f"\n=== Epoch {epoch}/{epochs} ===") |
|
epoch_loss_avg = tf.keras.metrics.Mean() |
|
|
|
step_in_epoch = 0 |
|
with tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}") as pbar: |
|
for (q_batch, p_batch) in train_dataset: |
|
step_in_epoch += 1 |
|
global_step += 1 |
|
|
|
with tf.GradientTape() as tape: |
|
q_enc = self.encoder(q_batch, training=True) |
|
p_enc = self.encoder(p_batch, training=True) |
|
|
|
sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True) |
|
bsz = tf.shape(q_enc)[0] |
|
labels = tf.range(bsz, dtype=tf.int32) |
|
loss_value = tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
labels=labels, logits=sim_matrix |
|
) |
|
loss_value = tf.reduce_mean(loss_value) |
|
|
|
gradients = tape.gradient(loss_value, train_vars) |
|
|
|
|
|
for i, grad in enumerate(gradients): |
|
if grad is not None: |
|
accum_grads[i] += tf.cast(grad, tf.float32) |
|
|
|
epoch_loss_avg(loss_value) |
|
|
|
|
|
if (step_in_epoch % accum_steps) == 0: |
|
|
|
|
|
for i in range(len(accum_grads)): |
|
accum_grads[i] /= accum_steps |
|
|
|
self.optimizer.apply_gradients( |
|
[(accum_grads[i], train_vars[i]) for i in range(len(accum_grads))] |
|
) |
|
|
|
accum_grads = [tf.zeros_like(var, dtype=tf.float32) for var in train_vars] |
|
|
|
|
|
if use_lr_schedule: |
|
|
|
lr = self.optimizer.learning_rate |
|
if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule): |
|
current_step = tf.cast(self.optimizer.iterations, tf.float32) |
|
current_lr = lr(current_step) |
|
else: |
|
current_lr = lr |
|
current_lr_value = float(current_lr.numpy()) |
|
else: |
|
current_lr_value = float(self.optimizer.learning_rate.numpy()) |
|
|
|
pbar.update(1) |
|
pbar.set_postfix({ |
|
"loss": f"{loss_value.numpy():.4f}", |
|
"lr": f"{current_lr_value:.2e}" |
|
}) |
|
|
|
|
|
|
|
|
|
leftover = (step_in_epoch % accum_steps) |
|
if leftover != 0: |
|
logger.info(f"Applying leftover accum_grads for partial batch group (size={leftover}).") |
|
|
|
|
|
|
|
for i in range(len(accum_grads)): |
|
accum_grads[i] *= float(leftover) / float(accum_steps) |
|
|
|
self.optimizer.apply_gradients( |
|
[(accum_grads[i], train_vars[i]) for i in range(len(accum_grads))] |
|
) |
|
accum_grads = [tf.zeros_like(var, dtype=tf.float32) for var in train_vars] |
|
|
|
|
|
val_loss_avg = tf.keras.metrics.Mean() |
|
for q_val, p_val in val_dataset: |
|
q_enc = self.encoder(q_val, training=False) |
|
p_enc = self.encoder(p_val, training=False) |
|
sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True) |
|
bs_val = tf.shape(q_enc)[0] |
|
labels_val = tf.range(bs_val, dtype=tf.int32) |
|
loss_val = tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
labels=labels_val, |
|
logits=sim_matrix |
|
) |
|
val_loss_avg(tf.reduce_mean(loss_val)) |
|
|
|
train_loss = epoch_loss_avg.result().numpy() |
|
val_loss = val_loss_avg.result().numpy() |
|
|
|
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}") |
|
|
|
|
|
with val_summary_writer.as_default(): |
|
tf.summary.scalar("val_loss", val_loss, step=epoch) |
|
|
|
|
|
manager.save() |
|
|
|
|
|
self.history['train_loss'].append(train_loss) |
|
self.history['val_loss'].append(val_loss) |
|
self.history.setdefault('learning_rate', []).append(float(current_lr_value)) |
|
|
|
|
|
if val_loss < best_val_loss - min_delta: |
|
best_val_loss = val_loss |
|
epochs_no_improve = 0 |
|
logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.") |
|
else: |
|
epochs_no_improve += 1 |
|
logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}") |
|
if epochs_no_improve >= early_stopping_patience: |
|
logger.info("Early stopping triggered.") |
|
break |
|
|
|
logger.info("In-batch training completed!") |
|
|
|
def _get_lr_schedule( |
|
self, |
|
total_steps: int, |
|
peak_lr: float, |
|
warmup_steps: int |
|
) -> tf.keras.optimizers.schedules.LearningRateSchedule: |
|
"""Create a custom learning rate schedule with warmup and cosine decay.""" |
|
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): |
|
def __init__( |
|
self, |
|
total_steps: int, |
|
peak_lr: float, |
|
warmup_steps: int |
|
): |
|
super().__init__() |
|
self.total_steps = tf.cast(total_steps, tf.float32) |
|
self.peak_lr = tf.cast(peak_lr, tf.float32) |
|
|
|
|
|
adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10)) |
|
self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32) |
|
|
|
|
|
self.initial_lr = self.peak_lr * 0.1 |
|
self.min_lr = self.peak_lr * 0.01 |
|
|
|
logger.info(f"Learning rate schedule initialized:") |
|
logger.info(f" Initial LR: {float(self.initial_lr):.6f}") |
|
logger.info(f" Peak LR: {float(self.peak_lr):.6f}") |
|
logger.info(f" Min LR: {float(self.min_lr):.6f}") |
|
logger.info(f" Warmup steps: {int(self.warmup_steps)}") |
|
logger.info(f" Total steps: {int(self.total_steps)}") |
|
|
|
def __call__(self, step): |
|
step = tf.cast(step, tf.float32) |
|
|
|
|
|
warmup_factor = tf.minimum(1.0, step / self.warmup_steps) |
|
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor |
|
|
|
|
|
decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps) |
|
decay_factor = (step - self.warmup_steps) / decay_steps |
|
decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) |
|
|
|
cosine_decay = 0.5 * (1.0 + tf.cos(tf.constant(math.pi) * decay_factor)) |
|
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay |
|
|
|
|
|
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr) |
|
|
|
|
|
final_lr = tf.maximum(self.min_lr, final_lr) |
|
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr) |
|
|
|
return final_lr |
|
|
|
def get_config(self): |
|
return { |
|
"total_steps": self.total_steps, |
|
"peak_lr": self.peak_lr, |
|
"warmup_steps": self.warmup_steps, |
|
} |
|
|
|
return CustomSchedule(total_steps, peak_lr, warmup_steps) |
|
|
|
def _cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> np.ndarray: |
|
"""Compute cosine similarity between two numpy arrays.""" |
|
normalized_emb1 = emb1 / np.linalg.norm(emb1, axis=1, keepdims=True) |
|
normalized_emb2 = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True) |
|
return np.dot(normalized_emb1, normalized_emb2.T) |
|
|
|
def chat( |
|
self, |
|
query: str, |
|
conversation_history: Optional[List[Tuple[str, str]]] = None, |
|
quality_checker: Optional['ResponseQualityChecker'] = None, |
|
top_k: int = 5, |
|
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]: |
|
""" |
|
Example chat method that always uses cross-encoder re-ranking |
|
if self.reranker is available. |
|
""" |
|
@self.run_on_device |
|
def get_response(self_arg, query_arg): |
|
|
|
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history) |
|
|
|
|
|
results = self_arg.retrieve_responses_cross_encoder( |
|
query=conversation_str, |
|
top_k=top_k, |
|
reranker=self_arg.reranker, |
|
summarizer=self_arg.summarizer, |
|
summarize_threshold=512 |
|
) |
|
|
|
|
|
if not results: |
|
return ( |
|
"I'm sorry, but I couldn't find a relevant response.", |
|
[], |
|
{} |
|
) |
|
|
|
if quality_checker: |
|
metrics = quality_checker.check_response_quality(query_arg, results) |
|
if not metrics.get('is_confident', False): |
|
return ( |
|
"I need more information to provide a good answer. Could you please clarify?", |
|
results, |
|
metrics |
|
) |
|
return results[0][0], results, metrics |
|
|
|
return results[0][0], results, {} |
|
|
|
return get_response(self, query) |
|
|
|
def _build_conversation_context( |
|
self, |
|
query: str, |
|
conversation_history: Optional[List[Tuple[str, str]]] |
|
) -> str: |
|
"""Build conversation context with better memory management.""" |
|
if not conversation_history: |
|
return f"{self.special_tokens['user']} {query}" |
|
|
|
conversation_parts = [] |
|
for user_txt, assistant_txt in conversation_history: |
|
conversation_parts.extend([ |
|
f"{self.special_tokens['user']} {user_txt}", |
|
f"{self.special_tokens['assistant']} {assistant_txt}" |
|
]) |
|
|
|
conversation_parts.append(f"{self.special_tokens['user']} {query}") |
|
return "\n".join(conversation_parts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|