增加交易策略、交易指标、量化库代码等文件夹
This commit is contained in:
@@ -0,0 +1,279 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
# from datetime import datetime
|
||||
import gzip
|
||||
import numpy as np
|
||||
import os
|
||||
import io
|
||||
from 专享08策略 import 专享08of # 导入您的 MyTrader 类
|
||||
|
||||
|
||||
class BacktestEngine:
|
||||
def __init__(self, trader_class, initial_capital=1000000):
|
||||
self.trader = trader_class()
|
||||
self.initial_capital = initial_capital
|
||||
self.equity_curve = []
|
||||
self.positions = {} # {instrument_id: {'long': {'today': 0, 'yesterday': 0}, 'short': {'today': 0, 'yesterday': 0}}}
|
||||
self.cash = initial_capital
|
||||
self.current_date = None
|
||||
|
||||
def run(self, data, start=0, end=None, start_date=None, end_date=None):
|
||||
for i, (_, row) in enumerate(data.iloc[start:end].iterrows()):
|
||||
tick = row.to_dict()
|
||||
|
||||
action_day = pd.to_datetime(tick["ActionDay"]).strftime("%Y-%m-%d")
|
||||
update_time = pd.to_datetime(tick["UpdateTime"]).strftime("%H:%M:%S")
|
||||
created_at = f"{action_day} {update_time}.{tick['UpdateMillisec']:03d}"
|
||||
|
||||
current_date = pd.to_datetime(created_at)
|
||||
|
||||
if start_date is not None and current_date < start_date:
|
||||
continue
|
||||
if end_date is not None and current_date > end_date:
|
||||
break
|
||||
|
||||
tick_date = pd.to_datetime(created_at).date()
|
||||
if self.current_date is None or tick_date != self.current_date:
|
||||
self.update_positions_day()
|
||||
self.current_date = tick_date
|
||||
|
||||
self.trader.Join(tickdata=tick)
|
||||
self.update_account(created_at, tick["LastPrice"], tick["InstrumentID"])
|
||||
|
||||
def update_positions_day(self):
|
||||
for position in self.positions.values():
|
||||
position["long"]["yesterday"] += position["long"]["today"]
|
||||
position["long"]["today"] = 0
|
||||
position["short"]["yesterday"] += position["short"]["today"]
|
||||
position["short"]["today"] = 0
|
||||
|
||||
def update_account(self, datetime, last_price, instrument_id):
|
||||
position_value = 0
|
||||
for inst, pos in self.positions.items():
|
||||
price = (
|
||||
last_price
|
||||
if inst == instrument_id
|
||||
else self.positions[inst]["last_price"]
|
||||
)
|
||||
long_value = (pos["long"]["today"] + pos["long"]["yesterday"]) * price
|
||||
short_value = (pos["short"]["today"] + pos["short"]["yesterday"]) * price
|
||||
position_value += long_value - short_value
|
||||
current_equity = self.cash + position_value
|
||||
self.equity_curve.append((datetime, current_equity))
|
||||
|
||||
def mock_insert_order(
|
||||
self, exchange_id, instrument_id, price, volume, direction, offset
|
||||
):
|
||||
is_buy = direction == b"0"
|
||||
is_open = offset == b"0"
|
||||
is_close_today = offset == b"3"
|
||||
|
||||
if instrument_id not in self.positions:
|
||||
self.positions[instrument_id] = {
|
||||
"long": {"today": 0, "yesterday": 0},
|
||||
"short": {"today": 0, "yesterday": 0},
|
||||
"last_price": price,
|
||||
}
|
||||
|
||||
position = self.positions[instrument_id]
|
||||
|
||||
if is_open:
|
||||
if is_buy:
|
||||
position["long"]["today"] += volume
|
||||
self.cash -= price * volume
|
||||
else:
|
||||
position["short"]["today"] += volume
|
||||
self.cash += price * volume
|
||||
else: # close
|
||||
if is_buy: # buy to close short
|
||||
if is_close_today:
|
||||
position["short"]["today"] -= volume
|
||||
else:
|
||||
if position["short"]["yesterday"] >= volume:
|
||||
position["short"]["yesterday"] -= volume
|
||||
else:
|
||||
remaining = volume - position["short"]["yesterday"]
|
||||
position["short"]["yesterday"] = 0
|
||||
position["short"]["today"] -= remaining
|
||||
self.cash -= price * volume
|
||||
else: # sell to close long
|
||||
if is_close_today:
|
||||
position["long"]["today"] -= volume
|
||||
else:
|
||||
if position["long"]["yesterday"] >= volume:
|
||||
position["long"]["yesterday"] -= volume
|
||||
else:
|
||||
remaining = volume - position["long"]["yesterday"]
|
||||
position["long"]["yesterday"] = 0
|
||||
position["long"]["today"] -= remaining
|
||||
self.cash += price * volume
|
||||
|
||||
position["last_price"] = price
|
||||
|
||||
def calculate_performance(self):
|
||||
df = pd.DataFrame(self.equity_curve, columns=["time", "equity"])
|
||||
df["time"] = pd.to_datetime(df["time"])
|
||||
df.set_index("time", inplace=True)
|
||||
|
||||
if len(df) < 2:
|
||||
print("警告:回测数据点不足,无法计算性能指标。")
|
||||
return {
|
||||
"total_return": 0,
|
||||
"sharpe_ratio": 0,
|
||||
"max_drawdown": 0,
|
||||
"equity_curve": df,
|
||||
}
|
||||
|
||||
df["returns"] = df["equity"].pct_change()
|
||||
|
||||
total_return = (df["equity"].iloc[-1] - df["equity"].iloc[0]) / df[
|
||||
"equity"
|
||||
].iloc[0]
|
||||
sharpe_ratio = np.sqrt(len(df)) * df["returns"].mean() / df["returns"].std()
|
||||
|
||||
drawdown = df["equity"] / df["equity"].cummax() - 1
|
||||
max_drawdown = drawdown.min()
|
||||
|
||||
return {
|
||||
"total_return": total_return,
|
||||
"sharpe_ratio": sharpe_ratio,
|
||||
"max_drawdown": max_drawdown,
|
||||
"equity_curve": df,
|
||||
}
|
||||
|
||||
def plot_performance(self):
|
||||
performance = self.calculate_performance()
|
||||
equity_curve = performance["equity_curve"]
|
||||
|
||||
plt.figure(figsize=(12, 8))
|
||||
plt.plot(equity_curve.index, equity_curve["equity"])
|
||||
plt.title("Equity Curve")
|
||||
plt.xlabel("Time")
|
||||
plt.ylabel("Equity")
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
|
||||
print(f"Total Return: {performance['total_return']:.6%}")
|
||||
print(f"Sharpe Ratio: {performance['sharpe_ratio']:.6f}")
|
||||
print(f"Max Drawdown: {performance['max_drawdown']:.6%}")
|
||||
|
||||
|
||||
# 定义中文表头到英文表头的映射
|
||||
header_mapping = {
|
||||
"交易日": "TradingDay",
|
||||
"合约代码": "InstrumentID",
|
||||
"交易所代码": "ExchangeID",
|
||||
"合约在交易所的代码": "ExchangeInstID",
|
||||
"最新价": "LastPrice",
|
||||
"上次结算价": "PreSettlementPrice",
|
||||
"昨收盘": "PreClosePrice",
|
||||
"昨持仓量": "PreOpenInterest",
|
||||
"今开盘": "OpenPrice",
|
||||
"最高价": "HighestPrice",
|
||||
"最低价": "LowestPrice",
|
||||
"数量": "Volume",
|
||||
"成交金额": "Turnover",
|
||||
"持仓量": "OpenInterest",
|
||||
"今收盘": "ClosePrice",
|
||||
"本次结算价": "SettlementPrice",
|
||||
"涨停板价": "UpperLimitPrice",
|
||||
"跌停板价": "LowerLimitPrice",
|
||||
"昨虚实度": "PreDelta",
|
||||
"今虚实度": "CurrDelta",
|
||||
"最后修改时间": "UpdateTime",
|
||||
"最后修改毫秒": "UpdateMillisec",
|
||||
"申买价一": "BidPrice1",
|
||||
"申买量一": "BidVolume1",
|
||||
"申卖价一": "AskPrice1",
|
||||
"申卖量一": "AskVolume1",
|
||||
"申买价二": "BidPrice2",
|
||||
"申买量二": "BidVolume2",
|
||||
"申卖价二": "AskPrice2",
|
||||
"申卖量二": "AskVolume2",
|
||||
"申买价三": "BidPrice3",
|
||||
"申买量三": "BidVolume3",
|
||||
"申卖价三": "AskPrice3",
|
||||
"申卖量三": "AskVolume3",
|
||||
"申买价四": "BidPrice4",
|
||||
"申买量四": "BidVolume4",
|
||||
"申卖价四": "AskPrice4",
|
||||
"申卖量四": "AskVolume4",
|
||||
"申买价五": "BidPrice5",
|
||||
"申买量五": "BidVolume5",
|
||||
"申卖价五": "AskPrice5",
|
||||
"申卖量五": "AskVolume5",
|
||||
"当日均价": "AveragePrice",
|
||||
"业务日期": "ActionDay",
|
||||
}
|
||||
|
||||
|
||||
def load_and_process_data(folder_path):
|
||||
dfs = []
|
||||
for filename in os.listdir(folder_path):
|
||||
file_path = os.path.join(folder_path, filename)
|
||||
|
||||
try:
|
||||
if filename.endswith(".gz"):
|
||||
# 处理 GZ 文件
|
||||
with gzip.open(file_path, "rt", encoding="gbk") as gz_file:
|
||||
csv_data = io.StringIO(gz_file.read())
|
||||
df = pd.read_csv(csv_data, parse_dates=["业务日期", "最后修改时间"])
|
||||
elif filename.endswith(".csv"):
|
||||
# 处理 CSV 文件
|
||||
df = pd.read_csv(
|
||||
file_path,
|
||||
encoding="utf-8",
|
||||
parse_dates=["业务日期", "最后修改时间"],
|
||||
)
|
||||
else:
|
||||
# 跳过非 GZ 和非 CSV 文件
|
||||
print(f"Skipping {filename}: not a GZ or CSV file")
|
||||
continue
|
||||
|
||||
# 重命名列
|
||||
df.rename(columns=header_mapping, inplace=True)
|
||||
|
||||
dfs.append(df)
|
||||
print(f"Successfully read {filename}")
|
||||
except Exception as e:
|
||||
print(f"Error reading {filename}: {str(e)}")
|
||||
print("Skipping this file.")
|
||||
continue
|
||||
|
||||
if dfs:
|
||||
data = pd.concat(dfs, ignore_index=True)
|
||||
data.sort_values(["ActionDay", "UpdateTime", "UpdateMillisec"], inplace=True)
|
||||
data = data.reset_index(drop=True) # 重置索引
|
||||
return data
|
||||
else:
|
||||
print("没有找到可读取的GZ或CSV文件")
|
||||
return None
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 使用示例
|
||||
folder_path = "./回测数据" # 替换为您的数据文件夹路径
|
||||
# 读取,排序,合并tick数据
|
||||
data = load_and_process_data(folder_path)
|
||||
print(data)
|
||||
if data is not None:
|
||||
# 初始化回测引擎,设置策略和初始资金
|
||||
backtest = BacktestEngine(专享08of, initial_capital=10000)
|
||||
|
||||
# 替换MyTrader中的insert_order方法
|
||||
backtest.trader.insert_order = backtest.mock_insert_order
|
||||
|
||||
# # 运行回测
|
||||
# backtest.run(data)
|
||||
|
||||
backtest.run(
|
||||
data,
|
||||
# start=1000,
|
||||
# end=3000,
|
||||
# start_date=pd.to_datetime('2023-01-01'),
|
||||
# end_date=pd.to_datetime('2023-01-31')
|
||||
)
|
||||
# 显示回测结果
|
||||
backtest.plot_performance()
|
||||
Reference in New Issue
Block a user