pharmacy-mcp / caching.py
Chris McMaster
Updates, improvements, new ADR features
f32824f
import json
import hashlib
import logging
from datetime import datetime, timedelta
from functools import wraps
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
class SimpleCache:
def __init__(self, default_ttl: int = 3600):
self.cache: Dict[str, Dict[str, Any]] = {}
self.default_ttl = default_ttl
self.stats = {"hits": 0, "misses": 0, "expired": 0}
def _is_expired(self, entry: Dict[str, Any]) -> bool:
return datetime.now() > entry['expires']
def get(self, key: str) -> Optional[Any]:
if key in self.cache:
entry = self.cache[key]
if not self._is_expired(entry):
self.stats["hits"] += 1
logger.debug(f"Cache hit for key: {key[:20]}...")
return entry['data']
else:
del self.cache[key]
self.stats["expired"] += 1
logger.debug(f"Cache expired for key: {key[:20]}...")
self.stats["misses"] += 1
return None
def set(self, key: str, data: Any, ttl: Optional[int] = None) -> None:
ttl = ttl or self.default_ttl
self.cache[key] = {
'data': data,
'expires': datetime.now() + timedelta(seconds=ttl)
}
logger.debug(f"Cached data for key: {key[:20]}... (TTL: {ttl}s)")
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics for monitoring."""
total_requests = sum(self.stats.values())
hit_rate = self.stats["hits"] / total_requests if total_requests > 0 else 0
return {
**self.stats,
"hit_rate": round(hit_rate, 3),
"cache_size": len(self.cache)
}
def clear_expired(self) -> int:
"""Manually clear expired entries and return count cleared."""
expired_keys = [
key for key, entry in self.cache.items()
if self._is_expired(entry)
]
for key in expired_keys:
del self.cache[key]
return len(expired_keys)
api_cache = SimpleCache()
def generate_cache_key(*args: Any, **kwargs: Any) -> str:
"""Generate a cache key from function arguments."""
try:
# Convert args to strings to avoid serialization issues
safe_args = []
for arg in args:
if isinstance(arg, (str, int, float, bool, type(None))):
safe_args.append(arg)
else:
safe_args.append(str(arg))
key_data = json.dumps([safe_args, sorted(kwargs.items())], sort_keys=True, default=str)
return hashlib.md5(key_data.encode()).hexdigest()
except Exception:
fallback_key = f"{args}_{kwargs}"
return hashlib.md5(fallback_key.encode()).hexdigest()
def with_caching(ttl: int = 3600):
"""Decorator to add caching to functions."""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
cache_key = f"{func.__name__}:{generate_cache_key(*args, **kwargs)}"
cached_result = api_cache.get(cache_key)
if cached_result is not None:
return cached_result
result = func(*args, **kwargs)
api_cache.set(cache_key, result, ttl)
return result
return wrapper
return decorator