import matplotlib.pyplot as plt import pandas as pd # from datetime import datetime import gzip import numpy as np import os import io from 专享08策略 import 专享08of # 导入您的 MyTrader 类 class BacktestEngine: def __init__(self, trader_class, initial_capital=1000000): self.trader = trader_class() self.initial_capital = initial_capital self.equity_curve = [] self.positions = {} # {instrument_id: {'long': {'today': 0, 'yesterday': 0}, 'short': {'today': 0, 'yesterday': 0}}} self.cash = initial_capital self.current_date = None def run(self, data, start=0, end=None, start_date=None, end_date=None): for i, (_, row) in enumerate(data.iloc[start:end].iterrows()): tick = row.to_dict() action_day = pd.to_datetime(tick["ActionDay"]).strftime("%Y-%m-%d") update_time = pd.to_datetime(tick["UpdateTime"]).strftime("%H:%M:%S") created_at = f"{action_day} {update_time}.{tick['UpdateMillisec']:03d}" current_date = pd.to_datetime(created_at) if start_date is not None and current_date < start_date: continue if end_date is not None and current_date > end_date: break tick_date = pd.to_datetime(created_at).date() if self.current_date is None or tick_date != self.current_date: self.update_positions_day() self.current_date = tick_date self.trader.Join(tickdata=tick) self.update_account(created_at, tick["LastPrice"], tick["InstrumentID"]) def update_positions_day(self): for position in self.positions.values(): position["long"]["yesterday"] += position["long"]["today"] position["long"]["today"] = 0 position["short"]["yesterday"] += position["short"]["today"] position["short"]["today"] = 0 def update_account(self, datetime, last_price, instrument_id): position_value = 0 for inst, pos in self.positions.items(): price = ( last_price if inst == instrument_id else self.positions[inst]["last_price"] ) long_value = (pos["long"]["today"] + pos["long"]["yesterday"]) * price short_value = (pos["short"]["today"] + pos["short"]["yesterday"]) * price position_value += long_value - short_value current_equity = self.cash + position_value self.equity_curve.append((datetime, current_equity)) def mock_insert_order( self, exchange_id, instrument_id, price, volume, direction, offset ): is_buy = direction == b"0" is_open = offset == b"0" is_close_today = offset == b"3" if instrument_id not in self.positions: self.positions[instrument_id] = { "long": {"today": 0, "yesterday": 0}, "short": {"today": 0, "yesterday": 0}, "last_price": price, } position = self.positions[instrument_id] if is_open: if is_buy: position["long"]["today"] += volume self.cash -= price * volume else: position["short"]["today"] += volume self.cash += price * volume else: # close if is_buy: # buy to close short if is_close_today: position["short"]["today"] -= volume else: if position["short"]["yesterday"] >= volume: position["short"]["yesterday"] -= volume else: remaining = volume - position["short"]["yesterday"] position["short"]["yesterday"] = 0 position["short"]["today"] -= remaining self.cash -= price * volume else: # sell to close long if is_close_today: position["long"]["today"] -= volume else: if position["long"]["yesterday"] >= volume: position["long"]["yesterday"] -= volume else: remaining = volume - position["long"]["yesterday"] position["long"]["yesterday"] = 0 position["long"]["today"] -= remaining self.cash += price * volume position["last_price"] = price def calculate_performance(self): df = pd.DataFrame(self.equity_curve, columns=["time", "equity"]) df["time"] = pd.to_datetime(df["time"]) df.set_index("time", inplace=True) if len(df) < 2: print("警告:回测数据点不足,无法计算性能指标。") return { "total_return": 0, "sharpe_ratio": 0, "max_drawdown": 0, "equity_curve": df, } df["returns"] = df["equity"].pct_change() total_return = (df["equity"].iloc[-1] - df["equity"].iloc[0]) / df[ "equity" ].iloc[0] sharpe_ratio = np.sqrt(len(df)) * df["returns"].mean() / df["returns"].std() drawdown = df["equity"] / df["equity"].cummax() - 1 max_drawdown = drawdown.min() return { "total_return": total_return, "sharpe_ratio": sharpe_ratio, "max_drawdown": max_drawdown, "equity_curve": df, } def plot_performance(self): performance = self.calculate_performance() equity_curve = performance["equity_curve"] plt.figure(figsize=(12, 8)) plt.plot(equity_curve.index, equity_curve["equity"]) plt.title("Equity Curve") plt.xlabel("Time") plt.ylabel("Equity") plt.grid(True) plt.show() print(f"Total Return: {performance['total_return']:.6%}") print(f"Sharpe Ratio: {performance['sharpe_ratio']:.6f}") print(f"Max Drawdown: {performance['max_drawdown']:.6%}") # 定义中文表头到英文表头的映射 header_mapping = { "交易日": "TradingDay", "合约代码": "InstrumentID", "交易所代码": "ExchangeID", "合约在交易所的代码": "ExchangeInstID", "最新价": "LastPrice", "上次结算价": "PreSettlementPrice", "昨收盘": "PreClosePrice", "昨持仓量": "PreOpenInterest", "今开盘": "OpenPrice", "最高价": "HighestPrice", "最低价": "LowestPrice", "数量": "Volume", "成交金额": "Turnover", "持仓量": "OpenInterest", "今收盘": "ClosePrice", "本次结算价": "SettlementPrice", "涨停板价": "UpperLimitPrice", "跌停板价": "LowerLimitPrice", "昨虚实度": "PreDelta", "今虚实度": "CurrDelta", "最后修改时间": "UpdateTime", "最后修改毫秒": "UpdateMillisec", "申买价一": "BidPrice1", "申买量一": "BidVolume1", "申卖价一": "AskPrice1", "申卖量一": "AskVolume1", "申买价二": "BidPrice2", "申买量二": "BidVolume2", "申卖价二": "AskPrice2", "申卖量二": "AskVolume2", "申买价三": "BidPrice3", "申买量三": "BidVolume3", "申卖价三": "AskPrice3", "申卖量三": "AskVolume3", "申买价四": "BidPrice4", "申买量四": "BidVolume4", "申卖价四": "AskPrice4", "申卖量四": "AskVolume4", "申买价五": "BidPrice5", "申买量五": "BidVolume5", "申卖价五": "AskPrice5", "申卖量五": "AskVolume5", "当日均价": "AveragePrice", "业务日期": "ActionDay", } def load_and_process_data(folder_path): dfs = [] for filename in os.listdir(folder_path): file_path = os.path.join(folder_path, filename) try: if filename.endswith(".gz"): # 处理 GZ 文件 with gzip.open(file_path, "rt", encoding="gbk") as gz_file: csv_data = io.StringIO(gz_file.read()) df = pd.read_csv(csv_data, parse_dates=["业务日期", "最后修改时间"]) elif filename.endswith(".csv"): # 处理 CSV 文件 df = pd.read_csv( file_path, encoding="utf-8", parse_dates=["业务日期", "最后修改时间"], ) else: # 跳过非 GZ 和非 CSV 文件 print(f"Skipping {filename}: not a GZ or CSV file") continue # 重命名列 df.rename(columns=header_mapping, inplace=True) dfs.append(df) print(f"Successfully read {filename}") except Exception as e: print(f"Error reading {filename}: {str(e)}") print("Skipping this file.") continue if dfs: data = pd.concat(dfs, ignore_index=True) data.sort_values(["ActionDay", "UpdateTime", "UpdateMillisec"], inplace=True) data = data.reset_index(drop=True) # 重置索引 return data else: print("没有找到可读取的GZ或CSV文件") return None # 使用示例 if __name__ == "__main__": # 使用示例 folder_path = "./回测数据" # 替换为您的数据文件夹路径 # 读取,排序,合并tick数据 data = load_and_process_data(folder_path) print(data) if data is not None: # 初始化回测引擎,设置策略和初始资金 backtest = BacktestEngine(专享08of, initial_capital=10000) # 替换MyTrader中的insert_order方法 backtest.trader.insert_order = backtest.mock_insert_order # # 运行回测 # backtest.run(data) backtest.run( data, # start=1000, # end=3000, # start_date=pd.to_datetime('2023-01-01'), # end_date=pd.to_datetime('2023-01-31') ) # 显示回测结果 backtest.plot_performance()