feat(v0.2.0): data pipeline

This commit is contained in:
0x_n3m0_
2026-01-05 11:34:18 +02:00
parent 2527938680
commit b5e7043df6
23 changed files with 2813 additions and 7 deletions

View File

@@ -0,0 +1,6 @@
timestamp,open,high,low,close,volume
2024-01-01 03:00:00,100.0,100.5,99.5,100.2,1000
2024-01-01 03:01:00,100.2,100.7,99.7,100.4,1100
2024-01-01 03:02:00,100.4,100.9,99.9,100.6,1200
2024-01-01 03:03:00,100.6,101.1,100.1,100.8,1300
2024-01-01 03:04:00,100.8,101.3,100.3,101.0,1400
1 timestamp open high low close volume
2 2024-01-01 03:00:00 100.0 100.5 99.5 100.2 1000
3 2024-01-01 03:01:00 100.2 100.7 99.7 100.4 1100
4 2024-01-01 03:02:00 100.4 100.9 99.9 100.6 1200
5 2024-01-01 03:03:00 100.6 101.1 100.1 100.8 1300
6 2024-01-01 03:04:00 100.8 101.3 100.3 101.0 1400

View File

@@ -0,0 +1,127 @@
"""Integration tests for database operations."""
import os
import tempfile
import pytest
from src.core.enums import Timeframe
from src.data.database import get_db_session, init_database
from src.data.models import OHLCVData
from src.data.repositories import OHLCVRepository
@pytest.fixture
def temp_db():
"""Create temporary database for testing."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
# Initialize database
init_database(create_tables=True)
yield db_path
# Cleanup
if os.path.exists(db_path):
os.unlink(db_path)
if "DATABASE_URL" in os.environ:
del os.environ["DATABASE_URL"]
def test_create_and_retrieve_ohlcv(temp_db):
"""Test creating and retrieving OHLCV records."""
from datetime import datetime
with get_db_session() as session:
repo = OHLCVRepository(session=session)
# Create record
record = OHLCVData(
symbol="DAX",
timeframe=Timeframe.M1,
timestamp=datetime(2024, 1, 1, 3, 0, 0),
open=100.0,
high=100.5,
low=99.5,
close=100.2,
volume=1000,
)
created = repo.create(record)
assert created.id is not None
# Retrieve record
retrieved = repo.get_by_id(created.id)
assert retrieved is not None
assert retrieved.symbol == "DAX"
assert retrieved.close == 100.2
def test_batch_create_ohlcv(temp_db):
"""Test batch creation of OHLCV records."""
from datetime import datetime, timedelta
with get_db_session() as session:
repo = OHLCVRepository(session=session)
# Create multiple records
records = []
base_time = datetime(2024, 1, 1, 3, 0, 0)
for i in range(10):
records.append(
OHLCVData(
symbol="DAX",
timeframe=Timeframe.M1,
timestamp=base_time + timedelta(minutes=i),
open=100.0 + i * 0.1,
high=100.5 + i * 0.1,
low=99.5 + i * 0.1,
close=100.2 + i * 0.1,
volume=1000,
)
)
created = repo.create_batch(records)
assert len(created) == 10
# Verify all records saved
retrieved = repo.get_by_timestamp_range(
"DAX",
Timeframe.M1,
base_time,
base_time + timedelta(minutes=10),
)
assert len(retrieved) == 10
def test_get_by_timestamp_range(temp_db):
"""Test retrieving records by timestamp range."""
from datetime import datetime, timedelta
with get_db_session() as session:
repo = OHLCVRepository(session=session)
# Create records
base_time = datetime(2024, 1, 1, 3, 0, 0)
for i in range(20):
record = OHLCVData(
symbol="DAX",
timeframe=Timeframe.M1,
timestamp=base_time + timedelta(minutes=i),
open=100.0,
high=100.5,
low=99.5,
close=100.2,
volume=1000,
)
repo.create(record)
# Retrieve subset
start = base_time + timedelta(minutes=5)
end = base_time + timedelta(minutes=15)
records = repo.get_by_timestamp_range("DAX", Timeframe.M1, start, end)
assert len(records) == 11 # Inclusive of start and end

View File

@@ -0,0 +1 @@
"""Unit tests for data module."""

View File

@@ -0,0 +1,65 @@
"""Tests for database connection and session management."""
import os
import tempfile
import pytest
from src.data.database import get_db_session, get_engine, init_database
@pytest.fixture
def temp_db():
"""Create temporary database for testing."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
# Set environment variable
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
yield db_path
# Cleanup
if os.path.exists(db_path):
os.unlink(db_path)
if "DATABASE_URL" in os.environ:
del os.environ["DATABASE_URL"]
def test_get_engine(temp_db):
"""Test engine creation."""
engine = get_engine()
assert engine is not None
assert str(engine.url).startswith("sqlite")
def test_init_database(temp_db):
"""Test database initialization."""
init_database(create_tables=True)
assert os.path.exists(temp_db)
def test_get_db_session(temp_db):
"""Test database session context manager."""
init_database(create_tables=True)
with get_db_session() as session:
assert session is not None
# Session should be usable
result = session.execute("SELECT 1").scalar()
assert result == 1
def test_session_rollback_on_error(temp_db):
"""Test that session rolls back on error."""
init_database(create_tables=True)
try:
with get_db_session() as session:
# Cause an error
session.execute("SELECT * FROM nonexistent_table")
except Exception:
pass # Expected
# Session should have been rolled back and closed
assert True # If we get here, rollback worked

View File

@@ -0,0 +1,83 @@
"""Tests for data loaders."""
import pandas as pd
import pytest
from src.core.enums import Timeframe
from src.data.loaders import CSVLoader, ParquetLoader
@pytest.fixture
def sample_ohlcv_data():
"""Create sample OHLCV DataFrame."""
dates = pd.date_range("2024-01-01 03:00", periods=100, freq="1min")
return pd.DataFrame(
{
"timestamp": dates,
"open": [100.0 + i * 0.1 for i in range(100)],
"high": [100.5 + i * 0.1 for i in range(100)],
"low": [99.5 + i * 0.1 for i in range(100)],
"close": [100.2 + i * 0.1 for i in range(100)],
"volume": [1000] * 100,
}
)
@pytest.fixture
def csv_file(sample_ohlcv_data, tmp_path):
"""Create temporary CSV file."""
csv_path = tmp_path / "test_data.csv"
sample_ohlcv_data.to_csv(csv_path, index=False)
return csv_path
@pytest.fixture
def parquet_file(sample_ohlcv_data, tmp_path):
"""Create temporary Parquet file."""
parquet_path = tmp_path / "test_data.parquet"
sample_ohlcv_data.to_parquet(parquet_path, index=False)
return parquet_path
def test_csv_loader(csv_file):
"""Test CSV loader."""
loader = CSVLoader()
df = loader.load(str(csv_file), symbol="DAX", timeframe=Timeframe.M1)
assert len(df) == 100
assert "symbol" in df.columns
assert "timeframe" in df.columns
assert df["symbol"].iloc[0] == "DAX"
assert df["timeframe"].iloc[0] == "1min"
def test_csv_loader_missing_file():
"""Test CSV loader with missing file."""
loader = CSVLoader()
with pytest.raises(Exception): # Should raise DataError
loader.load("nonexistent.csv")
def test_parquet_loader(parquet_file):
"""Test Parquet loader."""
loader = ParquetLoader()
df = loader.load(str(parquet_file), symbol="DAX", timeframe=Timeframe.M1)
assert len(df) == 100
assert "symbol" in df.columns
assert "timeframe" in df.columns
def test_load_and_preprocess(csv_file):
"""Test load_and_preprocess function."""
from src.data.loaders import load_and_preprocess
df = load_and_preprocess(
str(csv_file),
loader_type="csv",
validate=True,
preprocess=True,
)
assert len(df) == 100
assert "timestamp" in df.columns

View File

@@ -0,0 +1,87 @@
"""Tests for data preprocessors."""
import numpy as np
import pandas as pd
import pytest
from src.data.preprocessors import filter_session, handle_missing_data, remove_duplicates
@pytest.fixture
def sample_data_with_missing():
"""Create sample DataFrame with missing values."""
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 10,
"high": [100.5] * 10,
"low": [99.5] * 10,
"close": [100.2] * 10,
}
)
# Add some missing values
df.loc[2, "close"] = np.nan
df.loc[5, "open"] = np.nan
return df
@pytest.fixture
def sample_data_with_duplicates():
"""Create sample DataFrame with duplicates."""
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 10,
"high": [100.5] * 10,
"low": [99.5] * 10,
"close": [100.2] * 10,
}
)
# Add duplicate
df = pd.concat([df, df.iloc[[0]]], ignore_index=True)
return df
def test_handle_missing_data_forward_fill(sample_data_with_missing):
"""Test forward fill missing data."""
df = handle_missing_data(sample_data_with_missing, method="forward_fill")
assert df["close"].isna().sum() == 0
assert df["open"].isna().sum() == 0
def test_handle_missing_data_drop(sample_data_with_missing):
"""Test drop missing data."""
df = handle_missing_data(sample_data_with_missing, method="drop")
assert df["close"].isna().sum() == 0
assert df["open"].isna().sum() == 0
assert len(df) < len(sample_data_with_missing)
def test_remove_duplicates(sample_data_with_duplicates):
"""Test duplicate removal."""
df = remove_duplicates(sample_data_with_duplicates)
assert len(df) == 10 # Should remove duplicate
def test_filter_session():
"""Test session filtering."""
# Create data spanning multiple hours
dates = pd.date_range("2024-01-01 02:00", periods=120, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 120,
"high": [100.5] * 120,
"low": [99.5] * 120,
"close": [100.2] * 120,
}
)
# Filter to 3-4 AM EST
df_filtered = filter_session(df, session_start="03:00", session_end="04:00")
# Should have approximately 60 rows (1 hour of 1-minute data)
assert len(df_filtered) > 0
assert len(df_filtered) <= 60

View File

@@ -0,0 +1,97 @@
"""Tests for data validators."""
import pandas as pd
import pytest
from src.core.enums import Timeframe
from src.data.validators import check_continuity, detect_outliers, validate_ohlcv
@pytest.fixture
def valid_ohlcv_data():
"""Create valid OHLCV DataFrame."""
dates = pd.date_range("2024-01-01 03:00", periods=100, freq="1min")
return pd.DataFrame(
{
"timestamp": dates,
"open": [100.0 + i * 0.1 for i in range(100)],
"high": [100.5 + i * 0.1 for i in range(100)],
"low": [99.5 + i * 0.1 for i in range(100)],
"close": [100.2 + i * 0.1 for i in range(100)],
"volume": [1000] * 100,
}
)
@pytest.fixture
def invalid_ohlcv_data():
"""Create invalid OHLCV DataFrame."""
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 10,
"high": [99.0] * 10, # Invalid: high < low
"low": [99.5] * 10,
"close": [100.2] * 10,
}
)
return df
def test_validate_ohlcv_valid(valid_ohlcv_data):
"""Test validation with valid data."""
df = validate_ohlcv(valid_ohlcv_data)
assert len(df) == 100
def test_validate_ohlcv_invalid(invalid_ohlcv_data):
"""Test validation with invalid data."""
with pytest.raises(Exception): # Should raise ValidationError
validate_ohlcv(invalid_ohlcv_data)
def test_validate_ohlcv_missing_columns():
"""Test validation with missing columns."""
df = pd.DataFrame({"timestamp": pd.date_range("2024-01-01", periods=10)})
with pytest.raises(Exception): # Should raise ValidationError
validate_ohlcv(df)
def test_check_continuity(valid_ohlcv_data):
"""Test continuity check."""
is_continuous, gaps = check_continuity(valid_ohlcv_data, Timeframe.M1)
assert is_continuous
assert len(gaps) == 0
def test_check_continuity_with_gaps():
"""Test continuity check with gaps."""
# Create data with gaps
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
# Remove some dates to create gaps
dates = dates[[0, 1, 2, 5, 6, 7, 8, 9]] # Gap between index 2 and 5
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * len(dates),
"high": [100.5] * len(dates),
"low": [99.5] * len(dates),
"close": [100.2] * len(dates),
}
)
is_continuous, gaps = check_continuity(df, Timeframe.M1)
assert not is_continuous
assert len(gaps) > 0
def test_detect_outliers(valid_ohlcv_data):
"""Test outlier detection."""
# Add an outlier
df = valid_ohlcv_data.copy()
df.loc[50, "close"] = 200.0 # Extreme value
outliers = detect_outliers(df, columns=["close"], method="iqr", threshold=3.0)
assert outliers["is_outlier"].sum() > 0