from typing import Dict, List import numpy as np import torch import tensorflow as tf import tensorflow_hub as hub import re from pipeline_config import PipelineConfig from quality_metrics import QualityMetrics from paraphraser import Paraphraser from back_translator import BackTranslator import nlpaug.augmenter.word as naw from concurrent.futures import ThreadPoolExecutor from functools import lru_cache from sklearn.metrics.pairwise import cosine_similarity class DialogueAugmenter: """ Optimized dialogue augmentation with quality control and complexity management. """ def __init__(self, nlp, config: PipelineConfig): self.nlp = nlp self.config = config # Detect hardware and set appropriate batch sizes and optimization strategy self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.use_gpu = torch.cuda.is_available() if self.config.debug: print(f"Using device: {self.device}") if self.use_gpu: print(f"GPU Device: {torch.cuda.get_device_name(0)}") # Load base models self.quality_metrics = QualityMetrics(config) self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4') # Initialize augmentation models based on hardware self._initialize_augmentation_models() # Initialize caches self.embedding_cache = {} self.perplexity_cache = {} # Compile regex patterns self.spelling_pattern = re.compile(r'[a-zA-Z]{3,}') # GPU memory management if available if self.use_gpu: gpus = tf.config.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) def _initialize_augmentation_models(self): """Initialize augmentation models with appropriate device settings""" # Advanced augmentation techniques self.paraphraser = Paraphraser() self.back_translator = BackTranslator() if self.use_gpu: # Move models to GPU if available self.paraphraser.model = self.paraphraser.model.to(self.device) self.back_translator.model_pivot_forward = self.back_translator.model_pivot_forward.to(self.device) self.back_translator.model_pivot_backward = self.back_translator.model_pivot_backward.to(self.device) self.back_translator.model_backward = self.back_translator.model_backward.to(self.device) # Basic augmentation techniques self.word_augmenter = naw.SynonymAug(aug_src='wordnet') self.spelling_augmenter = naw.SpellingAug() self.augmenters = { 'advanced': [self.paraphraser, self.back_translator], 'basic': [ ('synonym', self.word_augmenter), ('spelling', self.spelling_augmenter) ] } @lru_cache(maxsize=1024) def _compute_embedding(self, text: str) -> np.ndarray: """Cached computation of text embedding""" if text in self.embedding_cache: return self.embedding_cache[text] embedding = self.use_model([text])[0].numpy() self.embedding_cache[text] = embedding return embedding def _compute_batch_embeddings(self, texts: List[str]) -> np.ndarray: """Compute embeddings for multiple texts at once with hardware optimization""" # Check cache first uncached_texts = [t for t in texts if t not in self.embedding_cache] if uncached_texts: embeddings = self.use_model(uncached_texts).numpy() # Update cache for text, embedding in zip(uncached_texts, embeddings): self.embedding_cache[text] = embedding # Return all embeddings (from cache or newly computed) return np.array([self.embedding_cache[t] for t in texts]) def _quick_quality_check(self, variation: str, original: str) -> bool: """ Stricter preliminary quality check while maintaining reasonable pass rates """ if self.config.debug: print(f"\nQuick check for variation: {variation}") # Stricter length check orig_len = len(original.split()) var_len = len(variation.split()) # For very short texts (1-3 words), still allow more variation if orig_len <= 3: if var_len > orig_len * 3: # Reduced from 4x to 3x if self.config.debug: print(f"Failed length check (short text): {var_len} vs {orig_len}") return False else: if var_len > orig_len * 2: # Reduced from 3x to 2x if self.config.debug: print(f"Failed length check (long text): {var_len} vs {orig_len}") return False # Enhanced content check - more words in common stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'} orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words) var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words) # Require more content word overlap content_overlap = len(orig_words.intersection(var_words)) / len(orig_words) if orig_words else 0 if content_overlap < 0.3: # Increased from no minimum to 30% overlap if self.config.debug: print(f"Failed content check: overlap {content_overlap:.2f}") return False if self.config.debug: print("Passed all quick checks") return True def _compute_metrics_parallel(self, original: str, candidates: List[str]) -> List[Dict[str, float]]: """Compute quality metrics for multiple candidates in parallel""" with ThreadPoolExecutor(max_workers=4) as executor: futures = [ executor.submit(self.quality_metrics.compute_metrics, original, candidate) for candidate in candidates ] return [future.result() for future in futures] def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]: """ Filter variations using batched computations with detailed logging """ if not variations: return [] if self.config.debug: print(f"\nStarting filtration of {len(variations)} variations") print(f"Context length: {len(context)}") print(f"Original turn: {original_turn}") words = original_turn.split() if len(words) < 3: if self.config.debug: print("Short text detected, using predefined variations") short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''}) return [var['text'] for var in short_text_variations] # If this is the first turn (no context), be more lenient if not context: preliminary_filtered = variations if self.config.debug: print("First turn - skipping preliminary filtering") else: # Quick preliminary filtering against original turn preliminary_filtered = [] for var in variations: passed = self._quick_quality_check(var, original_turn) if self.config.debug: print(f"\nVariation: {var}") print(f"Passed quick check: {passed}") if passed: preliminary_filtered.append(var) if self.config.debug: print(f"Variations after quick check: {len(preliminary_filtered)}") if not preliminary_filtered: return [] # Only use last turn for coherence recent_context = [context[-1]] if context else [] context_text = ' '.join(recent_context) if recent_context else '' # Even more lenient thresholds min_similarity = 0.1 # Further reduced min_coherence = 0.05 # Further reduced if context_text: if self.config.debug: print(f"\nContext text: {context_text}") all_texts = [context_text] + preliminary_filtered all_embeddings = self._compute_batch_embeddings(all_texts) context_embedding = all_embeddings[0] variation_embeddings = all_embeddings[1:] # Vectorized similarity computation context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0] # Response coherence check if recent_context: prev_embedding = self._compute_embedding(recent_context[-1]) response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0] else: response_coherence = np.ones_like(context_similarities) # Combined scoring with detailed logging filtered_variations = [] for i, (variation, sim, coh) in enumerate(zip( preliminary_filtered, context_similarities, response_coherence)): # Use absolute values for scoring combined_score = ( self.config.context_similarity_weight * abs(sim) + self.config.response_coherence_weight * abs(coh) ) if self.config.debug: print(f"\nVariation: {variation}") print(f"Context similarity: {sim:.3f}") print(f"Response coherence: {coh:.3f}") print(f"Combined score: {combined_score:.3f}") # Accept if EITHER score is good enough if (combined_score >= min_similarity or abs(coh) >= min_coherence): filtered_variations.append(variation) if self.config.debug: print("ACCEPTED") else: if self.config.debug: print("REJECTED") # If we have enough variations, stop if len(filtered_variations) >= self.config.max_variations_per_turn: break else: filtered_variations = preliminary_filtered[:self.config.max_variations_per_turn] if self.config.debug: print(f"\nFinal filtered variations: {len(filtered_variations)}") return filtered_variations def _generate_variations_progressive(self, text: str, needed: int) -> List[str]: """ Generate variations progressively until we have enough good ones """ variations = set() if self.config.debug: print(f"\nAttempting to generate {needed} variations for text: {text}") # Try advanced augmenters first for augmenter in self.augmenters['advanced']: if len(variations) >= needed: break try: if isinstance(augmenter, Paraphraser): if self.config.debug: print("Trying paraphrase augmentation...") new_vars = augmenter.paraphrase(text, num_return_sequences=needed-len(variations)) if self.config.debug: print(f"Paraphraser generated {len(new_vars)} variations") else: if self.config.debug: print("Trying back translation...") new_vars = [augmenter.back_translate(text)] if self.config.debug: print(f"Back translator generated {len(new_vars)} variations") valid_vars = [v for v in new_vars if v.strip() and v != text] variations.update(valid_vars) if self.config.debug: print(f"Current unique variations: {len(variations)}") except Exception as e: print(f"Error in advanced augmentation: {str(e)}") continue # Try basic augmenters if needed if len(variations) < needed: if self.config.debug: print("Not enough variations, trying basic augmenters...") for aug_type, augmenter in self.augmenters['basic']: if len(variations) >= needed: break try: if aug_type == 'spelling' and self._is_technical_or_formal_text(text): if self.config.debug: print("Skipping spelling augmentation for technical text") continue if self.config.debug: print(f"Trying {aug_type} augmentation...") new_vars = augmenter.augment(text, n=2) if isinstance(new_vars, list): valid_vars = [v for v in new_vars if v.strip() and v != text] variations.update(valid_vars) else: if new_vars.strip() and new_vars != text: variations.add(new_vars) if self.config.debug: print(f"After {aug_type}, total variations: {len(variations)}") except Exception as e: print(f"Error in {aug_type} augmentation: {str(e)}") continue variations_list = list(variations) if self.config.debug: print(f"Final number of variations generated: {len(variations_list)}") if not variations_list: print("WARNING: No variations were generated!") return variations_list def augment_dialogue(self, dialogue: Dict) -> List[Dict]: """ Create augmented versions of the dialogue with optimized processing """ # Early dialogue length check original_length = len(dialogue['turns']) if original_length > self.config.max_turns_per_dialogue: if self.config.debug: print(f"Truncating dialogue from {original_length} to {self.config.max_turns_per_dialogue} turns") dialogue['turns'] = dialogue['turns'][:self.config.max_turns_per_dialogue] turn_variations = [] context = [] # Process each turn with progressive generation for turn in dialogue['turns']: original_text = turn['text'] # Store original turn text variations = self._generate_variations_progressive( original_text, self.config.max_variations_per_turn ) # Batch filter variations with original text filtered_variations = self._filter_variations_batch( variations, context, original_text # Pass the original turn text ) # Create turn variations with speaker info turn_vars = [{'speaker': turn['speaker'], 'text': v} for v in filtered_variations] if self.config.debug: print(f"Turn {len(turn_variations)}: Generated {len(turn_vars)} variations") turn_variations.append(turn_vars) context.append(original_text) # Generate combinations with sampling augmented_dialogues = self._generate_dialogue_combinations( dialogue['dialogue_id'], turn_variations ) # Add original dialogue result = [{ 'dialogue_id': f"{dialogue['dialogue_id']}_original", 'turns': dialogue['turns'] }] # Add unique augmentations result.extend(augmented_dialogues[:self.config.augmentation_factor]) if self.config.debug: print(f"Generated {len(result)-1} unique augmented dialogues") return result def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]: """ Generate dialogue combinations using sampling """ augmented_dialogues = [] used_combinations = set() def generate_dialogues(current_turns=None, turn_index=0): if current_turns is None: current_turns = [] if len(augmented_dialogues) >= self.config.augmentation_factor: return if turn_index == len(turn_variations): dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns) if dialogue_fingerprint not in used_combinations: used_combinations.add(dialogue_fingerprint) augmented_dialogues.append({ 'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}", 'turns': current_turns.copy() }) return variations = list(turn_variations[turn_index]) np.random.shuffle(variations) for variation in variations[:self.config.max_sampled_variations]: if len(augmented_dialogues) >= self.config.augmentation_factor: return current_turns.append(variation) generate_dialogues(current_turns, turn_index + 1) current_turns.pop() try: generate_dialogues() except Exception as e: print(f"Error in dialogue generation: {str(e)}") return [] return augmented_dialogues def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool: """ Check if two dialogues are duplicates. """ text1 = " ".join(turn['text'] for turn in dialogue1['turns']) text2 = " ".join(turn['text'] for turn in dialogue2['turns']) return text1 == text2 def _augment_short_text(self, turn: Dict) -> List[Dict]: """ Special handling for very short texts with predefined variations. Args: turn (Dict): Original dialogue turn Returns: List[Dict]: List of variations for the short text """ text = turn['text'] common_variations = { 'goodbye': [ 'Bye!', 'Farewell!', 'See you!', 'Take care!', 'Goodbye!', 'Bye for now!', 'Until next time!' ], 'hello': [ 'Hi!', 'Hey!', 'Hello!', 'Greetings!', 'Good day!', 'Hi there!', 'Hello there!' ], 'yes': [ 'Yes!', 'Correct!', 'Indeed!', 'Absolutely!', 'That\'s right!', 'Definitely!', 'Sure!' ], 'no': [ 'No!', 'Nope!', 'Not at all!', 'Negative!', 'Unfortunately not!', 'I\'m afraid not!' ], 'thanks': [ 'Thank you!', 'Thanks a lot!', 'Many thanks!', 'I appreciate it!', 'Thank you so much!' ], 'ok': [ 'Okay!', 'Alright!', 'Sure!', 'Got it!', 'Understood!', 'Fine!', 'Great!', 'Perfect!', 'That works!', 'Sounds good!' ], 'good': [ 'Great!', 'Excellent!', 'Perfect!', 'Wonderful!', 'Fantastic!', 'Amazing!', 'Terrific!' ] } # Try to find matching variations text_lower = text.lower().rstrip('!.,?') variations = [] # Check if text matches any of our predefined categories for key, predefined_vars in common_variations.items(): if key in text_lower or text_lower in key: variations.extend(predefined_vars) # If no predefined variations found, generate simple variants if not variations: # Add punctuation variations variations = [ text.rstrip('!.,?') + '!', text.rstrip('!.,?') + '.', text.rstrip('!.,?') ] # Add capitalization variations variations.extend([ v.capitalize() for v in variations if v.capitalize() not in variations ]) # Filter variations for uniqueness and quality unique_variations = list(set(variations)) quality_variations = [] for var in unique_variations: metrics = self.quality_metrics.compute_metrics(text, var) quality_score = ( 0.35 * metrics['semantic_similarity'] + 0.30 * (1.0 - metrics['perplexity'] / 100) + 0.15 * (1.0 - metrics['grammar_errors'] / 10) + 0.15 * metrics['content_preservation'] + 0.10 * metrics['type_token_ratio'] ) # More lenient quality threshold for short texts if quality_score >= 0.5: # Lower threshold for short texts quality_variations.append(var) # Ensure we have at least some variations if not quality_variations: quality_variations = [text] # Return the variations with original speaker return [{'speaker': turn['speaker'], 'text': v} for v in quality_variations[:self.config.augmentation_factor]] def _is_technical_or_formal_text(self, text: str) -> bool: """ Check if text is formal/technical and shouldn't have spelling variations. """ formal_indicators = { 'technical_terms': {'api', 'config', 'database', 'server', 'system'}, 'formal_phrases': {'please advise', 'regarding', 'furthermore', 'moreover'}, 'professional_context': {'meeting', 'conference', 'project', 'deadline'} } text_lower = text.lower() words = set(text_lower.split()) for category in formal_indicators.values(): if words.intersection(category): return True return False