Spaces:
Running
on
Zero
Running
on
Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| import os | |
| import clip | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| class AestheticPredictor: | |
| """Aesthetic Score Predictor. | |
| Checkpoints from https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main | |
| Args: | |
| clip_model_dir (str): Path to the directory of the CLIP model. | |
| sac_model_path (str): Path to the pre-trained SAC model. | |
| device (str): Device to use for computation ("cuda" or "cpu"). | |
| """ | |
| def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"): | |
| self.device = device | |
| if clip_model_dir is None: | |
| model_path = snapshot_download( | |
| repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*" | |
| ) | |
| suffix = "aesthetic" | |
| model_path = snapshot_download( | |
| repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" | |
| ) | |
| clip_model_dir = os.path.join(model_path, suffix) | |
| if sac_model_path is None: | |
| model_path = snapshot_download( | |
| repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*" | |
| ) | |
| suffix = "aesthetic" | |
| model_path = snapshot_download( | |
| repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" | |
| ) | |
| sac_model_path = os.path.join( | |
| model_path, suffix, "sac+logos+ava1-l14-linearMSE.pth" | |
| ) | |
| self.clip_model, self.preprocess = self._load_clip_model( | |
| clip_model_dir | |
| ) | |
| self.sac_model = self._load_sac_model(sac_model_path, input_size=768) | |
| class MLP(pl.LightningModule): # noqa | |
| def __init__(self, input_size): | |
| super().__init__() | |
| self.layers = nn.Sequential( | |
| nn.Linear(input_size, 1024), | |
| nn.Dropout(0.2), | |
| nn.Linear(1024, 128), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, 64), | |
| nn.Dropout(0.1), | |
| nn.Linear(64, 16), | |
| nn.Linear(16, 1), | |
| ) | |
| def forward(self, x): | |
| return self.layers(x) | |
| def normalized(a, axis=-1, order=2): | |
| """Normalize the array to unit norm.""" | |
| l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) | |
| l2[l2 == 0] = 1 | |
| return a / np.expand_dims(l2, axis) | |
| def _load_clip_model(self, model_dir: str, model_name: str = "ViT-L/14"): | |
| """Load the CLIP model.""" | |
| model, preprocess = clip.load( | |
| model_name, download_root=model_dir, device=self.device | |
| ) | |
| return model, preprocess | |
| def _load_sac_model(self, model_path, input_size): | |
| """Load the SAC model.""" | |
| model = self.MLP(input_size) | |
| ckpt = torch.load(model_path) | |
| model.load_state_dict(ckpt) | |
| model.to(self.device) | |
| model.eval() | |
| return model | |
| def predict(self, image_path): | |
| """Predict the aesthetic score for a given image. | |
| Args: | |
| image_path (str): Path to the image file. | |
| Returns: | |
| float: Predicted aesthetic score. | |
| """ | |
| pil_image = Image.open(image_path) | |
| image = self.preprocess(pil_image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| # Extract CLIP features | |
| image_features = self.clip_model.encode_image(image) | |
| # Normalize features | |
| normalized_features = self.normalized( | |
| image_features.cpu().detach().numpy() | |
| ) | |
| # Predict score | |
| prediction = self.sac_model( | |
| torch.from_numpy(normalized_features) | |
| .type(torch.FloatTensor) | |
| .to(self.device) | |
| ) | |
| return prediction.item() | |
| if __name__ == "__main__": | |
| # Configuration | |
| img_path = "apps/assets/example_image/sample_00.jpg" | |
| # Initialize the predictor | |
| predictor = AestheticPredictor() | |
| # Predict the aesthetic score | |
| score = predictor.predict(img_path) | |
| print("Aesthetic score predicted by the model:", score) | |