Files

280 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import matplotlib.pyplot as plt
import pandas as pd
# from datetime import datetime
import gzip
import numpy as np
import os
import io
from 专享08策略 import 专享08of # 导入您的 MyTrader 类
class BacktestEngine:
def __init__(self, trader_class, initial_capital=1000000):
self.trader = trader_class()
self.initial_capital = initial_capital
self.equity_curve = []
self.positions = {} # {instrument_id: {'long': {'today': 0, 'yesterday': 0}, 'short': {'today': 0, 'yesterday': 0}}}
self.cash = initial_capital
self.current_date = None
def run(self, data, start=0, end=None, start_date=None, end_date=None):
for i, (_, row) in enumerate(data.iloc[start:end].iterrows()):
tick = row.to_dict()
action_day = pd.to_datetime(tick["ActionDay"]).strftime("%Y-%m-%d")
update_time = pd.to_datetime(tick["UpdateTime"]).strftime("%H:%M:%S")
created_at = f"{action_day} {update_time}.{tick['UpdateMillisec']:03d}"
current_date = pd.to_datetime(created_at)
if start_date is not None and current_date < start_date:
continue
if end_date is not None and current_date > end_date:
break
tick_date = pd.to_datetime(created_at).date()
if self.current_date is None or tick_date != self.current_date:
self.update_positions_day()
self.current_date = tick_date
self.trader.Join(tickdata=tick)
self.update_account(created_at, tick["LastPrice"], tick["InstrumentID"])
def update_positions_day(self):
for position in self.positions.values():
position["long"]["yesterday"] += position["long"]["today"]
position["long"]["today"] = 0
position["short"]["yesterday"] += position["short"]["today"]
position["short"]["today"] = 0
def update_account(self, datetime, last_price, instrument_id):
position_value = 0
for inst, pos in self.positions.items():
price = (
last_price
if inst == instrument_id
else self.positions[inst]["last_price"]
)
long_value = (pos["long"]["today"] + pos["long"]["yesterday"]) * price
short_value = (pos["short"]["today"] + pos["short"]["yesterday"]) * price
position_value += long_value - short_value
current_equity = self.cash + position_value
self.equity_curve.append((datetime, current_equity))
def mock_insert_order(
self, exchange_id, instrument_id, price, volume, direction, offset
):
is_buy = direction == b"0"
is_open = offset == b"0"
is_close_today = offset == b"3"
if instrument_id not in self.positions:
self.positions[instrument_id] = {
"long": {"today": 0, "yesterday": 0},
"short": {"today": 0, "yesterday": 0},
"last_price": price,
}
position = self.positions[instrument_id]
if is_open:
if is_buy:
position["long"]["today"] += volume
self.cash -= price * volume
else:
position["short"]["today"] += volume
self.cash += price * volume
else: # close
if is_buy: # buy to close short
if is_close_today:
position["short"]["today"] -= volume
else:
if position["short"]["yesterday"] >= volume:
position["short"]["yesterday"] -= volume
else:
remaining = volume - position["short"]["yesterday"]
position["short"]["yesterday"] = 0
position["short"]["today"] -= remaining
self.cash -= price * volume
else: # sell to close long
if is_close_today:
position["long"]["today"] -= volume
else:
if position["long"]["yesterday"] >= volume:
position["long"]["yesterday"] -= volume
else:
remaining = volume - position["long"]["yesterday"]
position["long"]["yesterday"] = 0
position["long"]["today"] -= remaining
self.cash += price * volume
position["last_price"] = price
def calculate_performance(self):
df = pd.DataFrame(self.equity_curve, columns=["time", "equity"])
df["time"] = pd.to_datetime(df["time"])
df.set_index("time", inplace=True)
if len(df) < 2:
print("警告:回测数据点不足,无法计算性能指标。")
return {
"total_return": 0,
"sharpe_ratio": 0,
"max_drawdown": 0,
"equity_curve": df,
}
df["returns"] = df["equity"].pct_change()
total_return = (df["equity"].iloc[-1] - df["equity"].iloc[0]) / df[
"equity"
].iloc[0]
sharpe_ratio = np.sqrt(len(df)) * df["returns"].mean() / df["returns"].std()
drawdown = df["equity"] / df["equity"].cummax() - 1
max_drawdown = drawdown.min()
return {
"total_return": total_return,
"sharpe_ratio": sharpe_ratio,
"max_drawdown": max_drawdown,
"equity_curve": df,
}
def plot_performance(self):
performance = self.calculate_performance()
equity_curve = performance["equity_curve"]
plt.figure(figsize=(12, 8))
plt.plot(equity_curve.index, equity_curve["equity"])
plt.title("Equity Curve")
plt.xlabel("Time")
plt.ylabel("Equity")
plt.grid(True)
plt.show()
print(f"Total Return: {performance['total_return']:.6%}")
print(f"Sharpe Ratio: {performance['sharpe_ratio']:.6f}")
print(f"Max Drawdown: {performance['max_drawdown']:.6%}")
# 定义中文表头到英文表头的映射
header_mapping = {
"交易日": "TradingDay",
"合约代码": "InstrumentID",
"交易所代码": "ExchangeID",
"合约在交易所的代码": "ExchangeInstID",
"最新价": "LastPrice",
"上次结算价": "PreSettlementPrice",
"昨收盘": "PreClosePrice",
"昨持仓量": "PreOpenInterest",
"今开盘": "OpenPrice",
"最高价": "HighestPrice",
"最低价": "LowestPrice",
"数量": "Volume",
"成交金额": "Turnover",
"持仓量": "OpenInterest",
"今收盘": "ClosePrice",
"本次结算价": "SettlementPrice",
"涨停板价": "UpperLimitPrice",
"跌停板价": "LowerLimitPrice",
"昨虚实度": "PreDelta",
"今虚实度": "CurrDelta",
"最后修改时间": "UpdateTime",
"最后修改毫秒": "UpdateMillisec",
"申买价一": "BidPrice1",
"申买量一": "BidVolume1",
"申卖价一": "AskPrice1",
"申卖量一": "AskVolume1",
"申买价二": "BidPrice2",
"申买量二": "BidVolume2",
"申卖价二": "AskPrice2",
"申卖量二": "AskVolume2",
"申买价三": "BidPrice3",
"申买量三": "BidVolume3",
"申卖价三": "AskPrice3",
"申卖量三": "AskVolume3",
"申买价四": "BidPrice4",
"申买量四": "BidVolume4",
"申卖价四": "AskPrice4",
"申卖量四": "AskVolume4",
"申买价五": "BidPrice5",
"申买量五": "BidVolume5",
"申卖价五": "AskPrice5",
"申卖量五": "AskVolume5",
"当日均价": "AveragePrice",
"业务日期": "ActionDay",
}
def load_and_process_data(folder_path):
dfs = []
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
try:
if filename.endswith(".gz"):
# 处理 GZ 文件
with gzip.open(file_path, "rt", encoding="gbk") as gz_file:
csv_data = io.StringIO(gz_file.read())
df = pd.read_csv(csv_data, parse_dates=["业务日期", "最后修改时间"])
elif filename.endswith(".csv"):
# 处理 CSV 文件
df = pd.read_csv(
file_path,
encoding="utf-8",
parse_dates=["业务日期", "最后修改时间"],
)
else:
# 跳过非 GZ 和非 CSV 文件
print(f"Skipping {filename}: not a GZ or CSV file")
continue
# 重命名列
df.rename(columns=header_mapping, inplace=True)
dfs.append(df)
print(f"Successfully read {filename}")
except Exception as e:
print(f"Error reading {filename}: {str(e)}")
print("Skipping this file.")
continue
if dfs:
data = pd.concat(dfs, ignore_index=True)
data.sort_values(["ActionDay", "UpdateTime", "UpdateMillisec"], inplace=True)
data = data.reset_index(drop=True) # 重置索引
return data
else:
print("没有找到可读取的GZ或CSV文件")
return None
# 使用示例
if __name__ == "__main__":
# 使用示例
folder_path = "./回测数据" # 替换为您的数据文件夹路径
# 读取排序合并tick数据
data = load_and_process_data(folder_path)
print(data)
if data is not None:
# 初始化回测引擎,设置策略和初始资金
backtest = BacktestEngine(专享08of, initial_capital=10000)
# 替换MyTrader中的insert_order方法
backtest.trader.insert_order = backtest.mock_insert_order
# # 运行回测
# backtest.run(data)
backtest.run(
data,
# start=1000,
# end=3000,
# start_date=pd.to_datetime('2023-01-01'),
# end_date=pd.to_datetime('2023-01-31')
)
# 显示回测结果
backtest.plot_performance()