feat(v0.2.0): data pipeline
This commit is contained in:
1
tests/unit/test_data/__init__.py
Normal file
1
tests/unit/test_data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Unit tests for data module."""
|
||||
65
tests/unit/test_data/test_database.py
Normal file
65
tests/unit/test_data/test_database.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Tests for database connection and session management."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from src.data.database import get_db_session, get_engine, init_database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db():
|
||||
"""Create temporary database for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
# Set environment variable
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
|
||||
|
||||
yield db_path
|
||||
|
||||
# Cleanup
|
||||
if os.path.exists(db_path):
|
||||
os.unlink(db_path)
|
||||
if "DATABASE_URL" in os.environ:
|
||||
del os.environ["DATABASE_URL"]
|
||||
|
||||
|
||||
def test_get_engine(temp_db):
|
||||
"""Test engine creation."""
|
||||
engine = get_engine()
|
||||
assert engine is not None
|
||||
assert str(engine.url).startswith("sqlite")
|
||||
|
||||
|
||||
def test_init_database(temp_db):
|
||||
"""Test database initialization."""
|
||||
init_database(create_tables=True)
|
||||
assert os.path.exists(temp_db)
|
||||
|
||||
|
||||
def test_get_db_session(temp_db):
|
||||
"""Test database session context manager."""
|
||||
init_database(create_tables=True)
|
||||
|
||||
with get_db_session() as session:
|
||||
assert session is not None
|
||||
# Session should be usable
|
||||
result = session.execute("SELECT 1").scalar()
|
||||
assert result == 1
|
||||
|
||||
|
||||
def test_session_rollback_on_error(temp_db):
|
||||
"""Test that session rolls back on error."""
|
||||
init_database(create_tables=True)
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# Cause an error
|
||||
session.execute("SELECT * FROM nonexistent_table")
|
||||
except Exception:
|
||||
pass # Expected
|
||||
|
||||
# Session should have been rolled back and closed
|
||||
assert True # If we get here, rollback worked
|
||||
83
tests/unit/test_data/test_loaders.py
Normal file
83
tests/unit/test_data/test_loaders.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Tests for data loaders."""
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from src.core.enums import Timeframe
|
||||
from src.data.loaders import CSVLoader, ParquetLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ohlcv_data():
|
||||
"""Create sample OHLCV DataFrame."""
|
||||
dates = pd.date_range("2024-01-01 03:00", periods=100, freq="1min")
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"timestamp": dates,
|
||||
"open": [100.0 + i * 0.1 for i in range(100)],
|
||||
"high": [100.5 + i * 0.1 for i in range(100)],
|
||||
"low": [99.5 + i * 0.1 for i in range(100)],
|
||||
"close": [100.2 + i * 0.1 for i in range(100)],
|
||||
"volume": [1000] * 100,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def csv_file(sample_ohlcv_data, tmp_path):
|
||||
"""Create temporary CSV file."""
|
||||
csv_path = tmp_path / "test_data.csv"
|
||||
sample_ohlcv_data.to_csv(csv_path, index=False)
|
||||
return csv_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parquet_file(sample_ohlcv_data, tmp_path):
|
||||
"""Create temporary Parquet file."""
|
||||
parquet_path = tmp_path / "test_data.parquet"
|
||||
sample_ohlcv_data.to_parquet(parquet_path, index=False)
|
||||
return parquet_path
|
||||
|
||||
|
||||
def test_csv_loader(csv_file):
|
||||
"""Test CSV loader."""
|
||||
loader = CSVLoader()
|
||||
df = loader.load(str(csv_file), symbol="DAX", timeframe=Timeframe.M1)
|
||||
|
||||
assert len(df) == 100
|
||||
assert "symbol" in df.columns
|
||||
assert "timeframe" in df.columns
|
||||
assert df["symbol"].iloc[0] == "DAX"
|
||||
assert df["timeframe"].iloc[0] == "1min"
|
||||
|
||||
|
||||
def test_csv_loader_missing_file():
|
||||
"""Test CSV loader with missing file."""
|
||||
loader = CSVLoader()
|
||||
with pytest.raises(Exception): # Should raise DataError
|
||||
loader.load("nonexistent.csv")
|
||||
|
||||
|
||||
def test_parquet_loader(parquet_file):
|
||||
"""Test Parquet loader."""
|
||||
loader = ParquetLoader()
|
||||
df = loader.load(str(parquet_file), symbol="DAX", timeframe=Timeframe.M1)
|
||||
|
||||
assert len(df) == 100
|
||||
assert "symbol" in df.columns
|
||||
assert "timeframe" in df.columns
|
||||
|
||||
|
||||
def test_load_and_preprocess(csv_file):
|
||||
"""Test load_and_preprocess function."""
|
||||
from src.data.loaders import load_and_preprocess
|
||||
|
||||
df = load_and_preprocess(
|
||||
str(csv_file),
|
||||
loader_type="csv",
|
||||
validate=True,
|
||||
preprocess=True,
|
||||
)
|
||||
|
||||
assert len(df) == 100
|
||||
assert "timestamp" in df.columns
|
||||
87
tests/unit/test_data/test_preprocessors.py
Normal file
87
tests/unit/test_data/test_preprocessors.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Tests for data preprocessors."""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from src.data.preprocessors import filter_session, handle_missing_data, remove_duplicates
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data_with_missing():
|
||||
"""Create sample DataFrame with missing values."""
|
||||
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"timestamp": dates,
|
||||
"open": [100.0] * 10,
|
||||
"high": [100.5] * 10,
|
||||
"low": [99.5] * 10,
|
||||
"close": [100.2] * 10,
|
||||
}
|
||||
)
|
||||
# Add some missing values
|
||||
df.loc[2, "close"] = np.nan
|
||||
df.loc[5, "open"] = np.nan
|
||||
return df
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data_with_duplicates():
|
||||
"""Create sample DataFrame with duplicates."""
|
||||
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"timestamp": dates,
|
||||
"open": [100.0] * 10,
|
||||
"high": [100.5] * 10,
|
||||
"low": [99.5] * 10,
|
||||
"close": [100.2] * 10,
|
||||
}
|
||||
)
|
||||
# Add duplicate
|
||||
df = pd.concat([df, df.iloc[[0]]], ignore_index=True)
|
||||
return df
|
||||
|
||||
|
||||
def test_handle_missing_data_forward_fill(sample_data_with_missing):
|
||||
"""Test forward fill missing data."""
|
||||
df = handle_missing_data(sample_data_with_missing, method="forward_fill")
|
||||
assert df["close"].isna().sum() == 0
|
||||
assert df["open"].isna().sum() == 0
|
||||
|
||||
|
||||
def test_handle_missing_data_drop(sample_data_with_missing):
|
||||
"""Test drop missing data."""
|
||||
df = handle_missing_data(sample_data_with_missing, method="drop")
|
||||
assert df["close"].isna().sum() == 0
|
||||
assert df["open"].isna().sum() == 0
|
||||
assert len(df) < len(sample_data_with_missing)
|
||||
|
||||
|
||||
def test_remove_duplicates(sample_data_with_duplicates):
|
||||
"""Test duplicate removal."""
|
||||
df = remove_duplicates(sample_data_with_duplicates)
|
||||
assert len(df) == 10 # Should remove duplicate
|
||||
|
||||
|
||||
def test_filter_session():
|
||||
"""Test session filtering."""
|
||||
# Create data spanning multiple hours
|
||||
dates = pd.date_range("2024-01-01 02:00", periods=120, freq="1min")
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"timestamp": dates,
|
||||
"open": [100.0] * 120,
|
||||
"high": [100.5] * 120,
|
||||
"low": [99.5] * 120,
|
||||
"close": [100.2] * 120,
|
||||
}
|
||||
)
|
||||
|
||||
# Filter to 3-4 AM EST
|
||||
df_filtered = filter_session(df, session_start="03:00", session_end="04:00")
|
||||
|
||||
# Should have approximately 60 rows (1 hour of 1-minute data)
|
||||
assert len(df_filtered) > 0
|
||||
assert len(df_filtered) <= 60
|
||||
97
tests/unit/test_data/test_validators.py
Normal file
97
tests/unit/test_data/test_validators.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Tests for data validators."""
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from src.core.enums import Timeframe
|
||||
from src.data.validators import check_continuity, detect_outliers, validate_ohlcv
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_ohlcv_data():
|
||||
"""Create valid OHLCV DataFrame."""
|
||||
dates = pd.date_range("2024-01-01 03:00", periods=100, freq="1min")
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"timestamp": dates,
|
||||
"open": [100.0 + i * 0.1 for i in range(100)],
|
||||
"high": [100.5 + i * 0.1 for i in range(100)],
|
||||
"low": [99.5 + i * 0.1 for i in range(100)],
|
||||
"close": [100.2 + i * 0.1 for i in range(100)],
|
||||
"volume": [1000] * 100,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_ohlcv_data():
|
||||
"""Create invalid OHLCV DataFrame."""
|
||||
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"timestamp": dates,
|
||||
"open": [100.0] * 10,
|
||||
"high": [99.0] * 10, # Invalid: high < low
|
||||
"low": [99.5] * 10,
|
||||
"close": [100.2] * 10,
|
||||
}
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
def test_validate_ohlcv_valid(valid_ohlcv_data):
|
||||
"""Test validation with valid data."""
|
||||
df = validate_ohlcv(valid_ohlcv_data)
|
||||
assert len(df) == 100
|
||||
|
||||
|
||||
def test_validate_ohlcv_invalid(invalid_ohlcv_data):
|
||||
"""Test validation with invalid data."""
|
||||
with pytest.raises(Exception): # Should raise ValidationError
|
||||
validate_ohlcv(invalid_ohlcv_data)
|
||||
|
||||
|
||||
def test_validate_ohlcv_missing_columns():
|
||||
"""Test validation with missing columns."""
|
||||
df = pd.DataFrame({"timestamp": pd.date_range("2024-01-01", periods=10)})
|
||||
with pytest.raises(Exception): # Should raise ValidationError
|
||||
validate_ohlcv(df)
|
||||
|
||||
|
||||
def test_check_continuity(valid_ohlcv_data):
|
||||
"""Test continuity check."""
|
||||
is_continuous, gaps = check_continuity(valid_ohlcv_data, Timeframe.M1)
|
||||
assert is_continuous
|
||||
assert len(gaps) == 0
|
||||
|
||||
|
||||
def test_check_continuity_with_gaps():
|
||||
"""Test continuity check with gaps."""
|
||||
# Create data with gaps
|
||||
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
|
||||
# Remove some dates to create gaps
|
||||
dates = dates[[0, 1, 2, 5, 6, 7, 8, 9]] # Gap between index 2 and 5
|
||||
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"timestamp": dates,
|
||||
"open": [100.0] * len(dates),
|
||||
"high": [100.5] * len(dates),
|
||||
"low": [99.5] * len(dates),
|
||||
"close": [100.2] * len(dates),
|
||||
}
|
||||
)
|
||||
|
||||
is_continuous, gaps = check_continuity(df, Timeframe.M1)
|
||||
assert not is_continuous
|
||||
assert len(gaps) > 0
|
||||
|
||||
|
||||
def test_detect_outliers(valid_ohlcv_data):
|
||||
"""Test outlier detection."""
|
||||
# Add an outlier
|
||||
df = valid_ohlcv_data.copy()
|
||||
df.loc[50, "close"] = 200.0 # Extreme value
|
||||
|
||||
outliers = detect_outliers(df, columns=["close"], method="iqr", threshold=3.0)
|
||||
assert outliers["is_outlier"].sum() > 0
|
||||
Reference in New Issue
Block a user