feat(v0.2.0): data pipeline

This commit is contained in:
0x_n3m0_
2026-01-05 11:34:18 +02:00
parent 2527938680
commit b5e7043df6
23 changed files with 2813 additions and 7 deletions

View File

@@ -0,0 +1 @@
"""Unit tests for data module."""

View File

@@ -0,0 +1,65 @@
"""Tests for database connection and session management."""
import os
import tempfile
import pytest
from src.data.database import get_db_session, get_engine, init_database
@pytest.fixture
def temp_db():
"""Create temporary database for testing."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
# Set environment variable
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
yield db_path
# Cleanup
if os.path.exists(db_path):
os.unlink(db_path)
if "DATABASE_URL" in os.environ:
del os.environ["DATABASE_URL"]
def test_get_engine(temp_db):
"""Test engine creation."""
engine = get_engine()
assert engine is not None
assert str(engine.url).startswith("sqlite")
def test_init_database(temp_db):
"""Test database initialization."""
init_database(create_tables=True)
assert os.path.exists(temp_db)
def test_get_db_session(temp_db):
"""Test database session context manager."""
init_database(create_tables=True)
with get_db_session() as session:
assert session is not None
# Session should be usable
result = session.execute("SELECT 1").scalar()
assert result == 1
def test_session_rollback_on_error(temp_db):
"""Test that session rolls back on error."""
init_database(create_tables=True)
try:
with get_db_session() as session:
# Cause an error
session.execute("SELECT * FROM nonexistent_table")
except Exception:
pass # Expected
# Session should have been rolled back and closed
assert True # If we get here, rollback worked

View File

@@ -0,0 +1,83 @@
"""Tests for data loaders."""
import pandas as pd
import pytest
from src.core.enums import Timeframe
from src.data.loaders import CSVLoader, ParquetLoader
@pytest.fixture
def sample_ohlcv_data():
"""Create sample OHLCV DataFrame."""
dates = pd.date_range("2024-01-01 03:00", periods=100, freq="1min")
return pd.DataFrame(
{
"timestamp": dates,
"open": [100.0 + i * 0.1 for i in range(100)],
"high": [100.5 + i * 0.1 for i in range(100)],
"low": [99.5 + i * 0.1 for i in range(100)],
"close": [100.2 + i * 0.1 for i in range(100)],
"volume": [1000] * 100,
}
)
@pytest.fixture
def csv_file(sample_ohlcv_data, tmp_path):
"""Create temporary CSV file."""
csv_path = tmp_path / "test_data.csv"
sample_ohlcv_data.to_csv(csv_path, index=False)
return csv_path
@pytest.fixture
def parquet_file(sample_ohlcv_data, tmp_path):
"""Create temporary Parquet file."""
parquet_path = tmp_path / "test_data.parquet"
sample_ohlcv_data.to_parquet(parquet_path, index=False)
return parquet_path
def test_csv_loader(csv_file):
"""Test CSV loader."""
loader = CSVLoader()
df = loader.load(str(csv_file), symbol="DAX", timeframe=Timeframe.M1)
assert len(df) == 100
assert "symbol" in df.columns
assert "timeframe" in df.columns
assert df["symbol"].iloc[0] == "DAX"
assert df["timeframe"].iloc[0] == "1min"
def test_csv_loader_missing_file():
"""Test CSV loader with missing file."""
loader = CSVLoader()
with pytest.raises(Exception): # Should raise DataError
loader.load("nonexistent.csv")
def test_parquet_loader(parquet_file):
"""Test Parquet loader."""
loader = ParquetLoader()
df = loader.load(str(parquet_file), symbol="DAX", timeframe=Timeframe.M1)
assert len(df) == 100
assert "symbol" in df.columns
assert "timeframe" in df.columns
def test_load_and_preprocess(csv_file):
"""Test load_and_preprocess function."""
from src.data.loaders import load_and_preprocess
df = load_and_preprocess(
str(csv_file),
loader_type="csv",
validate=True,
preprocess=True,
)
assert len(df) == 100
assert "timestamp" in df.columns

View File

@@ -0,0 +1,87 @@
"""Tests for data preprocessors."""
import numpy as np
import pandas as pd
import pytest
from src.data.preprocessors import filter_session, handle_missing_data, remove_duplicates
@pytest.fixture
def sample_data_with_missing():
"""Create sample DataFrame with missing values."""
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 10,
"high": [100.5] * 10,
"low": [99.5] * 10,
"close": [100.2] * 10,
}
)
# Add some missing values
df.loc[2, "close"] = np.nan
df.loc[5, "open"] = np.nan
return df
@pytest.fixture
def sample_data_with_duplicates():
"""Create sample DataFrame with duplicates."""
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 10,
"high": [100.5] * 10,
"low": [99.5] * 10,
"close": [100.2] * 10,
}
)
# Add duplicate
df = pd.concat([df, df.iloc[[0]]], ignore_index=True)
return df
def test_handle_missing_data_forward_fill(sample_data_with_missing):
"""Test forward fill missing data."""
df = handle_missing_data(sample_data_with_missing, method="forward_fill")
assert df["close"].isna().sum() == 0
assert df["open"].isna().sum() == 0
def test_handle_missing_data_drop(sample_data_with_missing):
"""Test drop missing data."""
df = handle_missing_data(sample_data_with_missing, method="drop")
assert df["close"].isna().sum() == 0
assert df["open"].isna().sum() == 0
assert len(df) < len(sample_data_with_missing)
def test_remove_duplicates(sample_data_with_duplicates):
"""Test duplicate removal."""
df = remove_duplicates(sample_data_with_duplicates)
assert len(df) == 10 # Should remove duplicate
def test_filter_session():
"""Test session filtering."""
# Create data spanning multiple hours
dates = pd.date_range("2024-01-01 02:00", periods=120, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 120,
"high": [100.5] * 120,
"low": [99.5] * 120,
"close": [100.2] * 120,
}
)
# Filter to 3-4 AM EST
df_filtered = filter_session(df, session_start="03:00", session_end="04:00")
# Should have approximately 60 rows (1 hour of 1-minute data)
assert len(df_filtered) > 0
assert len(df_filtered) <= 60

View File

@@ -0,0 +1,97 @@
"""Tests for data validators."""
import pandas as pd
import pytest
from src.core.enums import Timeframe
from src.data.validators import check_continuity, detect_outliers, validate_ohlcv
@pytest.fixture
def valid_ohlcv_data():
"""Create valid OHLCV DataFrame."""
dates = pd.date_range("2024-01-01 03:00", periods=100, freq="1min")
return pd.DataFrame(
{
"timestamp": dates,
"open": [100.0 + i * 0.1 for i in range(100)],
"high": [100.5 + i * 0.1 for i in range(100)],
"low": [99.5 + i * 0.1 for i in range(100)],
"close": [100.2 + i * 0.1 for i in range(100)],
"volume": [1000] * 100,
}
)
@pytest.fixture
def invalid_ohlcv_data():
"""Create invalid OHLCV DataFrame."""
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 10,
"high": [99.0] * 10, # Invalid: high < low
"low": [99.5] * 10,
"close": [100.2] * 10,
}
)
return df
def test_validate_ohlcv_valid(valid_ohlcv_data):
"""Test validation with valid data."""
df = validate_ohlcv(valid_ohlcv_data)
assert len(df) == 100
def test_validate_ohlcv_invalid(invalid_ohlcv_data):
"""Test validation with invalid data."""
with pytest.raises(Exception): # Should raise ValidationError
validate_ohlcv(invalid_ohlcv_data)
def test_validate_ohlcv_missing_columns():
"""Test validation with missing columns."""
df = pd.DataFrame({"timestamp": pd.date_range("2024-01-01", periods=10)})
with pytest.raises(Exception): # Should raise ValidationError
validate_ohlcv(df)
def test_check_continuity(valid_ohlcv_data):
"""Test continuity check."""
is_continuous, gaps = check_continuity(valid_ohlcv_data, Timeframe.M1)
assert is_continuous
assert len(gaps) == 0
def test_check_continuity_with_gaps():
"""Test continuity check with gaps."""
# Create data with gaps
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
# Remove some dates to create gaps
dates = dates[[0, 1, 2, 5, 6, 7, 8, 9]] # Gap between index 2 and 5
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * len(dates),
"high": [100.5] * len(dates),
"low": [99.5] * len(dates),
"close": [100.2] * len(dates),
}
)
is_continuous, gaps = check_continuity(df, Timeframe.M1)
assert not is_continuous
assert len(gaps) > 0
def test_detect_outliers(valid_ohlcv_data):
"""Test outlier detection."""
# Add an outlier
df = valid_ohlcv_data.copy()
df.loc[50, "close"] = 200.0 # Extreme value
outliers = detect_outliers(df, columns=["close"], method="iqr", threshold=3.0)
assert outliers["is_outlier"].sum() > 0