280 lines
9.8 KiB
Python
280 lines
9.8 KiB
Python
import matplotlib.pyplot as plt
|
||
import pandas as pd
|
||
|
||
# from datetime import datetime
|
||
import gzip
|
||
import numpy as np
|
||
import os
|
||
import io
|
||
from 专享08策略 import 专享08of # 导入您的 MyTrader 类
|
||
|
||
|
||
class BacktestEngine:
|
||
def __init__(self, trader_class, initial_capital=1000000):
|
||
self.trader = trader_class()
|
||
self.initial_capital = initial_capital
|
||
self.equity_curve = []
|
||
self.positions = {} # {instrument_id: {'long': {'today': 0, 'yesterday': 0}, 'short': {'today': 0, 'yesterday': 0}}}
|
||
self.cash = initial_capital
|
||
self.current_date = None
|
||
|
||
def run(self, data, start=0, end=None, start_date=None, end_date=None):
|
||
for i, (_, row) in enumerate(data.iloc[start:end].iterrows()):
|
||
tick = row.to_dict()
|
||
|
||
action_day = pd.to_datetime(tick["ActionDay"]).strftime("%Y-%m-%d")
|
||
update_time = pd.to_datetime(tick["UpdateTime"]).strftime("%H:%M:%S")
|
||
created_at = f"{action_day} {update_time}.{tick['UpdateMillisec']:03d}"
|
||
|
||
current_date = pd.to_datetime(created_at)
|
||
|
||
if start_date is not None and current_date < start_date:
|
||
continue
|
||
if end_date is not None and current_date > end_date:
|
||
break
|
||
|
||
tick_date = pd.to_datetime(created_at).date()
|
||
if self.current_date is None or tick_date != self.current_date:
|
||
self.update_positions_day()
|
||
self.current_date = tick_date
|
||
|
||
self.trader.Join(tickdata=tick)
|
||
self.update_account(created_at, tick["LastPrice"], tick["InstrumentID"])
|
||
|
||
def update_positions_day(self):
|
||
for position in self.positions.values():
|
||
position["long"]["yesterday"] += position["long"]["today"]
|
||
position["long"]["today"] = 0
|
||
position["short"]["yesterday"] += position["short"]["today"]
|
||
position["short"]["today"] = 0
|
||
|
||
def update_account(self, datetime, last_price, instrument_id):
|
||
position_value = 0
|
||
for inst, pos in self.positions.items():
|
||
price = (
|
||
last_price
|
||
if inst == instrument_id
|
||
else self.positions[inst]["last_price"]
|
||
)
|
||
long_value = (pos["long"]["today"] + pos["long"]["yesterday"]) * price
|
||
short_value = (pos["short"]["today"] + pos["short"]["yesterday"]) * price
|
||
position_value += long_value - short_value
|
||
current_equity = self.cash + position_value
|
||
self.equity_curve.append((datetime, current_equity))
|
||
|
||
def mock_insert_order(
|
||
self, exchange_id, instrument_id, price, volume, direction, offset
|
||
):
|
||
is_buy = direction == b"0"
|
||
is_open = offset == b"0"
|
||
is_close_today = offset == b"3"
|
||
|
||
if instrument_id not in self.positions:
|
||
self.positions[instrument_id] = {
|
||
"long": {"today": 0, "yesterday": 0},
|
||
"short": {"today": 0, "yesterday": 0},
|
||
"last_price": price,
|
||
}
|
||
|
||
position = self.positions[instrument_id]
|
||
|
||
if is_open:
|
||
if is_buy:
|
||
position["long"]["today"] += volume
|
||
self.cash -= price * volume
|
||
else:
|
||
position["short"]["today"] += volume
|
||
self.cash += price * volume
|
||
else: # close
|
||
if is_buy: # buy to close short
|
||
if is_close_today:
|
||
position["short"]["today"] -= volume
|
||
else:
|
||
if position["short"]["yesterday"] >= volume:
|
||
position["short"]["yesterday"] -= volume
|
||
else:
|
||
remaining = volume - position["short"]["yesterday"]
|
||
position["short"]["yesterday"] = 0
|
||
position["short"]["today"] -= remaining
|
||
self.cash -= price * volume
|
||
else: # sell to close long
|
||
if is_close_today:
|
||
position["long"]["today"] -= volume
|
||
else:
|
||
if position["long"]["yesterday"] >= volume:
|
||
position["long"]["yesterday"] -= volume
|
||
else:
|
||
remaining = volume - position["long"]["yesterday"]
|
||
position["long"]["yesterday"] = 0
|
||
position["long"]["today"] -= remaining
|
||
self.cash += price * volume
|
||
|
||
position["last_price"] = price
|
||
|
||
def calculate_performance(self):
|
||
df = pd.DataFrame(self.equity_curve, columns=["time", "equity"])
|
||
df["time"] = pd.to_datetime(df["time"])
|
||
df.set_index("time", inplace=True)
|
||
|
||
if len(df) < 2:
|
||
print("警告:回测数据点不足,无法计算性能指标。")
|
||
return {
|
||
"total_return": 0,
|
||
"sharpe_ratio": 0,
|
||
"max_drawdown": 0,
|
||
"equity_curve": df,
|
||
}
|
||
|
||
df["returns"] = df["equity"].pct_change()
|
||
|
||
total_return = (df["equity"].iloc[-1] - df["equity"].iloc[0]) / df[
|
||
"equity"
|
||
].iloc[0]
|
||
sharpe_ratio = np.sqrt(len(df)) * df["returns"].mean() / df["returns"].std()
|
||
|
||
drawdown = df["equity"] / df["equity"].cummax() - 1
|
||
max_drawdown = drawdown.min()
|
||
|
||
return {
|
||
"total_return": total_return,
|
||
"sharpe_ratio": sharpe_ratio,
|
||
"max_drawdown": max_drawdown,
|
||
"equity_curve": df,
|
||
}
|
||
|
||
def plot_performance(self):
|
||
performance = self.calculate_performance()
|
||
equity_curve = performance["equity_curve"]
|
||
|
||
plt.figure(figsize=(12, 8))
|
||
plt.plot(equity_curve.index, equity_curve["equity"])
|
||
plt.title("Equity Curve")
|
||
plt.xlabel("Time")
|
||
plt.ylabel("Equity")
|
||
plt.grid(True)
|
||
plt.show()
|
||
|
||
print(f"Total Return: {performance['total_return']:.6%}")
|
||
print(f"Sharpe Ratio: {performance['sharpe_ratio']:.6f}")
|
||
print(f"Max Drawdown: {performance['max_drawdown']:.6%}")
|
||
|
||
|
||
# 定义中文表头到英文表头的映射
|
||
header_mapping = {
|
||
"交易日": "TradingDay",
|
||
"合约代码": "InstrumentID",
|
||
"交易所代码": "ExchangeID",
|
||
"合约在交易所的代码": "ExchangeInstID",
|
||
"最新价": "LastPrice",
|
||
"上次结算价": "PreSettlementPrice",
|
||
"昨收盘": "PreClosePrice",
|
||
"昨持仓量": "PreOpenInterest",
|
||
"今开盘": "OpenPrice",
|
||
"最高价": "HighestPrice",
|
||
"最低价": "LowestPrice",
|
||
"数量": "Volume",
|
||
"成交金额": "Turnover",
|
||
"持仓量": "OpenInterest",
|
||
"今收盘": "ClosePrice",
|
||
"本次结算价": "SettlementPrice",
|
||
"涨停板价": "UpperLimitPrice",
|
||
"跌停板价": "LowerLimitPrice",
|
||
"昨虚实度": "PreDelta",
|
||
"今虚实度": "CurrDelta",
|
||
"最后修改时间": "UpdateTime",
|
||
"最后修改毫秒": "UpdateMillisec",
|
||
"申买价一": "BidPrice1",
|
||
"申买量一": "BidVolume1",
|
||
"申卖价一": "AskPrice1",
|
||
"申卖量一": "AskVolume1",
|
||
"申买价二": "BidPrice2",
|
||
"申买量二": "BidVolume2",
|
||
"申卖价二": "AskPrice2",
|
||
"申卖量二": "AskVolume2",
|
||
"申买价三": "BidPrice3",
|
||
"申买量三": "BidVolume3",
|
||
"申卖价三": "AskPrice3",
|
||
"申卖量三": "AskVolume3",
|
||
"申买价四": "BidPrice4",
|
||
"申买量四": "BidVolume4",
|
||
"申卖价四": "AskPrice4",
|
||
"申卖量四": "AskVolume4",
|
||
"申买价五": "BidPrice5",
|
||
"申买量五": "BidVolume5",
|
||
"申卖价五": "AskPrice5",
|
||
"申卖量五": "AskVolume5",
|
||
"当日均价": "AveragePrice",
|
||
"业务日期": "ActionDay",
|
||
}
|
||
|
||
|
||
def load_and_process_data(folder_path):
|
||
dfs = []
|
||
for filename in os.listdir(folder_path):
|
||
file_path = os.path.join(folder_path, filename)
|
||
|
||
try:
|
||
if filename.endswith(".gz"):
|
||
# 处理 GZ 文件
|
||
with gzip.open(file_path, "rt", encoding="gbk") as gz_file:
|
||
csv_data = io.StringIO(gz_file.read())
|
||
df = pd.read_csv(csv_data, parse_dates=["业务日期", "最后修改时间"])
|
||
elif filename.endswith(".csv"):
|
||
# 处理 CSV 文件
|
||
df = pd.read_csv(
|
||
file_path,
|
||
encoding="utf-8",
|
||
parse_dates=["业务日期", "最后修改时间"],
|
||
)
|
||
else:
|
||
# 跳过非 GZ 和非 CSV 文件
|
||
print(f"Skipping {filename}: not a GZ or CSV file")
|
||
continue
|
||
|
||
# 重命名列
|
||
df.rename(columns=header_mapping, inplace=True)
|
||
|
||
dfs.append(df)
|
||
print(f"Successfully read {filename}")
|
||
except Exception as e:
|
||
print(f"Error reading {filename}: {str(e)}")
|
||
print("Skipping this file.")
|
||
continue
|
||
|
||
if dfs:
|
||
data = pd.concat(dfs, ignore_index=True)
|
||
data.sort_values(["ActionDay", "UpdateTime", "UpdateMillisec"], inplace=True)
|
||
data = data.reset_index(drop=True) # 重置索引
|
||
return data
|
||
else:
|
||
print("没有找到可读取的GZ或CSV文件")
|
||
return None
|
||
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
# 使用示例
|
||
folder_path = "./回测数据" # 替换为您的数据文件夹路径
|
||
# 读取,排序,合并tick数据
|
||
data = load_and_process_data(folder_path)
|
||
print(data)
|
||
if data is not None:
|
||
# 初始化回测引擎,设置策略和初始资金
|
||
backtest = BacktestEngine(专享08of, initial_capital=10000)
|
||
|
||
# 替换MyTrader中的insert_order方法
|
||
backtest.trader.insert_order = backtest.mock_insert_order
|
||
|
||
# # 运行回测
|
||
# backtest.run(data)
|
||
|
||
backtest.run(
|
||
data,
|
||
# start=1000,
|
||
# end=3000,
|
||
# start_date=pd.to_datetime('2023-01-01'),
|
||
# end_date=pd.to_datetime('2023-01-31')
|
||
)
|
||
# 显示回测结果
|
||
backtest.plot_performance()
|