128 lines
3.5 KiB
Python
128 lines
3.5 KiB
Python
"""Integration tests for database operations."""
|
|
|
|
import os
|
|
import tempfile
|
|
|
|
import pytest
|
|
|
|
from src.core.enums import Timeframe
|
|
from src.data.database import get_db_session, init_database
|
|
from src.data.models import OHLCVData
|
|
from src.data.repositories import OHLCVRepository
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_db():
|
|
"""Create temporary database for testing."""
|
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
|
db_path = f.name
|
|
|
|
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
|
|
|
|
# Initialize database
|
|
init_database(create_tables=True)
|
|
|
|
yield db_path
|
|
|
|
# Cleanup
|
|
if os.path.exists(db_path):
|
|
os.unlink(db_path)
|
|
if "DATABASE_URL" in os.environ:
|
|
del os.environ["DATABASE_URL"]
|
|
|
|
|
|
def test_create_and_retrieve_ohlcv(temp_db):
|
|
"""Test creating and retrieving OHLCV records."""
|
|
from datetime import datetime
|
|
|
|
with get_db_session() as session:
|
|
repo = OHLCVRepository(session=session)
|
|
|
|
# Create record
|
|
record = OHLCVData(
|
|
symbol="DAX",
|
|
timeframe=Timeframe.M1,
|
|
timestamp=datetime(2024, 1, 1, 3, 0, 0),
|
|
open=100.0,
|
|
high=100.5,
|
|
low=99.5,
|
|
close=100.2,
|
|
volume=1000,
|
|
)
|
|
|
|
created = repo.create(record)
|
|
assert created.id is not None
|
|
|
|
# Retrieve record
|
|
retrieved = repo.get_by_id(created.id)
|
|
assert retrieved is not None
|
|
assert retrieved.symbol == "DAX"
|
|
assert retrieved.close == 100.2
|
|
|
|
|
|
def test_batch_create_ohlcv(temp_db):
|
|
"""Test batch creation of OHLCV records."""
|
|
from datetime import datetime, timedelta
|
|
|
|
with get_db_session() as session:
|
|
repo = OHLCVRepository(session=session)
|
|
|
|
# Create multiple records
|
|
records = []
|
|
base_time = datetime(2024, 1, 1, 3, 0, 0)
|
|
for i in range(10):
|
|
records.append(
|
|
OHLCVData(
|
|
symbol="DAX",
|
|
timeframe=Timeframe.M1,
|
|
timestamp=base_time + timedelta(minutes=i),
|
|
open=100.0 + i * 0.1,
|
|
high=100.5 + i * 0.1,
|
|
low=99.5 + i * 0.1,
|
|
close=100.2 + i * 0.1,
|
|
volume=1000,
|
|
)
|
|
)
|
|
|
|
created = repo.create_batch(records)
|
|
assert len(created) == 10
|
|
|
|
# Verify all records saved
|
|
retrieved = repo.get_by_timestamp_range(
|
|
"DAX",
|
|
Timeframe.M1,
|
|
base_time,
|
|
base_time + timedelta(minutes=10),
|
|
)
|
|
assert len(retrieved) == 10
|
|
|
|
|
|
def test_get_by_timestamp_range(temp_db):
|
|
"""Test retrieving records by timestamp range."""
|
|
from datetime import datetime, timedelta
|
|
|
|
with get_db_session() as session:
|
|
repo = OHLCVRepository(session=session)
|
|
|
|
# Create records
|
|
base_time = datetime(2024, 1, 1, 3, 0, 0)
|
|
for i in range(20):
|
|
record = OHLCVData(
|
|
symbol="DAX",
|
|
timeframe=Timeframe.M1,
|
|
timestamp=base_time + timedelta(minutes=i),
|
|
open=100.0,
|
|
high=100.5,
|
|
low=99.5,
|
|
close=100.2,
|
|
volume=1000,
|
|
)
|
|
repo.create(record)
|
|
|
|
# Retrieve subset
|
|
start = base_time + timedelta(minutes=5)
|
|
end = base_time + timedelta(minutes=15)
|
|
records = repo.get_by_timestamp_range("DAX", Timeframe.M1, start, end)
|
|
|
|
assert len(records) == 11 # Inclusive of start and end
|