"""Tests for data preprocessors.""" import numpy as np import pandas as pd import pytest from src.data.preprocessors import filter_session, handle_missing_data, remove_duplicates @pytest.fixture def sample_data_with_missing(): """Create sample DataFrame with missing values.""" dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min") df = pd.DataFrame( { "timestamp": dates, "open": [100.0] * 10, "high": [100.5] * 10, "low": [99.5] * 10, "close": [100.2] * 10, } ) # Add some missing values df.loc[2, "close"] = np.nan df.loc[5, "open"] = np.nan return df @pytest.fixture def sample_data_with_duplicates(): """Create sample DataFrame with duplicates.""" dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min") df = pd.DataFrame( { "timestamp": dates, "open": [100.0] * 10, "high": [100.5] * 10, "low": [99.5] * 10, "close": [100.2] * 10, } ) # Add duplicate df = pd.concat([df, df.iloc[[0]]], ignore_index=True) return df def test_handle_missing_data_forward_fill(sample_data_with_missing): """Test forward fill missing data.""" df = handle_missing_data(sample_data_with_missing, method="forward_fill") assert df["close"].isna().sum() == 0 assert df["open"].isna().sum() == 0 def test_handle_missing_data_drop(sample_data_with_missing): """Test drop missing data.""" df = handle_missing_data(sample_data_with_missing, method="drop") assert df["close"].isna().sum() == 0 assert df["open"].isna().sum() == 0 assert len(df) < len(sample_data_with_missing) def test_remove_duplicates(sample_data_with_duplicates): """Test duplicate removal.""" df = remove_duplicates(sample_data_with_duplicates) assert len(df) == 10 # Should remove duplicate def test_filter_session(): """Test session filtering.""" import pytz # type: ignore[import-untyped] # Create data spanning multiple hours explicitly in EST # Start at 2 AM EST and go for 2 hours (02:00-04:00) est = pytz.timezone("America/New_York") start_time = est.localize(pd.Timestamp("2024-01-01 02:00:00")) dates = pd.date_range(start=start_time, periods=120, freq="1min") df = pd.DataFrame( { "timestamp": dates, "open": [100.0] * 120, "high": [100.5] * 120, "low": [99.5] * 120, "close": [100.2] * 120, } ) # Filter to 3-4 AM EST - should get rows from minute 60-120 (60 rows) df_filtered = filter_session( df, session_start="03:00", session_end="04:00", timezone="America/New_York" ) # Should have approximately 60 rows (1 hour of 1-minute data) assert len(df_filtered) > 0, f"Expected filtered data but got {len(df_filtered)} rows" assert len(df_filtered) <= 61 # Inclusive endpoints