| |
|
| | import os, json, numpy as np, torch, torch.nn as nn
|
| |
|
| | OBJECTIVES = ["min_cd", "max_cl", "max_ld"]
|
| |
|
| | class MLPSelector(nn.Module):
|
| | def __init__(self, in_dim:int, n_airfoils:int, obj_dim:int=3, af_embed_dim:int=8, hidden:int=128):
|
| | super().__init__()
|
| | self.af_emb = nn.Embedding(n_airfoils, af_embed_dim)
|
| | self.net = nn.Sequential(
|
| | nn.Linear(in_dim + obj_dim + af_embed_dim, hidden),
|
| | nn.ReLU(),
|
| | nn.Linear(hidden, hidden),
|
| | nn.ReLU(),
|
| | nn.Linear(hidden, 1),
|
| | )
|
| | def forward(self, x, obj_id, af_id):
|
| | B = x.size(0)
|
| | obj_oh = torch.zeros(B, 3, device=x.device)
|
| | obj_oh[torch.arange(B), obj_id] = 1.0
|
| | af_e = self.af_emb(af_id)
|
| | z = torch.cat([x, obj_oh, af_e], dim=1)
|
| | return self.net(z).squeeze(1)
|
| |
|
| | def load_selector(local_dir=".", device="cpu"):
|
| | ckpt_path = os.path.join(local_dir, "best.pt")
|
| | if not os.path.exists(ckpt_path):
|
| | ckpt_path = os.path.join(local_dir, "last.pt")
|
| | if not os.path.exists(ckpt_path):
|
| | raise FileNotFoundError("best.pt/last.pt not found in "+local_dir)
|
| |
|
| | ckpt = torch.load(ckpt_path, map_location=device)
|
| | cfg = {
|
| | "in_dim": int(ckpt["in_dim"]),
|
| | "n_airfoils": int(ckpt["n_airfoils"]),
|
| | "feat_stats": {
|
| | "means": np.array(ckpt["feat_stats"]["means"], dtype=np.float32),
|
| | "stds": np.array(ckpt["feat_stats"]["stds"], dtype=np.float32),
|
| | }
|
| | }
|
| | model = MLPSelector(cfg["in_dim"], cfg["n_airfoils"])
|
| | model.load_state_dict(ckpt["model"])
|
| | model.to(device).eval()
|
| | return model, cfg
|
| |
|
| | def standardize(X_raw: np.ndarray, means: np.ndarray, stds: np.ndarray) -> np.ndarray:
|
| | X_imp = np.where(np.isfinite(X_raw), X_raw, means)
|
| | return (X_imp - means) / np.where(stds==0, 1.0, stds)
|
| |
|
| | def score_wings(model, X_std: np.ndarray, airfoil_id: int, objective: str, device="cpu"):
|
| | obj_id = OBJECTIVES.index(objective)
|
| | X = torch.tensor(X_std, dtype=torch.float32, device=device)
|
| | obj_ids = torch.full((X.size(0),), obj_id, dtype=torch.long, device=device)
|
| | af_ids = torch.full((X.size(0),), airfoil_id, dtype=torch.long, device=device)
|
| | with torch.no_grad():
|
| | probs = torch.sigmoid(model(X, obj_ids, af_ids)).cpu().numpy()
|
| | return probs
|
| |
|