96 lines
2.9 KiB
Python
96 lines
2.9 KiB
Python
"""Tests for data preprocessors."""
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pytest
|
|
|
|
from src.data.preprocessors import filter_session, handle_missing_data, remove_duplicates
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_data_with_missing():
|
|
"""Create sample DataFrame with missing values."""
|
|
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
|
|
df = pd.DataFrame(
|
|
{
|
|
"timestamp": dates,
|
|
"open": [100.0] * 10,
|
|
"high": [100.5] * 10,
|
|
"low": [99.5] * 10,
|
|
"close": [100.2] * 10,
|
|
}
|
|
)
|
|
# Add some missing values
|
|
df.loc[2, "close"] = np.nan
|
|
df.loc[5, "open"] = np.nan
|
|
return df
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_data_with_duplicates():
|
|
"""Create sample DataFrame with duplicates."""
|
|
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
|
|
df = pd.DataFrame(
|
|
{
|
|
"timestamp": dates,
|
|
"open": [100.0] * 10,
|
|
"high": [100.5] * 10,
|
|
"low": [99.5] * 10,
|
|
"close": [100.2] * 10,
|
|
}
|
|
)
|
|
# Add duplicate
|
|
df = pd.concat([df, df.iloc[[0]]], ignore_index=True)
|
|
return df
|
|
|
|
|
|
def test_handle_missing_data_forward_fill(sample_data_with_missing):
|
|
"""Test forward fill missing data."""
|
|
df = handle_missing_data(sample_data_with_missing, method="forward_fill")
|
|
assert df["close"].isna().sum() == 0
|
|
assert df["open"].isna().sum() == 0
|
|
|
|
|
|
def test_handle_missing_data_drop(sample_data_with_missing):
|
|
"""Test drop missing data."""
|
|
df = handle_missing_data(sample_data_with_missing, method="drop")
|
|
assert df["close"].isna().sum() == 0
|
|
assert df["open"].isna().sum() == 0
|
|
assert len(df) < len(sample_data_with_missing)
|
|
|
|
|
|
def test_remove_duplicates(sample_data_with_duplicates):
|
|
"""Test duplicate removal."""
|
|
df = remove_duplicates(sample_data_with_duplicates)
|
|
assert len(df) == 10 # Should remove duplicate
|
|
|
|
|
|
def test_filter_session():
|
|
"""Test session filtering."""
|
|
import pytz # type: ignore[import-untyped]
|
|
|
|
# Create data spanning multiple hours explicitly in EST
|
|
# Start at 2 AM EST and go for 2 hours (02:00-04:00)
|
|
est = pytz.timezone("America/New_York")
|
|
start_time = est.localize(pd.Timestamp("2024-01-01 02:00:00"))
|
|
dates = pd.date_range(start=start_time, periods=120, freq="1min")
|
|
|
|
df = pd.DataFrame(
|
|
{
|
|
"timestamp": dates,
|
|
"open": [100.0] * 120,
|
|
"high": [100.5] * 120,
|
|
"low": [99.5] * 120,
|
|
"close": [100.2] * 120,
|
|
}
|
|
)
|
|
|
|
# Filter to 3-4 AM EST - should get rows from minute 60-120 (60 rows)
|
|
df_filtered = filter_session(
|
|
df, session_start="03:00", session_end="04:00", timezone="America/New_York"
|
|
)
|
|
|
|
# Should have approximately 60 rows (1 hour of 1-minute data)
|
|
assert len(df_filtered) > 0, f"Expected filtered data but got {len(df_filtered)} rows"
|
|
assert len(df_filtered) <= 61 # Inclusive endpoints
|