Spaces:
Runtime error
Runtime error
More responsive
Browse files- app.py +59 -339
- assets/styles.css +124 -0
- components.py +184 -0
- data_utils.py +261 -0
app.py
CHANGED
|
@@ -1,112 +1,73 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import re
|
| 3 |
-
|
| 4 |
import crystal_toolkit.components as ctc
|
| 5 |
import dash
|
| 6 |
import dash_mp_components as dmp
|
| 7 |
import numpy as np
|
| 8 |
-
import pandas as pd
|
| 9 |
import periodictable
|
| 10 |
from crystal_toolkit.settings import SETTINGS
|
| 11 |
from dash import dcc, html
|
| 12 |
from dash.dependencies import Input, Output, State
|
| 13 |
from dash_breakpoints import WindowBreakpoints
|
| 14 |
-
from datasets import concatenate_datasets, load_dataset
|
| 15 |
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
|
| 16 |
from pymatgen.core import Structure
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
for subset in subsets:
|
| 27 |
-
dataset = load_dataset(
|
| 28 |
-
"LeMaterial/leMat-Bulk",
|
| 29 |
-
subset,
|
| 30 |
-
token=HF_TOKEN,
|
| 31 |
-
columns=[
|
| 32 |
-
"lattice_vectors",
|
| 33 |
-
"species_at_sites",
|
| 34 |
-
"cartesian_site_positions",
|
| 35 |
-
"energy",
|
| 36 |
-
# "energy_corrected", # not yet available in LeMat-Bulk
|
| 37 |
-
"immutable_id",
|
| 38 |
-
"elements",
|
| 39 |
-
"functional",
|
| 40 |
-
"stress_tensor",
|
| 41 |
-
"magnetic_moments",
|
| 42 |
-
"forces",
|
| 43 |
-
# "band_gap_direct", #future release
|
| 44 |
-
# "band_gap_indirect", #future release
|
| 45 |
-
"dos_ef",
|
| 46 |
-
# "charges", #future release
|
| 47 |
-
"functional",
|
| 48 |
-
"chemical_formula_reduced",
|
| 49 |
-
"chemical_formula_descriptive",
|
| 50 |
-
"total_magnetization",
|
| 51 |
-
"entalpic_fingerprint"
|
| 52 |
-
],
|
| 53 |
-
)
|
| 54 |
-
datasets.append(dataset["train"])
|
| 55 |
|
| 56 |
-
|
| 57 |
"chemical_formula_descriptive",
|
| 58 |
"functional",
|
| 59 |
"immutable_id",
|
| 60 |
"energy",
|
| 61 |
]
|
| 62 |
-
|
| 63 |
"chemical_formula_descriptive": "Formula",
|
| 64 |
"functional": "Functional",
|
| 65 |
"immutable_id": "Material ID",
|
| 66 |
"energy": "Energy (eV)",
|
| 67 |
}
|
| 68 |
|
|
|
|
| 69 |
mapping_table_idx_dataset_idx = {}
|
|
|
|
| 70 |
|
| 71 |
map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
|
| 72 |
-
n_elements = len(map_periodic_table)
|
| 73 |
|
| 74 |
-
#
|
| 75 |
-
|
| 76 |
-
dataset =
|
| 77 |
-
train_df = dataset.select_columns(["chemical_formula_descriptive"]).to_pandas()
|
| 78 |
-
|
| 79 |
-
pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
|
| 80 |
-
extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
|
| 81 |
-
extracted["count"] = extracted["count"].replace("", "1").astype(int)
|
| 82 |
-
|
| 83 |
-
wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting
|
| 84 |
-
index="level_0", # original row index
|
| 85 |
-
columns="element",
|
| 86 |
-
values="count",
|
| 87 |
-
aggfunc="sum",
|
| 88 |
-
fill_value=0,
|
| 89 |
)
|
| 90 |
|
| 91 |
-
all_elements = [el.symbol for el in periodictable.elements] # full element list
|
| 92 |
-
wide_df = wide_df.reindex(columns=all_elements, fill_value=0)
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
dataset_index = wide_df.values
|
| 96 |
-
|
| 97 |
-
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
|
| 98 |
-
dataset_index = (
|
| 99 |
-
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
|
| 100 |
-
) # Normalize vectors
|
| 101 |
-
|
| 102 |
-
del train_df, extracted, wide_df
|
| 103 |
-
|
| 104 |
# Initialize the Dash app
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
server = app.server # Expose the server for deployment
|
| 107 |
|
| 108 |
# Define the app layout
|
| 109 |
-
layout = html.Div(
|
| 110 |
[
|
| 111 |
WindowBreakpoints(
|
| 112 |
id="breakpoints",
|
|
@@ -119,178 +80,26 @@ layout = html.Div(
|
|
| 119 |
),
|
| 120 |
html.Div(
|
| 121 |
[
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
),
|
| 128 |
-
],
|
| 129 |
-
id="structure-container",
|
| 130 |
-
style={
|
| 131 |
-
"width": "44%",
|
| 132 |
-
"verticalAlign": "top",
|
| 133 |
-
"boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)",
|
| 134 |
-
"borderRadius": "10px",
|
| 135 |
-
"backgroundColor": "#f9f9f9",
|
| 136 |
-
"padding": "20px",
|
| 137 |
-
"textAlign": "center",
|
| 138 |
-
"display": "flex",
|
| 139 |
-
"justifyContent": "center",
|
| 140 |
-
"alignItems": "center",
|
| 141 |
-
},
|
| 142 |
-
),
|
| 143 |
-
html.Div(
|
| 144 |
-
id="properties-container",
|
| 145 |
-
style={
|
| 146 |
-
"width": "55%",
|
| 147 |
-
"paddingLeft": "4%",
|
| 148 |
-
"verticalAlign": "top",
|
| 149 |
-
"boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)",
|
| 150 |
-
"borderRadius": "10px",
|
| 151 |
-
"backgroundColor": "#f9f9f9",
|
| 152 |
-
"padding": "20px",
|
| 153 |
-
"overflow": "auto",
|
| 154 |
-
"maxHeight": "600px",
|
| 155 |
-
"display": "flex",
|
| 156 |
-
"justifyContent": "center",
|
| 157 |
-
"wordWrap": "break-word",
|
| 158 |
-
},
|
| 159 |
-
children=[
|
| 160 |
-
html.Div(
|
| 161 |
-
"Properties will be displayed here",
|
| 162 |
-
style={"textAlign": "center"},
|
| 163 |
-
),
|
| 164 |
-
],
|
| 165 |
-
),
|
| 166 |
],
|
| 167 |
-
|
| 168 |
-
"marginTop": "20px",
|
| 169 |
-
"display": "flex",
|
| 170 |
-
"justifyContent": "space-between", # Ensure the two sections are responsive
|
| 171 |
-
"flexWrap": "wrap",
|
| 172 |
-
},
|
| 173 |
),
|
| 174 |
html.Div(
|
| 175 |
[
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
[
|
| 183 |
-
dmp.MaterialsInput(
|
| 184 |
-
allowedInputTypes=["elements", "formula"],
|
| 185 |
-
hidePeriodicTable=False,
|
| 186 |
-
periodicTableMode="toggle",
|
| 187 |
-
hideWildcardButton=True,
|
| 188 |
-
showSubmitButton=True,
|
| 189 |
-
submitButtonText="Search",
|
| 190 |
-
type="elements",
|
| 191 |
-
id="materials-input",
|
| 192 |
-
),
|
| 193 |
-
],
|
| 194 |
-
id="materials-input-container",
|
| 195 |
-
style={
|
| 196 |
-
"width": "100%",
|
| 197 |
-
},
|
| 198 |
-
),
|
| 199 |
-
],
|
| 200 |
-
style={
|
| 201 |
-
"display": "flex",
|
| 202 |
-
"justifyContent": "center",
|
| 203 |
-
"width": "100%",
|
| 204 |
-
},
|
| 205 |
-
),
|
| 206 |
-
],
|
| 207 |
-
style={
|
| 208 |
-
"width": "48%",
|
| 209 |
-
"verticalAlign": "top",
|
| 210 |
-
},
|
| 211 |
-
),
|
| 212 |
-
html.Div(
|
| 213 |
-
[
|
| 214 |
-
html.Label(
|
| 215 |
-
"Select a row to display the material's structure and properties",
|
| 216 |
-
style={"margin-bottom": "20px"},
|
| 217 |
-
),
|
| 218 |
-
# dcc.Dropdown(
|
| 219 |
-
# id="material-dropdown",
|
| 220 |
-
# options=[], # Empty options initially
|
| 221 |
-
# value=None,
|
| 222 |
-
# ),
|
| 223 |
-
dash.dash_table.DataTable(
|
| 224 |
-
id="table",
|
| 225 |
-
columns=[
|
| 226 |
-
(
|
| 227 |
-
{"name": display_names[col], "id": col}
|
| 228 |
-
if col != "energy"
|
| 229 |
-
else {
|
| 230 |
-
"name": display_names[col],
|
| 231 |
-
"id": col,
|
| 232 |
-
"type": "numeric",
|
| 233 |
-
"format": {"specifier": ".2f"},
|
| 234 |
-
}
|
| 235 |
-
)
|
| 236 |
-
for col in display_columns
|
| 237 |
-
],
|
| 238 |
-
data=[{}],
|
| 239 |
-
style_cell={
|
| 240 |
-
"fontFamily": "Arial",
|
| 241 |
-
"padding": "10px",
|
| 242 |
-
"border": "1px solid #ddd", # Subtle border for elegance
|
| 243 |
-
"textAlign": "left",
|
| 244 |
-
"fontSize": "14px",
|
| 245 |
-
},
|
| 246 |
-
style_header={
|
| 247 |
-
"backgroundColor": "#f5f5f5", # Light grey header
|
| 248 |
-
"fontWeight": "bold",
|
| 249 |
-
"textAlign": "left",
|
| 250 |
-
"borderBottom": "2px solid #ddd",
|
| 251 |
-
},
|
| 252 |
-
style_data={
|
| 253 |
-
"backgroundColor": "#ffffff",
|
| 254 |
-
"color": "#333333",
|
| 255 |
-
"borderBottom": "1px solid #ddd",
|
| 256 |
-
},
|
| 257 |
-
style_data_conditional=[
|
| 258 |
-
{
|
| 259 |
-
"if": {"state": "active"},
|
| 260 |
-
"backgroundColor": "#e6f7ff",
|
| 261 |
-
"border": "1px solid #1890ff",
|
| 262 |
-
},
|
| 263 |
-
],
|
| 264 |
-
style_table={
|
| 265 |
-
"maxHeight": "400px",
|
| 266 |
-
"overflowX": "auto",
|
| 267 |
-
"overflowY": "auto",
|
| 268 |
-
},
|
| 269 |
-
style_as_list_view=True,
|
| 270 |
-
row_selectable="single",
|
| 271 |
-
selected_rows=[],
|
| 272 |
-
),
|
| 273 |
-
],
|
| 274 |
-
style={
|
| 275 |
-
"width": "48%",
|
| 276 |
-
# "maxWidth": "800px",
|
| 277 |
-
"margin": "0 auto",
|
| 278 |
-
"padding": "20px",
|
| 279 |
-
"backgroundColor": "#ffffff",
|
| 280 |
-
"borderRadius": "10px",
|
| 281 |
-
"boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)",
|
| 282 |
-
},
|
| 283 |
),
|
| 284 |
],
|
| 285 |
-
|
| 286 |
-
"margin-top": "20px",
|
| 287 |
-
"margin-bottom": "20px",
|
| 288 |
-
"display": "flex",
|
| 289 |
-
"flexDirection": "row",
|
| 290 |
-
"alignItems": "center",
|
| 291 |
-
},
|
| 292 |
),
|
| 293 |
-
# acknowledgements to mp dash components and crystal toolkit
|
| 294 |
html.Footer(
|
| 295 |
[
|
| 296 |
html.P(
|
|
@@ -308,16 +117,6 @@ layout = html.Div(
|
|
| 308 |
style={"textAlign": "center"},
|
| 309 |
)
|
| 310 |
],
|
| 311 |
-
style={
|
| 312 |
-
"display": "flex",
|
| 313 |
-
"justifyContent": "center",
|
| 314 |
-
"alignItems": "center",
|
| 315 |
-
"flexWrap": "wrap",
|
| 316 |
-
"padding": "1rem 0",
|
| 317 |
-
"backgroundColor": "#f1f1f1", # Optional: light gray footer background
|
| 318 |
-
"borderTop": "1px solid #ddd", # Optional: subtle border at the top
|
| 319 |
-
"width": "100%",
|
| 320 |
-
},
|
| 321 |
),
|
| 322 |
],
|
| 323 |
style={
|
|
@@ -327,34 +126,6 @@ layout = html.Div(
|
|
| 327 |
)
|
| 328 |
|
| 329 |
|
| 330 |
-
def search_materials(query):
|
| 331 |
-
query_vector = np.zeros(n_elements)
|
| 332 |
-
|
| 333 |
-
if "," in query:
|
| 334 |
-
element_list = [el.strip() for el in query.split(",")]
|
| 335 |
-
for el in element_list:
|
| 336 |
-
query_vector[map_periodic_table[el]] = 1
|
| 337 |
-
else:
|
| 338 |
-
# Formula
|
| 339 |
-
import re
|
| 340 |
-
|
| 341 |
-
matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query)
|
| 342 |
-
for el, numb in matches:
|
| 343 |
-
numb = int(numb) if numb else 1
|
| 344 |
-
query_vector[map_periodic_table[el]] = numb
|
| 345 |
-
|
| 346 |
-
similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector))
|
| 347 |
-
indices = np.argsort(similarity)[::-1][:top_k]
|
| 348 |
-
|
| 349 |
-
options = [dataset[int(i)] for i in indices]
|
| 350 |
-
|
| 351 |
-
mapping_table_idx_dataset_idx.clear()
|
| 352 |
-
for i, idx in enumerate(indices):
|
| 353 |
-
mapping_table_idx_dataset_idx[int(i)] = int(idx)
|
| 354 |
-
|
| 355 |
-
return options
|
| 356 |
-
|
| 357 |
-
|
| 358 |
# Callback to update the table based on search
|
| 359 |
@app.callback(
|
| 360 |
Output("table", "data"),
|
|
@@ -365,9 +136,11 @@ def on_submit_materials_input(n_clicks, query):
|
|
| 365 |
if n_clicks is None or not query:
|
| 366 |
return []
|
| 367 |
|
| 368 |
-
entries = search_materials(
|
|
|
|
|
|
|
| 369 |
|
| 370 |
-
return [{col: entry[col] for col in
|
| 371 |
|
| 372 |
|
| 373 |
# Callback to display the selected material
|
|
@@ -376,7 +149,6 @@ def on_submit_materials_input(n_clicks, query):
|
|
| 376 |
Output("structure-container", "children"),
|
| 377 |
Output("properties-container", "children"),
|
| 378 |
],
|
| 379 |
-
# Input("display-button", "n_clicks"),
|
| 380 |
Input("table", "active_cell"),
|
| 381 |
Input("table", "derived_virtual_selected_rows"),
|
| 382 |
)
|
|
@@ -408,69 +180,17 @@ def display_material(active_cell, selected_rows):
|
|
| 408 |
if row["magnetic_moments"]:
|
| 409 |
structure.add_site_property("magmom", row["magnetic_moments"])
|
| 410 |
|
| 411 |
-
sga =
|
| 412 |
-
|
| 413 |
-
# Create the StructureMoleculeComponent
|
| 414 |
-
structure_component = ctc.StructureMoleculeComponent(structure)
|
| 415 |
|
| 416 |
# Extract key properties
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
"Formula": row["chemical_formula_descriptive"],
|
| 420 |
-
"Energy per atom (eV/atom)": round(
|
| 421 |
-
row["energy"] / len(row["species_at_sites"]), 3
|
| 422 |
-
),
|
| 423 |
-
# "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release
|
| 424 |
-
"Total Magnetization (μB)": round(row["total_magnetization"], 3) if row['total_magnetization'] is not None else None,
|
| 425 |
-
"Density (g/cm^3)": round(structure.density, 3),
|
| 426 |
-
"Fermi energy level (eV)": round(row["dos_ef"],3) if row['dos_ef'] is not None else None,
|
| 427 |
-
"Crystal system": sga.get_crystal_system(),
|
| 428 |
-
"International Spacegroup": sga.get_symmetry_dataset().international,
|
| 429 |
-
"Magnetic moments (μB)": np.round(row["magnetic_moments"], 3),
|
| 430 |
-
"Stress tensor (kB)": np.round(row["stress_tensor"], 3),
|
| 431 |
-
"Forces on atoms (eV/A)": np.round(row["forces"], 3),
|
| 432 |
-
# "Bader charges (e-)": np.round(row["charges"], 3), # future release
|
| 433 |
-
"DFT Functional": row["functional"],
|
| 434 |
-
"Entalpic fingerprint": row['entalpic_fingerprint'],
|
| 435 |
-
}
|
| 436 |
-
|
| 437 |
-
# Format properties as an HTML table
|
| 438 |
-
properties_html = html.Table(
|
| 439 |
-
[
|
| 440 |
-
html.Tbody(
|
| 441 |
-
[
|
| 442 |
-
html.Tr(
|
| 443 |
-
[
|
| 444 |
-
html.Th(
|
| 445 |
-
key,
|
| 446 |
-
style={
|
| 447 |
-
"padding": "10px",
|
| 448 |
-
"verticalAlign": "middle",
|
| 449 |
-
},
|
| 450 |
-
),
|
| 451 |
-
html.Td(
|
| 452 |
-
str(value),
|
| 453 |
-
style={
|
| 454 |
-
"padding": "10px",
|
| 455 |
-
"borderBottom": "1px solid #ddd",
|
| 456 |
-
},
|
| 457 |
-
),
|
| 458 |
-
],
|
| 459 |
-
)
|
| 460 |
-
for key, value in properties.items()
|
| 461 |
-
],
|
| 462 |
-
)
|
| 463 |
-
],
|
| 464 |
-
style={
|
| 465 |
-
"width": "100%",
|
| 466 |
-
"borderCollapse": "collapse",
|
| 467 |
-
"fontFamily": "'Arial', sans-serif",
|
| 468 |
-
"fontSize": "14px",
|
| 469 |
-
"color": "#333333",
|
| 470 |
-
},
|
| 471 |
)
|
| 472 |
|
| 473 |
-
return
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
|
| 476 |
@app.callback(
|
|
@@ -505,7 +225,7 @@ def update_materials_input_layout(breakpoint_name, width):
|
|
| 505 |
|
| 506 |
|
| 507 |
# Register crystal toolkit with the app
|
| 508 |
-
ctc.register_crystal_toolkit(app, layout)
|
| 509 |
|
| 510 |
if __name__ == "__main__":
|
| 511 |
app.run_server(debug=True, port=7860, host="0.0.0.0")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import crystal_toolkit.components as ctc
|
| 2 |
import dash
|
| 3 |
import dash_mp_components as dmp
|
| 4 |
import numpy as np
|
|
|
|
| 5 |
import periodictable
|
| 6 |
from crystal_toolkit.settings import SETTINGS
|
| 7 |
from dash import dcc, html
|
| 8 |
from dash.dependencies import Input, Output, State
|
| 9 |
from dash_breakpoints import WindowBreakpoints
|
|
|
|
| 10 |
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
|
| 11 |
from pymatgen.core import Structure
|
| 12 |
|
| 13 |
+
from components import (
|
| 14 |
+
get_display_table,
|
| 15 |
+
get_dropdown,
|
| 16 |
+
get_materials_display,
|
| 17 |
+
get_periodic_table,
|
| 18 |
+
get_upload_div,
|
| 19 |
+
)
|
| 20 |
+
from data_utils import (
|
| 21 |
+
build_embeddings_index,
|
| 22 |
+
build_formula_index,
|
| 23 |
+
get_crystal_plot,
|
| 24 |
+
get_dataset,
|
| 25 |
+
get_properties_table,
|
| 26 |
+
search_materials,
|
| 27 |
+
)
|
| 28 |
|
| 29 |
+
EMPTY_DATA = False
|
| 30 |
+
CACHE_PATH = None
|
| 31 |
|
| 32 |
+
dataset = get_dataset()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
display_columns_query = [
|
| 35 |
"chemical_formula_descriptive",
|
| 36 |
"functional",
|
| 37 |
"immutable_id",
|
| 38 |
"energy",
|
| 39 |
]
|
| 40 |
+
display_names_query = {
|
| 41 |
"chemical_formula_descriptive": "Formula",
|
| 42 |
"functional": "Functional",
|
| 43 |
"immutable_id": "Material ID",
|
| 44 |
"energy": "Energy (eV)",
|
| 45 |
}
|
| 46 |
|
| 47 |
+
|
| 48 |
mapping_table_idx_dataset_idx = {}
|
| 49 |
+
available_similar_materials = []
|
| 50 |
|
| 51 |
map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
|
|
|
|
| 52 |
|
| 53 |
+
# dataset_index, immutable_id_to_idx = build_formula_index(dataset, cache_path=None)
|
| 54 |
+
dataset_index, immutable_id_to_idx = build_formula_index(
|
| 55 |
+
dataset, cache_path=CACHE_PATH, empty_data=EMPTY_DATA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
)
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
# Initialize the Dash app
|
| 59 |
+
external_stylesheets = [
|
| 60 |
+
"/assets/styles.css",
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
app = dash.Dash(
|
| 64 |
+
__name__,
|
| 65 |
+
external_stylesheets=external_stylesheets,
|
| 66 |
+
)
|
| 67 |
server = app.server # Expose the server for deployment
|
| 68 |
|
| 69 |
# Define the app layout
|
| 70 |
+
app.layout = html.Div(
|
| 71 |
[
|
| 72 |
WindowBreakpoints(
|
| 73 |
id="breakpoints",
|
|
|
|
| 80 |
),
|
| 81 |
html.Div(
|
| 82 |
[
|
| 83 |
+
get_materials_display(
|
| 84 |
+
"",
|
| 85 |
+
"Structure will be displayed here",
|
| 86 |
+
"Properties will be displayed here",
|
| 87 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
],
|
| 89 |
+
className="container-row",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
),
|
| 91 |
html.Div(
|
| 92 |
[
|
| 93 |
+
get_periodic_table("materials-input", {}),
|
| 94 |
+
get_display_table(
|
| 95 |
+
"table",
|
| 96 |
+
display_names_query,
|
| 97 |
+
display_columns_query,
|
| 98 |
+
"Select a row to display the material's structure and properties",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
),
|
| 100 |
],
|
| 101 |
+
className="container-row-periodic",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
),
|
|
|
|
| 103 |
html.Footer(
|
| 104 |
[
|
| 105 |
html.P(
|
|
|
|
| 117 |
style={"textAlign": "center"},
|
| 118 |
)
|
| 119 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
),
|
| 121 |
],
|
| 122 |
style={
|
|
|
|
| 126 |
)
|
| 127 |
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
# Callback to update the table based on search
|
| 130 |
@app.callback(
|
| 131 |
Output("table", "data"),
|
|
|
|
| 136 |
if n_clicks is None or not query:
|
| 137 |
return []
|
| 138 |
|
| 139 |
+
entries = search_materials(
|
| 140 |
+
query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table
|
| 141 |
+
)
|
| 142 |
|
| 143 |
+
return [{col: entry[col] for col in display_columns_query} for entry in entries]
|
| 144 |
|
| 145 |
|
| 146 |
# Callback to display the selected material
|
|
|
|
| 149 |
Output("structure-container", "children"),
|
| 150 |
Output("properties-container", "children"),
|
| 151 |
],
|
|
|
|
| 152 |
Input("table", "active_cell"),
|
| 153 |
Input("table", "derived_virtual_selected_rows"),
|
| 154 |
)
|
|
|
|
| 180 |
if row["magnetic_moments"]:
|
| 181 |
structure.add_site_property("magmom", row["magnetic_moments"])
|
| 182 |
|
| 183 |
+
structure_layout, sga = get_crystal_plot(structure)
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
# Extract key properties
|
| 186 |
+
properties_html = get_properties_table(
|
| 187 |
+
row, structure, sga, [None, None], container_type="results"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
)
|
| 189 |
|
| 190 |
+
return (
|
| 191 |
+
structure_layout,
|
| 192 |
+
properties_html,
|
| 193 |
+
)
|
| 194 |
|
| 195 |
|
| 196 |
@app.callback(
|
|
|
|
| 225 |
|
| 226 |
|
| 227 |
# Register crystal toolkit with the app
|
| 228 |
+
ctc.register_crystal_toolkit(app, app.layout)
|
| 229 |
|
| 230 |
if __name__ == "__main__":
|
| 231 |
app.run_server(debug=True, port=7860, host="0.0.0.0")
|
assets/styles.css
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
h1 {
|
| 2 |
+
font-size: 24px;
|
| 3 |
+
font-weight: 700;
|
| 4 |
+
color: #333;
|
| 5 |
+
}
|
| 6 |
+
|
| 7 |
+
.body {
|
| 8 |
+
background-color: #4a4a4a;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
.header-container {
|
| 12 |
+
display: flex;
|
| 13 |
+
flex-direction: row;
|
| 14 |
+
justify-content: space-between;
|
| 15 |
+
align-items: center;
|
| 16 |
+
padding: 20px;
|
| 17 |
+
margin-bottom: 20px;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
.container {
|
| 21 |
+
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);
|
| 22 |
+
border-radius: 10px;
|
| 23 |
+
background-color: rgb(249, 249, 249);
|
| 24 |
+
padding: 20px;
|
| 25 |
+
margin-left: 10px;
|
| 26 |
+
margin-right: 10px;
|
| 27 |
+
max-height: 600px;
|
| 28 |
+
justify-content: center;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
.container-visu {
|
| 32 |
+
width: 45%;
|
| 33 |
+
align-items: center;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
.container-table {
|
| 37 |
+
width: 50%;
|
| 38 |
+
align-items: center;
|
| 39 |
+
overflow: auto;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
/* remove background in periodical table */
|
| 43 |
+
.periodic-table {
|
| 44 |
+
background-color: transparent;
|
| 45 |
+
box-shadow: none;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
.container-row {
|
| 49 |
+
width: 100%;
|
| 50 |
+
display: flex;
|
| 51 |
+
flex-direction: row;
|
| 52 |
+
justify-content: space-between;
|
| 53 |
+
padding: 10px;
|
| 54 |
+
margin-bottom: 10px;
|
| 55 |
+
margin-top: 10px;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
.container-row-periodic {
|
| 59 |
+
width: 100%;
|
| 60 |
+
display: flex;
|
| 61 |
+
flex-direction: row;
|
| 62 |
+
align-items: center;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
.container-col {
|
| 66 |
+
width: 100%;
|
| 67 |
+
display: flex;
|
| 68 |
+
flex-direction: column;
|
| 69 |
+
justify-content: space-between;
|
| 70 |
+
padding: 10px;
|
| 71 |
+
margin-bottom: 10px;
|
| 72 |
+
margin-top: 10px;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
body {
|
| 76 |
+
font-family: "Arial", sans-serif;
|
| 77 |
+
font-size: 16px;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
@media (max-width: 800px) {
|
| 81 |
+
.container {
|
| 82 |
+
width: 100%;
|
| 83 |
+
margin: 5px;
|
| 84 |
+
margin-top: 10px;
|
| 85 |
+
margin-bottom: 10px;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
.container-row {
|
| 89 |
+
flex-direction: column;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
.container-row-periodic {
|
| 93 |
+
flex-direction: column;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
.container-visu {
|
| 97 |
+
width: 100%;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
.container-table {
|
| 101 |
+
width: 100%;
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
@media (max-width: 1000px) and (min-width: 800px) {
|
| 106 |
+
.container-visu {
|
| 107 |
+
width: 60%;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
.container-table {
|
| 111 |
+
width: 39%;
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
footer {
|
| 116 |
+
display: flex;
|
| 117 |
+
justify-content: center;
|
| 118 |
+
align-items: center;
|
| 119 |
+
flex-wrap: wrap;
|
| 120 |
+
margin-top: 40px;
|
| 121 |
+
background-color: #ffffff;
|
| 122 |
+
border-top: "1px solid #ddd";
|
| 123 |
+
width: 100%;
|
| 124 |
+
}
|
components.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dash
|
| 2 |
+
import dash_mp_components as dmp
|
| 3 |
+
from dash import dcc, html
|
| 4 |
+
|
| 5 |
+
display_columns = [
|
| 6 |
+
"chemical_formula_descriptive",
|
| 7 |
+
"functional",
|
| 8 |
+
"immutable_id",
|
| 9 |
+
"energy",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
display_names = {
|
| 13 |
+
"chemical_formula_descriptive": "Formula",
|
| 14 |
+
"functional": "Functional",
|
| 15 |
+
"immutable_id": "Material ID",
|
| 16 |
+
"energy": "Energy (eV)",
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_periodic_table(id, table_kwargs, **style_kwargs):
|
| 21 |
+
|
| 22 |
+
return html.Div(
|
| 23 |
+
[
|
| 24 |
+
html.H3("Search Materials (eg. 'Ac,Cd,Ge' or 'Ac2CdGe3')"),
|
| 25 |
+
html.Div(
|
| 26 |
+
[
|
| 27 |
+
dmp.MaterialsInput(
|
| 28 |
+
allowedInputTypes=[
|
| 29 |
+
"elements",
|
| 30 |
+
"formula",
|
| 31 |
+
],
|
| 32 |
+
hidePeriodicTable=False,
|
| 33 |
+
periodicTableMode="toggle",
|
| 34 |
+
hideWildcardButton=True,
|
| 35 |
+
showSubmitButton=True,
|
| 36 |
+
submitButtonText="Search",
|
| 37 |
+
type="elements",
|
| 38 |
+
**table_kwargs,
|
| 39 |
+
id=id,
|
| 40 |
+
),
|
| 41 |
+
],
|
| 42 |
+
id="materials-input-container",
|
| 43 |
+
style={
|
| 44 |
+
"width": "100%",
|
| 45 |
+
},
|
| 46 |
+
),
|
| 47 |
+
],
|
| 48 |
+
className="container periodic-table",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_dropdown(id, options, **style_kwargs):
|
| 53 |
+
return dcc.Dropdown(
|
| 54 |
+
id=id,
|
| 55 |
+
options=options,
|
| 56 |
+
placeholder="Embedder",
|
| 57 |
+
value=None,
|
| 58 |
+
clearable=False,
|
| 59 |
+
style=style_kwargs,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_upload_div(id, **style_kwargs):
|
| 64 |
+
return html.Div(
|
| 65 |
+
[
|
| 66 |
+
html.H3("Upload a CIF file"),
|
| 67 |
+
dcc.Upload(
|
| 68 |
+
id=id,
|
| 69 |
+
children=html.Div(
|
| 70 |
+
[
|
| 71 |
+
"Drag and Drop or ",
|
| 72 |
+
html.A("Select a CIF file"),
|
| 73 |
+
]
|
| 74 |
+
),
|
| 75 |
+
style={
|
| 76 |
+
"width": "100%",
|
| 77 |
+
"height": "60px",
|
| 78 |
+
"lineHeight": "60px",
|
| 79 |
+
"borderWidth": "1px",
|
| 80 |
+
"borderStyle": "dashed",
|
| 81 |
+
"borderRadius": "5px",
|
| 82 |
+
"textAlign": "center",
|
| 83 |
+
"margin": "10px",
|
| 84 |
+
},
|
| 85 |
+
multiple=False,
|
| 86 |
+
),
|
| 87 |
+
],
|
| 88 |
+
className="container",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_display_table(id, display_names, display_columns, text, **style_kwargs):
|
| 93 |
+
|
| 94 |
+
return html.Div(
|
| 95 |
+
[
|
| 96 |
+
html.Label(
|
| 97 |
+
text,
|
| 98 |
+
style={"margin-bottom": "20px"},
|
| 99 |
+
),
|
| 100 |
+
dash.dash_table.DataTable(
|
| 101 |
+
id=id,
|
| 102 |
+
columns=[
|
| 103 |
+
(
|
| 104 |
+
{
|
| 105 |
+
"name": display_names[col],
|
| 106 |
+
"id": col,
|
| 107 |
+
}
|
| 108 |
+
if col != "energy"
|
| 109 |
+
else {
|
| 110 |
+
"name": display_names[col],
|
| 111 |
+
"id": col,
|
| 112 |
+
"type": "numeric",
|
| 113 |
+
"format": {"specifier": ".2f"},
|
| 114 |
+
}
|
| 115 |
+
)
|
| 116 |
+
for col in display_columns
|
| 117 |
+
],
|
| 118 |
+
data=[{}],
|
| 119 |
+
style_cell={
|
| 120 |
+
"fontFamily": "Arial",
|
| 121 |
+
"padding": "10px",
|
| 122 |
+
"border": "1px solid #ddd", # Subtle border for elegance
|
| 123 |
+
"textAlign": "left",
|
| 124 |
+
"fontSize": "14px",
|
| 125 |
+
},
|
| 126 |
+
style_header={
|
| 127 |
+
"backgroundColor": "#f5f5f5", # Light grey header
|
| 128 |
+
"fontWeight": "bold",
|
| 129 |
+
"textAlign": "left",
|
| 130 |
+
"borderBottom": "2px solid #ddd",
|
| 131 |
+
},
|
| 132 |
+
style_data={
|
| 133 |
+
"backgroundColor": "#ffffff",
|
| 134 |
+
"color": "#333333",
|
| 135 |
+
"borderBottom": "1px solid #ddd",
|
| 136 |
+
},
|
| 137 |
+
style_data_conditional=[
|
| 138 |
+
{
|
| 139 |
+
"if": {"state": "active"},
|
| 140 |
+
"backgroundColor": "#e6f7ff",
|
| 141 |
+
"border": "1px solid #1890ff",
|
| 142 |
+
},
|
| 143 |
+
],
|
| 144 |
+
style_table={
|
| 145 |
+
"maxHeight": "400px",
|
| 146 |
+
"overflowX": "auto",
|
| 147 |
+
"overflowY": "auto",
|
| 148 |
+
},
|
| 149 |
+
style_as_list_view=True,
|
| 150 |
+
row_selectable="single",
|
| 151 |
+
selected_rows=[],
|
| 152 |
+
),
|
| 153 |
+
],
|
| 154 |
+
className="container",
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_materials_display(id, text_materials_div, text_table_div, **style_kwargs):
|
| 159 |
+
return html.Div(
|
| 160 |
+
[
|
| 161 |
+
html.Div(
|
| 162 |
+
[
|
| 163 |
+
html.Div(
|
| 164 |
+
text_materials_div,
|
| 165 |
+
style={"textAlign": "center"},
|
| 166 |
+
),
|
| 167 |
+
],
|
| 168 |
+
id=f"structure-container{id}",
|
| 169 |
+
className="container container-visu",
|
| 170 |
+
),
|
| 171 |
+
html.Div(
|
| 172 |
+
id=f"properties-container{id}",
|
| 173 |
+
className="container container-table",
|
| 174 |
+
style={"width": "100%"},
|
| 175 |
+
children=[
|
| 176 |
+
html.Div(
|
| 177 |
+
text_table_div,
|
| 178 |
+
style={"textAlign": "center"},
|
| 179 |
+
),
|
| 180 |
+
],
|
| 181 |
+
),
|
| 182 |
+
],
|
| 183 |
+
className="container-row",
|
| 184 |
+
)
|
data_utils.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import crystal_toolkit.components as ctc
|
| 5 |
+
import numpy as np
|
| 6 |
+
import periodictable
|
| 7 |
+
from dash import dcc, html
|
| 8 |
+
from datasets import concatenate_datasets, load_dataset
|
| 9 |
+
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
|
| 10 |
+
|
| 11 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 12 |
+
top_k = 500
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_dataset():
|
| 16 |
+
# Load only the train split of the dataset
|
| 17 |
+
datasets = []
|
| 18 |
+
subsets = [
|
| 19 |
+
"compatible_pbe",
|
| 20 |
+
"compatible_pbesol",
|
| 21 |
+
"compatible_scan",
|
| 22 |
+
"non_compatible",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
for subset in subsets:
|
| 26 |
+
dataset = load_dataset(
|
| 27 |
+
"LeMaterial/leMat-Bulk",
|
| 28 |
+
subset,
|
| 29 |
+
token=HF_TOKEN,
|
| 30 |
+
columns=[
|
| 31 |
+
"lattice_vectors",
|
| 32 |
+
"species_at_sites",
|
| 33 |
+
"cartesian_site_positions",
|
| 34 |
+
"energy",
|
| 35 |
+
# "energy_corrected", # not yet available in LeMat-Bulk
|
| 36 |
+
"immutable_id",
|
| 37 |
+
"elements",
|
| 38 |
+
"functional",
|
| 39 |
+
"stress_tensor",
|
| 40 |
+
"magnetic_moments",
|
| 41 |
+
"forces",
|
| 42 |
+
# "band_gap_direct", #future release
|
| 43 |
+
# "band_gap_indirect", #future release
|
| 44 |
+
"dos_ef",
|
| 45 |
+
# "charges", #future release
|
| 46 |
+
"functional",
|
| 47 |
+
"chemical_formula_reduced",
|
| 48 |
+
"chemical_formula_descriptive",
|
| 49 |
+
"total_magnetization",
|
| 50 |
+
"entalpic_fingerprint",
|
| 51 |
+
],
|
| 52 |
+
)
|
| 53 |
+
datasets.append(dataset["train"])
|
| 54 |
+
|
| 55 |
+
return concatenate_datasets(datasets)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
display_columns = [
|
| 59 |
+
"chemical_formula_descriptive",
|
| 60 |
+
"functional",
|
| 61 |
+
"immutable_id",
|
| 62 |
+
"energy",
|
| 63 |
+
]
|
| 64 |
+
display_names = {
|
| 65 |
+
"chemical_formula_descriptive": "Formula",
|
| 66 |
+
"functional": "Functional",
|
| 67 |
+
"immutable_id": "Material ID",
|
| 68 |
+
"energy": "Energy (eV)",
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Global shared variables
|
| 72 |
+
mapping_table_idx_dataset_idx = {}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False):
|
| 76 |
+
if empty_data:
|
| 77 |
+
return np.zeros((1, 1)), {}
|
| 78 |
+
|
| 79 |
+
use_dataset = dataset
|
| 80 |
+
if index_range is not None:
|
| 81 |
+
use_dataset = dataset.select(index_range)
|
| 82 |
+
|
| 83 |
+
# Preprocessing step to create an index for the dataset
|
| 84 |
+
if cache_path is not None:
|
| 85 |
+
train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb"))
|
| 86 |
+
|
| 87 |
+
dataset_index = pickle.load(open(f"{cache_path}/dataset_index.pkl", "rb"))
|
| 88 |
+
else:
|
| 89 |
+
train_df = use_dataset.select_columns(
|
| 90 |
+
["chemical_formula_descriptive", "immutable_id"]
|
| 91 |
+
).to_pandas()
|
| 92 |
+
|
| 93 |
+
pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
|
| 94 |
+
extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
|
| 95 |
+
extracted["count"] = extracted["count"].replace("", "1").astype(int)
|
| 96 |
+
|
| 97 |
+
wide_df = (
|
| 98 |
+
extracted.reset_index().pivot_table( # Move index to columns for pivoting
|
| 99 |
+
index="level_0", # original row index
|
| 100 |
+
columns="element",
|
| 101 |
+
values="count",
|
| 102 |
+
aggfunc="sum",
|
| 103 |
+
fill_value=0,
|
| 104 |
+
)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
all_elements = [el.symbol for el in periodictable.elements] # full element list
|
| 108 |
+
wide_df = wide_df.reindex(columns=all_elements, fill_value=0)
|
| 109 |
+
|
| 110 |
+
dataset_index = wide_df.values
|
| 111 |
+
|
| 112 |
+
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
|
| 113 |
+
dataset_index = (
|
| 114 |
+
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
|
| 115 |
+
) # Normalize vectors
|
| 116 |
+
|
| 117 |
+
immutable_id_to_idx = train_df["immutable_id"].to_dict()
|
| 118 |
+
immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()}
|
| 119 |
+
|
| 120 |
+
return dataset_index, immutable_id_to_idx
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
import pickle
|
| 124 |
+
from pathlib import Path
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# TODO: Just load the index from a file
|
| 128 |
+
def build_embeddings_index(empty_data=False):
|
| 129 |
+
if empty_data:
|
| 130 |
+
return None, {}, {}
|
| 131 |
+
|
| 132 |
+
features_dict = pickle.load(open("features_dict.pkl", "rb"))
|
| 133 |
+
|
| 134 |
+
from indexer import FAISSIndex
|
| 135 |
+
|
| 136 |
+
index = FAISSIndex()
|
| 137 |
+
for key in features_dict:
|
| 138 |
+
index.index.add(features_dict[key].reshape(1, -1))
|
| 139 |
+
|
| 140 |
+
idx_to_immutable_id = {i: key for i, key in enumerate(features_dict)}
|
| 141 |
+
|
| 142 |
+
# index = FAISSIndex.from_store("index.faiss")
|
| 143 |
+
|
| 144 |
+
return index, features_dict, idx_to_immutable_id
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def search_materials(
|
| 148 |
+
query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table
|
| 149 |
+
):
|
| 150 |
+
n_elements = len(map_periodic_table)
|
| 151 |
+
query_vector = np.zeros(n_elements)
|
| 152 |
+
|
| 153 |
+
if "," in query:
|
| 154 |
+
element_list = [el.strip() for el in query.split(",")]
|
| 155 |
+
for el in element_list:
|
| 156 |
+
query_vector[map_periodic_table[el]] = 1
|
| 157 |
+
else:
|
| 158 |
+
# Formula
|
| 159 |
+
import re
|
| 160 |
+
|
| 161 |
+
matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query)
|
| 162 |
+
for el, numb in matches:
|
| 163 |
+
numb = int(numb) if numb else 1
|
| 164 |
+
query_vector[map_periodic_table[el]] = numb
|
| 165 |
+
|
| 166 |
+
similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector))
|
| 167 |
+
indices = np.argsort(similarity)[::-1][:top_k]
|
| 168 |
+
|
| 169 |
+
options = [dataset[int(i)] for i in indices]
|
| 170 |
+
|
| 171 |
+
mapping_table_idx_dataset_idx.clear()
|
| 172 |
+
for i, idx in enumerate(indices):
|
| 173 |
+
mapping_table_idx_dataset_idx[int(i)] = int(idx)
|
| 174 |
+
|
| 175 |
+
return options
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def get_properties_table(
|
| 179 |
+
row, structure, sga, properties_container_update, container_type="query"
|
| 180 |
+
):
|
| 181 |
+
properties = {
|
| 182 |
+
"Material ID": row["immutable_id"],
|
| 183 |
+
"Formula": row["chemical_formula_descriptive"],
|
| 184 |
+
"Energy per atom (eV/atom)": round(
|
| 185 |
+
row["energy"] / len(row["species_at_sites"]), 3
|
| 186 |
+
),
|
| 187 |
+
# "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release
|
| 188 |
+
"Total Magnetization (μB)": (
|
| 189 |
+
round(row["total_magnetization"], 3)
|
| 190 |
+
if row["total_magnetization"] is not None
|
| 191 |
+
else None
|
| 192 |
+
),
|
| 193 |
+
"Density (g/cm^3)": round(structure.density, 3),
|
| 194 |
+
"Fermi energy level (eV)": (
|
| 195 |
+
round(row["dos_ef"], 3) if row["dos_ef"] is not None else None
|
| 196 |
+
),
|
| 197 |
+
"Crystal system": sga.get_crystal_system(),
|
| 198 |
+
"International Spacegroup": sga.get_symmetry_dataset().international,
|
| 199 |
+
"Magnetic moments (μB)": np.round(row["magnetic_moments"], 3),
|
| 200 |
+
"Stress tensor (kB)": np.round(row["stress_tensor"], 3),
|
| 201 |
+
"Forces on atoms (eV/A)": np.round(row["forces"], 3),
|
| 202 |
+
# "Bader charges (e-)": np.round(row["charges"], 3), # future release
|
| 203 |
+
"DFT Functional": row["functional"],
|
| 204 |
+
"Entalpic fingerprint": row["entalpic_fingerprint"],
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
style = {
|
| 208 |
+
"padding": "10px",
|
| 209 |
+
"borderBottom": "1px solid #ddd",
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
if container_type == "query":
|
| 213 |
+
properties_container_update[0] = properties
|
| 214 |
+
else:
|
| 215 |
+
properties_container_update[1] = properties
|
| 216 |
+
# if (type(value) in [str, float]) and (
|
| 217 |
+
# properties_container_update[0][key] == properties_container_update[1][key]
|
| 218 |
+
# ):
|
| 219 |
+
# style["backgroundColor"] = "#e6f7ff"
|
| 220 |
+
|
| 221 |
+
# Format properties as an HTML table
|
| 222 |
+
properties_html = html.Table(
|
| 223 |
+
[
|
| 224 |
+
html.Tbody(
|
| 225 |
+
[
|
| 226 |
+
html.Tr(
|
| 227 |
+
[
|
| 228 |
+
html.Th(
|
| 229 |
+
key,
|
| 230 |
+
style={
|
| 231 |
+
"padding": "10px",
|
| 232 |
+
"verticalAlign": "middle",
|
| 233 |
+
},
|
| 234 |
+
),
|
| 235 |
+
html.Td(
|
| 236 |
+
str(value),
|
| 237 |
+
style=style,
|
| 238 |
+
),
|
| 239 |
+
],
|
| 240 |
+
)
|
| 241 |
+
for key, value in properties.items()
|
| 242 |
+
],
|
| 243 |
+
)
|
| 244 |
+
],
|
| 245 |
+
style={
|
| 246 |
+
"width": "100%",
|
| 247 |
+
"borderCollapse": "collapse",
|
| 248 |
+
"fontFamily": "'Arial', sans-serif",
|
| 249 |
+
"fontSize": "14px",
|
| 250 |
+
"color": "#333333",
|
| 251 |
+
},
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
return properties_html
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def get_crystal_plot(structure):
|
| 258 |
+
sga = SpacegroupAnalyzer(structure)
|
| 259 |
+
# Create the StructureMoleculeComponent
|
| 260 |
+
structure_component = ctc.StructureMoleculeComponent(structure)
|
| 261 |
+
return structure_component.layout(), sga
|