Files
dax-ml/tests/unit/test_data/test_loaders.py
2026-01-05 11:34:18 +02:00

84 lines
2.2 KiB
Python

"""Tests for data loaders."""
import pandas as pd
import pytest
from src.core.enums import Timeframe
from src.data.loaders import CSVLoader, ParquetLoader
@pytest.fixture
def sample_ohlcv_data():
"""Create sample OHLCV DataFrame."""
dates = pd.date_range("2024-01-01 03:00", periods=100, freq="1min")
return pd.DataFrame(
{
"timestamp": dates,
"open": [100.0 + i * 0.1 for i in range(100)],
"high": [100.5 + i * 0.1 for i in range(100)],
"low": [99.5 + i * 0.1 for i in range(100)],
"close": [100.2 + i * 0.1 for i in range(100)],
"volume": [1000] * 100,
}
)
@pytest.fixture
def csv_file(sample_ohlcv_data, tmp_path):
"""Create temporary CSV file."""
csv_path = tmp_path / "test_data.csv"
sample_ohlcv_data.to_csv(csv_path, index=False)
return csv_path
@pytest.fixture
def parquet_file(sample_ohlcv_data, tmp_path):
"""Create temporary Parquet file."""
parquet_path = tmp_path / "test_data.parquet"
sample_ohlcv_data.to_parquet(parquet_path, index=False)
return parquet_path
def test_csv_loader(csv_file):
"""Test CSV loader."""
loader = CSVLoader()
df = loader.load(str(csv_file), symbol="DAX", timeframe=Timeframe.M1)
assert len(df) == 100
assert "symbol" in df.columns
assert "timeframe" in df.columns
assert df["symbol"].iloc[0] == "DAX"
assert df["timeframe"].iloc[0] == "1min"
def test_csv_loader_missing_file():
"""Test CSV loader with missing file."""
loader = CSVLoader()
with pytest.raises(Exception): # Should raise DataError
loader.load("nonexistent.csv")
def test_parquet_loader(parquet_file):
"""Test Parquet loader."""
loader = ParquetLoader()
df = loader.load(str(parquet_file), symbol="DAX", timeframe=Timeframe.M1)
assert len(df) == 100
assert "symbol" in df.columns
assert "timeframe" in df.columns
def test_load_and_preprocess(csv_file):
"""Test load_and_preprocess function."""
from src.data.loaders import load_and_preprocess
df = load_and_preprocess(
str(csv_file),
loader_type="csv",
validate=True,
preprocess=True,
)
assert len(df) == 100
assert "timestamp" in df.columns