feat(v0.2.0): data pipeline

This commit is contained in:
0x_n3m0_
2026-01-05 11:34:18 +02:00
parent 2527938680
commit b5e7043df6
23 changed files with 2813 additions and 7 deletions

BIN
data/ict_trading.db Normal file

Binary file not shown.

View File

@@ -17,7 +17,7 @@ colorlog>=6.7.0 # Optional, for colored console output
# Data processing # Data processing
pyarrow>=12.0.0 # For Parquet support pyarrow>=12.0.0 # For Parquet support
pytz>=2023.3 # Timezone support
# Utilities # Utilities
click>=8.1.0 # CLI framework click>=8.1.0 # CLI framework

183
scripts/download_data.py Executable file
View File

@@ -0,0 +1,183 @@
#!/usr/bin/env python3
"""Download DAX OHLCV data from external sources."""
import argparse
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.core.enums import Timeframe # noqa: E402
from src.logging import get_logger # noqa: E402
logger = get_logger(__name__)
def download_from_csv(
input_file: str,
symbol: str,
timeframe: Timeframe,
output_dir: Path,
) -> None:
"""
Copy/convert CSV file to standard format.
Args:
input_file: Path to input CSV file
symbol: Trading symbol
timeframe: Timeframe enum
output_dir: Output directory
"""
from src.data.loaders import CSVLoader
loader = CSVLoader()
df = loader.load(input_file, symbol=symbol, timeframe=timeframe)
# Ensure output directory exists
output_dir.mkdir(parents=True, exist_ok=True)
# Save as CSV
output_file = output_dir / f"{symbol}_{timeframe.value}.csv"
df.to_csv(output_file, index=False)
logger.info(f"Saved {len(df)} rows to {output_file}")
# Also save as Parquet for faster loading
output_parquet = output_dir / f"{symbol}_{timeframe.value}.parquet"
df.to_parquet(output_parquet, index=False)
logger.info(f"Saved {len(df)} rows to {output_parquet}")
def download_from_api(
symbol: str,
timeframe: Timeframe,
start_date: str,
end_date: str,
output_dir: Path,
api_provider: str = "manual",
) -> None:
"""
Download data from API (placeholder for future implementation).
Args:
symbol: Trading symbol
timeframe: Timeframe enum
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
output_dir: Output directory
api_provider: API provider name
"""
logger.warning(
"API download not yet implemented. " "Please provide CSV file using --input-file option."
)
logger.info(
f"Would download {symbol} {timeframe.value} data " f"from {start_date} to {end_date}"
)
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Download DAX OHLCV data",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Download from CSV file
python scripts/download_data.py --input-file data.csv \\
--symbol DAX --timeframe 1min \\
--output data/raw/ohlcv/1min/
# Download from API (when implemented)
python scripts/download_data.py --symbol DAX --timeframe 5min \\
--start 2024-01-01 --end 2024-01-31 \\
--output data/raw/ohlcv/5min/
""",
)
# Input options
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument(
"--input-file",
type=str,
help="Path to input CSV file",
)
input_group.add_argument(
"--api",
action="store_true",
help="Download from API (not yet implemented)",
)
# Required arguments
parser.add_argument(
"--symbol",
type=str,
default="DAX",
help="Trading symbol (default: DAX)",
)
parser.add_argument(
"--timeframe",
type=str,
choices=["1min", "5min", "15min"],
required=True,
help="Timeframe",
)
parser.add_argument(
"--output",
type=str,
required=True,
help="Output directory",
)
# Optional arguments for API download
parser.add_argument(
"--start",
type=str,
help="Start date (YYYY-MM-DD) for API download",
)
parser.add_argument(
"--end",
type=str,
help="End date (YYYY-MM-DD) for API download",
)
args = parser.parse_args()
try:
# Convert timeframe string to enum
timeframe_map = {
"1min": Timeframe.M1,
"5min": Timeframe.M5,
"15min": Timeframe.M15,
}
timeframe = timeframe_map[args.timeframe]
# Create output directory
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
# Download data
if args.input_file:
logger.info(f"Downloading from CSV: {args.input_file}")
download_from_csv(args.input_file, args.symbol, timeframe, output_dir)
elif args.api:
if not args.start or not args.end:
parser.error("--start and --end are required for API download")
download_from_api(
args.symbol,
timeframe,
args.start,
args.end,
output_dir,
)
logger.info("Data download completed successfully")
return 0
except Exception as e:
logger.error(f"Data download failed: {e}", exc_info=True)
return 1
if __name__ == "__main__":
sys.exit(main())

269
scripts/process_data.py Executable file
View File

@@ -0,0 +1,269 @@
#!/usr/bin/env python3
"""Batch process OHLCV data: clean, filter, and save."""
import argparse
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.core.enums import Timeframe # noqa: E402
from src.data.database import get_db_session # noqa: E402
from src.data.loaders import load_and_preprocess # noqa: E402
from src.data.models import OHLCVData # noqa: E402
from src.data.repositories import OHLCVRepository # noqa: E402
from src.logging import get_logger # noqa: E402
logger = get_logger(__name__)
def process_file(
input_file: Path,
symbol: str,
timeframe: Timeframe,
output_dir: Path,
save_to_db: bool = False,
filter_session_hours: bool = True,
) -> None:
"""
Process a single data file.
Args:
input_file: Path to input file
symbol: Trading symbol
timeframe: Timeframe enum
output_dir: Output directory
save_to_db: Whether to save to database
filter_session_hours: Whether to filter to trading session (3-4 AM EST)
"""
logger.info(f"Processing file: {input_file}")
# Load and preprocess
df = load_and_preprocess(
str(input_file),
loader_type="auto",
validate=True,
preprocess=True,
filter_to_session=filter_session_hours,
)
# Ensure symbol and timeframe columns
df["symbol"] = symbol
df["timeframe"] = timeframe.value
# Save processed CSV
output_dir.mkdir(parents=True, exist_ok=True)
output_csv = output_dir / f"{symbol}_{timeframe.value}_processed.csv"
df.to_csv(output_csv, index=False)
logger.info(f"Saved processed CSV: {output_csv} ({len(df)} rows)")
# Save processed Parquet
output_parquet = output_dir / f"{symbol}_{timeframe.value}_processed.parquet"
df.to_parquet(output_parquet, index=False)
logger.info(f"Saved processed Parquet: {output_parquet} ({len(df)} rows)")
# Save to database if requested
if save_to_db:
logger.info("Saving to database...")
with get_db_session() as session:
repo = OHLCVRepository(session=session)
# Convert DataFrame to OHLCVData models
records = []
for _, row in df.iterrows():
# Check if record already exists
if repo.exists(symbol, timeframe, row["timestamp"]):
continue
record = OHLCVData(
symbol=symbol,
timeframe=timeframe,
timestamp=row["timestamp"],
open=row["open"],
high=row["high"],
low=row["low"],
close=row["close"],
volume=row.get("volume"),
)
records.append(record)
if records:
repo.create_batch(records)
logger.info(f"Saved {len(records)} records to database")
else:
logger.info("No new records to save (all already exist)")
def process_directory(
input_dir: Path,
output_dir: Path,
symbol: str = "DAX",
save_to_db: bool = False,
filter_session_hours: bool = True,
) -> None:
"""
Process all data files in a directory.
Args:
input_dir: Input directory
output_dir: Output directory
symbol: Trading symbol
save_to_db: Whether to save to database
filter_session_hours: Whether to filter to trading session
"""
# Find all CSV and Parquet files
files = list(input_dir.glob("*.csv")) + list(input_dir.glob("*.parquet"))
if not files:
logger.warning(f"No data files found in {input_dir}")
return
# Detect timeframe from directory name or file
timeframe_map = {
"1min": Timeframe.M1,
"5min": Timeframe.M5,
"15min": Timeframe.M15,
}
timeframe = None
for tf_name, tf_enum in timeframe_map.items():
if tf_name in str(input_dir):
timeframe = tf_enum
break
if timeframe is None:
logger.error(f"Could not determine timeframe from directory: {input_dir}")
return
logger.info(f"Processing {len(files)} files from {input_dir}")
for file_path in files:
try:
process_file(
file_path,
symbol,
timeframe,
output_dir,
save_to_db,
filter_session_hours,
)
except Exception as e:
logger.error(f"Failed to process {file_path}: {e}", exc_info=True)
continue
logger.info("Batch processing completed")
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Batch process OHLCV data",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Process single file
python scripts/process_data.py --input data/raw/ohlcv/1min/m1.csv \\
--output data/processed/ --symbol DAX --timeframe 1min
# Process directory
python scripts/process_data.py --input data/raw/ohlcv/1min/ \\
--output data/processed/ --symbol DAX
# Process and save to database
python scripts/process_data.py --input data/raw/ohlcv/1min/ \\
--output data/processed/ --save-db
""",
)
parser.add_argument(
"--input",
type=str,
required=True,
help="Input file or directory",
)
parser.add_argument(
"--output",
type=str,
required=True,
help="Output directory",
)
parser.add_argument(
"--symbol",
type=str,
default="DAX",
help="Trading symbol (default: DAX)",
)
parser.add_argument(
"--timeframe",
type=str,
choices=["1min", "5min", "15min"],
help="Timeframe (required if processing single file)",
)
parser.add_argument(
"--save-db",
action="store_true",
help="Save processed data to database",
)
parser.add_argument(
"--no-session-filter",
action="store_true",
help="Don't filter to trading session hours (3-4 AM EST)",
)
args = parser.parse_args()
try:
input_path = Path(args.input)
output_dir = Path(args.output)
if not input_path.exists():
logger.error(f"Input path does not exist: {input_path}")
return 1
# Process single file or directory
if input_path.is_file():
if not args.timeframe:
parser.error("--timeframe is required when processing a single file")
return 1
timeframe_map = {
"1min": Timeframe.M1,
"5min": Timeframe.M5,
"15min": Timeframe.M15,
}
timeframe = timeframe_map[args.timeframe]
process_file(
input_path,
args.symbol,
timeframe,
output_dir,
save_to_db=args.save_db,
filter_session_hours=not args.no_session_filter,
)
elif input_path.is_dir():
process_directory(
input_path,
output_dir,
symbol=args.symbol,
save_to_db=args.save_db,
filter_session_hours=not args.no_session_filter,
)
else:
logger.error(f"Input path is neither file nor directory: {input_path}")
return 1
logger.info("Data processing completed successfully")
return 0
except Exception as e:
logger.error(f"Data processing failed: {e}", exc_info=True)
return 1
if __name__ == "__main__":
sys.exit(main())

47
scripts/setup_database.py Executable file
View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python3
"""Initialize database and create tables."""
import argparse
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.data.database import init_database # noqa: E402
from src.logging import get_logger # noqa: E402
logger = get_logger(__name__)
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(description="Initialize database and create tables")
parser.add_argument(
"--skip-tables",
action="store_true",
help="Skip table creation (useful for testing connection only)",
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="Enable verbose logging",
)
args = parser.parse_args()
try:
logger.info("Initializing database...")
init_database(create_tables=not args.skip_tables)
logger.info("Database initialization completed successfully")
return 0
except Exception as e:
logger.error(f"Database initialization failed: {e}", exc_info=True)
return 1
if __name__ == "__main__":
sys.exit(main())

172
scripts/validate_data_pipeline.py Executable file
View File

@@ -0,0 +1,172 @@
#!/usr/bin/env python3
"""Validate data pipeline setup for v0.2.0."""
import sys
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.core.enums import Timeframe # noqa: E402
from src.data.database import get_engine, init_database # noqa: E402
from src.data.loaders import CSVLoader, ParquetLoader # noqa: E402
from src.data.preprocessors import ( # noqa: E402
filter_session,
handle_missing_data,
remove_duplicates,
)
from src.data.repositories import OHLCVRepository # noqa: E402
from src.data.validators import check_continuity, detect_outliers, validate_ohlcv # noqa: E402
from src.logging import get_logger # noqa: E402
logger = get_logger(__name__)
def validate_imports():
"""Validate that all data module imports work."""
print("✓ Data module imports successful")
def validate_database():
"""Validate database setup."""
try:
engine = get_engine()
assert engine is not None
print("✓ Database engine created")
# Test initialization (will create tables if needed)
init_database(create_tables=True)
print("✓ Database initialization successful")
except Exception as e:
print(f"✗ Database validation failed: {e}")
raise
def validate_loaders():
"""Validate data loaders."""
try:
csv_loader = CSVLoader()
parquet_loader = ParquetLoader()
assert csv_loader is not None
assert parquet_loader is not None
print("✓ Data loaders initialized")
except Exception as e:
print(f"✗ Loader validation failed: {e}")
raise
def validate_preprocessors():
"""Validate preprocessors."""
import pandas as pd
import pytz # type: ignore[import-untyped]
# Create sample data with EST timezone (trading session is 3-4 AM EST)
est = pytz.timezone("America/New_York")
timestamps = pd.date_range("2024-01-01 03:00", periods=10, freq="1min", tz=est)
df = pd.DataFrame(
{
"timestamp": timestamps,
"open": [100.0] * 10,
"high": [100.5] * 10,
"low": [99.5] * 10,
"close": [100.2] * 10,
}
)
# Test preprocessors
df_processed = handle_missing_data(df)
df_processed = remove_duplicates(df_processed)
df_filtered = filter_session(df_processed)
assert len(df_filtered) > 0
print("✓ Preprocessors working")
def validate_validators():
"""Validate validators."""
import pandas as pd
# Create valid data (timezone not required for validators)
df = pd.DataFrame(
{
"timestamp": pd.date_range("2024-01-01 03:00", periods=10, freq="1min"),
"open": [100.0] * 10,
"high": [100.5] * 10,
"low": [99.5] * 10,
"close": [100.2] * 10,
}
)
# Test validators
df_validated = validate_ohlcv(df)
is_continuous, gaps = check_continuity(df_validated, Timeframe.M1)
_ = detect_outliers(df_validated) # Check it runs without error
assert len(df_validated) == 10
print("✓ Validators working")
def validate_repositories():
"""Validate repositories."""
from src.data.database import get_db_session
try:
with get_db_session() as session:
repo = OHLCVRepository(session=session)
assert repo is not None
print("✓ Repositories working")
except Exception as e:
print(f"✗ Repository validation failed: {e}")
raise
def validate_directories():
"""Validate directory structure."""
required_dirs = [
"data/raw/ohlcv/1min",
"data/raw/ohlcv/5min",
"data/raw/ohlcv/15min",
"data/processed/features",
"data/processed/patterns",
"data/labels/individual_patterns",
]
for dir_name in required_dirs:
dir_path = Path(dir_name)
if not dir_path.exists():
print(f"✗ Missing directory: {dir_name}")
return False
print("✓ Directory structure valid")
return True
def main():
"""Run all validation checks."""
print("Validating ICT ML Trading System v0.2.0 Data Pipeline...")
print("-" * 60)
try:
validate_imports()
validate_database()
validate_loaders()
validate_preprocessors()
validate_validators()
validate_repositories()
validate_directories()
print("-" * 60)
print("✓ All validations passed!")
return 0
except Exception as e:
print(f"✗ Validation failed: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -81,7 +81,7 @@ def load_config(config_path: Optional[Path] = None) -> Dict[str, Any]:
_config = config _config = config
logger.info("Configuration loaded successfully") logger.info("Configuration loaded successfully")
return config return config # type: ignore[no-any-return]
except Exception as e: except Exception as e:
raise ConfigurationError( raise ConfigurationError(
@@ -150,4 +150,3 @@ def _substitute_env_vars(config: Any) -> Any:
return config return config
else: else:
return config return config

View File

@@ -1,7 +1,7 @@
"""Application-wide constants.""" """Application-wide constants."""
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Any, Dict, List
# Project root directory # Project root directory
PROJECT_ROOT = Path(__file__).parent.parent.parent PROJECT_ROOT = Path(__file__).parent.parent.parent
@@ -50,7 +50,7 @@ PATTERN_THRESHOLDS: Dict[str, float] = {
} }
# Model configuration # Model configuration
MODEL_CONFIG: Dict[str, any] = { MODEL_CONFIG: Dict[str, Any] = {
"min_labels_per_pattern": 200, "min_labels_per_pattern": 200,
"train_test_split": 0.8, "train_test_split": 0.8,
"validation_split": 0.1, "validation_split": 0.1,
@@ -70,9 +70,8 @@ LOG_LEVELS: List[str] = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
LOG_FORMATS: List[str] = ["json", "text"] LOG_FORMATS: List[str] = ["json", "text"]
# Database constants # Database constants
DB_CONSTANTS: Dict[str, any] = { DB_CONSTANTS: Dict[str, Any] = {
"pool_size": 10, "pool_size": 10,
"max_overflow": 20, "max_overflow": 20,
"pool_timeout": 30, "pool_timeout": 30,
} }

41
src/data/__init__.py Normal file
View File

@@ -0,0 +1,41 @@
"""Data management module for ICT ML Trading System."""
from src.data.database import get_engine, get_session, init_database
from src.data.loaders import CSVLoader, DatabaseLoader, ParquetLoader
from src.data.models import DetectedPattern, OHLCVData, PatternLabel, SetupLabel, Trade
from src.data.preprocessors import filter_session, handle_missing_data, remove_duplicates
from src.data.repositories import OHLCVRepository, PatternRepository, Repository
from src.data.schemas import OHLCVSchema, PatternSchema
from src.data.validators import check_continuity, detect_outliers, validate_ohlcv
__all__ = [
# Database
"get_engine",
"get_session",
"init_database",
# Models
"OHLCVData",
"DetectedPattern",
"PatternLabel",
"SetupLabel",
"Trade",
# Loaders
"CSVLoader",
"ParquetLoader",
"DatabaseLoader",
# Preprocessors
"handle_missing_data",
"remove_duplicates",
"filter_session",
# Validators
"validate_ohlcv",
"check_continuity",
"detect_outliers",
# Repositories
"Repository",
"OHLCVRepository",
"PatternRepository",
# Schemas
"OHLCVSchema",
"PatternSchema",
]

212
src/data/database.py Normal file
View File

@@ -0,0 +1,212 @@
"""Database connection and session management."""
import os
from contextlib import contextmanager
from typing import Generator, Optional
from sqlalchemy import create_engine, event
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, sessionmaker
from src.config import get_config
from src.core.constants import DB_CONSTANTS
from src.core.exceptions import ConfigurationError, DataError
from src.logging import get_logger
logger = get_logger(__name__)
# Global engine and session factory
_engine: Optional[Engine] = None
_SessionLocal: Optional[sessionmaker] = None
def get_database_url() -> str:
"""
Get database URL from config or environment variable.
Returns:
Database URL string
Raises:
ConfigurationError: If database URL cannot be determined
"""
try:
config = get_config()
db_config = config.get("database", {})
database_url = os.getenv("DATABASE_URL") or db_config.get("database_url")
if not database_url:
raise ConfigurationError(
"Database URL not found in configuration or environment variables",
context={"config": db_config},
)
# Handle SQLite path expansion
if database_url.startswith("sqlite:///"):
db_path = database_url.replace("sqlite:///", "")
if not os.path.isabs(db_path):
# Relative path - make it absolute from project root
from src.core.constants import PROJECT_ROOT
db_path = str(PROJECT_ROOT / db_path)
database_url = f"sqlite:///{db_path}"
db_display = database_url.split("@")[-1] if "@" in database_url else "sqlite"
logger.debug(f"Database URL configured: {db_display}")
return database_url
except Exception as e:
raise ConfigurationError(
f"Failed to get database URL: {e}",
context={"error": str(e)},
) from e
def get_engine() -> Engine:
"""
Get or create SQLAlchemy engine with connection pooling.
Returns:
SQLAlchemy engine instance
"""
global _engine
if _engine is not None:
return _engine
database_url = get_database_url()
db_config = get_config().get("database", {})
# Connection pool settings
pool_size = db_config.get("pool_size", DB_CONSTANTS["pool_size"])
max_overflow = db_config.get("max_overflow", DB_CONSTANTS["max_overflow"])
pool_timeout = db_config.get("pool_timeout", DB_CONSTANTS["pool_timeout"])
pool_recycle = db_config.get("pool_recycle", 3600)
# SQLite-specific settings
connect_args = {}
if database_url.startswith("sqlite"):
sqlite_config = db_config.get("sqlite", {})
connect_args = {
"check_same_thread": sqlite_config.get("check_same_thread", False),
"timeout": sqlite_config.get("timeout", 20),
}
# PostgreSQL-specific settings
elif database_url.startswith("postgresql"):
postgres_config = db_config.get("postgresql", {})
connect_args = postgres_config.get("connect_args", {})
try:
_engine = create_engine(
database_url,
pool_size=pool_size,
max_overflow=max_overflow,
pool_timeout=pool_timeout,
pool_recycle=pool_recycle,
connect_args=connect_args,
echo=db_config.get("echo", False),
echo_pool=db_config.get("echo_pool", False),
)
# Add connection event listeners
@event.listens_for(_engine, "connect")
def set_sqlite_pragma(dbapi_conn, connection_record):
"""Set SQLite pragmas for better performance."""
if database_url.startswith("sqlite"):
cursor = dbapi_conn.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
logger.info(f"Database engine created: pool_size={pool_size}, max_overflow={max_overflow}")
return _engine
except Exception as e:
raise DataError(
f"Failed to create database engine: {e}",
context={
"database_url": database_url.split("@")[-1] if "@" in database_url else "sqlite"
},
) from e
def get_session() -> sessionmaker:
"""
Get or create session factory.
Returns:
SQLAlchemy sessionmaker instance
"""
global _SessionLocal
if _SessionLocal is not None:
return _SessionLocal
engine = get_engine()
_SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False)
logger.debug("Session factory created")
return _SessionLocal
@contextmanager
def get_db_session() -> Generator[Session, None, None]:
"""
Context manager for database sessions.
Yields:
Database session
Example:
>>> with get_db_session() as session:
... data = session.query(OHLCVData).all()
"""
SessionLocal = get_session()
session = SessionLocal()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Database session error: {e}", exc_info=True)
raise DataError(f"Database operation failed: {e}") from e
finally:
session.close()
def init_database(create_tables: bool = True) -> None:
"""
Initialize database and create tables.
Args:
create_tables: Whether to create tables if they don't exist
Raises:
DataError: If database initialization fails
"""
try:
engine = get_engine()
database_url = get_database_url()
# Create data directory for SQLite if needed
if database_url.startswith("sqlite"):
db_path = database_url.replace("sqlite:///", "")
db_dir = os.path.dirname(db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
logger.info(f"Created database directory: {db_dir}")
if create_tables:
# Import models to register them with SQLAlchemy
from src.data.models import Base
Base.metadata.create_all(bind=engine)
logger.info("Database tables created successfully")
logger.info("Database initialized successfully")
except Exception as e:
raise DataError(
f"Failed to initialize database: {e}",
context={"create_tables": create_tables},
) from e

337
src/data/loaders.py Normal file
View File

@@ -0,0 +1,337 @@
"""Data loaders for various data sources."""
from pathlib import Path
from typing import Optional
import pandas as pd
from src.core.enums import Timeframe
from src.core.exceptions import DataError
from src.data.preprocessors import filter_session, handle_missing_data, remove_duplicates
from src.data.validators import validate_ohlcv
from src.logging import get_logger
logger = get_logger(__name__)
class BaseLoader:
"""Base class for data loaders."""
def load(self, source: str, **kwargs) -> pd.DataFrame:
"""
Load data from source.
Args:
source: Data source path/identifier
**kwargs: Additional loader-specific arguments
Returns:
DataFrame with loaded data
Raises:
DataError: If loading fails
"""
raise NotImplementedError("Subclasses must implement load()")
class CSVLoader(BaseLoader):
"""Loader for CSV files."""
def load( # type: ignore[override]
self,
file_path: str,
symbol: Optional[str] = None,
timeframe: Optional[Timeframe] = None,
**kwargs,
) -> pd.DataFrame:
"""
Load OHLCV data from CSV file.
Args:
file_path: Path to CSV file
symbol: Optional symbol to add to DataFrame
timeframe: Optional timeframe to add to DataFrame
**kwargs: Additional pandas.read_csv arguments
Returns:
DataFrame with OHLCV data
Raises:
DataError: If file cannot be loaded
"""
file_path_obj = Path(file_path)
if not file_path_obj.exists():
raise DataError(
f"CSV file not found: {file_path}",
context={"file_path": str(file_path)},
)
try:
# Default CSV reading options
read_kwargs = {
"parse_dates": ["timestamp"],
"index_col": False,
}
read_kwargs.update(kwargs)
df = pd.read_csv(file_path, **read_kwargs)
# Ensure timestamp column exists
if "timestamp" not in df.columns and "time" in df.columns:
df.rename(columns={"time": "timestamp"}, inplace=True)
# Add metadata if provided
if symbol:
df["symbol"] = symbol
if timeframe:
df["timeframe"] = timeframe.value
# Standardize column names (case-insensitive)
column_mapping = {
"open": "open",
"high": "high",
"low": "low",
"close": "close",
"volume": "volume",
}
for old_name, new_name in column_mapping.items():
if old_name.lower() in [col.lower() for col in df.columns]:
matching_col = [col for col in df.columns if col.lower() == old_name.lower()][0]
if matching_col != new_name:
df.rename(columns={matching_col: new_name}, inplace=True)
logger.info(f"Loaded {len(df)} rows from CSV: {file_path}")
return df
except Exception as e:
raise DataError(
f"Failed to load CSV file: {e}",
context={"file_path": str(file_path)},
) from e
class ParquetLoader(BaseLoader):
"""Loader for Parquet files."""
def load( # type: ignore[override]
self,
file_path: str,
symbol: Optional[str] = None,
timeframe: Optional[Timeframe] = None,
**kwargs,
) -> pd.DataFrame:
"""
Load OHLCV data from Parquet file.
Args:
file_path: Path to Parquet file
symbol: Optional symbol to add to DataFrame
timeframe: Optional timeframe to add to DataFrame
**kwargs: Additional pandas.read_parquet arguments
Returns:
DataFrame with OHLCV data
Raises:
DataError: If file cannot be loaded
"""
file_path_obj = Path(file_path)
if not file_path_obj.exists():
raise DataError(
f"Parquet file not found: {file_path}",
context={"file_path": str(file_path)},
)
try:
df = pd.read_parquet(file_path, **kwargs)
# Add metadata if provided
if symbol:
df["symbol"] = symbol
if timeframe:
df["timeframe"] = timeframe.value
logger.info(f"Loaded {len(df)} rows from Parquet: {file_path}")
return df
except Exception as e:
raise DataError(
f"Failed to load Parquet file: {e}",
context={"file_path": str(file_path)},
) from e
class DatabaseLoader(BaseLoader):
"""Loader for database data."""
def __init__(self, session=None):
"""
Initialize database loader.
Args:
session: Optional database session (creates new if not provided)
"""
self.session = session
def load( # type: ignore[override]
self,
symbol: str,
timeframe: Timeframe,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
limit: Optional[int] = None,
**kwargs,
) -> pd.DataFrame:
"""
Load OHLCV data from database.
Args:
symbol: Trading symbol
timeframe: Timeframe enum
start_date: Optional start date (ISO format or datetime string)
end_date: Optional end date (ISO format or datetime string)
limit: Optional limit on number of records
**kwargs: Additional query arguments
Returns:
DataFrame with OHLCV data
Raises:
DataError: If database query fails
"""
from src.data.database import get_db_session
from src.data.repositories import OHLCVRepository
try:
# Use provided session or create new one
if self.session:
repo = OHLCVRepository(session=self.session)
session_context = None
else:
session_context = get_db_session()
session = session_context.__enter__()
repo = OHLCVRepository(session=session)
# Parse dates
start = pd.to_datetime(start_date) if start_date else None
end = pd.to_datetime(end_date) if end_date else None
# Query database
if start and end:
records = repo.get_by_timestamp_range(symbol, timeframe, start, end, limit)
else:
records = repo.get_latest(symbol, timeframe, limit or 1000)
# Convert to DataFrame
data = []
for record in records:
data.append(
{
"id": record.id,
"symbol": record.symbol,
"timeframe": record.timeframe.value,
"timestamp": record.timestamp,
"open": float(record.open),
"high": float(record.high),
"low": float(record.low),
"close": float(record.close),
"volume": record.volume,
}
)
df = pd.DataFrame(data)
if session_context:
session_context.__exit__(None, None, None)
logger.info(
f"Loaded {len(df)} rows from database: {symbol} {timeframe.value} "
f"({start_date} to {end_date})"
)
return df
except Exception as e:
raise DataError(
f"Failed to load data from database: {e}",
context={
"symbol": symbol,
"timeframe": timeframe.value,
"start_date": start_date,
"end_date": end_date,
},
) from e
def load_and_preprocess(
source: str,
loader_type: str = "auto",
validate: bool = True,
preprocess: bool = True,
filter_to_session: bool = False,
**loader_kwargs,
) -> pd.DataFrame:
"""
Load data and optionally validate/preprocess it.
Args:
source: Data source (file path or database identifier)
loader_type: Loader type ('csv', 'parquet', 'database', 'auto')
validate: Whether to validate data
preprocess: Whether to preprocess data (handle missing, remove duplicates)
filter_to_session: Whether to filter to trading session hours
**loader_kwargs: Additional arguments for loader
Returns:
Processed DataFrame
Raises:
DataError: If loading or processing fails
"""
# Auto-detect loader type
if loader_type == "auto":
source_path = Path(source)
if source_path.exists():
if source_path.suffix.lower() == ".csv":
loader_type = "csv"
elif source_path.suffix.lower() == ".parquet":
loader_type = "parquet"
else:
raise DataError(
f"Cannot auto-detect loader type for: {source}",
context={"source": str(source)},
)
else:
loader_type = "database"
# Create appropriate loader
loader: BaseLoader
if loader_type == "csv":
loader = CSVLoader()
elif loader_type == "parquet":
loader = ParquetLoader()
elif loader_type == "database":
loader = DatabaseLoader()
else:
raise DataError(
f"Invalid loader type: {loader_type}",
context={"valid_types": ["csv", "parquet", "database", "auto"]},
)
# Load data
df = loader.load(source, **loader_kwargs)
# Validate
if validate:
df = validate_ohlcv(df)
# Preprocess
if preprocess:
df = handle_missing_data(df, method="forward_fill")
df = remove_duplicates(df)
# Filter to session
if filter_to_session:
df = filter_session(df)
logger.info(f"Loaded and processed {len(df)} rows from {source}")
return df

223
src/data/models.py Normal file
View File

@@ -0,0 +1,223 @@
"""SQLAlchemy ORM models for data storage."""
from datetime import datetime
from sqlalchemy import (
Boolean,
Column,
DateTime,
Enum,
Float,
ForeignKey,
Index,
Integer,
Numeric,
String,
Text,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from src.core.enums import (
Grade,
OrderType,
PatternDirection,
PatternType,
SetupType,
Timeframe,
TradeDirection,
TradeStatus,
)
Base = declarative_base()
class OHLCVData(Base): # type: ignore[valid-type,misc]
"""OHLCV market data table."""
__tablename__ = "ohlcv_data"
id = Column(Integer, primary_key=True, index=True)
symbol = Column(String(20), nullable=False, index=True)
timeframe = Column(Enum(Timeframe), nullable=False, index=True)
timestamp = Column(DateTime, nullable=False, index=True)
open = Column(Numeric(20, 5), nullable=False)
high = Column(Numeric(20, 5), nullable=False)
low = Column(Numeric(20, 5), nullable=False)
close = Column(Numeric(20, 5), nullable=False)
volume = Column(Integer, nullable=True)
# Metadata
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
patterns = relationship("DetectedPattern", back_populates="ohlcv_data")
# Composite index for common queries
__table_args__ = (Index("idx_symbol_timeframe_timestamp", "symbol", "timeframe", "timestamp"),)
def __repr__(self) -> str:
return (
f"<OHLCVData(id={self.id}, symbol={self.symbol}, "
f"timeframe={self.timeframe}, timestamp={self.timestamp})>"
)
class DetectedPattern(Base): # type: ignore[valid-type,misc]
"""Detected ICT patterns table."""
__tablename__ = "detected_patterns"
id = Column(Integer, primary_key=True, index=True)
pattern_type = Column(Enum(PatternType), nullable=False, index=True)
direction = Column(Enum(PatternDirection), nullable=False)
timeframe = Column(Enum(Timeframe), nullable=False, index=True)
symbol = Column(String(20), nullable=False, index=True)
# Pattern location
start_timestamp = Column(DateTime, nullable=False, index=True)
end_timestamp = Column(DateTime, nullable=False)
ohlcv_data_id = Column(Integer, ForeignKey("ohlcv_data.id"), nullable=True)
# Price levels
entry_level = Column(Numeric(20, 5), nullable=True)
stop_loss = Column(Numeric(20, 5), nullable=True)
take_profit = Column(Numeric(20, 5), nullable=True)
high_level = Column(Numeric(20, 5), nullable=True)
low_level = Column(Numeric(20, 5), nullable=True)
# Pattern metadata
size_pips = Column(Float, nullable=True)
strength_score = Column(Float, nullable=True)
context_data = Column(Text, nullable=True) # JSON string for additional context
# Metadata
detected_at = Column(DateTime, default=datetime.utcnow, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
# Relationships
ohlcv_data = relationship("OHLCVData", back_populates="patterns")
labels = relationship("PatternLabel", back_populates="pattern")
# Composite index
__table_args__ = (
Index("idx_pattern_type_symbol_timestamp", "pattern_type", "symbol", "start_timestamp"),
)
def __repr__(self) -> str:
return (
f"<DetectedPattern(id={self.id}, pattern_type={self.pattern_type}, "
f"direction={self.direction}, timestamp={self.start_timestamp})>"
)
class PatternLabel(Base): # type: ignore[valid-type,misc]
"""Labels for individual patterns."""
__tablename__ = "pattern_labels"
id = Column(Integer, primary_key=True, index=True)
pattern_id = Column(Integer, ForeignKey("detected_patterns.id"), nullable=False, index=True)
grade = Column(Enum(Grade), nullable=False, index=True)
notes = Column(Text, nullable=True)
# Labeler metadata
labeled_by = Column(String(100), nullable=True)
labeled_at = Column(DateTime, default=datetime.utcnow, nullable=False)
confidence = Column(Float, nullable=True) # Labeler's confidence (0-1)
# Quality checks
is_anchor = Column(Boolean, default=False, nullable=False, index=True)
reviewed = Column(Boolean, default=False, nullable=False)
# Relationships
pattern = relationship("DetectedPattern", back_populates="labels")
def __repr__(self) -> str:
return (
f"<PatternLabel(id={self.id}, pattern_id={self.pattern_id}, "
f"grade={self.grade}, labeled_at={self.labeled_at})>"
)
class SetupLabel(Base): # type: ignore[valid-type,misc]
"""Labels for complete trading setups."""
__tablename__ = "setup_labels"
id = Column(Integer, primary_key=True, index=True)
setup_type = Column(Enum(SetupType), nullable=False, index=True)
symbol = Column(String(20), nullable=False, index=True)
session_date = Column(DateTime, nullable=False, index=True)
# Setup components (pattern IDs)
fvg_id = Column(Integer, ForeignKey("detected_patterns.id"), nullable=True)
order_block_id = Column(Integer, ForeignKey("detected_patterns.id"), nullable=True)
liquidity_id = Column(Integer, ForeignKey("detected_patterns.id"), nullable=True)
# Label
grade = Column(Enum(Grade), nullable=False, index=True)
outcome = Column(String(50), nullable=True) # "win", "loss", "breakeven"
pnl = Column(Numeric(20, 2), nullable=True)
# Labeler metadata
labeled_by = Column(String(100), nullable=True)
labeled_at = Column(DateTime, default=datetime.utcnow, nullable=False)
notes = Column(Text, nullable=True)
def __repr__(self) -> str:
return (
f"<SetupLabel(id={self.id}, setup_type={self.setup_type}, "
f"session_date={self.session_date}, grade={self.grade})>"
)
class Trade(Base): # type: ignore[valid-type,misc]
"""Trade execution records."""
__tablename__ = "trades"
id = Column(Integer, primary_key=True, index=True)
symbol = Column(String(20), nullable=False, index=True)
direction = Column(Enum(TradeDirection), nullable=False)
order_type = Column(Enum(OrderType), nullable=False)
status = Column(Enum(TradeStatus), nullable=False, index=True)
# Entry
entry_price = Column(Numeric(20, 5), nullable=False)
entry_timestamp = Column(DateTime, nullable=False, index=True)
entry_size = Column(Integer, nullable=False)
# Exit
exit_price = Column(Numeric(20, 5), nullable=True)
exit_timestamp = Column(DateTime, nullable=True)
exit_size = Column(Integer, nullable=True)
# Risk management
stop_loss = Column(Numeric(20, 5), nullable=True)
take_profit = Column(Numeric(20, 5), nullable=True)
risk_amount = Column(Numeric(20, 2), nullable=True)
# P&L
pnl = Column(Numeric(20, 2), nullable=True)
pnl_pips = Column(Float, nullable=True)
commission = Column(Numeric(20, 2), nullable=True)
# Related patterns
pattern_id = Column(Integer, ForeignKey("detected_patterns.id"), nullable=True)
setup_id = Column(Integer, ForeignKey("setup_labels.id"), nullable=True)
# Metadata
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
notes = Column(Text, nullable=True)
# Composite index
__table_args__ = (Index("idx_symbol_status_timestamp", "symbol", "status", "entry_timestamp"),)
def __repr__(self) -> str:
return (
f"<Trade(id={self.id}, symbol={self.symbol}, direction={self.direction}, "
f"status={self.status}, entry_price={self.entry_price})>"
)

181
src/data/preprocessors.py Normal file
View File

@@ -0,0 +1,181 @@
"""Data preprocessing functions."""
from datetime import datetime
from typing import Optional
import pandas as pd
import pytz # type: ignore[import-untyped]
from src.core.constants import SESSION_TIMES
from src.core.exceptions import DataError
from src.logging import get_logger
logger = get_logger(__name__)
def handle_missing_data(
df: pd.DataFrame,
method: str = "forward_fill",
columns: Optional[list] = None,
) -> pd.DataFrame:
"""
Handle missing data in DataFrame.
Args:
df: DataFrame with potential missing values
method: Method to handle missing data
('forward_fill', 'backward_fill', 'drop', 'interpolate')
columns: Specific columns to process (defaults to all numeric columns)
Returns:
DataFrame with missing data handled
Raises:
DataError: If method is invalid
"""
if df.empty:
return df
if columns is None:
# Default to numeric columns
columns = df.select_dtypes(include=["number"]).columns.tolist()
df_processed = df.copy()
missing_before = df_processed[columns].isna().sum().sum()
if missing_before == 0:
logger.debug("No missing data found")
return df_processed
logger.info(f"Handling {missing_before} missing values using method: {method}")
for col in columns:
if col not in df_processed.columns:
continue
if method == "forward_fill":
df_processed[col] = df_processed[col].ffill()
elif method == "backward_fill":
df_processed[col] = df_processed[col].bfill()
elif method == "drop":
df_processed = df_processed.dropna(subset=[col])
elif method == "interpolate":
df_processed[col] = df_processed[col].interpolate(method="linear")
else:
raise DataError(
f"Invalid missing data method: {method}",
context={"valid_methods": ["forward_fill", "backward_fill", "drop", "interpolate"]},
)
missing_after = df_processed[columns].isna().sum().sum()
logger.info(f"Missing data handled: {missing_before} -> {missing_after}")
return df_processed
def remove_duplicates(
df: pd.DataFrame,
subset: Optional[list] = None,
keep: str = "first",
timestamp_col: str = "timestamp",
) -> pd.DataFrame:
"""
Remove duplicate rows from DataFrame.
Args:
df: DataFrame with potential duplicates
subset: Columns to consider for duplicates (defaults to timestamp)
keep: Which duplicates to keep ('first', 'last', False to drop all)
timestamp_col: Name of timestamp column
Returns:
DataFrame with duplicates removed
"""
if df.empty:
return df
if subset is None:
subset = [timestamp_col] if timestamp_col in df.columns else None
duplicates_before = len(df)
df_processed = df.drop_duplicates(subset=subset, keep=keep)
duplicates_removed = duplicates_before - len(df_processed)
if duplicates_removed > 0:
logger.info(f"Removed {duplicates_removed} duplicate rows")
else:
logger.debug("No duplicates found")
return df_processed
def filter_session(
df: pd.DataFrame,
timestamp_col: str = "timestamp",
session_start: Optional[str] = None,
session_end: Optional[str] = None,
timezone: str = "America/New_York",
) -> pd.DataFrame:
"""
Filter DataFrame to trading session hours (default: 3:00-4:00 AM EST).
Args:
df: DataFrame with timestamp column
timestamp_col: Name of timestamp column
session_start: Session start time (HH:MM format, defaults to config)
session_end: Session end time (HH:MM format, defaults to config)
timezone: Timezone for session times (defaults to EST)
Returns:
Filtered DataFrame
Raises:
DataError: If timestamp column is missing or invalid
"""
if df.empty:
return df
if timestamp_col not in df.columns:
raise DataError(
f"Timestamp column '{timestamp_col}' not found",
context={"columns": df.columns.tolist()},
)
# Get session times from config or use defaults
if session_start is None:
session_start = SESSION_TIMES.get("start", "03:00")
if session_end is None:
session_end = SESSION_TIMES.get("end", "04:00")
# Parse session times
start_time = datetime.strptime(session_start, "%H:%M").time()
end_time = datetime.strptime(session_end, "%H:%M").time()
# Ensure timestamp is datetime
if not pd.api.types.is_datetime64_any_dtype(df[timestamp_col]):
df[timestamp_col] = pd.to_datetime(df[timestamp_col])
# Convert to session timezone if needed
tz = pytz.timezone(timezone)
if df[timestamp_col].dt.tz is None:
# Assume UTC if no timezone
df[timestamp_col] = df[timestamp_col].dt.tz_localize("UTC")
df[timestamp_col] = df[timestamp_col].dt.tz_convert(tz)
# Filter by time of day
df_filtered = df[
(df[timestamp_col].dt.time >= start_time) & (df[timestamp_col].dt.time <= end_time)
].copy()
rows_before = len(df)
rows_after = len(df_filtered)
logger.info(
f"Filtered to session {session_start}-{session_end} {timezone}: "
f"{rows_before} -> {rows_after} rows"
)
return df_filtered

355
src/data/repositories.py Normal file
View File

@@ -0,0 +1,355 @@
"""Repository pattern for data access layer."""
from datetime import datetime
from typing import List, Optional
from sqlalchemy import and_, desc
from sqlalchemy.orm import Session
from src.core.enums import PatternType, Timeframe
from src.core.exceptions import DataError
from src.data.models import DetectedPattern, OHLCVData, PatternLabel
from src.logging import get_logger
logger = get_logger(__name__)
class Repository:
"""Base repository class with common database operations."""
def __init__(self, session: Optional[Session] = None):
"""
Initialize repository.
Args:
session: Optional database session (creates new if not provided)
"""
self._session = session
@property
def session(self) -> Session:
"""Get database session."""
if self._session is None:
# Use context manager for automatic cleanup
raise RuntimeError("Session must be provided or use context manager")
return self._session
class OHLCVRepository(Repository):
"""Repository for OHLCV data operations."""
def create(self, data: OHLCVData) -> OHLCVData:
"""
Create new OHLCV record.
Args:
data: OHLCVData instance
Returns:
Created OHLCVData instance
Raises:
DataError: If creation fails
"""
try:
self.session.add(data)
self.session.flush()
logger.debug(f"Created OHLCV record: {data.id}")
return data
except Exception as e:
logger.error(f"Failed to create OHLCV record: {e}", exc_info=True)
raise DataError(f"Failed to create OHLCV record: {e}") from e
def create_batch(self, data_list: List[OHLCVData]) -> List[OHLCVData]:
"""
Create multiple OHLCV records in batch.
Args:
data_list: List of OHLCVData instances
Returns:
List of created OHLCVData instances
Raises:
DataError: If batch creation fails
"""
try:
self.session.add_all(data_list)
self.session.flush()
logger.info(f"Created {len(data_list)} OHLCV records in batch")
return data_list
except Exception as e:
logger.error(f"Failed to create OHLCV records in batch: {e}", exc_info=True)
raise DataError(f"Failed to create OHLCV records: {e}") from e
def get_by_id(self, record_id: int) -> Optional[OHLCVData]:
"""
Get OHLCV record by ID.
Args:
record_id: Record ID
Returns:
OHLCVData instance or None if not found
"""
result = self.session.query(OHLCVData).filter(OHLCVData.id == record_id).first()
return result # type: ignore[no-any-return]
def get_by_timestamp_range(
self,
symbol: str,
timeframe: Timeframe,
start: datetime,
end: datetime,
limit: Optional[int] = None,
) -> List[OHLCVData]:
"""
Get OHLCV data for symbol/timeframe within timestamp range.
Args:
symbol: Trading symbol
timeframe: Timeframe enum
start: Start timestamp
end: End timestamp
limit: Optional limit on number of records
Returns:
List of OHLCVData instances
"""
query = (
self.session.query(OHLCVData)
.filter(
and_(
OHLCVData.symbol == symbol,
OHLCVData.timeframe == timeframe,
OHLCVData.timestamp >= start,
OHLCVData.timestamp <= end,
)
)
.order_by(OHLCVData.timestamp)
)
if limit:
query = query.limit(limit)
result = query.all()
return result # type: ignore[no-any-return]
def get_latest(self, symbol: str, timeframe: Timeframe, limit: int = 1) -> List[OHLCVData]:
"""
Get latest OHLCV records for symbol/timeframe.
Args:
symbol: Trading symbol
timeframe: Timeframe enum
limit: Number of records to return
Returns:
List of OHLCVData instances (most recent first)
"""
result = (
self.session.query(OHLCVData)
.filter(
and_(
OHLCVData.symbol == symbol,
OHLCVData.timeframe == timeframe,
)
)
.order_by(desc(OHLCVData.timestamp))
.limit(limit)
.all()
)
return result # type: ignore[no-any-return]
def exists(self, symbol: str, timeframe: Timeframe, timestamp: datetime) -> bool:
"""
Check if OHLCV record exists.
Args:
symbol: Trading symbol
timeframe: Timeframe enum
timestamp: Record timestamp
Returns:
True if record exists, False otherwise
"""
count = (
self.session.query(OHLCVData)
.filter(
and_(
OHLCVData.symbol == symbol,
OHLCVData.timeframe == timeframe,
OHLCVData.timestamp == timestamp,
)
)
.count()
)
return bool(count > 0)
def delete_by_timestamp_range(
self,
symbol: str,
timeframe: Timeframe,
start: datetime,
end: datetime,
) -> int:
"""
Delete OHLCV records within timestamp range.
Args:
symbol: Trading symbol
timeframe: Timeframe enum
start: Start timestamp
end: End timestamp
Returns:
Number of records deleted
"""
try:
deleted = (
self.session.query(OHLCVData)
.filter(
and_(
OHLCVData.symbol == symbol,
OHLCVData.timeframe == timeframe,
OHLCVData.timestamp >= start,
OHLCVData.timestamp <= end,
)
)
.delete(synchronize_session=False)
)
logger.info(f"Deleted {deleted} OHLCV records")
return int(deleted)
except Exception as e:
logger.error(f"Failed to delete OHLCV records: {e}", exc_info=True)
raise DataError(f"Failed to delete OHLCV records: {e}") from e
class PatternRepository(Repository):
"""Repository for detected pattern operations."""
def create(self, pattern: DetectedPattern) -> DetectedPattern:
"""
Create new pattern record.
Args:
pattern: DetectedPattern instance
Returns:
Created DetectedPattern instance
Raises:
DataError: If creation fails
"""
try:
self.session.add(pattern)
self.session.flush()
logger.debug(f"Created pattern record: {pattern.id} ({pattern.pattern_type})")
return pattern
except Exception as e:
logger.error(f"Failed to create pattern record: {e}", exc_info=True)
raise DataError(f"Failed to create pattern record: {e}") from e
def create_batch(self, patterns: List[DetectedPattern]) -> List[DetectedPattern]:
"""
Create multiple pattern records in batch.
Args:
patterns: List of DetectedPattern instances
Returns:
List of created DetectedPattern instances
Raises:
DataError: If batch creation fails
"""
try:
self.session.add_all(patterns)
self.session.flush()
logger.info(f"Created {len(patterns)} pattern records in batch")
return patterns
except Exception as e:
logger.error(f"Failed to create pattern records in batch: {e}", exc_info=True)
raise DataError(f"Failed to create pattern records: {e}") from e
def get_by_id(self, pattern_id: int) -> Optional[DetectedPattern]:
"""
Get pattern by ID.
Args:
pattern_id: Pattern ID
Returns:
DetectedPattern instance or None if not found
"""
result = (
self.session.query(DetectedPattern).filter(DetectedPattern.id == pattern_id).first()
)
return result # type: ignore[no-any-return]
def get_by_type_and_range(
self,
pattern_type: PatternType,
symbol: str,
start: datetime,
end: datetime,
timeframe: Optional[Timeframe] = None,
) -> List[DetectedPattern]:
"""
Get patterns by type within timestamp range.
Args:
pattern_type: Pattern type enum
symbol: Trading symbol
start: Start timestamp
end: End timestamp
timeframe: Optional timeframe filter
Returns:
List of DetectedPattern instances
"""
query = self.session.query(DetectedPattern).filter(
and_(
DetectedPattern.pattern_type == pattern_type,
DetectedPattern.symbol == symbol,
DetectedPattern.start_timestamp >= start,
DetectedPattern.start_timestamp <= end,
)
)
if timeframe:
query = query.filter(DetectedPattern.timeframe == timeframe)
return query.order_by(DetectedPattern.start_timestamp).all() # type: ignore[no-any-return]
def get_unlabeled(
self,
pattern_type: Optional[PatternType] = None,
symbol: Optional[str] = None,
limit: int = 100,
) -> List[DetectedPattern]:
"""
Get patterns that don't have labels yet.
Args:
pattern_type: Optional pattern type filter
symbol: Optional symbol filter
limit: Maximum number of records to return
Returns:
List of unlabeled DetectedPattern instances
"""
query = (
self.session.query(DetectedPattern)
.outerjoin(PatternLabel)
.filter(PatternLabel.id.is_(None))
)
if pattern_type:
query = query.filter(DetectedPattern.pattern_type == pattern_type)
if symbol:
query = query.filter(DetectedPattern.symbol == symbol)
result = query.order_by(desc(DetectedPattern.detected_at)).limit(limit).all()
return result # type: ignore[no-any-return]

91
src/data/schemas.py Normal file
View File

@@ -0,0 +1,91 @@
"""Pydantic schemas for data validation."""
from datetime import datetime
from decimal import Decimal
from typing import Optional
from pydantic import BaseModel, Field, field_validator
from src.core.enums import PatternDirection, PatternType, Timeframe
class OHLCVSchema(BaseModel):
"""Schema for OHLCV data validation."""
symbol: str = Field(..., description="Trading symbol (e.g., 'DAX')")
timeframe: Timeframe = Field(..., description="Timeframe enum")
timestamp: datetime = Field(..., description="Candle timestamp")
open: Decimal = Field(..., gt=0, description="Open price")
high: Decimal = Field(..., gt=0, description="High price")
low: Decimal = Field(..., gt=0, description="Low price")
close: Decimal = Field(..., gt=0, description="Close price")
volume: Optional[int] = Field(None, ge=0, description="Volume")
@field_validator("high", "low")
@classmethod
def validate_price_range(cls, v: Decimal, info) -> Decimal:
"""Validate that high >= low and prices are within reasonable range."""
if info.field_name == "high":
low = info.data.get("low")
if low and v < low:
raise ValueError("High price must be >= low price")
elif info.field_name == "low":
high = info.data.get("high")
if high and v > high:
raise ValueError("Low price must be <= high price")
return v
@field_validator("open", "close")
@classmethod
def validate_price_bounds(cls, v: Decimal, info) -> Decimal:
"""Validate that open/close are within high/low range."""
high = info.data.get("high")
low = info.data.get("low")
if high and low:
if not (low <= v <= high):
raise ValueError(f"{info.field_name} must be between low and high")
return v
class Config:
"""Pydantic config."""
json_encoders = {
Decimal: str,
datetime: lambda v: v.isoformat(),
}
class PatternSchema(BaseModel):
"""Schema for detected pattern validation."""
pattern_type: PatternType = Field(..., description="Pattern type enum")
direction: PatternDirection = Field(..., description="Pattern direction")
timeframe: Timeframe = Field(..., description="Timeframe enum")
symbol: str = Field(..., description="Trading symbol")
start_timestamp: datetime = Field(..., description="Pattern start timestamp")
end_timestamp: datetime = Field(..., description="Pattern end timestamp")
entry_level: Optional[Decimal] = Field(None, description="Entry price level")
stop_loss: Optional[Decimal] = Field(None, description="Stop loss level")
take_profit: Optional[Decimal] = Field(None, description="Take profit level")
high_level: Optional[Decimal] = Field(None, description="Pattern high level")
low_level: Optional[Decimal] = Field(None, description="Pattern low level")
size_pips: Optional[float] = Field(None, ge=0, description="Pattern size in pips")
strength_score: Optional[float] = Field(None, ge=0, le=1, description="Strength score (0-1)")
context_data: Optional[str] = Field(None, description="Additional context as JSON string")
@field_validator("end_timestamp")
@classmethod
def validate_timestamp_order(cls, v: datetime, info) -> datetime:
"""Validate that end_timestamp >= start_timestamp."""
start = info.data.get("start_timestamp")
if start and v < start:
raise ValueError("end_timestamp must be >= start_timestamp")
return v
class Config:
"""Pydantic config."""
json_encoders = {
Decimal: str,
datetime: lambda v: v.isoformat(),
}

231
src/data/validators.py Normal file
View File

@@ -0,0 +1,231 @@
"""Data validation functions."""
from datetime import datetime, timedelta
from typing import List, Optional, Tuple
import numpy as np
import pandas as pd
from src.core.enums import Timeframe
from src.core.exceptions import ValidationError
from src.logging import get_logger
logger = get_logger(__name__)
def validate_ohlcv(df: pd.DataFrame, required_columns: Optional[List[str]] = None) -> pd.DataFrame:
"""
Validate OHLCV DataFrame structure and data quality.
Args:
df: DataFrame with OHLCV data
required_columns: Optional list of required columns (defaults to standard OHLCV)
Returns:
Validated DataFrame
Raises:
ValidationError: If validation fails
"""
if required_columns is None:
required_columns = ["timestamp", "open", "high", "low", "close"]
# Check required columns exist
missing_cols = [col for col in required_columns if col not in df.columns]
if missing_cols:
raise ValidationError(
f"Missing required columns: {missing_cols}",
context={"columns": df.columns.tolist(), "required": required_columns},
)
# Check for empty DataFrame
if df.empty:
raise ValidationError("DataFrame is empty")
# Validate price columns
price_cols = ["open", "high", "low", "close"]
for col in price_cols:
if col in df.columns:
# Check for negative or zero prices
if (df[col] <= 0).any():
invalid_count = (df[col] <= 0).sum()
raise ValidationError(
f"Invalid {col} values (<= 0): {invalid_count} rows",
context={"column": col, "invalid_rows": invalid_count},
)
# Check for infinite values
if np.isinf(df[col]).any():
invalid_count = np.isinf(df[col]).sum()
raise ValidationError(
f"Infinite {col} values: {invalid_count} rows",
context={"column": col, "invalid_rows": invalid_count},
)
# Validate high >= low
if "high" in df.columns and "low" in df.columns:
invalid = df["high"] < df["low"]
if invalid.any():
invalid_count = invalid.sum()
raise ValidationError(
f"High < Low in {invalid_count} rows",
context={"invalid_rows": invalid_count},
)
# Validate open/close within high/low range
if all(col in df.columns for col in ["open", "close", "high", "low"]):
invalid_open = (df["open"] < df["low"]) | (df["open"] > df["high"])
invalid_close = (df["close"] < df["low"]) | (df["close"] > df["high"])
if invalid_open.any() or invalid_close.any():
invalid_count = invalid_open.sum() + invalid_close.sum()
raise ValidationError(
f"Open/Close outside High/Low range: {invalid_count} rows",
context={"invalid_rows": invalid_count},
)
# Validate timestamp column
if "timestamp" in df.columns:
if not pd.api.types.is_datetime64_any_dtype(df["timestamp"]):
try:
df["timestamp"] = pd.to_datetime(df["timestamp"])
except Exception as e:
raise ValidationError(
f"Invalid timestamp format: {e}",
context={"column": "timestamp"},
) from e
# Check for duplicate timestamps
duplicates = df["timestamp"].duplicated().sum()
if duplicates > 0:
logger.warning(f"Found {duplicates} duplicate timestamps")
logger.debug(f"Validated OHLCV DataFrame: {len(df)} rows, {len(df.columns)} columns")
return df
def check_continuity(
df: pd.DataFrame,
timeframe: Timeframe,
timestamp_col: str = "timestamp",
max_gap_minutes: Optional[int] = None,
) -> Tuple[bool, List[datetime]]:
"""
Check for gaps in timestamp continuity.
Args:
df: DataFrame with timestamp column
timeframe: Expected timeframe
timestamp_col: Name of timestamp column
max_gap_minutes: Maximum allowed gap in minutes (defaults to timeframe duration)
Returns:
Tuple of (is_continuous, list_of_gaps)
Raises:
ValidationError: If timestamp column is missing or invalid
"""
if timestamp_col not in df.columns:
raise ValidationError(
f"Timestamp column '{timestamp_col}' not found",
context={"columns": df.columns.tolist()},
)
if df.empty:
return True, []
# Determine expected interval
timeframe_minutes = {
Timeframe.M1: 1,
Timeframe.M5: 5,
Timeframe.M15: 15,
}
expected_interval = timedelta(minutes=timeframe_minutes.get(timeframe, 1))
if max_gap_minutes:
max_gap = timedelta(minutes=max_gap_minutes)
else:
max_gap = expected_interval * 2 # Allow 2x timeframe as max gap
# Sort by timestamp
df_sorted = df.sort_values(timestamp_col).copy()
timestamps = pd.to_datetime(df_sorted[timestamp_col])
# Find gaps
gaps = []
for i in range(len(timestamps) - 1):
gap = timestamps.iloc[i + 1] - timestamps.iloc[i]
if gap > max_gap:
gaps.append(timestamps.iloc[i])
is_continuous = len(gaps) == 0
if gaps:
logger.warning(
f"Found {len(gaps)} gaps in continuity (timeframe: {timeframe}, " f"max_gap: {max_gap})"
)
return is_continuous, gaps
def detect_outliers(
df: pd.DataFrame,
columns: Optional[List[str]] = None,
method: str = "iqr",
threshold: float = 3.0,
) -> pd.DataFrame:
"""
Detect outliers in price columns.
Args:
df: DataFrame with price data
columns: Columns to check (defaults to OHLCV price columns)
method: Detection method ('iqr' or 'zscore')
threshold: Threshold for outlier detection
Returns:
DataFrame with boolean mask (True = outlier)
Raises:
ValidationError: If method is invalid or columns missing
"""
if columns is None:
columns = [col for col in ["open", "high", "low", "close"] if col in df.columns]
if not columns:
raise ValidationError("No columns specified for outlier detection")
missing_cols = [col for col in columns if col not in df.columns]
if missing_cols:
raise ValidationError(
f"Columns not found: {missing_cols}",
context={"columns": df.columns.tolist()},
)
outlier_mask = pd.Series([False] * len(df), index=df.index)
for col in columns:
if method == "iqr":
Q1 = df[col].quantile(0.25)
Q3 = df[col].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - threshold * IQR
upper_bound = Q3 + threshold * IQR
col_outliers = (df[col] < lower_bound) | (df[col] > upper_bound)
elif method == "zscore":
z_scores = np.abs((df[col] - df[col].mean()) / df[col].std())
col_outliers = z_scores > threshold
else:
raise ValidationError(
f"Invalid outlier detection method: {method}",
context={"valid_methods": ["iqr", "zscore"]},
)
outlier_mask |= col_outliers
outlier_count = outlier_mask.sum()
if outlier_count > 0:
logger.warning(f"Detected {outlier_count} outliers using {method} method")
return outlier_mask.to_frame("is_outlier")

View File

@@ -0,0 +1,6 @@
timestamp,open,high,low,close,volume
2024-01-01 03:00:00,100.0,100.5,99.5,100.2,1000
2024-01-01 03:01:00,100.2,100.7,99.7,100.4,1100
2024-01-01 03:02:00,100.4,100.9,99.9,100.6,1200
2024-01-01 03:03:00,100.6,101.1,100.1,100.8,1300
2024-01-01 03:04:00,100.8,101.3,100.3,101.0,1400
1 timestamp open high low close volume
2 2024-01-01 03:00:00 100.0 100.5 99.5 100.2 1000
3 2024-01-01 03:01:00 100.2 100.7 99.7 100.4 1100
4 2024-01-01 03:02:00 100.4 100.9 99.9 100.6 1200
5 2024-01-01 03:03:00 100.6 101.1 100.1 100.8 1300
6 2024-01-01 03:04:00 100.8 101.3 100.3 101.0 1400

View File

@@ -0,0 +1,127 @@
"""Integration tests for database operations."""
import os
import tempfile
import pytest
from src.core.enums import Timeframe
from src.data.database import get_db_session, init_database
from src.data.models import OHLCVData
from src.data.repositories import OHLCVRepository
@pytest.fixture
def temp_db():
"""Create temporary database for testing."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
# Initialize database
init_database(create_tables=True)
yield db_path
# Cleanup
if os.path.exists(db_path):
os.unlink(db_path)
if "DATABASE_URL" in os.environ:
del os.environ["DATABASE_URL"]
def test_create_and_retrieve_ohlcv(temp_db):
"""Test creating and retrieving OHLCV records."""
from datetime import datetime
with get_db_session() as session:
repo = OHLCVRepository(session=session)
# Create record
record = OHLCVData(
symbol="DAX",
timeframe=Timeframe.M1,
timestamp=datetime(2024, 1, 1, 3, 0, 0),
open=100.0,
high=100.5,
low=99.5,
close=100.2,
volume=1000,
)
created = repo.create(record)
assert created.id is not None
# Retrieve record
retrieved = repo.get_by_id(created.id)
assert retrieved is not None
assert retrieved.symbol == "DAX"
assert retrieved.close == 100.2
def test_batch_create_ohlcv(temp_db):
"""Test batch creation of OHLCV records."""
from datetime import datetime, timedelta
with get_db_session() as session:
repo = OHLCVRepository(session=session)
# Create multiple records
records = []
base_time = datetime(2024, 1, 1, 3, 0, 0)
for i in range(10):
records.append(
OHLCVData(
symbol="DAX",
timeframe=Timeframe.M1,
timestamp=base_time + timedelta(minutes=i),
open=100.0 + i * 0.1,
high=100.5 + i * 0.1,
low=99.5 + i * 0.1,
close=100.2 + i * 0.1,
volume=1000,
)
)
created = repo.create_batch(records)
assert len(created) == 10
# Verify all records saved
retrieved = repo.get_by_timestamp_range(
"DAX",
Timeframe.M1,
base_time,
base_time + timedelta(minutes=10),
)
assert len(retrieved) == 10
def test_get_by_timestamp_range(temp_db):
"""Test retrieving records by timestamp range."""
from datetime import datetime, timedelta
with get_db_session() as session:
repo = OHLCVRepository(session=session)
# Create records
base_time = datetime(2024, 1, 1, 3, 0, 0)
for i in range(20):
record = OHLCVData(
symbol="DAX",
timeframe=Timeframe.M1,
timestamp=base_time + timedelta(minutes=i),
open=100.0,
high=100.5,
low=99.5,
close=100.2,
volume=1000,
)
repo.create(record)
# Retrieve subset
start = base_time + timedelta(minutes=5)
end = base_time + timedelta(minutes=15)
records = repo.get_by_timestamp_range("DAX", Timeframe.M1, start, end)
assert len(records) == 11 # Inclusive of start and end

View File

@@ -0,0 +1 @@
"""Unit tests for data module."""

View File

@@ -0,0 +1,65 @@
"""Tests for database connection and session management."""
import os
import tempfile
import pytest
from src.data.database import get_db_session, get_engine, init_database
@pytest.fixture
def temp_db():
"""Create temporary database for testing."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
# Set environment variable
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
yield db_path
# Cleanup
if os.path.exists(db_path):
os.unlink(db_path)
if "DATABASE_URL" in os.environ:
del os.environ["DATABASE_URL"]
def test_get_engine(temp_db):
"""Test engine creation."""
engine = get_engine()
assert engine is not None
assert str(engine.url).startswith("sqlite")
def test_init_database(temp_db):
"""Test database initialization."""
init_database(create_tables=True)
assert os.path.exists(temp_db)
def test_get_db_session(temp_db):
"""Test database session context manager."""
init_database(create_tables=True)
with get_db_session() as session:
assert session is not None
# Session should be usable
result = session.execute("SELECT 1").scalar()
assert result == 1
def test_session_rollback_on_error(temp_db):
"""Test that session rolls back on error."""
init_database(create_tables=True)
try:
with get_db_session() as session:
# Cause an error
session.execute("SELECT * FROM nonexistent_table")
except Exception:
pass # Expected
# Session should have been rolled back and closed
assert True # If we get here, rollback worked

View File

@@ -0,0 +1,83 @@
"""Tests for data loaders."""
import pandas as pd
import pytest
from src.core.enums import Timeframe
from src.data.loaders import CSVLoader, ParquetLoader
@pytest.fixture
def sample_ohlcv_data():
"""Create sample OHLCV DataFrame."""
dates = pd.date_range("2024-01-01 03:00", periods=100, freq="1min")
return pd.DataFrame(
{
"timestamp": dates,
"open": [100.0 + i * 0.1 for i in range(100)],
"high": [100.5 + i * 0.1 for i in range(100)],
"low": [99.5 + i * 0.1 for i in range(100)],
"close": [100.2 + i * 0.1 for i in range(100)],
"volume": [1000] * 100,
}
)
@pytest.fixture
def csv_file(sample_ohlcv_data, tmp_path):
"""Create temporary CSV file."""
csv_path = tmp_path / "test_data.csv"
sample_ohlcv_data.to_csv(csv_path, index=False)
return csv_path
@pytest.fixture
def parquet_file(sample_ohlcv_data, tmp_path):
"""Create temporary Parquet file."""
parquet_path = tmp_path / "test_data.parquet"
sample_ohlcv_data.to_parquet(parquet_path, index=False)
return parquet_path
def test_csv_loader(csv_file):
"""Test CSV loader."""
loader = CSVLoader()
df = loader.load(str(csv_file), symbol="DAX", timeframe=Timeframe.M1)
assert len(df) == 100
assert "symbol" in df.columns
assert "timeframe" in df.columns
assert df["symbol"].iloc[0] == "DAX"
assert df["timeframe"].iloc[0] == "1min"
def test_csv_loader_missing_file():
"""Test CSV loader with missing file."""
loader = CSVLoader()
with pytest.raises(Exception): # Should raise DataError
loader.load("nonexistent.csv")
def test_parquet_loader(parquet_file):
"""Test Parquet loader."""
loader = ParquetLoader()
df = loader.load(str(parquet_file), symbol="DAX", timeframe=Timeframe.M1)
assert len(df) == 100
assert "symbol" in df.columns
assert "timeframe" in df.columns
def test_load_and_preprocess(csv_file):
"""Test load_and_preprocess function."""
from src.data.loaders import load_and_preprocess
df = load_and_preprocess(
str(csv_file),
loader_type="csv",
validate=True,
preprocess=True,
)
assert len(df) == 100
assert "timestamp" in df.columns

View File

@@ -0,0 +1,87 @@
"""Tests for data preprocessors."""
import numpy as np
import pandas as pd
import pytest
from src.data.preprocessors import filter_session, handle_missing_data, remove_duplicates
@pytest.fixture
def sample_data_with_missing():
"""Create sample DataFrame with missing values."""
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 10,
"high": [100.5] * 10,
"low": [99.5] * 10,
"close": [100.2] * 10,
}
)
# Add some missing values
df.loc[2, "close"] = np.nan
df.loc[5, "open"] = np.nan
return df
@pytest.fixture
def sample_data_with_duplicates():
"""Create sample DataFrame with duplicates."""
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 10,
"high": [100.5] * 10,
"low": [99.5] * 10,
"close": [100.2] * 10,
}
)
# Add duplicate
df = pd.concat([df, df.iloc[[0]]], ignore_index=True)
return df
def test_handle_missing_data_forward_fill(sample_data_with_missing):
"""Test forward fill missing data."""
df = handle_missing_data(sample_data_with_missing, method="forward_fill")
assert df["close"].isna().sum() == 0
assert df["open"].isna().sum() == 0
def test_handle_missing_data_drop(sample_data_with_missing):
"""Test drop missing data."""
df = handle_missing_data(sample_data_with_missing, method="drop")
assert df["close"].isna().sum() == 0
assert df["open"].isna().sum() == 0
assert len(df) < len(sample_data_with_missing)
def test_remove_duplicates(sample_data_with_duplicates):
"""Test duplicate removal."""
df = remove_duplicates(sample_data_with_duplicates)
assert len(df) == 10 # Should remove duplicate
def test_filter_session():
"""Test session filtering."""
# Create data spanning multiple hours
dates = pd.date_range("2024-01-01 02:00", periods=120, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 120,
"high": [100.5] * 120,
"low": [99.5] * 120,
"close": [100.2] * 120,
}
)
# Filter to 3-4 AM EST
df_filtered = filter_session(df, session_start="03:00", session_end="04:00")
# Should have approximately 60 rows (1 hour of 1-minute data)
assert len(df_filtered) > 0
assert len(df_filtered) <= 60

View File

@@ -0,0 +1,97 @@
"""Tests for data validators."""
import pandas as pd
import pytest
from src.core.enums import Timeframe
from src.data.validators import check_continuity, detect_outliers, validate_ohlcv
@pytest.fixture
def valid_ohlcv_data():
"""Create valid OHLCV DataFrame."""
dates = pd.date_range("2024-01-01 03:00", periods=100, freq="1min")
return pd.DataFrame(
{
"timestamp": dates,
"open": [100.0 + i * 0.1 for i in range(100)],
"high": [100.5 + i * 0.1 for i in range(100)],
"low": [99.5 + i * 0.1 for i in range(100)],
"close": [100.2 + i * 0.1 for i in range(100)],
"volume": [1000] * 100,
}
)
@pytest.fixture
def invalid_ohlcv_data():
"""Create invalid OHLCV DataFrame."""
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * 10,
"high": [99.0] * 10, # Invalid: high < low
"low": [99.5] * 10,
"close": [100.2] * 10,
}
)
return df
def test_validate_ohlcv_valid(valid_ohlcv_data):
"""Test validation with valid data."""
df = validate_ohlcv(valid_ohlcv_data)
assert len(df) == 100
def test_validate_ohlcv_invalid(invalid_ohlcv_data):
"""Test validation with invalid data."""
with pytest.raises(Exception): # Should raise ValidationError
validate_ohlcv(invalid_ohlcv_data)
def test_validate_ohlcv_missing_columns():
"""Test validation with missing columns."""
df = pd.DataFrame({"timestamp": pd.date_range("2024-01-01", periods=10)})
with pytest.raises(Exception): # Should raise ValidationError
validate_ohlcv(df)
def test_check_continuity(valid_ohlcv_data):
"""Test continuity check."""
is_continuous, gaps = check_continuity(valid_ohlcv_data, Timeframe.M1)
assert is_continuous
assert len(gaps) == 0
def test_check_continuity_with_gaps():
"""Test continuity check with gaps."""
# Create data with gaps
dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min")
# Remove some dates to create gaps
dates = dates[[0, 1, 2, 5, 6, 7, 8, 9]] # Gap between index 2 and 5
df = pd.DataFrame(
{
"timestamp": dates,
"open": [100.0] * len(dates),
"high": [100.5] * len(dates),
"low": [99.5] * len(dates),
"close": [100.2] * len(dates),
}
)
is_continuous, gaps = check_continuity(df, Timeframe.M1)
assert not is_continuous
assert len(gaps) > 0
def test_detect_outliers(valid_ohlcv_data):
"""Test outlier detection."""
# Add an outlier
df = valid_ohlcv_data.copy()
df.loc[50, "close"] = 200.0 # Extreme value
outliers = detect_outliers(df, columns=["close"], method="iqr", threshold=3.0)
assert outliers["is_outlier"].sum() > 0