import torch import torch.optim as optim from torch.utils.data import DataLoader from models.moe_model import MoEModel from utils.data_loader import load_data # Load data train_loader, test_loader = load_data() # Initialize model, loss function, and optimizer model = MoEModel(input_dim=512, num_experts=3) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # Training loop for epoch in range(10): model.train() for vision_input, audio_input, sensor_input, labels in train_loader: optimizer.zero_grad() outputs = model(vision_input, audio_input, sensor_input) loss = criterion(outputs, labels) loss.backward() optimizer.step() print(f"Epoch {epoch+1}, Loss: {loss.item()}") # Evaluation model.eval() correct, total = 0, 0 with torch.no_grad(): for vision_input, audio_input, sensor_input, labels in test_loader: outputs = model(vision_input, audio_input, sensor_input) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"Accuracy: {100 * correct / total}%")