Files
dax-ml/tests/integration/test_database.py
2026-01-05 11:34:18 +02:00

128 lines
3.5 KiB
Python

"""Integration tests for database operations."""
import os
import tempfile
import pytest
from src.core.enums import Timeframe
from src.data.database import get_db_session, init_database
from src.data.models import OHLCVData
from src.data.repositories import OHLCVRepository
@pytest.fixture
def temp_db():
"""Create temporary database for testing."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
# Initialize database
init_database(create_tables=True)
yield db_path
# Cleanup
if os.path.exists(db_path):
os.unlink(db_path)
if "DATABASE_URL" in os.environ:
del os.environ["DATABASE_URL"]
def test_create_and_retrieve_ohlcv(temp_db):
"""Test creating and retrieving OHLCV records."""
from datetime import datetime
with get_db_session() as session:
repo = OHLCVRepository(session=session)
# Create record
record = OHLCVData(
symbol="DAX",
timeframe=Timeframe.M1,
timestamp=datetime(2024, 1, 1, 3, 0, 0),
open=100.0,
high=100.5,
low=99.5,
close=100.2,
volume=1000,
)
created = repo.create(record)
assert created.id is not None
# Retrieve record
retrieved = repo.get_by_id(created.id)
assert retrieved is not None
assert retrieved.symbol == "DAX"
assert retrieved.close == 100.2
def test_batch_create_ohlcv(temp_db):
"""Test batch creation of OHLCV records."""
from datetime import datetime, timedelta
with get_db_session() as session:
repo = OHLCVRepository(session=session)
# Create multiple records
records = []
base_time = datetime(2024, 1, 1, 3, 0, 0)
for i in range(10):
records.append(
OHLCVData(
symbol="DAX",
timeframe=Timeframe.M1,
timestamp=base_time + timedelta(minutes=i),
open=100.0 + i * 0.1,
high=100.5 + i * 0.1,
low=99.5 + i * 0.1,
close=100.2 + i * 0.1,
volume=1000,
)
)
created = repo.create_batch(records)
assert len(created) == 10
# Verify all records saved
retrieved = repo.get_by_timestamp_range(
"DAX",
Timeframe.M1,
base_time,
base_time + timedelta(minutes=10),
)
assert len(retrieved) == 10
def test_get_by_timestamp_range(temp_db):
"""Test retrieving records by timestamp range."""
from datetime import datetime, timedelta
with get_db_session() as session:
repo = OHLCVRepository(session=session)
# Create records
base_time = datetime(2024, 1, 1, 3, 0, 0)
for i in range(20):
record = OHLCVData(
symbol="DAX",
timeframe=Timeframe.M1,
timestamp=base_time + timedelta(minutes=i),
open=100.0,
high=100.5,
low=99.5,
close=100.2,
volume=1000,
)
repo.create(record)
# Retrieve subset
start = base_time + timedelta(minutes=5)
end = base_time + timedelta(minutes=15)
records = repo.get_by_timestamp_range("DAX", Timeframe.M1, start, end)
assert len(records) == 11 # Inclusive of start and end