Files
dax-ml/tests/unit/test_data/test_preprocessors.py
2026-01-05 11:34:18 +02:00

88 lines
2.5 KiB
Python

"""Tests for data preprocessors."""
import numpy as np
import pandas as pd
import pytest
from src.data.preprocessors import filter_session, handle_missing_data, remove_duplicates
@pytest.fixture
def sample_data_with_missing():
"""Create sample DataFrame with missing values."""
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 10,
"high": [100.5] * 10,
"low": [99.5] * 10,
"close": [100.2] * 10,
}
)
# Add some missing values
df.loc[2, "close"] = np.nan
df.loc[5, "open"] = np.nan
return df
@pytest.fixture
def sample_data_with_duplicates():
"""Create sample DataFrame with duplicates."""
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 10,
"high": [100.5] * 10,
"low": [99.5] * 10,
"close": [100.2] * 10,
}
)
# Add duplicate
df = pd.concat([df, df.iloc[[0]]], ignore_index=True)
return df
def test_handle_missing_data_forward_fill(sample_data_with_missing):
"""Test forward fill missing data."""
df = handle_missing_data(sample_data_with_missing, method="forward_fill")
assert df["close"].isna().sum() == 0
assert df["open"].isna().sum() == 0
def test_handle_missing_data_drop(sample_data_with_missing):
"""Test drop missing data."""
df = handle_missing_data(sample_data_with_missing, method="drop")
assert df["close"].isna().sum() == 0
assert df["open"].isna().sum() == 0
assert len(df) < len(sample_data_with_missing)
def test_remove_duplicates(sample_data_with_duplicates):
"""Test duplicate removal."""
df = remove_duplicates(sample_data_with_duplicates)
assert len(df) == 10 # Should remove duplicate
def test_filter_session():
"""Test session filtering."""
# Create data spanning multiple hours
dates = pd.date_range("2024-01-01 02:00", periods=120, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 120,
"high": [100.5] * 120,
"low": [99.5] * 120,
"close": [100.2] * 120,
}
)
# Filter to 3-4 AM EST
df_filtered = filter_session(df, session_start="03:00", session_end="04:00")
# Should have approximately 60 rows (1 hour of 1-minute data)
assert len(df_filtered) > 0
assert len(df_filtered) <= 60