70 lines
1.7 KiB
Python
70 lines
1.7 KiB
Python
"""Tests for database connection and session management."""
|
|
|
|
import os
|
|
import tempfile
|
|
|
|
import pytest
|
|
|
|
from src.data.database import get_db_session, get_engine, init_database
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_db():
|
|
"""Create temporary database for testing."""
|
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
|
db_path = f.name
|
|
|
|
# Set environment variable
|
|
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
|
|
|
|
yield db_path
|
|
|
|
# Cleanup
|
|
if os.path.exists(db_path):
|
|
os.unlink(db_path)
|
|
if "DATABASE_URL" in os.environ:
|
|
del os.environ["DATABASE_URL"]
|
|
|
|
|
|
def test_get_engine(temp_db):
|
|
"""Test engine creation."""
|
|
engine = get_engine()
|
|
assert engine is not None
|
|
assert str(engine.url).startswith("sqlite")
|
|
|
|
|
|
def test_init_database(temp_db):
|
|
"""Test database initialization."""
|
|
init_database(create_tables=True)
|
|
assert os.path.exists(temp_db)
|
|
|
|
|
|
def test_get_db_session(temp_db):
|
|
"""Test database session context manager."""
|
|
from sqlalchemy import text
|
|
|
|
init_database(create_tables=True)
|
|
|
|
with get_db_session() as session:
|
|
assert session is not None
|
|
# Session should be usable
|
|
result = session.execute(text("SELECT 1")).scalar()
|
|
assert result == 1
|
|
|
|
|
|
def test_session_rollback_on_error(temp_db):
|
|
"""Test that session rolls back on error."""
|
|
from sqlalchemy import text
|
|
|
|
init_database(create_tables=True)
|
|
|
|
try:
|
|
with get_db_session() as session:
|
|
# Cause an error
|
|
session.execute(text("SELECT * FROM nonexistent_table"))
|
|
except Exception:
|
|
pass # Expected
|
|
|
|
# Session should have been rolled back and closed
|
|
assert True # If we get here, rollback worked
|