Spaces:
Running
Running
import json | |
import asyncio | |
from ollama import AsyncClient, ResponseError | |
from typing import Dict, List, Any | |
import time | |
from datetime import datetime | |
import logging | |
from tqdm import tqdm | |
import os | |
# Configure logging with more detailed format | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler('model_testing.log'), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Ollama client configuration | |
OLLAMA_HOST = "http://20.7.14.25:11434" | |
client = AsyncClient(host=OLLAMA_HOST) | |
# Create results directory if it doesn't exist | |
RESULTS_DIR = "model_test_results" | |
os.makedirs(RESULTS_DIR, exist_ok=True) | |
async def check_model_exists(model_name: str) -> bool: | |
"""Check if a model is already pulled.""" | |
try: | |
# Try to get model info - if it exists, it will return successfully | |
await client.show(model_name) | |
return True | |
except ResponseError: | |
return False | |
except Exception as e: | |
logger.error(f"Error checking if model {model_name} exists: {e}") | |
return False | |
async def load_approved_models() -> List[str]: | |
"""Load approved models from the JSON file.""" | |
try: | |
with open('approved_models.json', 'r') as f: | |
data = json.load(f) | |
models = [model[0] for model in data['approved_models']] | |
logger.info(f"Successfully loaded {len(models)} models from approved_models.json") | |
return models | |
except Exception as e: | |
logger.error(f"Error loading approved models: {e}") | |
return [] | |
async def pull_model(model_name: str) -> bool: | |
"""Pull a model from Ollama.""" | |
try: | |
logger.info(f"Starting to pull model: {model_name}") | |
start_time = time.time() | |
await client.pull(model_name) | |
end_time = time.time() | |
logger.info(f"Successfully pulled {model_name} in {end_time - start_time:.2f} seconds") | |
return True | |
except ResponseError as e: | |
logger.error(f"Error pulling model {model_name}: {e}") | |
return False | |
except Exception as e: | |
logger.error(f"Unexpected error while pulling {model_name}: {e}") | |
return False | |
async def check_loaded_models(): | |
"""Check if there are any models currently loaded in memory.""" | |
try: | |
# Use ollama ps to check loaded models | |
ps_response = await client.ps() | |
if ps_response and hasattr(ps_response, 'models'): | |
logger.warning("Found loaded models in memory. Waiting for keep_alive to unload them...") | |
# Just log the loaded models, they will be unloaded by keep_alive: 0 | |
for model in ps_response.models: | |
if model.name: | |
logger.info(f"Model currently loaded in memory: {model.name}") | |
logger.debug(f"Model details: size={model.size}, vram={model.size_vram}, params={model.details.parameter_size}") | |
except Exception as e: | |
logger.error(f"Error checking loaded models: {e}") | |
async def test_model(model_name: str) -> Dict[str, Any]: | |
"""Test a model and collect performance stats.""" | |
stats = { | |
"model_name": model_name, | |
"timestamp": datetime.now().isoformat(), | |
"success": False, | |
"error": None, | |
"performance": {}, | |
"model_info": {} | |
} | |
try: | |
logger.info(f"Starting performance test for model: {model_name}") | |
# Test with a comprehensive prompt that should generate a longer response | |
prompt = """You are a creative writing assistant. Write a short story about a futuristic city where: | |
1. The city is powered by a mysterious energy source | |
2. The inhabitants have developed unique abilities | |
3. There's a hidden conflict between different factions | |
4. The protagonist discovers a shocking truth about the city's origins | |
Make the story engaging and include vivid descriptions of the city's architecture and technology.""" | |
# First, generate a small response to ensure the model is loaded | |
await client.generate( | |
model=model_name, | |
prompt="test", | |
stream=False, | |
options={ | |
"max_tokens": 1, | |
"keep_alive": 1 # Keep the model loaded | |
} | |
) | |
# Get model info while it's loaded | |
ps_response = await client.ps() | |
if ps_response and hasattr(ps_response, 'models'): | |
model_found = False | |
for model in ps_response.models: | |
if model.name == model_name: | |
model_found = True | |
stats["model_info"] = { | |
"size": model.size, | |
"size_vram": model.size_vram, | |
"parameter_size": model.details.parameter_size, | |
"quantization_level": model.details.quantization_level, | |
"format": model.details.format, | |
"family": model.details.family | |
} | |
logger.info(f"Found model info for {model_name}: {stats['model_info']}") | |
break | |
if not model_found: | |
logger.warning(f"Model {model_name} not found in ps response. Available models: {[m.name for m in ps_response.models]}") | |
else: | |
logger.warning(f"No models found in ps response") | |
start_time = time.time() | |
# Now generate the full response | |
response = await client.generate( | |
model=model_name, | |
prompt=prompt, | |
stream=False, | |
options={ | |
"temperature": 0.7, | |
"top_p": 0.9, | |
"top_k": 40, | |
"max_tokens": 1000, | |
"repetition_penalty": 1.0, | |
"seed": 42, | |
"keep_alive": 0 # Ensure model is unloaded after generation | |
} | |
) | |
end_time = time.time() | |
# Calculate performance metrics | |
total_tokens = len(response.get("response", "").split()) | |
total_time = end_time - start_time | |
tokens_per_second = total_tokens / total_time if total_time > 0 else 0 | |
prompt_tokens = len(prompt.split()) | |
generation_tokens = total_tokens - prompt_tokens | |
# Collect detailed performance metrics | |
stats["performance"] = { | |
"response_time": total_time, | |
"total_tokens": total_tokens, | |
"tokens_per_second": tokens_per_second, | |
"prompt_tokens": prompt_tokens, | |
"generation_tokens": generation_tokens, | |
"generation_tokens_per_second": generation_tokens / total_time if total_time > 0 else 0, | |
"response": response.get("response", ""), | |
"eval_count": response.get("eval_count", 0), | |
"eval_duration": response.get("eval_duration", 0), | |
"prompt_eval_duration": response.get("prompt_eval_duration", 0), | |
"total_duration": response.get("total_duration", 0), | |
} | |
stats["success"] = True | |
logger.info(f"Successfully tested {model_name}: {tokens_per_second:.2f} tokens/second") | |
except Exception as e: | |
stats["error"] = str(e) | |
logger.error(f"Error testing model {model_name}: {e}") | |
return stats | |
async def save_results(results: List[Dict[str, Any]], timestamp: str): | |
"""Save results in multiple formats.""" | |
# Save detailed results | |
detailed_path = os.path.join(RESULTS_DIR, f"model_stats_{timestamp}.json") | |
with open(detailed_path, 'w') as f: | |
json.dump(results, f, indent=2) | |
logger.info(f"Saved detailed results to {detailed_path}") | |
# Save summary results | |
summary = [] | |
for result in results: | |
if result["success"]: | |
perf = result["performance"] | |
model_info = result["model_info"] | |
summary.append({ | |
"model_name": result["model_name"], | |
"model_size": model_info.get("size", 0), | |
"vram_size": model_info.get("size_vram", 0), | |
"parameter_size": model_info.get("parameter_size", ""), | |
"quantization": model_info.get("quantization_level", ""), | |
"tokens_per_second": perf["tokens_per_second"], | |
"generation_tokens_per_second": perf["generation_tokens_per_second"], | |
"total_tokens": perf["total_tokens"], | |
"response_time": perf["response_time"], | |
"success": result["success"] | |
}) | |
summary_path = os.path.join(RESULTS_DIR, f"model_stats_summary_{timestamp}.json") | |
with open(summary_path, 'w') as f: | |
json.dump(summary, f, indent=2) | |
logger.info(f"Saved summary results to {summary_path}") | |
# Log top performers | |
successful_results = [r for r in results if r["success"]] | |
if successful_results: | |
top_performers = sorted( | |
successful_results, | |
key=lambda x: x["performance"]["tokens_per_second"], | |
reverse=True | |
)[:5] | |
logger.info("\nTop 5 performers by tokens per second:") | |
for r in top_performers: | |
model_info = r["model_info"] | |
logger.info(f"{r['model_name']}:") | |
logger.info(f" Tokens/second: {r['performance']['tokens_per_second']:.2f}") | |
logger.info(f" VRAM Usage: {model_info.get('size_vram', 0)/1024/1024/1024:.2f} GB") | |
logger.info(f" Parameter Size: {model_info.get('parameter_size', 'N/A')}") | |
logger.info(f" Quantization: {model_info.get('quantization_level', 'N/A')}") | |
logger.info(" " + "-" * 30) | |
async def main(): | |
"""Main function to run the model testing process.""" | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
logger.info("Starting model testing process") | |
# Load approved models | |
models = await load_approved_models() | |
if not models: | |
logger.error("No models loaded. Exiting.") | |
return | |
# Check and unload any models that might be already loaded | |
await check_loaded_models() | |
# Check which models need to be pulled | |
models_to_pull = [] | |
for model in models: | |
if not await check_model_exists(model): | |
models_to_pull.append(model) | |
if models_to_pull: | |
logger.info(f"Found {len(models_to_pull)} models that need to be pulled:") | |
for model in models_to_pull: | |
logger.info(f"- {model}") | |
# Ask user if they want to pull missing models | |
while True: | |
response = input("\nDo you want to pull the missing models? (yes/no): ").lower() | |
if response in ['yes', 'no']: | |
break | |
print("Please answer 'yes' or 'no'") | |
if response == 'yes': | |
# Pull missing models with progress bar | |
logger.info("Starting model pulling phase") | |
for model in tqdm(models_to_pull, desc="Pulling models"): | |
await pull_model(model) | |
else: | |
logger.warning("Skipping model pulling. Some models may not be available for testing.") | |
# Filter out models that weren't pulled | |
models = [model for model in models if model not in models_to_pull] | |
if not models: | |
logger.error("No models available for testing. Exiting.") | |
return | |
else: | |
logger.info("All models are already pulled. Skipping pulling phase.") | |
# Test all models with progress bar | |
logger.info("Starting model testing phase") | |
results = [] | |
for model in tqdm(models, desc="Testing models"): | |
# Check for any loaded models before testing | |
await check_loaded_models() | |
stats = await test_model(model) | |
results.append(stats) | |
# Save intermediate results after each model | |
await save_results(results, timestamp) | |
# Add sleep between model tests to ensure proper cleanup | |
logger.info("Waiting 3 seconds before next model test...") | |
await asyncio.sleep(3) | |
# Save final results | |
await save_results(results, timestamp) | |
# Log summary | |
successful_tests = sum(1 for r in results if r["success"]) | |
logger.info(f"Model testing completed. {successful_tests}/{len(models)} models tested successfully") | |
if __name__ == "__main__": | |
asyncio.run(main()) |