From b5e7043df6497c2d23ff3b1163b7326c2436fe94 Mon Sep 17 00:00:00 2001 From: 0x_n3m0_ <0x_n3m0_@Mac.lan> Date: Mon, 5 Jan 2026 11:34:18 +0200 Subject: [PATCH] feat(v0.2.0): data pipeline --- data/ict_trading.db | Bin 0 -> 135168 bytes requirements.txt | 2 +- scripts/download_data.py | 183 ++++++++++ scripts/process_data.py | 269 +++++++++++++++ scripts/setup_database.py | 47 +++ scripts/validate_data_pipeline.py | 172 ++++++++++ src/config/config_loader.py | 3 +- src/core/constants.py | 7 +- src/data/__init__.py | 41 +++ src/data/database.py | 212 ++++++++++++ src/data/loaders.py | 337 +++++++++++++++++++ src/data/models.py | 223 ++++++++++++ src/data/preprocessors.py | 181 ++++++++++ src/data/repositories.py | 355 ++++++++++++++++++++ src/data/schemas.py | 91 +++++ src/data/validators.py | 231 +++++++++++++ tests/fixtures/sample_data/sample_ohlcv.csv | 6 + tests/integration/test_database.py | 127 +++++++ tests/unit/test_data/__init__.py | 1 + tests/unit/test_data/test_database.py | 65 ++++ tests/unit/test_data/test_loaders.py | 83 +++++ tests/unit/test_data/test_preprocessors.py | 87 +++++ tests/unit/test_data/test_validators.py | 97 ++++++ 23 files changed, 2813 insertions(+), 7 deletions(-) create mode 100644 data/ict_trading.db create mode 100755 scripts/download_data.py create mode 100755 scripts/process_data.py create mode 100755 scripts/setup_database.py create mode 100755 scripts/validate_data_pipeline.py create mode 100644 src/data/__init__.py create mode 100644 src/data/database.py create mode 100644 src/data/loaders.py create mode 100644 src/data/models.py create mode 100644 src/data/preprocessors.py create mode 100644 src/data/repositories.py create mode 100644 src/data/schemas.py create mode 100644 src/data/validators.py create mode 100644 tests/fixtures/sample_data/sample_ohlcv.csv create mode 100644 tests/integration/test_database.py create mode 100644 tests/unit/test_data/__init__.py create mode 100644 tests/unit/test_data/test_database.py create mode 100644 tests/unit/test_data/test_loaders.py create mode 100644 tests/unit/test_data/test_preprocessors.py create mode 100644 tests/unit/test_data/test_validators.py diff --git a/data/ict_trading.db b/data/ict_trading.db new file mode 100644 index 0000000000000000000000000000000000000000..b1db4d17dd4c0fcc90757ab9d7e4d419505303a8 GIT binary patch literal 135168 zcmeI*%WvC89tUtswna;}V>@ZXI&Hf2!GKt}jqSoMil9gphfa+;Qde<;+WT-hXs1+|FXcMmmP9M4u=mZ8~ETAz6ILYoX2nG z!VpW-zEBQ@Xy=yhl$`9`e!ckbJmaf zvG=}FC%w#bzj=A&x!-0dXMUWyRs5kiH}!O?AU!SoQdpb#Y2v-{pT}?Jf6f0Zt129a zz(9e*m4ctXf4LxUsJXArjzLb%Q@z`8U1E3iR%4H}9Njd;16MZdT4h&LwySrvM@m|L zrB+qMGD_Jr)b4^L-&AwGVn(0?yI~Nq+~py${32=L`~>ycyd-Nh;`NM(L9loi7lwr8 z<FAA4^U$(`GR%B_Nc^x^TKF(geV)H0$*P)ram#J&wPLhl z#JPUN-OBA9EphTnMbn_?dsn-w)s<>(SE)YO*;!Fa#h{XmWOjNUR_dFdR_f&qRVdrB zT;eFZ+M`{sCcopv(D%-iAphFxsv1qA>+Kyn@OHQFX+lL!tFvz!q|+qI)=sVBR~I5E zw`;W>tx^?A*yPwGC&Wli{Jg$>uTp=ke5O72hppFyty*2%zFYN#!X>qLZB~*W+{pDl4#q$tr|!57+tppOP3TYUZX&~__;^UtGdJUyUV;vm$*-BpriK+Vm{eVXJwp8x6NXukgX*?yYsCiyV^{ zUFYv<_3h2_`s#|Z;V;IHYjyRO3ivmySUkx4> z3i7*4xppvvc?NnG_%Bj?a_}VYY|$lxAMk5#P0|HvDY78(biSzD4qJ$k~Kl;~XwUr+%?C0X`7#G`ds`Dh??05sbz%1JHt2s$tTFSjR@s~ATnE#aT zM_>yXKR#du>w{-4En&a2mZq?M8B0o+bYc@c7PtSUfO!NM86aE7M2tWV=5P$##AOHafKmY;|fWU|fl)Qf7 z`F}*uF1iH)2tWV=5P$##AOHafKmY;|U;=pl$4Gzx1Rwwb2tWV=5P$##AOHafjJ^P# z|408EqlXZH00bZa0SG_<0uX=z1RwwbJpW@1KmY;|fB*y_009U<00Izz00c%~0MGxU ze~!^Z2tWV=5P$##AOHafKmY;|fB^pfe~bYLKmY;|fB*y_009U<00Izzz~~F$`~T>l zWAqRL5P$##AOHafKmY;|fB*y_faiaV0SG_<0uX=z1Rwwb2tWV=5P-nw3*h;G^v^MR z2muH{00Izz00bZa0SG_<0uaFSKgIwAAOHafKmY;|fB*y_009U{6G5V7(Ijl z1Rwwb2tWV=5P$##AOHaf;Q1e800Izz00bZa0SG_<0uX=z1RyZ_0{H*`js7`C4x{*N&L0SG_<0uX=z1Rwwb2tWV=5Ey*{JpYgWIYtj5009U< z00Izz00bZa0SG_<0(kz%7=Qo-AOHafKmY;|fB*y_009V$z5t&8NBgiNLdRq9Uur~42#CzjEkKfGyI`+jmc>GDfBFW3kxo>lBW3NRV;vRMN zR%4H}oaoQ5ZPvBQuBPl(ZtrMHG({RAkxQvo28{ zR_dFdR_f)obyduA*50#PVe4?pf^kf!#286F0;rSVwNtYDbb+RWC}n)1v*nr`5MN%j>Hv%DR8l zha-l*cNTVOZIu-Wk{WJ$M>>{E9A#H~wCh#>d42m{rT$p?Ona=9O~bEzt5(;x?^eBI zW!7y~scTzWU8`ycyd-Nh;`NM(L9loi7lwr8<MPs6NkMl4MoQy}0Es9bydex4C{qwtR?b0~U}VBcI^H+#fh^#rYJ>zE_x-O{=qS z8l=-C%GOS;;!n>ILAhP4?P!&%&@P)Co8*KTsfoj8Ht+LG&92_HSxJ6yBiH*l7z2rQ zKzH1R?dq=CCiJIvH<95|d^{v_#H;WIMk13^js;SC@61SYjW({y8fWm??_Cvi$6x&i zv}{&v*|72%%?e1pm@5hkO8*p9)ZWyzByZEkE@w-6hWmD-O%k1-5nFnyCvy1R^b2P+ zGpt;7N^~isEeV&hWM!t2I5s`J33MW*pE-Jo=`9r5$y&?qe^fFNzcf64w0uTOqFK~u zQ*?o&5a9xs2-+V*iQ~oLjbQ^JqZuqweX$}*@{Jq0?<9YXON2=rnCi!6D_tV_B^EpJ ze(3}8Zp1KcdXcBMMB)9#hhm0s5QQ1u7c=yG9o?zqHRk$$g2ItHJ#=~=A9R`b}q zZ)?PEwn)dd&-4~KCgKg$hJTOaxE8(rv>ZpcxngN<;|bBbwzY4%DXE9%!C_{#E$bvR z%`v~E)tX(0FYk_Plg@#As5?!|jx781Qs9zP*Y8$%s}`Q)tUJ4=P3Y~bp=YHJyCMjb z!bO3n_;)}h^7y4uq&^{<=s(!x@B zfA|pRuKIPx*cXwvh4T16Hs;c-GPq_B(573fq^lq;yFqHt~hqE^h@Gluj3cLCH3UCT!)EqT*L!e%`I%o*vH$-Z-f5r;2tWV=5P$##AOHafKmY;| zfWQS8@V@_F$^Bz&?*65}&aTb;WoCW)@6(T_{yp{m)Qa>s>65}g$NxOOH-0JqZGMAh z!#@{Mp!2pQ-&)M|7J}EXuD8L!(b=o-@GTcRx)gttm!8>sm&a0+a;zS;ex1JcUF2_l z`!%rFas5IWQTng!B9*fvYA^p5{q#aRvB5gQ-*nP9i0+ZY-!ifQzY_ywXODJ-#qvMJ z(JEJzM91Dt@0i~L7WZv2?x`Mh-p?Zh^N*-o*9!9TV(t&Ib7LB;;Ps69*lTf?$A2aM z&9y~IUR=z5Kj*(fWhv~xZ}sb7uTFz9+1q3W?}15v@`>+I@i)APpN-=CXqZOuX7(fT zlMrL?+9v(P9o^FsixWFL=`U$nChwa{@S;?BM@x5rxbz!NGLrW?)lUqkM_7ND6s^@e z`05{vY}0w7H`>-wNBoq;cJZX|?V(3wpNZ(cptk96P!ibsm9Sknv9DijOBCPFa2{`w z#hnhrz9t3VSLqku|9$sR(0rk&xEqYtTT)NnkmSm8uJ?K{%c6O7_KP=Qqle#&&%9`2 zq>MMGqe=9Wj=yQmt55w}rdx7o;BZS$G$*ZuZzqc;TpJRamr4uHQ`IL|CHW&7eloq^ z{zQ)kUK|{lEey&2vQJole7zvAE$5C>W0^)Y&Km>cu(d6k-(TSlUz6kyX=ie&0sPH3 z(K1&D#t64kMT;n9wvh^(5KF~@o0+f)u~dw1f5aw4y1SFj9c)6xCPZvP40aXMX+fV) zqPrc{|Eo=i`2K$pAL!902tWV=5P$##AOHafKmY;|fWY|*;Q9Z2Jy9D3AOHafKmY;| zfB*y_009U<;35j(`TruG6|@Ng5P$##AOHafKmY;|fB*y_aJ~X~{y$$&)CK_vKmY;| zfB*y_009U<00I!Whyvc<|2H;&5nF~fK>z{}fB*y_009U<00Izz00ba#UIO_3e_oQP z2Lcd)00bZa0SG_<0uX=z1R!v61@QcTanB7}1_1~_00Izz00bZa0SG_<0uVSa0X+Yo zmn7=6.7.0 # Optional, for colored console output # Data processing pyarrow>=12.0.0 # For Parquet support +pytz>=2023.3 # Timezone support # Utilities click>=8.1.0 # CLI framework - diff --git a/scripts/download_data.py b/scripts/download_data.py new file mode 100755 index 0000000..bdae3de --- /dev/null +++ b/scripts/download_data.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +"""Download DAX OHLCV data from external sources.""" + +import argparse +import sys +from pathlib import Path + +# Add project root to path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from src.core.enums import Timeframe # noqa: E402 +from src.logging import get_logger # noqa: E402 + +logger = get_logger(__name__) + + +def download_from_csv( + input_file: str, + symbol: str, + timeframe: Timeframe, + output_dir: Path, +) -> None: + """ + Copy/convert CSV file to standard format. + + Args: + input_file: Path to input CSV file + symbol: Trading symbol + timeframe: Timeframe enum + output_dir: Output directory + """ + from src.data.loaders import CSVLoader + + loader = CSVLoader() + df = loader.load(input_file, symbol=symbol, timeframe=timeframe) + + # Ensure output directory exists + output_dir.mkdir(parents=True, exist_ok=True) + + # Save as CSV + output_file = output_dir / f"{symbol}_{timeframe.value}.csv" + df.to_csv(output_file, index=False) + logger.info(f"Saved {len(df)} rows to {output_file}") + + # Also save as Parquet for faster loading + output_parquet = output_dir / f"{symbol}_{timeframe.value}.parquet" + df.to_parquet(output_parquet, index=False) + logger.info(f"Saved {len(df)} rows to {output_parquet}") + + +def download_from_api( + symbol: str, + timeframe: Timeframe, + start_date: str, + end_date: str, + output_dir: Path, + api_provider: str = "manual", +) -> None: + """ + Download data from API (placeholder for future implementation). + + Args: + symbol: Trading symbol + timeframe: Timeframe enum + start_date: Start date (YYYY-MM-DD) + end_date: End date (YYYY-MM-DD) + output_dir: Output directory + api_provider: API provider name + """ + logger.warning( + "API download not yet implemented. " "Please provide CSV file using --input-file option." + ) + logger.info( + f"Would download {symbol} {timeframe.value} data " f"from {start_date} to {end_date}" + ) + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Download DAX OHLCV data", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Download from CSV file + python scripts/download_data.py --input-file data.csv \\ + --symbol DAX --timeframe 1min \\ + --output data/raw/ohlcv/1min/ + + # Download from API (when implemented) + python scripts/download_data.py --symbol DAX --timeframe 5min \\ + --start 2024-01-01 --end 2024-01-31 \\ + --output data/raw/ohlcv/5min/ + """, + ) + + # Input options + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument( + "--input-file", + type=str, + help="Path to input CSV file", + ) + input_group.add_argument( + "--api", + action="store_true", + help="Download from API (not yet implemented)", + ) + + # Required arguments + parser.add_argument( + "--symbol", + type=str, + default="DAX", + help="Trading symbol (default: DAX)", + ) + parser.add_argument( + "--timeframe", + type=str, + choices=["1min", "5min", "15min"], + required=True, + help="Timeframe", + ) + parser.add_argument( + "--output", + type=str, + required=True, + help="Output directory", + ) + + # Optional arguments for API download + parser.add_argument( + "--start", + type=str, + help="Start date (YYYY-MM-DD) for API download", + ) + parser.add_argument( + "--end", + type=str, + help="End date (YYYY-MM-DD) for API download", + ) + + args = parser.parse_args() + + try: + # Convert timeframe string to enum + timeframe_map = { + "1min": Timeframe.M1, + "5min": Timeframe.M5, + "15min": Timeframe.M15, + } + timeframe = timeframe_map[args.timeframe] + + # Create output directory + output_dir = Path(args.output) + output_dir.mkdir(parents=True, exist_ok=True) + + # Download data + if args.input_file: + logger.info(f"Downloading from CSV: {args.input_file}") + download_from_csv(args.input_file, args.symbol, timeframe, output_dir) + elif args.api: + if not args.start or not args.end: + parser.error("--start and --end are required for API download") + download_from_api( + args.symbol, + timeframe, + args.start, + args.end, + output_dir, + ) + + logger.info("Data download completed successfully") + return 0 + + except Exception as e: + logger.error(f"Data download failed: {e}", exc_info=True) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/process_data.py b/scripts/process_data.py new file mode 100755 index 0000000..37effee --- /dev/null +++ b/scripts/process_data.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +"""Batch process OHLCV data: clean, filter, and save.""" + +import argparse +import sys +from pathlib import Path + +# Add project root to path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from src.core.enums import Timeframe # noqa: E402 +from src.data.database import get_db_session # noqa: E402 +from src.data.loaders import load_and_preprocess # noqa: E402 +from src.data.models import OHLCVData # noqa: E402 +from src.data.repositories import OHLCVRepository # noqa: E402 +from src.logging import get_logger # noqa: E402 + +logger = get_logger(__name__) + + +def process_file( + input_file: Path, + symbol: str, + timeframe: Timeframe, + output_dir: Path, + save_to_db: bool = False, + filter_session_hours: bool = True, +) -> None: + """ + Process a single data file. + + Args: + input_file: Path to input file + symbol: Trading symbol + timeframe: Timeframe enum + output_dir: Output directory + save_to_db: Whether to save to database + filter_session_hours: Whether to filter to trading session (3-4 AM EST) + """ + logger.info(f"Processing file: {input_file}") + + # Load and preprocess + df = load_and_preprocess( + str(input_file), + loader_type="auto", + validate=True, + preprocess=True, + filter_to_session=filter_session_hours, + ) + + # Ensure symbol and timeframe columns + df["symbol"] = symbol + df["timeframe"] = timeframe.value + + # Save processed CSV + output_dir.mkdir(parents=True, exist_ok=True) + output_csv = output_dir / f"{symbol}_{timeframe.value}_processed.csv" + df.to_csv(output_csv, index=False) + logger.info(f"Saved processed CSV: {output_csv} ({len(df)} rows)") + + # Save processed Parquet + output_parquet = output_dir / f"{symbol}_{timeframe.value}_processed.parquet" + df.to_parquet(output_parquet, index=False) + logger.info(f"Saved processed Parquet: {output_parquet} ({len(df)} rows)") + + # Save to database if requested + if save_to_db: + logger.info("Saving to database...") + with get_db_session() as session: + repo = OHLCVRepository(session=session) + + # Convert DataFrame to OHLCVData models + records = [] + for _, row in df.iterrows(): + # Check if record already exists + if repo.exists(symbol, timeframe, row["timestamp"]): + continue + + record = OHLCVData( + symbol=symbol, + timeframe=timeframe, + timestamp=row["timestamp"], + open=row["open"], + high=row["high"], + low=row["low"], + close=row["close"], + volume=row.get("volume"), + ) + records.append(record) + + if records: + repo.create_batch(records) + logger.info(f"Saved {len(records)} records to database") + else: + logger.info("No new records to save (all already exist)") + + +def process_directory( + input_dir: Path, + output_dir: Path, + symbol: str = "DAX", + save_to_db: bool = False, + filter_session_hours: bool = True, +) -> None: + """ + Process all data files in a directory. + + Args: + input_dir: Input directory + output_dir: Output directory + symbol: Trading symbol + save_to_db: Whether to save to database + filter_session_hours: Whether to filter to trading session + """ + # Find all CSV and Parquet files + files = list(input_dir.glob("*.csv")) + list(input_dir.glob("*.parquet")) + + if not files: + logger.warning(f"No data files found in {input_dir}") + return + + # Detect timeframe from directory name or file + timeframe_map = { + "1min": Timeframe.M1, + "5min": Timeframe.M5, + "15min": Timeframe.M15, + } + + timeframe = None + for tf_name, tf_enum in timeframe_map.items(): + if tf_name in str(input_dir): + timeframe = tf_enum + break + + if timeframe is None: + logger.error(f"Could not determine timeframe from directory: {input_dir}") + return + + logger.info(f"Processing {len(files)} files from {input_dir}") + + for file_path in files: + try: + process_file( + file_path, + symbol, + timeframe, + output_dir, + save_to_db, + filter_session_hours, + ) + except Exception as e: + logger.error(f"Failed to process {file_path}: {e}", exc_info=True) + continue + + logger.info("Batch processing completed") + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Batch process OHLCV data", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Process single file + python scripts/process_data.py --input data/raw/ohlcv/1min/m1.csv \\ + --output data/processed/ --symbol DAX --timeframe 1min + + # Process directory + python scripts/process_data.py --input data/raw/ohlcv/1min/ \\ + --output data/processed/ --symbol DAX + + # Process and save to database + python scripts/process_data.py --input data/raw/ohlcv/1min/ \\ + --output data/processed/ --save-db + """, + ) + + parser.add_argument( + "--input", + type=str, + required=True, + help="Input file or directory", + ) + parser.add_argument( + "--output", + type=str, + required=True, + help="Output directory", + ) + parser.add_argument( + "--symbol", + type=str, + default="DAX", + help="Trading symbol (default: DAX)", + ) + parser.add_argument( + "--timeframe", + type=str, + choices=["1min", "5min", "15min"], + help="Timeframe (required if processing single file)", + ) + parser.add_argument( + "--save-db", + action="store_true", + help="Save processed data to database", + ) + parser.add_argument( + "--no-session-filter", + action="store_true", + help="Don't filter to trading session hours (3-4 AM EST)", + ) + + args = parser.parse_args() + + try: + input_path = Path(args.input) + output_dir = Path(args.output) + + if not input_path.exists(): + logger.error(f"Input path does not exist: {input_path}") + return 1 + + # Process single file or directory + if input_path.is_file(): + if not args.timeframe: + parser.error("--timeframe is required when processing a single file") + return 1 + + timeframe_map = { + "1min": Timeframe.M1, + "5min": Timeframe.M5, + "15min": Timeframe.M15, + } + timeframe = timeframe_map[args.timeframe] + + process_file( + input_path, + args.symbol, + timeframe, + output_dir, + save_to_db=args.save_db, + filter_session_hours=not args.no_session_filter, + ) + + elif input_path.is_dir(): + process_directory( + input_path, + output_dir, + symbol=args.symbol, + save_to_db=args.save_db, + filter_session_hours=not args.no_session_filter, + ) + + else: + logger.error(f"Input path is neither file nor directory: {input_path}") + return 1 + + logger.info("Data processing completed successfully") + return 0 + + except Exception as e: + logger.error(f"Data processing failed: {e}", exc_info=True) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/setup_database.py b/scripts/setup_database.py new file mode 100755 index 0000000..0772407 --- /dev/null +++ b/scripts/setup_database.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +"""Initialize database and create tables.""" + +import argparse +import sys +from pathlib import Path + +# Add project root to path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from src.data.database import init_database # noqa: E402 +from src.logging import get_logger # noqa: E402 + +logger = get_logger(__name__) + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser(description="Initialize database and create tables") + parser.add_argument( + "--skip-tables", + action="store_true", + help="Skip table creation (useful for testing connection only)", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + try: + logger.info("Initializing database...") + init_database(create_tables=not args.skip_tables) + logger.info("Database initialization completed successfully") + return 0 + + except Exception as e: + logger.error(f"Database initialization failed: {e}", exc_info=True) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/validate_data_pipeline.py b/scripts/validate_data_pipeline.py new file mode 100755 index 0000000..aff8037 --- /dev/null +++ b/scripts/validate_data_pipeline.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +"""Validate data pipeline setup for v0.2.0.""" + +import sys +from pathlib import Path + +# Add project root to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from src.core.enums import Timeframe # noqa: E402 +from src.data.database import get_engine, init_database # noqa: E402 +from src.data.loaders import CSVLoader, ParquetLoader # noqa: E402 +from src.data.preprocessors import ( # noqa: E402 + filter_session, + handle_missing_data, + remove_duplicates, +) +from src.data.repositories import OHLCVRepository # noqa: E402 +from src.data.validators import check_continuity, detect_outliers, validate_ohlcv # noqa: E402 +from src.logging import get_logger # noqa: E402 + +logger = get_logger(__name__) + + +def validate_imports(): + """Validate that all data module imports work.""" + print("✓ Data module imports successful") + + +def validate_database(): + """Validate database setup.""" + try: + engine = get_engine() + assert engine is not None + print("✓ Database engine created") + + # Test initialization (will create tables if needed) + init_database(create_tables=True) + print("✓ Database initialization successful") + except Exception as e: + print(f"✗ Database validation failed: {e}") + raise + + +def validate_loaders(): + """Validate data loaders.""" + try: + csv_loader = CSVLoader() + parquet_loader = ParquetLoader() + assert csv_loader is not None + assert parquet_loader is not None + print("✓ Data loaders initialized") + except Exception as e: + print(f"✗ Loader validation failed: {e}") + raise + + +def validate_preprocessors(): + """Validate preprocessors.""" + import pandas as pd + import pytz # type: ignore[import-untyped] + + # Create sample data with EST timezone (trading session is 3-4 AM EST) + est = pytz.timezone("America/New_York") + timestamps = pd.date_range("2024-01-01 03:00", periods=10, freq="1min", tz=est) + + df = pd.DataFrame( + { + "timestamp": timestamps, + "open": [100.0] * 10, + "high": [100.5] * 10, + "low": [99.5] * 10, + "close": [100.2] * 10, + } + ) + + # Test preprocessors + df_processed = handle_missing_data(df) + df_processed = remove_duplicates(df_processed) + df_filtered = filter_session(df_processed) + + assert len(df_filtered) > 0 + print("✓ Preprocessors working") + + +def validate_validators(): + """Validate validators.""" + import pandas as pd + + # Create valid data (timezone not required for validators) + df = pd.DataFrame( + { + "timestamp": pd.date_range("2024-01-01 03:00", periods=10, freq="1min"), + "open": [100.0] * 10, + "high": [100.5] * 10, + "low": [99.5] * 10, + "close": [100.2] * 10, + } + ) + + # Test validators + df_validated = validate_ohlcv(df) + is_continuous, gaps = check_continuity(df_validated, Timeframe.M1) + _ = detect_outliers(df_validated) # Check it runs without error + + assert len(df_validated) == 10 + print("✓ Validators working") + + +def validate_repositories(): + """Validate repositories.""" + from src.data.database import get_db_session + + try: + with get_db_session() as session: + repo = OHLCVRepository(session=session) + assert repo is not None + print("✓ Repositories working") + except Exception as e: + print(f"✗ Repository validation failed: {e}") + raise + + +def validate_directories(): + """Validate directory structure.""" + required_dirs = [ + "data/raw/ohlcv/1min", + "data/raw/ohlcv/5min", + "data/raw/ohlcv/15min", + "data/processed/features", + "data/processed/patterns", + "data/labels/individual_patterns", + ] + + for dir_name in required_dirs: + dir_path = Path(dir_name) + if not dir_path.exists(): + print(f"✗ Missing directory: {dir_name}") + return False + + print("✓ Directory structure valid") + return True + + +def main(): + """Run all validation checks.""" + print("Validating ICT ML Trading System v0.2.0 Data Pipeline...") + print("-" * 60) + + try: + validate_imports() + validate_database() + validate_loaders() + validate_preprocessors() + validate_validators() + validate_repositories() + validate_directories() + + print("-" * 60) + print("✓ All validations passed!") + return 0 + + except Exception as e: + print(f"✗ Validation failed: {e}") + import traceback + + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/config/config_loader.py b/src/config/config_loader.py index 1292729..1e9a48d 100644 --- a/src/config/config_loader.py +++ b/src/config/config_loader.py @@ -81,7 +81,7 @@ def load_config(config_path: Optional[Path] = None) -> Dict[str, Any]: _config = config logger.info("Configuration loaded successfully") - return config + return config # type: ignore[no-any-return] except Exception as e: raise ConfigurationError( @@ -150,4 +150,3 @@ def _substitute_env_vars(config: Any) -> Any: return config else: return config - diff --git a/src/core/constants.py b/src/core/constants.py index efa409b..cc107a0 100644 --- a/src/core/constants.py +++ b/src/core/constants.py @@ -1,7 +1,7 @@ """Application-wide constants.""" from pathlib import Path -from typing import Dict, List +from typing import Any, Dict, List # Project root directory PROJECT_ROOT = Path(__file__).parent.parent.parent @@ -50,7 +50,7 @@ PATTERN_THRESHOLDS: Dict[str, float] = { } # Model configuration -MODEL_CONFIG: Dict[str, any] = { +MODEL_CONFIG: Dict[str, Any] = { "min_labels_per_pattern": 200, "train_test_split": 0.8, "validation_split": 0.1, @@ -70,9 +70,8 @@ LOG_LEVELS: List[str] = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] LOG_FORMATS: List[str] = ["json", "text"] # Database constants -DB_CONSTANTS: Dict[str, any] = { +DB_CONSTANTS: Dict[str, Any] = { "pool_size": 10, "max_overflow": 20, "pool_timeout": 30, } - diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..8ed18ab --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1,41 @@ +"""Data management module for ICT ML Trading System.""" + +from src.data.database import get_engine, get_session, init_database +from src.data.loaders import CSVLoader, DatabaseLoader, ParquetLoader +from src.data.models import DetectedPattern, OHLCVData, PatternLabel, SetupLabel, Trade +from src.data.preprocessors import filter_session, handle_missing_data, remove_duplicates +from src.data.repositories import OHLCVRepository, PatternRepository, Repository +from src.data.schemas import OHLCVSchema, PatternSchema +from src.data.validators import check_continuity, detect_outliers, validate_ohlcv + +__all__ = [ + # Database + "get_engine", + "get_session", + "init_database", + # Models + "OHLCVData", + "DetectedPattern", + "PatternLabel", + "SetupLabel", + "Trade", + # Loaders + "CSVLoader", + "ParquetLoader", + "DatabaseLoader", + # Preprocessors + "handle_missing_data", + "remove_duplicates", + "filter_session", + # Validators + "validate_ohlcv", + "check_continuity", + "detect_outliers", + # Repositories + "Repository", + "OHLCVRepository", + "PatternRepository", + # Schemas + "OHLCVSchema", + "PatternSchema", +] diff --git a/src/data/database.py b/src/data/database.py new file mode 100644 index 0000000..5714226 --- /dev/null +++ b/src/data/database.py @@ -0,0 +1,212 @@ +"""Database connection and session management.""" + +import os +from contextlib import contextmanager +from typing import Generator, Optional + +from sqlalchemy import create_engine, event +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session, sessionmaker + +from src.config import get_config +from src.core.constants import DB_CONSTANTS +from src.core.exceptions import ConfigurationError, DataError +from src.logging import get_logger + +logger = get_logger(__name__) + +# Global engine and session factory +_engine: Optional[Engine] = None +_SessionLocal: Optional[sessionmaker] = None + + +def get_database_url() -> str: + """ + Get database URL from config or environment variable. + + Returns: + Database URL string + + Raises: + ConfigurationError: If database URL cannot be determined + """ + try: + config = get_config() + db_config = config.get("database", {}) + database_url = os.getenv("DATABASE_URL") or db_config.get("database_url") + + if not database_url: + raise ConfigurationError( + "Database URL not found in configuration or environment variables", + context={"config": db_config}, + ) + + # Handle SQLite path expansion + if database_url.startswith("sqlite:///"): + db_path = database_url.replace("sqlite:///", "") + if not os.path.isabs(db_path): + # Relative path - make it absolute from project root + from src.core.constants import PROJECT_ROOT + + db_path = str(PROJECT_ROOT / db_path) + database_url = f"sqlite:///{db_path}" + + db_display = database_url.split("@")[-1] if "@" in database_url else "sqlite" + logger.debug(f"Database URL configured: {db_display}") + return database_url + + except Exception as e: + raise ConfigurationError( + f"Failed to get database URL: {e}", + context={"error": str(e)}, + ) from e + + +def get_engine() -> Engine: + """ + Get or create SQLAlchemy engine with connection pooling. + + Returns: + SQLAlchemy engine instance + """ + global _engine + + if _engine is not None: + return _engine + + database_url = get_database_url() + db_config = get_config().get("database", {}) + + # Connection pool settings + pool_size = db_config.get("pool_size", DB_CONSTANTS["pool_size"]) + max_overflow = db_config.get("max_overflow", DB_CONSTANTS["max_overflow"]) + pool_timeout = db_config.get("pool_timeout", DB_CONSTANTS["pool_timeout"]) + pool_recycle = db_config.get("pool_recycle", 3600) + + # SQLite-specific settings + connect_args = {} + if database_url.startswith("sqlite"): + sqlite_config = db_config.get("sqlite", {}) + connect_args = { + "check_same_thread": sqlite_config.get("check_same_thread", False), + "timeout": sqlite_config.get("timeout", 20), + } + + # PostgreSQL-specific settings + elif database_url.startswith("postgresql"): + postgres_config = db_config.get("postgresql", {}) + connect_args = postgres_config.get("connect_args", {}) + + try: + _engine = create_engine( + database_url, + pool_size=pool_size, + max_overflow=max_overflow, + pool_timeout=pool_timeout, + pool_recycle=pool_recycle, + connect_args=connect_args, + echo=db_config.get("echo", False), + echo_pool=db_config.get("echo_pool", False), + ) + + # Add connection event listeners + @event.listens_for(_engine, "connect") + def set_sqlite_pragma(dbapi_conn, connection_record): + """Set SQLite pragmas for better performance.""" + if database_url.startswith("sqlite"): + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.execute("PRAGMA journal_mode=WAL") + cursor.close() + + logger.info(f"Database engine created: pool_size={pool_size}, max_overflow={max_overflow}") + return _engine + + except Exception as e: + raise DataError( + f"Failed to create database engine: {e}", + context={ + "database_url": database_url.split("@")[-1] if "@" in database_url else "sqlite" + }, + ) from e + + +def get_session() -> sessionmaker: + """ + Get or create session factory. + + Returns: + SQLAlchemy sessionmaker instance + """ + global _SessionLocal + + if _SessionLocal is not None: + return _SessionLocal + + engine = get_engine() + _SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False) + logger.debug("Session factory created") + return _SessionLocal + + +@contextmanager +def get_db_session() -> Generator[Session, None, None]: + """ + Context manager for database sessions. + + Yields: + Database session + + Example: + >>> with get_db_session() as session: + ... data = session.query(OHLCVData).all() + """ + SessionLocal = get_session() + session = SessionLocal() + try: + yield session + session.commit() + except Exception as e: + session.rollback() + logger.error(f"Database session error: {e}", exc_info=True) + raise DataError(f"Database operation failed: {e}") from e + finally: + session.close() + + +def init_database(create_tables: bool = True) -> None: + """ + Initialize database and create tables. + + Args: + create_tables: Whether to create tables if they don't exist + + Raises: + DataError: If database initialization fails + """ + try: + engine = get_engine() + database_url = get_database_url() + + # Create data directory for SQLite if needed + if database_url.startswith("sqlite"): + db_path = database_url.replace("sqlite:///", "") + db_dir = os.path.dirname(db_path) + if db_dir and not os.path.exists(db_dir): + os.makedirs(db_dir, exist_ok=True) + logger.info(f"Created database directory: {db_dir}") + + if create_tables: + # Import models to register them with SQLAlchemy + from src.data.models import Base + + Base.metadata.create_all(bind=engine) + logger.info("Database tables created successfully") + + logger.info("Database initialized successfully") + + except Exception as e: + raise DataError( + f"Failed to initialize database: {e}", + context={"create_tables": create_tables}, + ) from e diff --git a/src/data/loaders.py b/src/data/loaders.py new file mode 100644 index 0000000..130c7cd --- /dev/null +++ b/src/data/loaders.py @@ -0,0 +1,337 @@ +"""Data loaders for various data sources.""" + +from pathlib import Path +from typing import Optional + +import pandas as pd + +from src.core.enums import Timeframe +from src.core.exceptions import DataError +from src.data.preprocessors import filter_session, handle_missing_data, remove_duplicates +from src.data.validators import validate_ohlcv +from src.logging import get_logger + +logger = get_logger(__name__) + + +class BaseLoader: + """Base class for data loaders.""" + + def load(self, source: str, **kwargs) -> pd.DataFrame: + """ + Load data from source. + + Args: + source: Data source path/identifier + **kwargs: Additional loader-specific arguments + + Returns: + DataFrame with loaded data + + Raises: + DataError: If loading fails + """ + raise NotImplementedError("Subclasses must implement load()") + + +class CSVLoader(BaseLoader): + """Loader for CSV files.""" + + def load( # type: ignore[override] + self, + file_path: str, + symbol: Optional[str] = None, + timeframe: Optional[Timeframe] = None, + **kwargs, + ) -> pd.DataFrame: + """ + Load OHLCV data from CSV file. + + Args: + file_path: Path to CSV file + symbol: Optional symbol to add to DataFrame + timeframe: Optional timeframe to add to DataFrame + **kwargs: Additional pandas.read_csv arguments + + Returns: + DataFrame with OHLCV data + + Raises: + DataError: If file cannot be loaded + """ + file_path_obj = Path(file_path) + if not file_path_obj.exists(): + raise DataError( + f"CSV file not found: {file_path}", + context={"file_path": str(file_path)}, + ) + + try: + # Default CSV reading options + read_kwargs = { + "parse_dates": ["timestamp"], + "index_col": False, + } + read_kwargs.update(kwargs) + + df = pd.read_csv(file_path, **read_kwargs) + + # Ensure timestamp column exists + if "timestamp" not in df.columns and "time" in df.columns: + df.rename(columns={"time": "timestamp"}, inplace=True) + + # Add metadata if provided + if symbol: + df["symbol"] = symbol + if timeframe: + df["timeframe"] = timeframe.value + + # Standardize column names (case-insensitive) + column_mapping = { + "open": "open", + "high": "high", + "low": "low", + "close": "close", + "volume": "volume", + } + for old_name, new_name in column_mapping.items(): + if old_name.lower() in [col.lower() for col in df.columns]: + matching_col = [col for col in df.columns if col.lower() == old_name.lower()][0] + if matching_col != new_name: + df.rename(columns={matching_col: new_name}, inplace=True) + + logger.info(f"Loaded {len(df)} rows from CSV: {file_path}") + return df + + except Exception as e: + raise DataError( + f"Failed to load CSV file: {e}", + context={"file_path": str(file_path)}, + ) from e + + +class ParquetLoader(BaseLoader): + """Loader for Parquet files.""" + + def load( # type: ignore[override] + self, + file_path: str, + symbol: Optional[str] = None, + timeframe: Optional[Timeframe] = None, + **kwargs, + ) -> pd.DataFrame: + """ + Load OHLCV data from Parquet file. + + Args: + file_path: Path to Parquet file + symbol: Optional symbol to add to DataFrame + timeframe: Optional timeframe to add to DataFrame + **kwargs: Additional pandas.read_parquet arguments + + Returns: + DataFrame with OHLCV data + + Raises: + DataError: If file cannot be loaded + """ + file_path_obj = Path(file_path) + if not file_path_obj.exists(): + raise DataError( + f"Parquet file not found: {file_path}", + context={"file_path": str(file_path)}, + ) + + try: + df = pd.read_parquet(file_path, **kwargs) + + # Add metadata if provided + if symbol: + df["symbol"] = symbol + if timeframe: + df["timeframe"] = timeframe.value + + logger.info(f"Loaded {len(df)} rows from Parquet: {file_path}") + return df + + except Exception as e: + raise DataError( + f"Failed to load Parquet file: {e}", + context={"file_path": str(file_path)}, + ) from e + + +class DatabaseLoader(BaseLoader): + """Loader for database data.""" + + def __init__(self, session=None): + """ + Initialize database loader. + + Args: + session: Optional database session (creates new if not provided) + """ + self.session = session + + def load( # type: ignore[override] + self, + symbol: str, + timeframe: Timeframe, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + limit: Optional[int] = None, + **kwargs, + ) -> pd.DataFrame: + """ + Load OHLCV data from database. + + Args: + symbol: Trading symbol + timeframe: Timeframe enum + start_date: Optional start date (ISO format or datetime string) + end_date: Optional end date (ISO format or datetime string) + limit: Optional limit on number of records + **kwargs: Additional query arguments + + Returns: + DataFrame with OHLCV data + + Raises: + DataError: If database query fails + """ + from src.data.database import get_db_session + from src.data.repositories import OHLCVRepository + + try: + # Use provided session or create new one + if self.session: + repo = OHLCVRepository(session=self.session) + session_context = None + else: + session_context = get_db_session() + session = session_context.__enter__() + repo = OHLCVRepository(session=session) + + # Parse dates + start = pd.to_datetime(start_date) if start_date else None + end = pd.to_datetime(end_date) if end_date else None + + # Query database + if start and end: + records = repo.get_by_timestamp_range(symbol, timeframe, start, end, limit) + else: + records = repo.get_latest(symbol, timeframe, limit or 1000) + + # Convert to DataFrame + data = [] + for record in records: + data.append( + { + "id": record.id, + "symbol": record.symbol, + "timeframe": record.timeframe.value, + "timestamp": record.timestamp, + "open": float(record.open), + "high": float(record.high), + "low": float(record.low), + "close": float(record.close), + "volume": record.volume, + } + ) + + df = pd.DataFrame(data) + + if session_context: + session_context.__exit__(None, None, None) + + logger.info( + f"Loaded {len(df)} rows from database: {symbol} {timeframe.value} " + f"({start_date} to {end_date})" + ) + return df + + except Exception as e: + raise DataError( + f"Failed to load data from database: {e}", + context={ + "symbol": symbol, + "timeframe": timeframe.value, + "start_date": start_date, + "end_date": end_date, + }, + ) from e + + +def load_and_preprocess( + source: str, + loader_type: str = "auto", + validate: bool = True, + preprocess: bool = True, + filter_to_session: bool = False, + **loader_kwargs, +) -> pd.DataFrame: + """ + Load data and optionally validate/preprocess it. + + Args: + source: Data source (file path or database identifier) + loader_type: Loader type ('csv', 'parquet', 'database', 'auto') + validate: Whether to validate data + preprocess: Whether to preprocess data (handle missing, remove duplicates) + filter_to_session: Whether to filter to trading session hours + **loader_kwargs: Additional arguments for loader + + Returns: + Processed DataFrame + + Raises: + DataError: If loading or processing fails + """ + # Auto-detect loader type + if loader_type == "auto": + source_path = Path(source) + if source_path.exists(): + if source_path.suffix.lower() == ".csv": + loader_type = "csv" + elif source_path.suffix.lower() == ".parquet": + loader_type = "parquet" + else: + raise DataError( + f"Cannot auto-detect loader type for: {source}", + context={"source": str(source)}, + ) + else: + loader_type = "database" + + # Create appropriate loader + loader: BaseLoader + if loader_type == "csv": + loader = CSVLoader() + elif loader_type == "parquet": + loader = ParquetLoader() + elif loader_type == "database": + loader = DatabaseLoader() + else: + raise DataError( + f"Invalid loader type: {loader_type}", + context={"valid_types": ["csv", "parquet", "database", "auto"]}, + ) + + # Load data + df = loader.load(source, **loader_kwargs) + + # Validate + if validate: + df = validate_ohlcv(df) + + # Preprocess + if preprocess: + df = handle_missing_data(df, method="forward_fill") + df = remove_duplicates(df) + + # Filter to session + if filter_to_session: + df = filter_session(df) + + logger.info(f"Loaded and processed {len(df)} rows from {source}") + return df diff --git a/src/data/models.py b/src/data/models.py new file mode 100644 index 0000000..e4cb9d1 --- /dev/null +++ b/src/data/models.py @@ -0,0 +1,223 @@ +"""SQLAlchemy ORM models for data storage.""" + +from datetime import datetime + +from sqlalchemy import ( + Boolean, + Column, + DateTime, + Enum, + Float, + ForeignKey, + Index, + Integer, + Numeric, + String, + Text, +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship + +from src.core.enums import ( + Grade, + OrderType, + PatternDirection, + PatternType, + SetupType, + Timeframe, + TradeDirection, + TradeStatus, +) + +Base = declarative_base() + + +class OHLCVData(Base): # type: ignore[valid-type,misc] + """OHLCV market data table.""" + + __tablename__ = "ohlcv_data" + + id = Column(Integer, primary_key=True, index=True) + symbol = Column(String(20), nullable=False, index=True) + timeframe = Column(Enum(Timeframe), nullable=False, index=True) + timestamp = Column(DateTime, nullable=False, index=True) + open = Column(Numeric(20, 5), nullable=False) + high = Column(Numeric(20, 5), nullable=False) + low = Column(Numeric(20, 5), nullable=False) + close = Column(Numeric(20, 5), nullable=False) + volume = Column(Integer, nullable=True) + + # Metadata + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + + # Relationships + patterns = relationship("DetectedPattern", back_populates="ohlcv_data") + + # Composite index for common queries + __table_args__ = (Index("idx_symbol_timeframe_timestamp", "symbol", "timeframe", "timestamp"),) + + def __repr__(self) -> str: + return ( + f"" + ) + + +class DetectedPattern(Base): # type: ignore[valid-type,misc] + """Detected ICT patterns table.""" + + __tablename__ = "detected_patterns" + + id = Column(Integer, primary_key=True, index=True) + pattern_type = Column(Enum(PatternType), nullable=False, index=True) + direction = Column(Enum(PatternDirection), nullable=False) + timeframe = Column(Enum(Timeframe), nullable=False, index=True) + symbol = Column(String(20), nullable=False, index=True) + + # Pattern location + start_timestamp = Column(DateTime, nullable=False, index=True) + end_timestamp = Column(DateTime, nullable=False) + ohlcv_data_id = Column(Integer, ForeignKey("ohlcv_data.id"), nullable=True) + + # Price levels + entry_level = Column(Numeric(20, 5), nullable=True) + stop_loss = Column(Numeric(20, 5), nullable=True) + take_profit = Column(Numeric(20, 5), nullable=True) + high_level = Column(Numeric(20, 5), nullable=True) + low_level = Column(Numeric(20, 5), nullable=True) + + # Pattern metadata + size_pips = Column(Float, nullable=True) + strength_score = Column(Float, nullable=True) + context_data = Column(Text, nullable=True) # JSON string for additional context + + # Metadata + detected_at = Column(DateTime, default=datetime.utcnow, nullable=False) + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + + # Relationships + ohlcv_data = relationship("OHLCVData", back_populates="patterns") + labels = relationship("PatternLabel", back_populates="pattern") + + # Composite index + __table_args__ = ( + Index("idx_pattern_type_symbol_timestamp", "pattern_type", "symbol", "start_timestamp"), + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + +class PatternLabel(Base): # type: ignore[valid-type,misc] + """Labels for individual patterns.""" + + __tablename__ = "pattern_labels" + + id = Column(Integer, primary_key=True, index=True) + pattern_id = Column(Integer, ForeignKey("detected_patterns.id"), nullable=False, index=True) + grade = Column(Enum(Grade), nullable=False, index=True) + notes = Column(Text, nullable=True) + + # Labeler metadata + labeled_by = Column(String(100), nullable=True) + labeled_at = Column(DateTime, default=datetime.utcnow, nullable=False) + confidence = Column(Float, nullable=True) # Labeler's confidence (0-1) + + # Quality checks + is_anchor = Column(Boolean, default=False, nullable=False, index=True) + reviewed = Column(Boolean, default=False, nullable=False) + + # Relationships + pattern = relationship("DetectedPattern", back_populates="labels") + + def __repr__(self) -> str: + return ( + f"" + ) + + +class SetupLabel(Base): # type: ignore[valid-type,misc] + """Labels for complete trading setups.""" + + __tablename__ = "setup_labels" + + id = Column(Integer, primary_key=True, index=True) + setup_type = Column(Enum(SetupType), nullable=False, index=True) + symbol = Column(String(20), nullable=False, index=True) + session_date = Column(DateTime, nullable=False, index=True) + + # Setup components (pattern IDs) + fvg_id = Column(Integer, ForeignKey("detected_patterns.id"), nullable=True) + order_block_id = Column(Integer, ForeignKey("detected_patterns.id"), nullable=True) + liquidity_id = Column(Integer, ForeignKey("detected_patterns.id"), nullable=True) + + # Label + grade = Column(Enum(Grade), nullable=False, index=True) + outcome = Column(String(50), nullable=True) # "win", "loss", "breakeven" + pnl = Column(Numeric(20, 2), nullable=True) + + # Labeler metadata + labeled_by = Column(String(100), nullable=True) + labeled_at = Column(DateTime, default=datetime.utcnow, nullable=False) + notes = Column(Text, nullable=True) + + def __repr__(self) -> str: + return ( + f"" + ) + + +class Trade(Base): # type: ignore[valid-type,misc] + """Trade execution records.""" + + __tablename__ = "trades" + + id = Column(Integer, primary_key=True, index=True) + symbol = Column(String(20), nullable=False, index=True) + direction = Column(Enum(TradeDirection), nullable=False) + order_type = Column(Enum(OrderType), nullable=False) + status = Column(Enum(TradeStatus), nullable=False, index=True) + + # Entry + entry_price = Column(Numeric(20, 5), nullable=False) + entry_timestamp = Column(DateTime, nullable=False, index=True) + entry_size = Column(Integer, nullable=False) + + # Exit + exit_price = Column(Numeric(20, 5), nullable=True) + exit_timestamp = Column(DateTime, nullable=True) + exit_size = Column(Integer, nullable=True) + + # Risk management + stop_loss = Column(Numeric(20, 5), nullable=True) + take_profit = Column(Numeric(20, 5), nullable=True) + risk_amount = Column(Numeric(20, 2), nullable=True) + + # P&L + pnl = Column(Numeric(20, 2), nullable=True) + pnl_pips = Column(Float, nullable=True) + commission = Column(Numeric(20, 2), nullable=True) + + # Related patterns + pattern_id = Column(Integer, ForeignKey("detected_patterns.id"), nullable=True) + setup_id = Column(Integer, ForeignKey("setup_labels.id"), nullable=True) + + # Metadata + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + notes = Column(Text, nullable=True) + + # Composite index + __table_args__ = (Index("idx_symbol_status_timestamp", "symbol", "status", "entry_timestamp"),) + + def __repr__(self) -> str: + return ( + f"" + ) diff --git a/src/data/preprocessors.py b/src/data/preprocessors.py new file mode 100644 index 0000000..016e498 --- /dev/null +++ b/src/data/preprocessors.py @@ -0,0 +1,181 @@ +"""Data preprocessing functions.""" + +from datetime import datetime +from typing import Optional + +import pandas as pd +import pytz # type: ignore[import-untyped] + +from src.core.constants import SESSION_TIMES +from src.core.exceptions import DataError +from src.logging import get_logger + +logger = get_logger(__name__) + + +def handle_missing_data( + df: pd.DataFrame, + method: str = "forward_fill", + columns: Optional[list] = None, +) -> pd.DataFrame: + """ + Handle missing data in DataFrame. + + Args: + df: DataFrame with potential missing values + method: Method to handle missing data + ('forward_fill', 'backward_fill', 'drop', 'interpolate') + columns: Specific columns to process (defaults to all numeric columns) + + Returns: + DataFrame with missing data handled + + Raises: + DataError: If method is invalid + """ + if df.empty: + return df + + if columns is None: + # Default to numeric columns + columns = df.select_dtypes(include=["number"]).columns.tolist() + + df_processed = df.copy() + missing_before = df_processed[columns].isna().sum().sum() + + if missing_before == 0: + logger.debug("No missing data found") + return df_processed + + logger.info(f"Handling {missing_before} missing values using method: {method}") + + for col in columns: + if col not in df_processed.columns: + continue + + if method == "forward_fill": + df_processed[col] = df_processed[col].ffill() + + elif method == "backward_fill": + df_processed[col] = df_processed[col].bfill() + + elif method == "drop": + df_processed = df_processed.dropna(subset=[col]) + + elif method == "interpolate": + df_processed[col] = df_processed[col].interpolate(method="linear") + + else: + raise DataError( + f"Invalid missing data method: {method}", + context={"valid_methods": ["forward_fill", "backward_fill", "drop", "interpolate"]}, + ) + + missing_after = df_processed[columns].isna().sum().sum() + logger.info(f"Missing data handled: {missing_before} -> {missing_after}") + + return df_processed + + +def remove_duplicates( + df: pd.DataFrame, + subset: Optional[list] = None, + keep: str = "first", + timestamp_col: str = "timestamp", +) -> pd.DataFrame: + """ + Remove duplicate rows from DataFrame. + + Args: + df: DataFrame with potential duplicates + subset: Columns to consider for duplicates (defaults to timestamp) + keep: Which duplicates to keep ('first', 'last', False to drop all) + timestamp_col: Name of timestamp column + + Returns: + DataFrame with duplicates removed + """ + if df.empty: + return df + + if subset is None: + subset = [timestamp_col] if timestamp_col in df.columns else None + + duplicates_before = len(df) + df_processed = df.drop_duplicates(subset=subset, keep=keep) + duplicates_removed = duplicates_before - len(df_processed) + + if duplicates_removed > 0: + logger.info(f"Removed {duplicates_removed} duplicate rows") + else: + logger.debug("No duplicates found") + + return df_processed + + +def filter_session( + df: pd.DataFrame, + timestamp_col: str = "timestamp", + session_start: Optional[str] = None, + session_end: Optional[str] = None, + timezone: str = "America/New_York", +) -> pd.DataFrame: + """ + Filter DataFrame to trading session hours (default: 3:00-4:00 AM EST). + + Args: + df: DataFrame with timestamp column + timestamp_col: Name of timestamp column + session_start: Session start time (HH:MM format, defaults to config) + session_end: Session end time (HH:MM format, defaults to config) + timezone: Timezone for session times (defaults to EST) + + Returns: + Filtered DataFrame + + Raises: + DataError: If timestamp column is missing or invalid + """ + if df.empty: + return df + + if timestamp_col not in df.columns: + raise DataError( + f"Timestamp column '{timestamp_col}' not found", + context={"columns": df.columns.tolist()}, + ) + + # Get session times from config or use defaults + if session_start is None: + session_start = SESSION_TIMES.get("start", "03:00") + if session_end is None: + session_end = SESSION_TIMES.get("end", "04:00") + + # Parse session times + start_time = datetime.strptime(session_start, "%H:%M").time() + end_time = datetime.strptime(session_end, "%H:%M").time() + + # Ensure timestamp is datetime + if not pd.api.types.is_datetime64_any_dtype(df[timestamp_col]): + df[timestamp_col] = pd.to_datetime(df[timestamp_col]) + + # Convert to session timezone if needed + tz = pytz.timezone(timezone) + if df[timestamp_col].dt.tz is None: + # Assume UTC if no timezone + df[timestamp_col] = df[timestamp_col].dt.tz_localize("UTC") + df[timestamp_col] = df[timestamp_col].dt.tz_convert(tz) + + # Filter by time of day + df_filtered = df[ + (df[timestamp_col].dt.time >= start_time) & (df[timestamp_col].dt.time <= end_time) + ].copy() + + rows_before = len(df) + rows_after = len(df_filtered) + logger.info( + f"Filtered to session {session_start}-{session_end} {timezone}: " + f"{rows_before} -> {rows_after} rows" + ) + + return df_filtered diff --git a/src/data/repositories.py b/src/data/repositories.py new file mode 100644 index 0000000..4f44178 --- /dev/null +++ b/src/data/repositories.py @@ -0,0 +1,355 @@ +"""Repository pattern for data access layer.""" + +from datetime import datetime +from typing import List, Optional + +from sqlalchemy import and_, desc +from sqlalchemy.orm import Session + +from src.core.enums import PatternType, Timeframe +from src.core.exceptions import DataError +from src.data.models import DetectedPattern, OHLCVData, PatternLabel +from src.logging import get_logger + +logger = get_logger(__name__) + + +class Repository: + """Base repository class with common database operations.""" + + def __init__(self, session: Optional[Session] = None): + """ + Initialize repository. + + Args: + session: Optional database session (creates new if not provided) + """ + self._session = session + + @property + def session(self) -> Session: + """Get database session.""" + if self._session is None: + # Use context manager for automatic cleanup + raise RuntimeError("Session must be provided or use context manager") + return self._session + + +class OHLCVRepository(Repository): + """Repository for OHLCV data operations.""" + + def create(self, data: OHLCVData) -> OHLCVData: + """ + Create new OHLCV record. + + Args: + data: OHLCVData instance + + Returns: + Created OHLCVData instance + + Raises: + DataError: If creation fails + """ + try: + self.session.add(data) + self.session.flush() + logger.debug(f"Created OHLCV record: {data.id}") + return data + except Exception as e: + logger.error(f"Failed to create OHLCV record: {e}", exc_info=True) + raise DataError(f"Failed to create OHLCV record: {e}") from e + + def create_batch(self, data_list: List[OHLCVData]) -> List[OHLCVData]: + """ + Create multiple OHLCV records in batch. + + Args: + data_list: List of OHLCVData instances + + Returns: + List of created OHLCVData instances + + Raises: + DataError: If batch creation fails + """ + try: + self.session.add_all(data_list) + self.session.flush() + logger.info(f"Created {len(data_list)} OHLCV records in batch") + return data_list + except Exception as e: + logger.error(f"Failed to create OHLCV records in batch: {e}", exc_info=True) + raise DataError(f"Failed to create OHLCV records: {e}") from e + + def get_by_id(self, record_id: int) -> Optional[OHLCVData]: + """ + Get OHLCV record by ID. + + Args: + record_id: Record ID + + Returns: + OHLCVData instance or None if not found + """ + result = self.session.query(OHLCVData).filter(OHLCVData.id == record_id).first() + return result # type: ignore[no-any-return] + + def get_by_timestamp_range( + self, + symbol: str, + timeframe: Timeframe, + start: datetime, + end: datetime, + limit: Optional[int] = None, + ) -> List[OHLCVData]: + """ + Get OHLCV data for symbol/timeframe within timestamp range. + + Args: + symbol: Trading symbol + timeframe: Timeframe enum + start: Start timestamp + end: End timestamp + limit: Optional limit on number of records + + Returns: + List of OHLCVData instances + """ + query = ( + self.session.query(OHLCVData) + .filter( + and_( + OHLCVData.symbol == symbol, + OHLCVData.timeframe == timeframe, + OHLCVData.timestamp >= start, + OHLCVData.timestamp <= end, + ) + ) + .order_by(OHLCVData.timestamp) + ) + + if limit: + query = query.limit(limit) + + result = query.all() + return result # type: ignore[no-any-return] + + def get_latest(self, symbol: str, timeframe: Timeframe, limit: int = 1) -> List[OHLCVData]: + """ + Get latest OHLCV records for symbol/timeframe. + + Args: + symbol: Trading symbol + timeframe: Timeframe enum + limit: Number of records to return + + Returns: + List of OHLCVData instances (most recent first) + """ + result = ( + self.session.query(OHLCVData) + .filter( + and_( + OHLCVData.symbol == symbol, + OHLCVData.timeframe == timeframe, + ) + ) + .order_by(desc(OHLCVData.timestamp)) + .limit(limit) + .all() + ) + return result # type: ignore[no-any-return] + + def exists(self, symbol: str, timeframe: Timeframe, timestamp: datetime) -> bool: + """ + Check if OHLCV record exists. + + Args: + symbol: Trading symbol + timeframe: Timeframe enum + timestamp: Record timestamp + + Returns: + True if record exists, False otherwise + """ + count = ( + self.session.query(OHLCVData) + .filter( + and_( + OHLCVData.symbol == symbol, + OHLCVData.timeframe == timeframe, + OHLCVData.timestamp == timestamp, + ) + ) + .count() + ) + return bool(count > 0) + + def delete_by_timestamp_range( + self, + symbol: str, + timeframe: Timeframe, + start: datetime, + end: datetime, + ) -> int: + """ + Delete OHLCV records within timestamp range. + + Args: + symbol: Trading symbol + timeframe: Timeframe enum + start: Start timestamp + end: End timestamp + + Returns: + Number of records deleted + """ + try: + deleted = ( + self.session.query(OHLCVData) + .filter( + and_( + OHLCVData.symbol == symbol, + OHLCVData.timeframe == timeframe, + OHLCVData.timestamp >= start, + OHLCVData.timestamp <= end, + ) + ) + .delete(synchronize_session=False) + ) + logger.info(f"Deleted {deleted} OHLCV records") + return int(deleted) + except Exception as e: + logger.error(f"Failed to delete OHLCV records: {e}", exc_info=True) + raise DataError(f"Failed to delete OHLCV records: {e}") from e + + +class PatternRepository(Repository): + """Repository for detected pattern operations.""" + + def create(self, pattern: DetectedPattern) -> DetectedPattern: + """ + Create new pattern record. + + Args: + pattern: DetectedPattern instance + + Returns: + Created DetectedPattern instance + + Raises: + DataError: If creation fails + """ + try: + self.session.add(pattern) + self.session.flush() + logger.debug(f"Created pattern record: {pattern.id} ({pattern.pattern_type})") + return pattern + except Exception as e: + logger.error(f"Failed to create pattern record: {e}", exc_info=True) + raise DataError(f"Failed to create pattern record: {e}") from e + + def create_batch(self, patterns: List[DetectedPattern]) -> List[DetectedPattern]: + """ + Create multiple pattern records in batch. + + Args: + patterns: List of DetectedPattern instances + + Returns: + List of created DetectedPattern instances + + Raises: + DataError: If batch creation fails + """ + try: + self.session.add_all(patterns) + self.session.flush() + logger.info(f"Created {len(patterns)} pattern records in batch") + return patterns + except Exception as e: + logger.error(f"Failed to create pattern records in batch: {e}", exc_info=True) + raise DataError(f"Failed to create pattern records: {e}") from e + + def get_by_id(self, pattern_id: int) -> Optional[DetectedPattern]: + """ + Get pattern by ID. + + Args: + pattern_id: Pattern ID + + Returns: + DetectedPattern instance or None if not found + """ + result = ( + self.session.query(DetectedPattern).filter(DetectedPattern.id == pattern_id).first() + ) + return result # type: ignore[no-any-return] + + def get_by_type_and_range( + self, + pattern_type: PatternType, + symbol: str, + start: datetime, + end: datetime, + timeframe: Optional[Timeframe] = None, + ) -> List[DetectedPattern]: + """ + Get patterns by type within timestamp range. + + Args: + pattern_type: Pattern type enum + symbol: Trading symbol + start: Start timestamp + end: End timestamp + timeframe: Optional timeframe filter + + Returns: + List of DetectedPattern instances + """ + query = self.session.query(DetectedPattern).filter( + and_( + DetectedPattern.pattern_type == pattern_type, + DetectedPattern.symbol == symbol, + DetectedPattern.start_timestamp >= start, + DetectedPattern.start_timestamp <= end, + ) + ) + + if timeframe: + query = query.filter(DetectedPattern.timeframe == timeframe) + + return query.order_by(DetectedPattern.start_timestamp).all() # type: ignore[no-any-return] + + def get_unlabeled( + self, + pattern_type: Optional[PatternType] = None, + symbol: Optional[str] = None, + limit: int = 100, + ) -> List[DetectedPattern]: + """ + Get patterns that don't have labels yet. + + Args: + pattern_type: Optional pattern type filter + symbol: Optional symbol filter + limit: Maximum number of records to return + + Returns: + List of unlabeled DetectedPattern instances + """ + query = ( + self.session.query(DetectedPattern) + .outerjoin(PatternLabel) + .filter(PatternLabel.id.is_(None)) + ) + + if pattern_type: + query = query.filter(DetectedPattern.pattern_type == pattern_type) + + if symbol: + query = query.filter(DetectedPattern.symbol == symbol) + + result = query.order_by(desc(DetectedPattern.detected_at)).limit(limit).all() + return result # type: ignore[no-any-return] diff --git a/src/data/schemas.py b/src/data/schemas.py new file mode 100644 index 0000000..04db263 --- /dev/null +++ b/src/data/schemas.py @@ -0,0 +1,91 @@ +"""Pydantic schemas for data validation.""" + +from datetime import datetime +from decimal import Decimal +from typing import Optional + +from pydantic import BaseModel, Field, field_validator + +from src.core.enums import PatternDirection, PatternType, Timeframe + + +class OHLCVSchema(BaseModel): + """Schema for OHLCV data validation.""" + + symbol: str = Field(..., description="Trading symbol (e.g., 'DAX')") + timeframe: Timeframe = Field(..., description="Timeframe enum") + timestamp: datetime = Field(..., description="Candle timestamp") + open: Decimal = Field(..., gt=0, description="Open price") + high: Decimal = Field(..., gt=0, description="High price") + low: Decimal = Field(..., gt=0, description="Low price") + close: Decimal = Field(..., gt=0, description="Close price") + volume: Optional[int] = Field(None, ge=0, description="Volume") + + @field_validator("high", "low") + @classmethod + def validate_price_range(cls, v: Decimal, info) -> Decimal: + """Validate that high >= low and prices are within reasonable range.""" + if info.field_name == "high": + low = info.data.get("low") + if low and v < low: + raise ValueError("High price must be >= low price") + elif info.field_name == "low": + high = info.data.get("high") + if high and v > high: + raise ValueError("Low price must be <= high price") + return v + + @field_validator("open", "close") + @classmethod + def validate_price_bounds(cls, v: Decimal, info) -> Decimal: + """Validate that open/close are within high/low range.""" + high = info.data.get("high") + low = info.data.get("low") + if high and low: + if not (low <= v <= high): + raise ValueError(f"{info.field_name} must be between low and high") + return v + + class Config: + """Pydantic config.""" + + json_encoders = { + Decimal: str, + datetime: lambda v: v.isoformat(), + } + + +class PatternSchema(BaseModel): + """Schema for detected pattern validation.""" + + pattern_type: PatternType = Field(..., description="Pattern type enum") + direction: PatternDirection = Field(..., description="Pattern direction") + timeframe: Timeframe = Field(..., description="Timeframe enum") + symbol: str = Field(..., description="Trading symbol") + start_timestamp: datetime = Field(..., description="Pattern start timestamp") + end_timestamp: datetime = Field(..., description="Pattern end timestamp") + entry_level: Optional[Decimal] = Field(None, description="Entry price level") + stop_loss: Optional[Decimal] = Field(None, description="Stop loss level") + take_profit: Optional[Decimal] = Field(None, description="Take profit level") + high_level: Optional[Decimal] = Field(None, description="Pattern high level") + low_level: Optional[Decimal] = Field(None, description="Pattern low level") + size_pips: Optional[float] = Field(None, ge=0, description="Pattern size in pips") + strength_score: Optional[float] = Field(None, ge=0, le=1, description="Strength score (0-1)") + context_data: Optional[str] = Field(None, description="Additional context as JSON string") + + @field_validator("end_timestamp") + @classmethod + def validate_timestamp_order(cls, v: datetime, info) -> datetime: + """Validate that end_timestamp >= start_timestamp.""" + start = info.data.get("start_timestamp") + if start and v < start: + raise ValueError("end_timestamp must be >= start_timestamp") + return v + + class Config: + """Pydantic config.""" + + json_encoders = { + Decimal: str, + datetime: lambda v: v.isoformat(), + } diff --git a/src/data/validators.py b/src/data/validators.py new file mode 100644 index 0000000..290d900 --- /dev/null +++ b/src/data/validators.py @@ -0,0 +1,231 @@ +"""Data validation functions.""" + +from datetime import datetime, timedelta +from typing import List, Optional, Tuple + +import numpy as np +import pandas as pd + +from src.core.enums import Timeframe +from src.core.exceptions import ValidationError +from src.logging import get_logger + +logger = get_logger(__name__) + + +def validate_ohlcv(df: pd.DataFrame, required_columns: Optional[List[str]] = None) -> pd.DataFrame: + """ + Validate OHLCV DataFrame structure and data quality. + + Args: + df: DataFrame with OHLCV data + required_columns: Optional list of required columns (defaults to standard OHLCV) + + Returns: + Validated DataFrame + + Raises: + ValidationError: If validation fails + """ + if required_columns is None: + required_columns = ["timestamp", "open", "high", "low", "close"] + + # Check required columns exist + missing_cols = [col for col in required_columns if col not in df.columns] + if missing_cols: + raise ValidationError( + f"Missing required columns: {missing_cols}", + context={"columns": df.columns.tolist(), "required": required_columns}, + ) + + # Check for empty DataFrame + if df.empty: + raise ValidationError("DataFrame is empty") + + # Validate price columns + price_cols = ["open", "high", "low", "close"] + for col in price_cols: + if col in df.columns: + # Check for negative or zero prices + if (df[col] <= 0).any(): + invalid_count = (df[col] <= 0).sum() + raise ValidationError( + f"Invalid {col} values (<= 0): {invalid_count} rows", + context={"column": col, "invalid_rows": invalid_count}, + ) + + # Check for infinite values + if np.isinf(df[col]).any(): + invalid_count = np.isinf(df[col]).sum() + raise ValidationError( + f"Infinite {col} values: {invalid_count} rows", + context={"column": col, "invalid_rows": invalid_count}, + ) + + # Validate high >= low + if "high" in df.columns and "low" in df.columns: + invalid = df["high"] < df["low"] + if invalid.any(): + invalid_count = invalid.sum() + raise ValidationError( + f"High < Low in {invalid_count} rows", + context={"invalid_rows": invalid_count}, + ) + + # Validate open/close within high/low range + if all(col in df.columns for col in ["open", "close", "high", "low"]): + invalid_open = (df["open"] < df["low"]) | (df["open"] > df["high"]) + invalid_close = (df["close"] < df["low"]) | (df["close"] > df["high"]) + if invalid_open.any() or invalid_close.any(): + invalid_count = invalid_open.sum() + invalid_close.sum() + raise ValidationError( + f"Open/Close outside High/Low range: {invalid_count} rows", + context={"invalid_rows": invalid_count}, + ) + + # Validate timestamp column + if "timestamp" in df.columns: + if not pd.api.types.is_datetime64_any_dtype(df["timestamp"]): + try: + df["timestamp"] = pd.to_datetime(df["timestamp"]) + except Exception as e: + raise ValidationError( + f"Invalid timestamp format: {e}", + context={"column": "timestamp"}, + ) from e + + # Check for duplicate timestamps + duplicates = df["timestamp"].duplicated().sum() + if duplicates > 0: + logger.warning(f"Found {duplicates} duplicate timestamps") + + logger.debug(f"Validated OHLCV DataFrame: {len(df)} rows, {len(df.columns)} columns") + return df + + +def check_continuity( + df: pd.DataFrame, + timeframe: Timeframe, + timestamp_col: str = "timestamp", + max_gap_minutes: Optional[int] = None, +) -> Tuple[bool, List[datetime]]: + """ + Check for gaps in timestamp continuity. + + Args: + df: DataFrame with timestamp column + timeframe: Expected timeframe + timestamp_col: Name of timestamp column + max_gap_minutes: Maximum allowed gap in minutes (defaults to timeframe duration) + + Returns: + Tuple of (is_continuous, list_of_gaps) + + Raises: + ValidationError: If timestamp column is missing or invalid + """ + if timestamp_col not in df.columns: + raise ValidationError( + f"Timestamp column '{timestamp_col}' not found", + context={"columns": df.columns.tolist()}, + ) + + if df.empty: + return True, [] + + # Determine expected interval + timeframe_minutes = { + Timeframe.M1: 1, + Timeframe.M5: 5, + Timeframe.M15: 15, + } + expected_interval = timedelta(minutes=timeframe_minutes.get(timeframe, 1)) + + if max_gap_minutes: + max_gap = timedelta(minutes=max_gap_minutes) + else: + max_gap = expected_interval * 2 # Allow 2x timeframe as max gap + + # Sort by timestamp + df_sorted = df.sort_values(timestamp_col).copy() + timestamps = pd.to_datetime(df_sorted[timestamp_col]) + + # Find gaps + gaps = [] + for i in range(len(timestamps) - 1): + gap = timestamps.iloc[i + 1] - timestamps.iloc[i] + if gap > max_gap: + gaps.append(timestamps.iloc[i]) + + is_continuous = len(gaps) == 0 + + if gaps: + logger.warning( + f"Found {len(gaps)} gaps in continuity (timeframe: {timeframe}, " f"max_gap: {max_gap})" + ) + + return is_continuous, gaps + + +def detect_outliers( + df: pd.DataFrame, + columns: Optional[List[str]] = None, + method: str = "iqr", + threshold: float = 3.0, +) -> pd.DataFrame: + """ + Detect outliers in price columns. + + Args: + df: DataFrame with price data + columns: Columns to check (defaults to OHLCV price columns) + method: Detection method ('iqr' or 'zscore') + threshold: Threshold for outlier detection + + Returns: + DataFrame with boolean mask (True = outlier) + + Raises: + ValidationError: If method is invalid or columns missing + """ + if columns is None: + columns = [col for col in ["open", "high", "low", "close"] if col in df.columns] + + if not columns: + raise ValidationError("No columns specified for outlier detection") + + missing_cols = [col for col in columns if col not in df.columns] + if missing_cols: + raise ValidationError( + f"Columns not found: {missing_cols}", + context={"columns": df.columns.tolist()}, + ) + + outlier_mask = pd.Series([False] * len(df), index=df.index) + + for col in columns: + if method == "iqr": + Q1 = df[col].quantile(0.25) + Q3 = df[col].quantile(0.75) + IQR = Q3 - Q1 + lower_bound = Q1 - threshold * IQR + upper_bound = Q3 + threshold * IQR + col_outliers = (df[col] < lower_bound) | (df[col] > upper_bound) + + elif method == "zscore": + z_scores = np.abs((df[col] - df[col].mean()) / df[col].std()) + col_outliers = z_scores > threshold + + else: + raise ValidationError( + f"Invalid outlier detection method: {method}", + context={"valid_methods": ["iqr", "zscore"]}, + ) + + outlier_mask |= col_outliers + + outlier_count = outlier_mask.sum() + if outlier_count > 0: + logger.warning(f"Detected {outlier_count} outliers using {method} method") + + return outlier_mask.to_frame("is_outlier") diff --git a/tests/fixtures/sample_data/sample_ohlcv.csv b/tests/fixtures/sample_data/sample_ohlcv.csv new file mode 100644 index 0000000..24574a3 --- /dev/null +++ b/tests/fixtures/sample_data/sample_ohlcv.csv @@ -0,0 +1,6 @@ +timestamp,open,high,low,close,volume +2024-01-01 03:00:00,100.0,100.5,99.5,100.2,1000 +2024-01-01 03:01:00,100.2,100.7,99.7,100.4,1100 +2024-01-01 03:02:00,100.4,100.9,99.9,100.6,1200 +2024-01-01 03:03:00,100.6,101.1,100.1,100.8,1300 +2024-01-01 03:04:00,100.8,101.3,100.3,101.0,1400 diff --git a/tests/integration/test_database.py b/tests/integration/test_database.py new file mode 100644 index 0000000..e6ca714 --- /dev/null +++ b/tests/integration/test_database.py @@ -0,0 +1,127 @@ +"""Integration tests for database operations.""" + +import os +import tempfile + +import pytest + +from src.core.enums import Timeframe +from src.data.database import get_db_session, init_database +from src.data.models import OHLCVData +from src.data.repositories import OHLCVRepository + + +@pytest.fixture +def temp_db(): + """Create temporary database for testing.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + + os.environ["DATABASE_URL"] = f"sqlite:///{db_path}" + + # Initialize database + init_database(create_tables=True) + + yield db_path + + # Cleanup + if os.path.exists(db_path): + os.unlink(db_path) + if "DATABASE_URL" in os.environ: + del os.environ["DATABASE_URL"] + + +def test_create_and_retrieve_ohlcv(temp_db): + """Test creating and retrieving OHLCV records.""" + from datetime import datetime + + with get_db_session() as session: + repo = OHLCVRepository(session=session) + + # Create record + record = OHLCVData( + symbol="DAX", + timeframe=Timeframe.M1, + timestamp=datetime(2024, 1, 1, 3, 0, 0), + open=100.0, + high=100.5, + low=99.5, + close=100.2, + volume=1000, + ) + + created = repo.create(record) + assert created.id is not None + + # Retrieve record + retrieved = repo.get_by_id(created.id) + assert retrieved is not None + assert retrieved.symbol == "DAX" + assert retrieved.close == 100.2 + + +def test_batch_create_ohlcv(temp_db): + """Test batch creation of OHLCV records.""" + from datetime import datetime, timedelta + + with get_db_session() as session: + repo = OHLCVRepository(session=session) + + # Create multiple records + records = [] + base_time = datetime(2024, 1, 1, 3, 0, 0) + for i in range(10): + records.append( + OHLCVData( + symbol="DAX", + timeframe=Timeframe.M1, + timestamp=base_time + timedelta(minutes=i), + open=100.0 + i * 0.1, + high=100.5 + i * 0.1, + low=99.5 + i * 0.1, + close=100.2 + i * 0.1, + volume=1000, + ) + ) + + created = repo.create_batch(records) + assert len(created) == 10 + + # Verify all records saved + retrieved = repo.get_by_timestamp_range( + "DAX", + Timeframe.M1, + base_time, + base_time + timedelta(minutes=10), + ) + assert len(retrieved) == 10 + + +def test_get_by_timestamp_range(temp_db): + """Test retrieving records by timestamp range.""" + from datetime import datetime, timedelta + + with get_db_session() as session: + repo = OHLCVRepository(session=session) + + # Create records + base_time = datetime(2024, 1, 1, 3, 0, 0) + for i in range(20): + record = OHLCVData( + symbol="DAX", + timeframe=Timeframe.M1, + timestamp=base_time + timedelta(minutes=i), + open=100.0, + high=100.5, + low=99.5, + close=100.2, + volume=1000, + ) + repo.create(record) + + # Retrieve subset + start = base_time + timedelta(minutes=5) + end = base_time + timedelta(minutes=15) + records = repo.get_by_timestamp_range("DAX", Timeframe.M1, start, end) + + assert len(records) == 11 # Inclusive of start and end diff --git a/tests/unit/test_data/__init__.py b/tests/unit/test_data/__init__.py new file mode 100644 index 0000000..f190e65 --- /dev/null +++ b/tests/unit/test_data/__init__.py @@ -0,0 +1 @@ +"""Unit tests for data module.""" diff --git a/tests/unit/test_data/test_database.py b/tests/unit/test_data/test_database.py new file mode 100644 index 0000000..e94570e --- /dev/null +++ b/tests/unit/test_data/test_database.py @@ -0,0 +1,65 @@ +"""Tests for database connection and session management.""" + +import os +import tempfile + +import pytest + +from src.data.database import get_db_session, get_engine, init_database + + +@pytest.fixture +def temp_db(): + """Create temporary database for testing.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + + # Set environment variable + os.environ["DATABASE_URL"] = f"sqlite:///{db_path}" + + yield db_path + + # Cleanup + if os.path.exists(db_path): + os.unlink(db_path) + if "DATABASE_URL" in os.environ: + del os.environ["DATABASE_URL"] + + +def test_get_engine(temp_db): + """Test engine creation.""" + engine = get_engine() + assert engine is not None + assert str(engine.url).startswith("sqlite") + + +def test_init_database(temp_db): + """Test database initialization.""" + init_database(create_tables=True) + assert os.path.exists(temp_db) + + +def test_get_db_session(temp_db): + """Test database session context manager.""" + init_database(create_tables=True) + + with get_db_session() as session: + assert session is not None + # Session should be usable + result = session.execute("SELECT 1").scalar() + assert result == 1 + + +def test_session_rollback_on_error(temp_db): + """Test that session rolls back on error.""" + init_database(create_tables=True) + + try: + with get_db_session() as session: + # Cause an error + session.execute("SELECT * FROM nonexistent_table") + except Exception: + pass # Expected + + # Session should have been rolled back and closed + assert True # If we get here, rollback worked diff --git a/tests/unit/test_data/test_loaders.py b/tests/unit/test_data/test_loaders.py new file mode 100644 index 0000000..ffb46fe --- /dev/null +++ b/tests/unit/test_data/test_loaders.py @@ -0,0 +1,83 @@ +"""Tests for data loaders.""" + +import pandas as pd +import pytest + +from src.core.enums import Timeframe +from src.data.loaders import CSVLoader, ParquetLoader + + +@pytest.fixture +def sample_ohlcv_data(): + """Create sample OHLCV DataFrame.""" + dates = pd.date_range("2024-01-01 03:00", periods=100, freq="1min") + return pd.DataFrame( + { + "timestamp": dates, + "open": [100.0 + i * 0.1 for i in range(100)], + "high": [100.5 + i * 0.1 for i in range(100)], + "low": [99.5 + i * 0.1 for i in range(100)], + "close": [100.2 + i * 0.1 for i in range(100)], + "volume": [1000] * 100, + } + ) + + +@pytest.fixture +def csv_file(sample_ohlcv_data, tmp_path): + """Create temporary CSV file.""" + csv_path = tmp_path / "test_data.csv" + sample_ohlcv_data.to_csv(csv_path, index=False) + return csv_path + + +@pytest.fixture +def parquet_file(sample_ohlcv_data, tmp_path): + """Create temporary Parquet file.""" + parquet_path = tmp_path / "test_data.parquet" + sample_ohlcv_data.to_parquet(parquet_path, index=False) + return parquet_path + + +def test_csv_loader(csv_file): + """Test CSV loader.""" + loader = CSVLoader() + df = loader.load(str(csv_file), symbol="DAX", timeframe=Timeframe.M1) + + assert len(df) == 100 + assert "symbol" in df.columns + assert "timeframe" in df.columns + assert df["symbol"].iloc[0] == "DAX" + assert df["timeframe"].iloc[0] == "1min" + + +def test_csv_loader_missing_file(): + """Test CSV loader with missing file.""" + loader = CSVLoader() + with pytest.raises(Exception): # Should raise DataError + loader.load("nonexistent.csv") + + +def test_parquet_loader(parquet_file): + """Test Parquet loader.""" + loader = ParquetLoader() + df = loader.load(str(parquet_file), symbol="DAX", timeframe=Timeframe.M1) + + assert len(df) == 100 + assert "symbol" in df.columns + assert "timeframe" in df.columns + + +def test_load_and_preprocess(csv_file): + """Test load_and_preprocess function.""" + from src.data.loaders import load_and_preprocess + + df = load_and_preprocess( + str(csv_file), + loader_type="csv", + validate=True, + preprocess=True, + ) + + assert len(df) == 100 + assert "timestamp" in df.columns diff --git a/tests/unit/test_data/test_preprocessors.py b/tests/unit/test_data/test_preprocessors.py new file mode 100644 index 0000000..ea2206d --- /dev/null +++ b/tests/unit/test_data/test_preprocessors.py @@ -0,0 +1,87 @@ +"""Tests for data preprocessors.""" + +import numpy as np +import pandas as pd +import pytest + +from src.data.preprocessors import filter_session, handle_missing_data, remove_duplicates + + +@pytest.fixture +def sample_data_with_missing(): + """Create sample DataFrame with missing values.""" + dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min") + df = pd.DataFrame( + { + "timestamp": dates, + "open": [100.0] * 10, + "high": [100.5] * 10, + "low": [99.5] * 10, + "close": [100.2] * 10, + } + ) + # Add some missing values + df.loc[2, "close"] = np.nan + df.loc[5, "open"] = np.nan + return df + + +@pytest.fixture +def sample_data_with_duplicates(): + """Create sample DataFrame with duplicates.""" + dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min") + df = pd.DataFrame( + { + "timestamp": dates, + "open": [100.0] * 10, + "high": [100.5] * 10, + "low": [99.5] * 10, + "close": [100.2] * 10, + } + ) + # Add duplicate + df = pd.concat([df, df.iloc[[0]]], ignore_index=True) + return df + + +def test_handle_missing_data_forward_fill(sample_data_with_missing): + """Test forward fill missing data.""" + df = handle_missing_data(sample_data_with_missing, method="forward_fill") + assert df["close"].isna().sum() == 0 + assert df["open"].isna().sum() == 0 + + +def test_handle_missing_data_drop(sample_data_with_missing): + """Test drop missing data.""" + df = handle_missing_data(sample_data_with_missing, method="drop") + assert df["close"].isna().sum() == 0 + assert df["open"].isna().sum() == 0 + assert len(df) < len(sample_data_with_missing) + + +def test_remove_duplicates(sample_data_with_duplicates): + """Test duplicate removal.""" + df = remove_duplicates(sample_data_with_duplicates) + assert len(df) == 10 # Should remove duplicate + + +def test_filter_session(): + """Test session filtering.""" + # Create data spanning multiple hours + dates = pd.date_range("2024-01-01 02:00", periods=120, freq="1min") + df = pd.DataFrame( + { + "timestamp": dates, + "open": [100.0] * 120, + "high": [100.5] * 120, + "low": [99.5] * 120, + "close": [100.2] * 120, + } + ) + + # Filter to 3-4 AM EST + df_filtered = filter_session(df, session_start="03:00", session_end="04:00") + + # Should have approximately 60 rows (1 hour of 1-minute data) + assert len(df_filtered) > 0 + assert len(df_filtered) <= 60 diff --git a/tests/unit/test_data/test_validators.py b/tests/unit/test_data/test_validators.py new file mode 100644 index 0000000..a1bf8d0 --- /dev/null +++ b/tests/unit/test_data/test_validators.py @@ -0,0 +1,97 @@ +"""Tests for data validators.""" + +import pandas as pd +import pytest + +from src.core.enums import Timeframe +from src.data.validators import check_continuity, detect_outliers, validate_ohlcv + + +@pytest.fixture +def valid_ohlcv_data(): + """Create valid OHLCV DataFrame.""" + dates = pd.date_range("2024-01-01 03:00", periods=100, freq="1min") + return pd.DataFrame( + { + "timestamp": dates, + "open": [100.0 + i * 0.1 for i in range(100)], + "high": [100.5 + i * 0.1 for i in range(100)], + "low": [99.5 + i * 0.1 for i in range(100)], + "close": [100.2 + i * 0.1 for i in range(100)], + "volume": [1000] * 100, + } + ) + + +@pytest.fixture +def invalid_ohlcv_data(): + """Create invalid OHLCV DataFrame.""" + dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min") + df = pd.DataFrame( + { + "timestamp": dates, + "open": [100.0] * 10, + "high": [99.0] * 10, # Invalid: high < low + "low": [99.5] * 10, + "close": [100.2] * 10, + } + ) + return df + + +def test_validate_ohlcv_valid(valid_ohlcv_data): + """Test validation with valid data.""" + df = validate_ohlcv(valid_ohlcv_data) + assert len(df) == 100 + + +def test_validate_ohlcv_invalid(invalid_ohlcv_data): + """Test validation with invalid data.""" + with pytest.raises(Exception): # Should raise ValidationError + validate_ohlcv(invalid_ohlcv_data) + + +def test_validate_ohlcv_missing_columns(): + """Test validation with missing columns.""" + df = pd.DataFrame({"timestamp": pd.date_range("2024-01-01", periods=10)}) + with pytest.raises(Exception): # Should raise ValidationError + validate_ohlcv(df) + + +def test_check_continuity(valid_ohlcv_data): + """Test continuity check.""" + is_continuous, gaps = check_continuity(valid_ohlcv_data, Timeframe.M1) + assert is_continuous + assert len(gaps) == 0 + + +def test_check_continuity_with_gaps(): + """Test continuity check with gaps.""" + # Create data with gaps + dates = pd.date_range("2024-01-01 03:00", periods=10, freq="1min") + # Remove some dates to create gaps + dates = dates[[0, 1, 2, 5, 6, 7, 8, 9]] # Gap between index 2 and 5 + + df = pd.DataFrame( + { + "timestamp": dates, + "open": [100.0] * len(dates), + "high": [100.5] * len(dates), + "low": [99.5] * len(dates), + "close": [100.2] * len(dates), + } + ) + + is_continuous, gaps = check_continuity(df, Timeframe.M1) + assert not is_continuous + assert len(gaps) > 0 + + +def test_detect_outliers(valid_ohlcv_data): + """Test outlier detection.""" + # Add an outlier + df = valid_ohlcv_data.copy() + df.loc[50, "close"] = 200.0 # Extreme value + + outliers = detect_outliers(df, columns=["close"], method="iqr", threshold=3.0) + assert outliers["is_outlier"].sum() > 0