|
import os |
|
import torch |
|
import numpy as np |
|
import lightning.pytorch as pl |
|
import gradio as gr |
|
import imageio |
|
import random |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
import skdim |
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
from PIL import Image |
|
from matplotlib import cm |
|
from safetensors.torch import save_file, load_file |
|
from sklearn.cluster import AgglomerativeClustering |
|
from sklearn.manifold import TSNE |
|
from sklearn.neighbors import KDTree |
|
from sklearn.preprocessing import StandardScaler |
|
|
|
from minimal_script import EmbeddingNetwork, closest_interval, adj_size, PLModule |
|
|
|
|
|
class PredictDataset(Dataset): |
|
def __init__(self, data_dir, sample=None): |
|
self.image_paths = [] |
|
extensions = ('jpg', 'jpeg', 'png', 'tif', 'webp') |
|
for fname in sorted(os.listdir(data_dir)): |
|
if any(fname.lower().endswith(ext) for ext in extensions): |
|
self.image_paths.append(os.path.join(data_dir, fname)) |
|
if sample: |
|
self.image_paths = random.sample(self.image_paths, sample) |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
path = self.image_paths[idx] |
|
image = imageio.v3.imread(path).copy() |
|
image = torch.from_numpy(image).permute(2, 0, 1) |
|
processed = closest_interval(adj_size(image, 1024)) |
|
processed = 2*(processed/255)-1 |
|
return processed.detach(), path |
|
|
|
|
|
def explore_embedding_space(embeddings, image_paths, model): |
|
""" |
|
Create an interface for exploring N-dimensional image embeddings |
|
|
|
Args: |
|
embeddings: NumPy array of shape [B, N] |
|
image_paths: List of B image file paths |
|
""" |
|
|
|
assert len(embeddings) == len(image_paths), "Mismatch between embeddings and image paths" |
|
assert embeddings.ndim == 2, "Embeddings should be 2-dimensional" |
|
|
|
|
|
min_vals = embeddings.min(axis=0) |
|
max_vals = embeddings.max(axis=0) |
|
ranges = max_vals - min_vals |
|
|
|
|
|
tree = KDTree(embeddings) |
|
|
|
|
|
initial_point = embeddings.mean(axis=0).tolist() |
|
|
|
|
|
sliders = [] |
|
for i in range(embeddings.shape[1]): |
|
slider = gr.Slider( |
|
float(min_vals[i]), |
|
float(max_vals[i]), |
|
value=float(initial_point[i]), |
|
step=float(ranges[i]) / 100, |
|
label=f"Dimension {i + 1}" |
|
) |
|
sliders.append(slider) |
|
|
|
def compute_gradient_heatmap(image_path): |
|
"""Compute gradient heatmap for an image""" |
|
|
|
img = imageio.v3.imread(image_path).copy() |
|
img = torch.from_numpy(img).permute(2, 0, 1) |
|
img_tensor = closest_interval(adj_size(img, 1024)).unsqueeze(0) |
|
img_tensor = 2*(img_tensor/255)-1 |
|
img_tensor.requires_grad_(True) |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
img_tensor = img_tensor.to(device).to(torch.float16) |
|
|
|
|
|
with torch.enable_grad(): |
|
embd = model(img_tensor) |
|
norm = embd.norm(p=2, dim=1).sum() |
|
grad = torch.autograd.grad(norm, img_tensor, retain_graph=False)[0] |
|
|
|
|
|
grad_mag = grad.squeeze(0).norm(dim=0).detach().cpu().numpy() |
|
|
|
|
|
grad_min, grad_max = grad_mag.min(), grad_mag.max() |
|
if grad_max > grad_min: |
|
grad_norm = (grad_mag - grad_min) / (grad_max - grad_min) |
|
else: |
|
grad_norm = grad_mag * 0 |
|
|
|
heatmap = cm.jet(grad_norm)[..., :3] |
|
return heatmap |
|
|
|
def overlay_heatmap(original_img, heatmap, alpha=0.4): |
|
"""Overlay heatmap on original image""" |
|
|
|
heatmap_img = Image.fromarray((heatmap * 255).astype(np.uint8)) |
|
heatmap_img = heatmap_img.resize(original_img.size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
blended = Image.blend(original_img, heatmap_img, alpha) |
|
return blended |
|
|
|
def get_overlay_image(image_path): |
|
"""Get image with gradient overlay""" |
|
img = Image.open(image_path).convert('RGB') |
|
|
|
|
|
return img |
|
|
|
def add_caption_to_image(image, caption): |
|
"""Add text caption to the bottom of an image""" |
|
|
|
if isinstance(image, Image.Image): |
|
img = np.array(image) |
|
else: |
|
img = image.copy() |
|
|
|
|
|
bar_height = 30 |
|
img = cv2.copyMakeBorder(img, 0, bar_height, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0]) |
|
|
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
text_size = cv2.getTextSize(caption, font, 0.5, 1)[0] |
|
text_x = (img.shape[1] - text_size[0]) // 2 |
|
text_y = img.shape[0] - 10 |
|
cv2.putText(img, caption, (text_x, text_y), font, 0.5, (255, 255, 255), 1) |
|
|
|
return Image.fromarray(img) |
|
|
|
|
|
def find_nearby_images(*point): |
|
point = np.array(point).reshape(1, -1) |
|
distances, indices = tree.query(point, k=8) |
|
indices = indices[0] |
|
distances = distances[0] |
|
|
|
|
|
paths = [image_paths[i] for i in indices] |
|
images_with_gradients = [get_overlay_image(p) for p in paths] |
|
|
|
|
|
final_images = [] |
|
for img, dist in zip(images_with_gradients, distances): |
|
caption = f"Dist: {dist:.2f}" |
|
final_img = add_caption_to_image(img, caption) |
|
final_images.append(final_img) |
|
|
|
warning = "" |
|
if distances[0] > 5.0: |
|
warning = "⚠️ Nearest image is far (distance={:.2f}). Consider adjusting sliders.".format(distances[0]) |
|
|
|
return final_images, warning |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## N-Dimensional Embedding Space Explorer") |
|
gr.Markdown("Adjust sliders to navigate. Images show gradient of embedding norm w.r.t. input.") |
|
|
|
|
|
warning = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
gallery = gr.Gallery( |
|
label="Nearest Images (Distance Ordered)", |
|
columns=4, |
|
object_fit="contain", |
|
height="auto", |
|
show_label=True, |
|
) |
|
|
|
|
|
with gr.Row(): |
|
for slider in sliders: |
|
slider.render() |
|
|
|
|
|
for slider in sliders: |
|
slider.change( |
|
find_nearby_images, |
|
inputs=sliders, |
|
outputs=[gallery, warning] |
|
) |
|
|
|
|
|
demo.load( |
|
find_nearby_images, |
|
inputs=sliders, |
|
outputs=[gallery, warning] |
|
) |
|
|
|
return demo |
|
|
|
|
|
|
|
def generate_embeddings(image_folder, mode, model): |
|
predict_dataset = PredictDataset(image_folder, 5000) |
|
predict_loader = DataLoader(predict_dataset, batch_size=1, num_workers=5, pin_memory=True) |
|
trainer = pl.Trainer(accelerator="gpu", logger=False, enable_checkpointing=False, precision="16-mixed") |
|
predictions_0 = trainer.predict(model, predict_loader) |
|
predictions = torch.cat([pred[0] for pred in predictions_0], dim=0).numpy() |
|
paths = [] |
|
for pred in predictions_0: |
|
for i in pred[1]: |
|
paths.append(i) |
|
if mode == 'Grouping': |
|
|
|
|
|
|
|
|
|
estimators = [skdim.id.TwoNN(), skdim.id.CorrInt(), skdim.id.DANCo()] |
|
results = {} |
|
|
|
for est in estimators: |
|
est.fit(predictions) |
|
results[type(est).__name__] = est.dimension_ |
|
|
|
print("Intrinsic Dimension Estimates:") |
|
for name, dim in results.items(): |
|
print(f"{name}: {dim:.2f}") |
|
labels = cluster_embeddings(predictions) |
|
|
|
row_norms = np.linalg.norm(predictions, axis=1) |
|
average_norms = np.mean(np.abs(predictions), axis=0) |
|
plt.figure(figsize=(8, 5)) |
|
plt.bar(range(predictions.shape[1]), average_norms, color='skyblue') |
|
plt.xlabel('Feature Index (C)') |
|
plt.ylabel('Average Norm') |
|
plt.title(f'Average Norm for Each Feature (Column)') |
|
plt.xticks(range(predictions.shape[1])) |
|
|
|
plt.savefig('Norms.png') |
|
|
|
plt.figure(figsize=(8, 6)) |
|
tsne = TSNE(n_components=2, random_state=42) |
|
reduced_data = tsne.fit_transform(predictions) |
|
plt.scatter(reduced_data[:, 0], reduced_data[:, 1], c=row_norms, cmap='viridis', s=50, edgecolor='k', label="Data Points") |
|
plt.colorbar(label='Norm Value') |
|
plt.xlabel('Feature 1') |
|
plt.ylabel('Feature 2') |
|
plt.title(f'Scatter Plot of Data Points and Average Norm') |
|
plt.legend() |
|
plt.grid(True) |
|
plt.axis('equal') |
|
|
|
plt.savefig('Groups.png') |
|
|
|
|
|
unique_clusters = np.unique(labels) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Explore Image Clusters by Style") |
|
|
|
|
|
cluster_selector = gr.Dropdown(choices=unique_clusters.tolist(), label="Select Cluster to Explore") |
|
|
|
|
|
image_gallery = gr.Gallery(label="Sample Images from Selected Cluster") |
|
|
|
|
|
|
|
def explore_clusters(cluster_idx): |
|
|
|
cluster_images = [paths[i] for i in range(len(labels)) if labels[i] == cluster_idx] |
|
|
|
images = [Image.open(img_path) for img_path in cluster_images[:50]] |
|
return images |
|
|
|
|
|
cluster_selector.change(fn=explore_clusters, inputs=cluster_selector, outputs=image_gallery) |
|
|
|
demo.launch() |
|
elif mode == 'Explore': |
|
demo = explore_embedding_space(predictions, paths, model.to('cuda').to(torch.float16)) |
|
demo.launch() |
|
|
|
|
|
|
|
def cluster_embeddings(predictions, distance_threshold=32.0): |
|
agg_clustering = AgglomerativeClustering( |
|
n_clusters=None, |
|
distance_threshold=distance_threshold, |
|
linkage='ward' |
|
) |
|
labels = agg_clustering.fit_predict(predictions) |
|
return labels |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
folder = 'Enter Images folder name here' |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = PLModule() |
|
state_dict = load_file("Style_Embedder_v3.safetensors") |
|
model.network.load_state_dict(state_dict) |
|
|
|
generate_embeddings(folder, 'Grouping', model) |
|
|