Samee-ur's picture
Create app.py
cc950c6 verified
# app.py
import gradio as gr
import json
import itertools
import numpy as np
import matplotlib.pyplot as plt
from rcwa import Material, Layer, LayerStack, Source, Solver
import openai
import logging
import random
import os
# --- Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- API Key ---
openai.api_key = os.getenv("OPENAI_API_KEY")
# --- Constants ---
start_wl = 0.32
stop_wl = 0.80
step_wl = 0.01
wavelengths = np.arange(start_wl, stop_wl + step_wl, step_wl)
materials = ['Si', 'Si3N4', 'SiO2', 'AlN']
# --- Spectrum Simulation ---
def simulate_spectrum(layer_order, thickness_nm=100):
source = Source(wavelength=start_wl)
reflection_layer = Layer(n=1.0)
transmission_layer = Layer(material=Material("Si"))
try:
layers = [Layer(material=Material(m), thickness=thickness_nm * 1e-3) for m in layer_order]
stack = LayerStack(*layers, incident_layer=reflection_layer, transmission_layer=transmission_layer)
solver = Solver(stack, source, (1, 1))
result = solver.solve(wavelength=wavelengths)
return np.array(result['TTot']).tolist()
except Exception as e:
print(f"Simulation failed for {layer_order}: {e}")
return None
def cosine_similarity(vec1, vec2):
a, b = np.array(vec1), np.array(vec2)
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def find_best_permutation(materials, target_spectrum):
best_score, best_order = -1, None
for order in itertools.permutations(materials, 4):
spectrum = simulate_spectrum(order)
if spectrum is None:
continue
score = cosine_similarity(spectrum, target_spectrum)
if score > best_score:
best_score, best_order = score, order
return {
"best_order": list(best_order),
"cosine_score": float(best_score)
}
def run_agent_with_spectrum(target_spectrum):
tools = [{
"type": "function",
"function": {
"name": "find_best_permutation",
"description": "Find best layer order to match a transmission spectrum",
"parameters": {
"type": "object",
"properties": {
"materials": {"type": "array", "items": {"type": "string"}},
"target_spectrum": {"type": "array", "items": {"type": "number"}}
},
"required": ["materials", "target_spectrum"]
}
}
}]
messages = [
{"role": "system", "content": "You are a simulation agent that finds the best optical stack to match a target spectrum."},
{"role": "user", "content": f"Match this transmission spectrum with a 4-layer stack of Si, Si3N4, SiO2, AlN:\n{target_spectrum}"}
]
try:
response = openai.chat.completions.create(
model="gpt-4o",
messages=messages,
tools=tools,
tool_choice={"type": "function", "function": {"name": "find_best_permutation"}}
)
tool_call = response.choices[0].message.tool_calls[0]
args = json.loads(tool_call.function.arguments)
result = find_best_permutation(**args)
predicted_spectrum = simulate_spectrum(result["best_order"])
return {
"true_target": target_spectrum,
"predicted_spectrum": predicted_spectrum,
"result": result,
"tool_call": {
"function": tool_call.function.name,
"arguments": args
},
"raw_response": response.model_dump(),
"system_prompt": messages[0]["content"],
"user_prompt": messages[1]["content"]
}
except Exception as e:
return {
"true_target": target_spectrum,
"predicted_spectrum": None,
"result": find_best_permutation(materials, target_spectrum),
"tool_call": None,
"raw_response": {"error": str(e)},
"system_prompt": messages[0]["content"],
"user_prompt": messages[1]["content"]
}
def plot_spectra(wavelengths, target, predicted):
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(wavelengths, target, label="Target Spectrum", color="blue")
if predicted:
ax.plot(wavelengths, predicted, label="Predicted Spectrum", color="red", linestyle="--")
ax.set_xlabel("Wavelength (µm)")
ax.set_ylabel("Transmission")
ax.set_title("Spectrum Comparison")
ax.grid(True)
ax.legend()
return fig
with gr.Blocks(title="Optical Thin-Film Stack AI Agent") as demo:
gr.Markdown("""
# 🧠 Optical Thin-Film Stack AI Agent
This interactive demo shows an **AI agent using a physics-based simulator (RCWA)** to solve an inverse optics problem.
The AI agent calls RCWA to **recover the correct material ordering** of a thin-film stack by matching its optical transmission spectrum to a given input.
---
### 🛡️ Materials in the Stack: **Si** (high-index semiconductor), **Si₃N₄** (medium-index dielectric), **SiO₂** (low-index glass), **AlN** (wide-bandgap insulating ceramic)
---
""")
gr.Markdown("""
## 🔍 What's Happening Under the Hood
1. A **random 4-layer material stack** is generated from the materials above, where each material has a thickness of 100nm.
2. We simulate its **transmission spectrum** using **RCWA (Rigorous Coupled-Wave Analysis)** — a gold-standard method in computational optics.
3. The AI agent receives this spectrum and is asked: _\"What material order would produce this response?\"_
4. The AI agent invokes the tool `find_best_permutation(...)` — triggering a brute-force search using RCWA over all possible material orders.
5. The best match is returned, and we show both spectra and a cosine similarity score.
> 🧠 This isn't prompt-tuning. This is **agentic AI**, invoking a **verifiable physical simulator** as a tool.
""")
run_btn = gr.Button("🎲 Generate & Run")
true_order_box = gr.Textbox(label="True Layer Order For the 4 materials")
system_box = gr.Textbox(label="System Message to AI Agent", lines=2)
prompt_box = gr.Textbox(label="User Prompt to AI Agent", lines=4)
pred_order_box = gr.Textbox(label="AI Agent Predicted Layer Order")
score_box = gr.Textbox(label="Cosine Similarity")
plot_output = gr.Plot(label="Target vs Predicted Spectrum")
tool_output = gr.Textbox(label="Tool Call", lines=6)
raw_output = gr.Textbox(label="Raw GPT Response", lines=10)
def random_run():
true_order = random.sample(materials, 4)
spectrum = simulate_spectrum(true_order)
if spectrum is None:
return "Simulation failed", "", "", "", "", None, "", ""
result = run_agent_with_spectrum(spectrum)
plot = plot_spectra(wavelengths, spectrum, result["predicted_spectrum"])
return (
", ".join(true_order),
result["system_prompt"],
result["user_prompt"],
", ".join(result["result"]["best_order"]),
round(result["result"]["cosine_score"], 5),
plot,
json.dumps(result["tool_call"], indent=2),
json.dumps(result["raw_response"], indent=2)
)
run_btn.click(fn=random_run, inputs=[], outputs=[
true_order_box,
system_box,
prompt_box,
pred_order_box,
score_box,
plot_output,
tool_output,
raw_output
])
if __name__ == "__main__":
demo.launch()