Spaces:
Running
Running
import os | |
import sqlite3 | |
import re | |
def utils_extract_db_schema_as_string( | |
db_id, base_path, normalize=False, sql: str | None = None, get_insert_into: bool = False | |
): | |
""" | |
Extracts the full schema of an SQLite database into a single string. | |
:param base_path: Base path where the database is located. | |
:param db_id: Path to the SQLite database file. | |
:param normalize: Whether to normalize the schema string. | |
:param sql: Optional SQL query to filter specific tables. | |
:return: Schema of the database as a single string. | |
""" | |
connection = sqlite3.connect(base_path) | |
cursor = connection.cursor() | |
# Get the schema entries based on the provided SQL query | |
schema_entries = _get_schema_entries(cursor, sql, get_insert_into) | |
# Combine all schema definitions into a single string | |
schema_string = _combine_schema_entries(schema_entries, normalize) | |
return schema_string | |
def _get_schema_entries(cursor, sql=None, get_insert_into=False): | |
""" | |
Retrieves schema entries and optionally data entries from the SQLite database. | |
:param cursor: SQLite cursor object. | |
:param sql: Optional SQL query to filter specific tables. | |
:param get_insert_into: Boolean flag to include INSERT INTO statements. | |
:return: List of schema and optionally data entries. | |
""" | |
entries = [] | |
if sql: | |
# Extract table names from the provided SQL query | |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
tables = [tbl[0] for tbl in cursor.fetchall() if tbl[0].lower() in sql.lower()] | |
else: | |
# Retrieve all table names | |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
tables = [tbl[0] for tbl in cursor.fetchall()] | |
for table in tables: | |
# Retrieve the CREATE TABLE statement for each table | |
cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}' AND sql IS NOT NULL;") | |
create_table_stmt = cursor.fetchone() | |
if create_table_stmt: | |
entries.append(create_table_stmt[0]) | |
if get_insert_into: | |
# Retrieve all data from the table | |
cursor.execute(f"SELECT * FROM {table};") | |
rows = cursor.fetchall() | |
column_names = [description[0] for description in cursor.description] | |
# Generate INSERT INTO statements for each row | |
# TODO now hardcoded to first 3 | |
for row in rows[:3]: | |
values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row) | |
insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});" | |
entries.append(insert_stmt) | |
return entries | |
def _combine_schema_entries(schema_entries, normalize): | |
""" | |
Combines schema entries into a single string. | |
:param schema_entries: List of schema entries. | |
:param normalize: Whether to normalize the schema string. | |
:return: Combined schema string. | |
""" | |
if not normalize: | |
return "\n".join(entry for entry in schema_entries) | |
return "\n".join( | |
re.sub( | |
r"\s*\)", | |
")", | |
re.sub( | |
r"\(\s*", | |
"(", | |
re.sub( | |
r"(`\w+`)\s+\(", | |
r"\1(", | |
re.sub( | |
r"^\s*([^\s(]+)", | |
r"`\1`", | |
re.sub( | |
r"\s+", | |
" ", | |
entry.replace("CREATE TABLE", "").replace("\t", " "), | |
).strip(), | |
), | |
), | |
), | |
) | |
for entry in schema_entries | |
) | |