"""Tests for database connection and session management.""" import os import tempfile import pytest from src.data.database import get_db_session, get_engine, init_database @pytest.fixture def temp_db(): """Create temporary database for testing.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: db_path = f.name # Set environment variable os.environ["DATABASE_URL"] = f"sqlite:///{db_path}" yield db_path # Cleanup if os.path.exists(db_path): os.unlink(db_path) if "DATABASE_URL" in os.environ: del os.environ["DATABASE_URL"] def test_get_engine(temp_db): """Test engine creation.""" engine = get_engine() assert engine is not None assert str(engine.url).startswith("sqlite") def test_init_database(temp_db): """Test database initialization.""" init_database(create_tables=True) assert os.path.exists(temp_db) def test_get_db_session(temp_db): """Test database session context manager.""" from sqlalchemy import text init_database(create_tables=True) with get_db_session() as session: assert session is not None # Session should be usable result = session.execute(text("SELECT 1")).scalar() assert result == 1 def test_session_rollback_on_error(temp_db): """Test that session rolls back on error.""" from sqlalchemy import text init_database(create_tables=True) try: with get_db_session() as session: # Cause an error session.execute(text("SELECT * FROM nonexistent_table")) except Exception: pass # Expected # Session should have been rolled back and closed assert True # If we get here, rollback worked