Files
dax-ml/tests/unit/test_data/test_database.py

70 lines
1.7 KiB
Python

"""Tests for database connection and session management."""
import os
import tempfile
import pytest
from src.data.database import get_db_session, get_engine, init_database
@pytest.fixture
def temp_db():
"""Create temporary database for testing."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
# Set environment variable
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
yield db_path
# Cleanup
if os.path.exists(db_path):
os.unlink(db_path)
if "DATABASE_URL" in os.environ:
del os.environ["DATABASE_URL"]
def test_get_engine(temp_db):
"""Test engine creation."""
engine = get_engine()
assert engine is not None
assert str(engine.url).startswith("sqlite")
def test_init_database(temp_db):
"""Test database initialization."""
init_database(create_tables=True)
assert os.path.exists(temp_db)
def test_get_db_session(temp_db):
"""Test database session context manager."""
from sqlalchemy import text
init_database(create_tables=True)
with get_db_session() as session:
assert session is not None
# Session should be usable
result = session.execute(text("SELECT 1")).scalar()
assert result == 1
def test_session_rollback_on_error(temp_db):
"""Test that session rolls back on error."""
from sqlalchemy import text
init_database(create_tables=True)
try:
with get_db_session() as session:
# Cause an error
session.execute(text("SELECT * FROM nonexistent_table"))
except Exception:
pass # Expected
# Session should have been rolled back and closed
assert True # If we get here, rollback worked