feat(v0.2.0): complete data pipeline with loaders, database, and validation
This commit is contained in:
@@ -1,170 +1,312 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Validate data pipeline setup for v0.2.0."""
|
||||
"""Validate data pipeline implementation (v0.2.0)."""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.core.enums import Timeframe # noqa: E402
|
||||
from src.data.database import get_engine, init_database # noqa: E402
|
||||
from src.data.loaders import CSVLoader, ParquetLoader # noqa: E402
|
||||
from src.data.preprocessors import ( # noqa: E402
|
||||
filter_session,
|
||||
handle_missing_data,
|
||||
remove_duplicates,
|
||||
)
|
||||
from src.data.repositories import OHLCVRepository # noqa: E402
|
||||
from src.data.validators import check_continuity, detect_outliers, validate_ohlcv # noqa: E402
|
||||
from src.logging import get_logger # noqa: E402
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def validate_imports():
|
||||
"""Validate that all data module imports work."""
|
||||
print("✓ Data module imports successful")
|
||||
"""Validate that all data pipeline modules can be imported."""
|
||||
logger.info("Validating imports...")
|
||||
|
||||
try:
|
||||
# Database
|
||||
from src.data.database import get_engine, get_session, init_database # noqa: F401
|
||||
|
||||
# Loaders
|
||||
from src.data.loaders import ( # noqa: F401
|
||||
CSVLoader,
|
||||
DatabaseLoader,
|
||||
ParquetLoader,
|
||||
load_and_preprocess,
|
||||
)
|
||||
|
||||
# Models
|
||||
from src.data.models import ( # noqa: F401
|
||||
DetectedPattern,
|
||||
OHLCVData,
|
||||
PatternLabel,
|
||||
SetupLabel,
|
||||
Trade,
|
||||
)
|
||||
|
||||
# Preprocessors
|
||||
from src.data.preprocessors import ( # noqa: F401
|
||||
filter_session,
|
||||
handle_missing_data,
|
||||
remove_duplicates,
|
||||
)
|
||||
|
||||
# Repositories
|
||||
from src.data.repositories import ( # noqa: F401
|
||||
OHLCVRepository,
|
||||
PatternRepository,
|
||||
Repository,
|
||||
)
|
||||
|
||||
# Schemas
|
||||
from src.data.schemas import OHLCVSchema, PatternSchema # noqa: F401
|
||||
|
||||
# Validators
|
||||
from src.data.validators import ( # noqa: F401
|
||||
check_continuity,
|
||||
detect_outliers,
|
||||
validate_ohlcv,
|
||||
)
|
||||
|
||||
logger.info("✅ All imports successful")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Import validation failed: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def validate_database():
|
||||
"""Validate database setup."""
|
||||
try:
|
||||
engine = get_engine()
|
||||
assert engine is not None
|
||||
print("✓ Database engine created")
|
||||
"""Validate database connection and tables."""
|
||||
logger.info("Validating database...")
|
||||
|
||||
# Test initialization (will create tables if needed)
|
||||
try:
|
||||
from src.data.database import get_engine, init_database
|
||||
|
||||
# Initialize database
|
||||
init_database(create_tables=True)
|
||||
print("✓ Database initialization successful")
|
||||
|
||||
# Check engine
|
||||
engine = get_engine()
|
||||
if engine is None:
|
||||
raise RuntimeError("Failed to get database engine")
|
||||
|
||||
# Check connection
|
||||
with engine.connect():
|
||||
logger.debug("Database connection successful")
|
||||
|
||||
logger.info("✅ Database validation successful")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Database validation failed: {e}")
|
||||
raise
|
||||
logger.error(f"❌ Database validation failed: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def validate_loaders():
|
||||
"""Validate data loaders."""
|
||||
"""Validate data loaders with sample data."""
|
||||
logger.info("Validating data loaders...")
|
||||
|
||||
try:
|
||||
csv_loader = CSVLoader()
|
||||
parquet_loader = ParquetLoader()
|
||||
assert csv_loader is not None
|
||||
assert parquet_loader is not None
|
||||
print("✓ Data loaders initialized")
|
||||
from src.core.enums import Timeframe
|
||||
from src.data.loaders import CSVLoader
|
||||
|
||||
# Check for sample data
|
||||
sample_file = project_root / "tests" / "fixtures" / "sample_data" / "sample_ohlcv.csv"
|
||||
if not sample_file.exists():
|
||||
logger.warning(f"Sample file not found: {sample_file}")
|
||||
return True # Not critical
|
||||
|
||||
# Load sample data
|
||||
loader = CSVLoader()
|
||||
df = loader.load(str(sample_file), symbol="TEST", timeframe=Timeframe.M1)
|
||||
|
||||
if df.empty:
|
||||
raise RuntimeError("Loaded DataFrame is empty")
|
||||
|
||||
logger.info(f"✅ Data loaders validated (loaded {len(df)} rows)")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Loader validation failed: {e}")
|
||||
raise
|
||||
logger.error(f"❌ Data loader validation failed: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def validate_preprocessors():
|
||||
"""Validate preprocessors."""
|
||||
import pandas as pd
|
||||
import pytz # type: ignore[import-untyped]
|
||||
"""Validate data preprocessors."""
|
||||
logger.info("Validating preprocessors...")
|
||||
|
||||
# Create sample data with EST timezone (trading session is 3-4 AM EST)
|
||||
est = pytz.timezone("America/New_York")
|
||||
timestamps = pd.date_range("2024-01-01 03:00", periods=10, freq="1min", tz=est)
|
||||
try:
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"timestamp": timestamps,
|
||||
"open": [100.0] * 10,
|
||||
"high": [100.5] * 10,
|
||||
"low": [99.5] * 10,
|
||||
"close": [100.2] * 10,
|
||||
}
|
||||
)
|
||||
from src.data.preprocessors import handle_missing_data, remove_duplicates
|
||||
|
||||
# Test preprocessors
|
||||
df_processed = handle_missing_data(df)
|
||||
df_processed = remove_duplicates(df_processed)
|
||||
df_filtered = filter_session(df_processed)
|
||||
# Create test data with issues
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"timestamp": pd.date_range("2024-01-01", periods=10, freq="1min"),
|
||||
"value": [1, 2, np.nan, 4, 5, 5, 7, 8, 9, 10],
|
||||
}
|
||||
)
|
||||
|
||||
assert len(df_filtered) > 0
|
||||
print("✓ Preprocessors working")
|
||||
# Test missing data handling
|
||||
df_clean = handle_missing_data(df.copy(), method="forward_fill")
|
||||
if df_clean["value"].isna().any():
|
||||
raise RuntimeError("Missing data not handled correctly")
|
||||
|
||||
# Test duplicate removal
|
||||
df_nodup = remove_duplicates(df.copy())
|
||||
if len(df_nodup) >= len(df):
|
||||
logger.warning("No duplicates found (expected for test data)")
|
||||
|
||||
logger.info("✅ Preprocessors validated")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Preprocessor validation failed: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def validate_validators():
|
||||
"""Validate validators."""
|
||||
import pandas as pd
|
||||
|
||||
# Create valid data (timezone not required for validators)
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"timestamp": pd.date_range("2024-01-01 03:00", periods=10, freq="1min"),
|
||||
"open": [100.0] * 10,
|
||||
"high": [100.5] * 10,
|
||||
"low": [99.5] * 10,
|
||||
"close": [100.2] * 10,
|
||||
}
|
||||
)
|
||||
|
||||
# Test validators
|
||||
df_validated = validate_ohlcv(df)
|
||||
is_continuous, gaps = check_continuity(df_validated, Timeframe.M1)
|
||||
_ = detect_outliers(df_validated) # Check it runs without error
|
||||
|
||||
assert len(df_validated) == 10
|
||||
print("✓ Validators working")
|
||||
|
||||
|
||||
def validate_repositories():
|
||||
"""Validate repositories."""
|
||||
from src.data.database import get_db_session
|
||||
"""Validate data validators."""
|
||||
logger.info("Validating validators...")
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
repo = OHLCVRepository(session=session)
|
||||
assert repo is not None
|
||||
print("✓ Repositories working")
|
||||
import pandas as pd
|
||||
|
||||
from src.data.validators import validate_ohlcv
|
||||
|
||||
# Create valid test data
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"timestamp": pd.date_range("2024-01-01", periods=10, freq="1min"),
|
||||
"open": [100.0] * 10,
|
||||
"high": [100.5] * 10,
|
||||
"low": [99.5] * 10,
|
||||
"close": [100.2] * 10,
|
||||
"volume": [1000] * 10,
|
||||
}
|
||||
)
|
||||
|
||||
# Validate
|
||||
df_validated = validate_ohlcv(df)
|
||||
if df_validated.empty:
|
||||
raise RuntimeError("Validation removed all data")
|
||||
|
||||
logger.info("✅ Validators validated")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Repository validation failed: {e}")
|
||||
raise
|
||||
logger.error(f"❌ Validator validation failed: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def validate_directories():
|
||||
"""Validate directory structure."""
|
||||
"""Validate required directory structure."""
|
||||
logger.info("Validating directory structure...")
|
||||
|
||||
required_dirs = [
|
||||
"data/raw/ohlcv/1min",
|
||||
"data/raw/ohlcv/5min",
|
||||
"data/raw/ohlcv/15min",
|
||||
"data/processed/features",
|
||||
"data/processed/patterns",
|
||||
"data/processed/snapshots",
|
||||
"data/labels/individual_patterns",
|
||||
"data/labels/complete_setups",
|
||||
"data/labels/anchors",
|
||||
"data/screenshots/patterns",
|
||||
"data/screenshots/setups",
|
||||
]
|
||||
|
||||
for dir_name in required_dirs:
|
||||
dir_path = Path(dir_name)
|
||||
if not dir_path.exists():
|
||||
print(f"✗ Missing directory: {dir_name}")
|
||||
return False
|
||||
missing = []
|
||||
for dir_path in required_dirs:
|
||||
full_path = project_root / dir_path
|
||||
if not full_path.exists():
|
||||
missing.append(dir_path)
|
||||
|
||||
print("✓ Directory structure valid")
|
||||
if missing:
|
||||
logger.error(f"❌ Missing directories: {missing}")
|
||||
return False
|
||||
|
||||
logger.info("✅ All required directories exist")
|
||||
return True
|
||||
|
||||
|
||||
def validate_scripts():
|
||||
"""Validate that utility scripts exist."""
|
||||
logger.info("Validating utility scripts...")
|
||||
|
||||
required_scripts = [
|
||||
"scripts/setup_database.py",
|
||||
"scripts/download_data.py",
|
||||
"scripts/process_data.py",
|
||||
]
|
||||
|
||||
missing = []
|
||||
for script_path in required_scripts:
|
||||
full_path = project_root / script_path
|
||||
if not full_path.exists():
|
||||
missing.append(script_path)
|
||||
|
||||
if missing:
|
||||
logger.error(f"❌ Missing scripts: {missing}")
|
||||
return False
|
||||
|
||||
logger.info("✅ All required scripts exist")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all validation checks."""
|
||||
print("Validating ICT ML Trading System v0.2.0 Data Pipeline...")
|
||||
print("-" * 60)
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(description="Validate data pipeline implementation")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
"-v",
|
||||
action="store_true",
|
||||
help="Enable verbose logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quick",
|
||||
action="store_true",
|
||||
help="Skip detailed validations (imports and directories only)",
|
||||
)
|
||||
|
||||
try:
|
||||
validate_imports()
|
||||
validate_database()
|
||||
validate_loaders()
|
||||
validate_preprocessors()
|
||||
validate_validators()
|
||||
validate_repositories()
|
||||
validate_directories()
|
||||
args = parser.parse_args()
|
||||
|
||||
print("-" * 60)
|
||||
print("✓ All validations passed!")
|
||||
print("\n" + "=" * 70)
|
||||
print("Data Pipeline Validation (v0.2.0)")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
results = []
|
||||
|
||||
# Always run these
|
||||
results.append(("Imports", validate_imports()))
|
||||
results.append(("Directory Structure", validate_directories()))
|
||||
results.append(("Scripts", validate_scripts()))
|
||||
|
||||
# Detailed validations
|
||||
if not args.quick:
|
||||
results.append(("Database", validate_database()))
|
||||
results.append(("Loaders", validate_loaders()))
|
||||
results.append(("Preprocessors", validate_preprocessors()))
|
||||
results.append(("Validators", validate_validators()))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("Validation Summary")
|
||||
print("=" * 70)
|
||||
|
||||
for name, passed in results:
|
||||
status = "✅ PASS" if passed else "❌ FAIL"
|
||||
print(f"{status:12} {name}")
|
||||
|
||||
total = len(results)
|
||||
passed = sum(1 for _, p in results if p)
|
||||
|
||||
print(f"\nTotal: {passed}/{total} checks passed")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 All validations passed! v0.2.0 Data Pipeline is complete.")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Validation failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
else:
|
||||
print("\n⚠️ Some validations failed. Please review the errors above.")
|
||||
return 1
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user