from __future__ import annotations
import logging
import os
import sqlite3
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple
from .models import (
ActionLog,
AgentSnapshot,
Analytics,
CombatLog,
DatabaseHelper,
SimulationRun,
TradeLog,
WorldSnapshot,
)
from .schema import DatabaseSchema
logger = logging.getLogger(__name__)
[docs]
class Database:
"""Main database interface for the simulation framework"""
def __init__(self, db_path: str = "simulation.db"):
self.db_path = db_path
self.connection = None
self._initialize_database()
def _initialize_database(self) -> None:
"""Initialize database and create tables if they don't exist"""
try:
# Create database directory if it doesn't exist
db_dir = os.path.dirname(self.db_path)
if db_dir: # Only create directory if path has a directory component
os.makedirs(db_dir, exist_ok=True)
with self.get_connection() as conn:
# Enable foreign keys
conn.execute("PRAGMA foreign_keys = ON")
# Set WAL mode for better concurrent access
conn.execute("PRAGMA journal_mode = WAL")
# Check current schema version
current_version = self._get_schema_version(conn)
if current_version < DatabaseSchema.SCHEMA_VERSION:
logger.info(
f"Migrating database from version {current_version} to {DatabaseSchema.SCHEMA_VERSION}"
)
self._migrate_schema(
conn, current_version, DatabaseSchema.SCHEMA_VERSION
)
logger.info(f"Database initialized at {self.db_path}")
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
raise
def _get_schema_version(self, conn: sqlite3.Connection) -> int:
"""Get current schema version from database"""
try:
cursor = conn.execute("SELECT MAX(version) FROM schema_version")
result = cursor.fetchone()
return result[0] if result and result[0] is not None else 0
except sqlite3.OperationalError:
# Table doesn't exist, assume version 0
return 0
def _migrate_schema(
self, conn: sqlite3.Connection, from_version: int, to_version: int
) -> None:
"""Migrate database schema from one version to another"""
migration_sql = DatabaseSchema.get_migration_sql(from_version, to_version)
for sql in migration_sql:
if sql.strip():
conn.execute(sql)
# Update schema version
conn.execute(
"INSERT OR REPLACE INTO schema_version (version) VALUES (?)", (to_version,)
)
conn.commit()
[docs]
@contextmanager
def get_connection(self):
"""Get database connection with proper cleanup"""
conn = None
try:
conn = sqlite3.connect(self.db_path, timeout=30.0)
conn.row_factory = sqlite3.Row # Enable dict-like access to rows
yield conn
except Exception as e:
if conn:
conn.rollback()
logger.error(f"Database error: {e}")
raise
finally:
if conn:
conn.close()
# Simulation Run operations
[docs]
def create_simulation_run(self, simulation: SimulationRun) -> int:
"""Create a new simulation run record"""
with self.get_connection() as conn:
data = DatabaseHelper.dataclass_to_dict(simulation, for_insert=True)
placeholders = ", ".join(["?" for _ in data])
columns = ", ".join(data.keys())
cursor = conn.execute(
f"INSERT INTO simulation_runs ({columns}) VALUES ({placeholders})",
list(data.values()),
)
conn.commit()
return cursor.lastrowid
[docs]
def get_simulation_run(self, simulation_id: int) -> Optional[SimulationRun]:
"""Get simulation run by ID"""
with self.get_connection() as conn:
cursor = conn.execute(
"SELECT * FROM simulation_runs WHERE id = ?", (simulation_id,)
)
row = cursor.fetchone()
return DatabaseHelper.row_to_dataclass(row, SimulationRun) if row else None
[docs]
def update_simulation_run(self, simulation: SimulationRun) -> bool:
"""Update simulation run record"""
if not simulation.id:
return False
with self.get_connection() as conn:
data = DatabaseHelper.dataclass_to_dict(simulation)
data.pop("id") # Remove ID from update data
set_clause = ", ".join([f"{key} = ?" for key in data.keys()])
values = list(data.values()) + [simulation.id]
cursor = conn.execute(
f"UPDATE simulation_runs SET {set_clause} WHERE id = ?", values
)
conn.commit()
return cursor.rowcount > 0
[docs]
def list_simulation_runs(
self, limit: int = 100, offset: int = 0
) -> List[SimulationRun]:
"""List all simulation runs"""
with self.get_connection() as conn:
cursor = conn.execute(
"""
SELECT * FROM simulation_runs
ORDER BY created_at DESC
LIMIT ? OFFSET ?
""",
(limit, offset),
)
rows = cursor.fetchall()
return [DatabaseHelper.row_to_dataclass(row, SimulationRun) for row in rows]
# Agent Snapshot operations
[docs]
def save_agent_snapshot(self, snapshot: AgentSnapshot) -> int:
"""Save agent snapshot"""
with self.get_connection() as conn:
data = DatabaseHelper.dataclass_to_dict(snapshot, for_insert=True)
placeholders = ", ".join(["?" for _ in data])
columns = ", ".join(data.keys())
cursor = conn.execute(
f"INSERT INTO agent_snapshots ({columns}) VALUES ({placeholders})",
list(data.values()),
)
conn.commit()
return cursor.lastrowid
[docs]
def save_agent_snapshots_batch(self, snapshots: List[AgentSnapshot]) -> None:
"""Save multiple agent snapshots efficiently"""
if not snapshots:
return
with self.get_connection() as conn:
# Prepare data
data_list = [
DatabaseHelper.dataclass_to_dict(s, for_insert=True) for s in snapshots
]
# All snapshots should have the same fields
columns = list(data_list[0].keys())
placeholders = ", ".join(["?" for _ in columns])
columns_str = ", ".join(columns)
# Execute batch insert
values_list = [list(data.values()) for data in data_list]
conn.executemany(
f"INSERT INTO agent_snapshots ({columns_str}) VALUES ({placeholders})",
values_list,
)
conn.commit()
[docs]
def get_agent_snapshots(
self,
simulation_id: int,
agent_id: Optional[int] = None,
start_tick: Optional[int] = None,
end_tick: Optional[int] = None,
limit: int = 1000,
) -> List[AgentSnapshot]:
"""Get agent snapshots with filtering"""
with self.get_connection() as conn:
conditions = ["simulation_id = ?"]
params = [simulation_id]
if agent_id is not None:
conditions.append("agent_id = ?")
params.append(agent_id)
if start_tick is not None:
conditions.append("tick >= ?")
params.append(start_tick)
if end_tick is not None:
conditions.append("tick <= ?")
params.append(end_tick)
where_clause = " AND ".join(conditions)
params.append(limit)
cursor = conn.execute(
f"""
SELECT * FROM agent_snapshots
WHERE {where_clause}
ORDER BY tick DESC, agent_id
LIMIT ?
""",
params,
)
rows = cursor.fetchall()
return [DatabaseHelper.row_to_dataclass(row, AgentSnapshot) for row in rows]
# World Snapshot operations
[docs]
def save_world_snapshot(self, snapshot: WorldSnapshot) -> int:
"""Save world snapshot"""
with self.get_connection() as conn:
data = DatabaseHelper.dataclass_to_dict(snapshot, for_insert=True)
placeholders = ", ".join(["?" for _ in data])
columns = ", ".join(data.keys())
cursor = conn.execute(
f"INSERT INTO world_snapshots ({columns}) VALUES ({placeholders})",
list(data.values()),
)
conn.commit()
return cursor.lastrowid
[docs]
def get_world_snapshots(
self,
simulation_id: int,
start_tick: Optional[int] = None,
end_tick: Optional[int] = None,
limit: int = 1000,
) -> List[WorldSnapshot]:
"""Get world snapshots with filtering"""
with self.get_connection() as conn:
conditions = ["simulation_id = ?"]
params = [simulation_id]
if start_tick is not None:
conditions.append("tick >= ?")
params.append(start_tick)
if end_tick is not None:
conditions.append("tick <= ?")
params.append(end_tick)
where_clause = " AND ".join(conditions)
params.append(limit)
cursor = conn.execute(
f"""
SELECT * FROM world_snapshots
WHERE {where_clause}
ORDER BY tick DESC
LIMIT ?
""",
params,
)
rows = cursor.fetchall()
return [DatabaseHelper.row_to_dataclass(row, WorldSnapshot) for row in rows]
# Action Log operations
[docs]
def log_action(self, action: ActionLog) -> int:
"""Log an action"""
with self.get_connection() as conn:
data = DatabaseHelper.dataclass_to_dict(action, for_insert=True)
placeholders = ", ".join(["?" for _ in data])
columns = ", ".join(data.keys())
cursor = conn.execute(
f"INSERT INTO action_logs ({columns}) VALUES ({placeholders})",
list(data.values()),
)
conn.commit()
return cursor.lastrowid
[docs]
def get_action_logs(
self,
simulation_id: int,
agent_id: Optional[int] = None,
action_type: Optional[str] = None,
start_tick: Optional[int] = None,
end_tick: Optional[int] = None,
limit: int = 1000,
) -> List[ActionLog]:
"""Get action logs with filtering"""
with self.get_connection() as conn:
conditions = ["simulation_id = ?"]
params = [simulation_id]
if agent_id is not None:
conditions.append("agent_id = ?")
params.append(agent_id)
if action_type is not None:
conditions.append("action_type = ?")
params.append(action_type)
if start_tick is not None:
conditions.append("tick >= ?")
params.append(start_tick)
if end_tick is not None:
conditions.append("tick <= ?")
params.append(end_tick)
where_clause = " AND ".join(conditions)
params.append(limit)
cursor = conn.execute(
f"""
SELECT * FROM action_logs
WHERE {where_clause}
ORDER BY tick DESC, agent_id
LIMIT ?
""",
params,
)
rows = cursor.fetchall()
return [DatabaseHelper.row_to_dataclass(row, ActionLog) for row in rows]
# Trade Log operations
[docs]
def log_trade(self, trade: TradeLog) -> int:
"""Log a trade transaction"""
with self.get_connection() as conn:
data = DatabaseHelper.dataclass_to_dict(trade, for_insert=True)
placeholders = ", ".join(["?" for _ in data])
columns = ", ".join(data.keys())
cursor = conn.execute(
f"INSERT INTO trade_logs ({columns}) VALUES ({placeholders})",
list(data.values()),
)
conn.commit()
return cursor.lastrowid
# Combat Log operations
[docs]
def log_combat(self, combat: CombatLog) -> int:
"""Log a combat action"""
with self.get_connection() as conn:
data = DatabaseHelper.dataclass_to_dict(combat, for_insert=True)
placeholders = ", ".join(["?" for _ in data])
columns = ", ".join(data.keys())
cursor = conn.execute(
f"INSERT INTO combat_logs ({columns}) VALUES ({placeholders})",
list(data.values()),
)
conn.commit()
return cursor.lastrowid
# Analytics operations
[docs]
def save_analytics(self, analytics: Analytics) -> int:
"""Save analytics metric"""
with self.get_connection() as conn:
data = DatabaseHelper.dataclass_to_dict(analytics, for_insert=True)
placeholders = ", ".join(["?" for _ in data])
columns = ", ".join(data.keys())
cursor = conn.execute(
f"INSERT INTO analytics ({columns}) VALUES ({placeholders})",
list(data.values()),
)
conn.commit()
return cursor.lastrowid
[docs]
def get_analytics(
self,
simulation_id: int,
metric_name: Optional[str] = None,
category: Optional[str] = None,
start_tick: Optional[int] = None,
end_tick: Optional[int] = None,
limit: int = 1000,
) -> List[Analytics]:
"""Get analytics data with filtering"""
with self.get_connection() as conn:
conditions = ["simulation_id = ?"]
params = [simulation_id]
if metric_name is not None:
conditions.append("metric_name = ?")
params.append(metric_name)
if category is not None:
conditions.append("category = ?")
params.append(category)
if start_tick is not None:
conditions.append("tick >= ?")
params.append(start_tick)
if end_tick is not None:
conditions.append("tick <= ?")
params.append(end_tick)
where_clause = " AND ".join(conditions)
params.append(limit)
cursor = conn.execute(
f"""
SELECT * FROM analytics
WHERE {where_clause}
ORDER BY tick DESC, metric_name
LIMIT ?
""",
params,
)
rows = cursor.fetchall()
return [DatabaseHelper.row_to_dataclass(row, Analytics) for row in rows]
# Query operations using views
[docs]
def get_agent_summary(self, simulation_id: int) -> List[Dict[str, Any]]:
"""Get agent summary using database view"""
with self.get_connection() as conn:
cursor = conn.execute(
"SELECT * FROM agent_summary WHERE simulation_id = ?", (simulation_id,)
)
return [dict(row) for row in cursor.fetchall()]
[docs]
def get_action_summary(self, simulation_id: int) -> List[Dict[str, Any]]:
"""Get action summary using database view"""
with self.get_connection() as conn:
cursor = conn.execute(
"SELECT * FROM action_summary WHERE simulation_id = ? ORDER BY total_actions DESC",
(simulation_id,),
)
return [dict(row) for row in cursor.fetchall()]
[docs]
def get_trade_summary(self, simulation_id: int) -> Optional[Dict[str, Any]]:
"""Get trade summary using database view"""
with self.get_connection() as conn:
cursor = conn.execute(
"SELECT * FROM trade_summary WHERE simulation_id = ?", (simulation_id,)
)
row = cursor.fetchone()
return dict(row) if row else None
[docs]
def get_combat_summary(self, simulation_id: int) -> Optional[Dict[str, Any]]:
"""Get combat summary using database view"""
with self.get_connection() as conn:
cursor = conn.execute(
"SELECT * FROM combat_summary WHERE simulation_id = ?", (simulation_id,)
)
row = cursor.fetchone()
return dict(row) if row else None
# Utility operations
[docs]
def cleanup_old_data(self) -> None:
"""Clean up old data according to retention policies"""
cleanup_sql = DatabaseSchema.get_cleanup_sql()
with self.get_connection() as conn:
for sql in cleanup_sql:
if sql.strip():
conn.execute(sql)
conn.commit()
logger.info("Database cleanup completed")
[docs]
def get_database_stats(self) -> Dict[str, Any]:
"""Get database statistics"""
with self.get_connection() as conn:
stats = {}
# Table row counts
tables = [
"simulation_runs",
"agent_snapshots",
"world_snapshots",
"action_logs",
"trade_logs",
"combat_logs",
"analytics",
]
for table in tables:
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
stats[f"{table}_count"] = cursor.fetchone()[0]
# Database size
cursor = conn.execute("PRAGMA page_count")
page_count = cursor.fetchone()[0]
cursor = conn.execute("PRAGMA page_size")
page_size = cursor.fetchone()[0]
stats["database_size_bytes"] = page_count * page_size
return stats
[docs]
def execute_custom_query(
self, query: str, params: Tuple = ()
) -> List[Dict[str, Any]]:
"""Execute custom SQL query (use with caution)"""
with self.get_connection() as conn:
cursor = conn.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
[docs]
def close(self) -> None:
"""Close database connection"""
if self.connection:
self.connection.close()
self.connection = None