"""Integration tests for database operations.""" import os import tempfile import pytest from src.core.enums import Timeframe from src.data.database import get_db_session, init_database from src.data.models import OHLCVData from src.data.repositories import OHLCVRepository @pytest.fixture def temp_db(): """Create temporary database for testing.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: db_path = f.name os.environ["DATABASE_URL"] = f"sqlite:///{db_path}" # Initialize database init_database(create_tables=True) yield db_path # Cleanup if os.path.exists(db_path): os.unlink(db_path) if "DATABASE_URL" in os.environ: del os.environ["DATABASE_URL"] def test_create_and_retrieve_ohlcv(temp_db): """Test creating and retrieving OHLCV records.""" from datetime import datetime with get_db_session() as session: repo = OHLCVRepository(session=session) # Create record with unique timestamp record = OHLCVData( symbol="DAX", timeframe=Timeframe.M1, timestamp=datetime(2024, 1, 1, 2, 0, 0), # Different hour to avoid collision open=100.0, high=100.5, low=99.5, close=100.2, volume=1000, ) created = repo.create(record) assert created.id is not None # Retrieve record retrieved = repo.get_by_id(created.id) assert retrieved is not None assert retrieved.symbol == "DAX" assert retrieved.close == 100.2 def test_batch_create_ohlcv(temp_db): """Test batch creation of OHLCV records.""" from datetime import datetime, timedelta with get_db_session() as session: repo = OHLCVRepository(session=session) # Create multiple records records = [] base_time = datetime(2024, 1, 1, 3, 0, 0) for i in range(10): records.append( OHLCVData( symbol="DAX", timeframe=Timeframe.M1, timestamp=base_time + timedelta(minutes=i), open=100.0 + i * 0.1, high=100.5 + i * 0.1, low=99.5 + i * 0.1, close=100.2 + i * 0.1, volume=1000, ) ) created = repo.create_batch(records) assert len(created) == 10 # Verify all records saved # Query from 03:00 to 03:09 (we created records for i=0 to 9) retrieved = repo.get_by_timestamp_range( "DAX", Timeframe.M1, base_time, base_time + timedelta(minutes=9), ) assert len(retrieved) == 10 def test_get_by_timestamp_range(temp_db): """Test retrieving records by timestamp range.""" from datetime import datetime, timedelta with get_db_session() as session: repo = OHLCVRepository(session=session) # Create records with unique timestamp range (4 AM hour) base_time = datetime(2024, 1, 1, 4, 0, 0) for i in range(20): record = OHLCVData( symbol="DAX", timeframe=Timeframe.M1, timestamp=base_time + timedelta(minutes=i), open=100.0, high=100.5, low=99.5, close=100.2, volume=1000, ) repo.create(record) # Retrieve subset start = base_time + timedelta(minutes=5) end = base_time + timedelta(minutes=15) records = repo.get_by_timestamp_range("DAX", Timeframe.M1, start, end) assert len(records) == 11 # Inclusive of start and end