Spaces:
Runtime error
Runtime error
"""Multimodal reasoning strategy implementation.""" | |
import logging | |
from typing import Dict, Any, List, Optional, Set, Union, Type, Tuple | |
import json | |
from dataclasses import dataclass, field | |
from enum import Enum | |
from datetime import datetime | |
import numpy as np | |
from collections import defaultdict | |
from .base import ReasoningStrategy, StrategyResult | |
class ModalityType(Enum): | |
"""Types of modalities supported.""" | |
TEXT = "text" | |
IMAGE = "image" | |
AUDIO = "audio" | |
VIDEO = "video" | |
STRUCTURED = "structured" | |
class ModalityFeatures: | |
"""Features extracted from different modalities.""" | |
text: List[Dict[str, Any]] | |
image: Optional[List[Dict[str, Any]]] = None | |
audio: Optional[List[Dict[str, Any]]] = None | |
video: Optional[List[Dict[str, Any]]] = None | |
structured: Optional[List[Dict[str, Any]]] = None | |
timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) | |
class ModalityAlignment: | |
"""Alignment between different modalities.""" | |
modality1: ModalityType | |
modality2: ModalityType | |
features1: Dict[str, Any] | |
features2: Dict[str, Any] | |
similarity: float | |
timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) | |
class MultimodalStrategy(ReasoningStrategy): | |
"""Advanced multimodal reasoning that: | |
1. Processes multiple input types | |
2. Combines modalities | |
3. Extracts cross-modal patterns | |
4. Generates multimodal insights | |
5. Validates consistency | |
""" | |
def __init__(self, config: Optional[Dict[str, Any]] = None): | |
"""Initialize multimodal reasoning.""" | |
super().__init__() | |
self.config = config or {} | |
# Standard reasoning parameters | |
self.min_confidence = self.config.get('min_confidence', 0.7) | |
self.min_similarity = self.config.get('min_similarity', 0.7) | |
# Configure modality weights | |
self.weights = self.config.get('modality_weights', { | |
ModalityType.TEXT.value: 0.4, | |
ModalityType.IMAGE.value: 0.3, | |
ModalityType.AUDIO.value: 0.1, | |
ModalityType.VIDEO.value: 0.1, | |
ModalityType.STRUCTURED.value: 0.1 | |
}) | |
# Performance metrics | |
self.performance_metrics = { | |
'processed_modalities': defaultdict(int), | |
'alignments_found': 0, | |
'successful_alignments': 0, | |
'failed_alignments': 0, | |
'avg_similarity': 0.0, | |
'modality_distribution': defaultdict(int), | |
'total_features_extracted': 0, | |
'total_alignments_created': 0, | |
'total_integrations': 0 | |
} | |
async def reason( | |
self, | |
query: str, | |
context: Dict[str, Any] | |
) -> StrategyResult: | |
""" | |
Apply multimodal reasoning to process and integrate different types of information. | |
Args: | |
query: The input query to reason about | |
context: Additional context and parameters | |
Returns: | |
StrategyResult containing the reasoning output and metadata | |
""" | |
try: | |
# Process across modalities | |
modalities = await self._process_modalities(query, context) | |
self.performance_metrics['total_features_extracted'] = sum( | |
len(features) for features in modalities.values() | |
) | |
# Update modality distribution | |
for modality, features in modalities.items(): | |
self.performance_metrics['modality_distribution'][modality] += len(features) | |
# Align cross-modal information | |
alignments = await self._cross_modal_alignment(modalities, context) | |
self.performance_metrics['total_alignments_created'] = len(alignments) | |
# Integrate aligned information | |
integration = await self._integrated_analysis(alignments, context) | |
self.performance_metrics['total_integrations'] = len(integration) | |
# Generate final response | |
response = await self._generate_response(integration, context) | |
# Build reasoning trace | |
reasoning_trace = self._build_reasoning_trace( | |
modalities, alignments, integration, response | |
) | |
# Calculate final confidence | |
confidence = self._calculate_confidence(integration) | |
if confidence >= self.min_confidence: | |
return StrategyResult( | |
strategy_type="multimodal", | |
success=True, | |
answer=response.get('text'), | |
confidence=confidence, | |
reasoning_trace=reasoning_trace, | |
metadata={ | |
'modalities': list(modalities.keys()), | |
'alignments': len(alignments), | |
'integration_size': len(integration) | |
}, | |
performance_metrics=self.performance_metrics | |
) | |
return StrategyResult( | |
strategy_type="multimodal", | |
success=False, | |
answer=None, | |
confidence=confidence, | |
reasoning_trace=reasoning_trace, | |
metadata={'error': 'Insufficient confidence in results'}, | |
performance_metrics=self.performance_metrics | |
) | |
except Exception as e: | |
logging.error(f"Multimodal reasoning error: {str(e)}") | |
return StrategyResult( | |
strategy_type="multimodal", | |
success=False, | |
answer=None, | |
confidence=0.0, | |
reasoning_trace=[{ | |
'step': 'error', | |
'error': str(e), | |
'timestamp': datetime.now().isoformat() | |
}], | |
metadata={'error': str(e)}, | |
performance_metrics=self.performance_metrics | |
) | |
async def _process_modalities( | |
self, | |
query: str, | |
context: Dict[str, Any] | |
) -> Dict[str, List[Dict[str, Any]]]: | |
"""Process query across different modalities.""" | |
modalities = {} | |
# Process text | |
if 'text' in context: | |
modalities[ModalityType.TEXT.value] = self._process_text(context['text']) | |
self.performance_metrics['processed_modalities'][ModalityType.TEXT.value] += 1 | |
# Process images | |
if 'images' in context: | |
modalities[ModalityType.IMAGE.value] = self._process_images(context['images']) | |
self.performance_metrics['processed_modalities'][ModalityType.IMAGE.value] += 1 | |
# Process audio | |
if 'audio' in context: | |
modalities[ModalityType.AUDIO.value] = self._process_audio(context['audio']) | |
self.performance_metrics['processed_modalities'][ModalityType.AUDIO.value] += 1 | |
# Process video | |
if 'video' in context: | |
modalities[ModalityType.VIDEO.value] = self._process_video(context['video']) | |
self.performance_metrics['processed_modalities'][ModalityType.VIDEO.value] += 1 | |
# Process structured data | |
if 'structured' in context: | |
modalities[ModalityType.STRUCTURED.value] = self._process_structured(context['structured']) | |
self.performance_metrics['processed_modalities'][ModalityType.STRUCTURED.value] += 1 | |
return modalities | |
async def _cross_modal_alignment( | |
self, | |
modalities: Dict[str, List[Dict[str, Any]]], | |
context: Dict[str, Any] | |
) -> List[ModalityAlignment]: | |
"""Align information across different modalities.""" | |
alignments = [] | |
# Get all modality pairs | |
modality_pairs = [ | |
(m1, m2) for i, m1 in enumerate(modalities.keys()) | |
for m2 in list(modalities.keys())[i+1:] | |
] | |
# Align each pair | |
for mod1, mod2 in modality_pairs: | |
items1 = modalities[mod1] | |
items2 = modalities[mod2] | |
# Calculate cross-modal similarities | |
for item1 in items1: | |
for item2 in items2: | |
similarity = self._calculate_similarity(item1, item2) | |
self.performance_metrics['alignments_found'] += 1 | |
if similarity >= self.min_similarity: | |
self.performance_metrics['successful_alignments'] += 1 | |
alignments.append(ModalityAlignment( | |
modality1=ModalityType(mod1), | |
modality2=ModalityType(mod2), | |
features1=item1, | |
features2=item2, | |
similarity=similarity | |
)) | |
else: | |
self.performance_metrics['failed_alignments'] += 1 | |
# Update average similarity | |
if alignments: | |
self.performance_metrics['avg_similarity'] = ( | |
sum(a.similarity for a in alignments) / len(alignments) | |
) | |
return alignments | |
def _calculate_similarity( | |
self, | |
item1: Dict[str, Any], | |
item2: Dict[str, Any] | |
) -> float: | |
"""Calculate similarity between two items from different modalities.""" | |
# Simple feature overlap for now | |
features1 = set(str(v) for v in item1.values()) | |
features2 = set(str(v) for v in item2.values()) | |
if not features1 or not features2: | |
return 0.0 | |
overlap = len(features1.intersection(features2)) | |
total = len(features1.union(features2)) | |
return overlap / total if total > 0 else 0.0 | |
async def _integrated_analysis( | |
self, | |
alignments: List[ModalityAlignment], | |
context: Dict[str, Any] | |
) -> List[Dict[str, Any]]: | |
"""Perform integrated analysis of aligned information.""" | |
integrated = [] | |
# Group alignments by similarity | |
similarity_groups = defaultdict(list) | |
for align in alignments: | |
similarity_groups[align.similarity].append(align) | |
# Process groups in order of similarity | |
for similarity, group in sorted( | |
similarity_groups.items(), | |
key=lambda x: x[0], | |
reverse=True | |
): | |
# Combine aligned features | |
for align in group: | |
integrated.append({ | |
'features': { | |
**align.features1, | |
**align.features2 | |
}, | |
'modalities': [ | |
align.modality1.value, | |
align.modality2.value | |
], | |
'confidence': align.similarity, | |
'timestamp': align.timestamp | |
}) | |
return integrated | |
async def _generate_response( | |
self, | |
integration: List[Dict[str, Any]], | |
context: Dict[str, Any] | |
) -> Dict[str, Any]: | |
"""Generate coherent response from integrated analysis.""" | |
if not integration: | |
return {'text': '', 'confidence': 0.0} | |
# Combine all integrated features | |
all_features = {} | |
for item in integration: | |
all_features.update(item['features']) | |
# Generate response text | |
response_text = [] | |
# Add main findings | |
response_text.append("Main findings across modalities:") | |
for feature, value in all_features.items(): | |
response_text.append(f"- {feature}: {value}") | |
# Add confidence | |
confidence = sum(item['confidence'] for item in integration) / len(integration) | |
response_text.append(f"\nOverall confidence: {confidence:.2f}") | |
return { | |
'text': "\n".join(response_text), | |
'confidence': confidence, | |
'timestamp': datetime.now().isoformat() | |
} | |
def _calculate_confidence(self, integration: List[Dict[str, Any]]) -> float: | |
"""Calculate overall confidence score.""" | |
if not integration: | |
return 0.0 | |
# Base confidence | |
confidence = 0.5 | |
# Adjust based on number of modalities | |
unique_modalities = set() | |
for item in integration: | |
unique_modalities.update(item['modalities']) | |
modality_bonus = len(unique_modalities) * 0.1 | |
confidence += min(modality_bonus, 0.3) | |
# Adjust based on integration quality | |
avg_similarity = sum( | |
item['confidence'] for item in integration | |
) / len(integration) | |
confidence += avg_similarity * 0.2 | |
return min(confidence, 1.0) | |
def _build_reasoning_trace( | |
self, | |
modalities: Dict[str, List[Dict[str, Any]]], | |
alignments: List[ModalityAlignment], | |
integration: List[Dict[str, Any]], | |
response: Dict[str, Any] | |
) -> List[Dict[str, Any]]: | |
"""Build the reasoning trace for multimodal processing.""" | |
trace = [] | |
# Modality processing step | |
trace.append({ | |
'step': 'modality_processing', | |
'modalities': { | |
mod: len(features) | |
for mod, features in modalities.items() | |
}, | |
'timestamp': datetime.now().isoformat() | |
}) | |
# Alignment step | |
trace.append({ | |
'step': 'cross_modal_alignment', | |
'alignments': [ | |
{ | |
'modalities': [a.modality1.value, a.modality2.value], | |
'similarity': a.similarity | |
} | |
for a in alignments | |
], | |
'timestamp': datetime.now().isoformat() | |
}) | |
# Integration step | |
trace.append({ | |
'step': 'integration', | |
'integrated_items': len(integration), | |
'timestamp': datetime.now().isoformat() | |
}) | |
# Response generation step | |
trace.append({ | |
'step': 'response_generation', | |
'response': response, | |
'timestamp': datetime.now().isoformat() | |
}) | |
return trace | |
def _process_text(self, text: str) -> List[Dict[str, Any]]: | |
"""Process text modality.""" | |
# Simple text processing for now | |
return [{'text': text, 'timestamp': datetime.now().isoformat()}] | |
def _process_images(self, images: List[str]) -> List[Dict[str, Any]]: | |
"""Process image modality.""" | |
# Simple image processing for now | |
return [{ | |
'image': image, | |
'timestamp': datetime.now().isoformat() | |
} for image in images] | |
def _process_audio(self, audio: List[str]) -> List[Dict[str, Any]]: | |
"""Process audio modality.""" | |
# Simple audio processing for now | |
return [{ | |
'audio': audio_file, | |
'timestamp': datetime.now().isoformat() | |
} for audio_file in audio] | |
def _process_video(self, video: List[str]) -> List[Dict[str, Any]]: | |
"""Process video modality.""" | |
# Simple video processing for now | |
return [{ | |
'video': video_file, | |
'timestamp': datetime.now().isoformat() | |
} for video_file in video] | |
def _process_structured(self, structured: Dict[str, Any]) -> List[Dict[str, Any]]: | |
"""Process structured data modality.""" | |
# Simple structured data processing for now | |
return [{ | |
'structured': structured, | |
'timestamp': datetime.now().isoformat() | |
}] | |