319 lines
9.9 KiB
Python
319 lines
9.9 KiB
Python
from abc import ABC
|
||
from typing import Union, Type, TYPE_CHECKING
|
||
from collections import defaultdict
|
||
|
||
from vnpy.trader.constant import Interval, Direction, Offset
|
||
from vnpy.trader.object import BarData, TickData, OrderData, TradeData, ContractData
|
||
from vnpy.trader.utility import virtual
|
||
|
||
from .object import PortfolioData
|
||
|
||
if TYPE_CHECKING:
|
||
from .engine import StrategyEngine
|
||
|
||
|
||
class StrategyTemplate(ABC):
|
||
"""期权策略模板"""
|
||
|
||
author: str = ""
|
||
|
||
def __init__(self, engine: "StrategyEngine", name: str) -> None:
|
||
"""构造函数"""
|
||
self.engine: "StrategyEngine" = engine
|
||
self.name: str = name
|
||
|
||
self.vt_symbols: set[str] = set()
|
||
self.inited: bool = False
|
||
self.trading: bool = False
|
||
|
||
# 初始化变量和参数列表
|
||
if not hasattr(self, "parameters"):
|
||
self.parameters: list[str] = []
|
||
|
||
if not hasattr(self, "variables"):
|
||
self.variables: list[str] = []
|
||
|
||
self.variables = ["inited", "trading", "pos_data", "target_data"] + self.variables
|
||
|
||
# 委托缓存数据
|
||
self.orders: dict[str, OrderData] = {}
|
||
self.active_orderids: set[str] = set()
|
||
|
||
# 持仓和目标数据
|
||
self.pos_data: dict[str, int] = defaultdict(int) # 实际持仓
|
||
self.target_data: dict[str, int] = defaultdict(int) # 目标持仓
|
||
|
||
# 期权组合
|
||
self.portfolios: dict[str, PortfolioData] = {}
|
||
|
||
def update_setting(self, setting: dict) -> None:
|
||
"""更新策略参数"""
|
||
for name in self.parameters:
|
||
if name in setting:
|
||
setattr(self, name, setting[name])
|
||
|
||
def get_parameters(self) -> dict:
|
||
"""查询策略参数"""
|
||
strategy_parameters: dict = {}
|
||
for name in self.parameters:
|
||
strategy_parameters[name] = getattr(self, name)
|
||
return strategy_parameters
|
||
|
||
def get_variables(self) -> dict:
|
||
"""查询策略变量"""
|
||
strategy_variables: dict = {}
|
||
for name in self.variables:
|
||
strategy_variables[name] = getattr(self, name)
|
||
return strategy_variables
|
||
|
||
def get_data(self) -> dict:
|
||
"""查询策略状态数据"""
|
||
strategy_data: dict = {
|
||
"strategy_name": self.name,
|
||
"vt_symbols": self.vt_symbols,
|
||
"class_name": self.__class__.__name__,
|
||
"author": self.author,
|
||
"parameters": self.get_parameters(),
|
||
"variables": self.get_variables(),
|
||
}
|
||
return strategy_data
|
||
|
||
@virtual
|
||
def on_init(self) -> None:
|
||
"""策略初始化"""
|
||
pass
|
||
|
||
@virtual
|
||
def on_start(self) -> None:
|
||
"""策略启动"""
|
||
pass
|
||
|
||
@virtual
|
||
def on_stop(self) -> None:
|
||
"""策略停止"""
|
||
pass
|
||
|
||
@virtual
|
||
def on_tick(self, tick: TickData) -> None:
|
||
"""Tick推送"""
|
||
pass
|
||
|
||
@virtual
|
||
def on_bars(self, bars: dict[str, BarData]) -> None:
|
||
"""K线推送"""
|
||
pass
|
||
|
||
def update_trade(self, trade: TradeData) -> None:
|
||
"""更新成交数据"""
|
||
if trade.direction == Direction.LONG:
|
||
self.pos_data[trade.vt_symbol] += trade.volume
|
||
else:
|
||
self.pos_data[trade.vt_symbol] -= trade.volume
|
||
|
||
def update_order(self, order: OrderData) -> None:
|
||
"""更新委托数据"""
|
||
self.orders[order.vt_orderid] = order
|
||
|
||
if not order.is_active() and order.vt_orderid in self.active_orderids:
|
||
self.active_orderids.remove(order.vt_orderid)
|
||
|
||
def buy(self, vt_symbol: str, price: float, volume: float) -> list[str]:
|
||
"""买入开仓"""
|
||
return self.send_order(vt_symbol, Direction.LONG, Offset.OPEN, price, volume)
|
||
|
||
def sell(self, vt_symbol: str, price: float, volume: float) -> list[str]:
|
||
"""卖出平仓"""
|
||
return self.send_order(vt_symbol, Direction.SHORT, Offset.CLOSE, price, volume)
|
||
|
||
def short(self, vt_symbol: str, price: float, volume: float) -> list[str]:
|
||
"""卖出开仓"""
|
||
return self.send_order(vt_symbol, Direction.SHORT, Offset.OPEN, price, volume)
|
||
|
||
def cover(self, vt_symbol: str, price: float, volume: float) -> list[str]:
|
||
"""买入平仓"""
|
||
return self.send_order(vt_symbol, Direction.LONG, Offset.CLOSE, price, volume)
|
||
|
||
def send_order(
|
||
self,
|
||
vt_symbol: str,
|
||
direction: Direction,
|
||
offset: Offset,
|
||
price: float,
|
||
volume: float,
|
||
) -> list[str]:
|
||
"""委托下单"""
|
||
if not self.trading:
|
||
return []
|
||
|
||
vt_orderids = self.engine.send_order(
|
||
self,
|
||
vt_symbol,
|
||
direction,
|
||
offset,
|
||
price,
|
||
volume
|
||
)
|
||
self.active_orderids.update(vt_orderids)
|
||
return vt_orderids
|
||
|
||
def cancel_order(self, vt_orderid: str) -> None:
|
||
"""委托撤单"""
|
||
if not self.trading:
|
||
return
|
||
|
||
self.engine.cancel_order(self, vt_orderid)
|
||
|
||
def cancel_all(self) -> None:
|
||
"""全撤委托"""
|
||
for vt_orderid in list(self.active_orderids):
|
||
self.cancel_order(vt_orderid)
|
||
|
||
def get_pos(self, vt_symbol: str) -> int:
|
||
"""查询当前持仓"""
|
||
return self.pos_data.get(vt_symbol, 0)
|
||
|
||
def get_target(self, vt_symbol: str) -> int:
|
||
"""查询目标仓位"""
|
||
return self.target_data[vt_symbol]
|
||
|
||
def set_target(self, vt_symbol: str, target: int) -> None:
|
||
"""设置目标仓位"""
|
||
self.target_data[vt_symbol] = target
|
||
|
||
def clear_targets(self) -> None:
|
||
"""清空目标仓位"""
|
||
self.target_data.clear()
|
||
|
||
def execute_trading(self, price_data: dict[str, float], percent_add: float) -> None:
|
||
"""基于目标执行调仓交易"""
|
||
self.cancel_all()
|
||
|
||
# 只发出当前K线切片有行情的合约的委托
|
||
for vt_symbol, price in price_data.items():
|
||
# 计算仓差
|
||
target: int = self.get_target(vt_symbol)
|
||
pos: int = self.get_pos(vt_symbol)
|
||
diff: int = target - pos
|
||
|
||
# 多头
|
||
if diff > 0:
|
||
# 计算多头委托价
|
||
order_price: float = price * (1 + percent_add)
|
||
|
||
# 计算买平和买开数量
|
||
cover_volume: int = 0
|
||
buy_volume: int = 0
|
||
|
||
if pos < 0:
|
||
cover_volume = min(diff, abs(pos))
|
||
buy_volume = diff - cover_volume
|
||
else:
|
||
buy_volume = diff
|
||
|
||
# 发出对应委托
|
||
if cover_volume:
|
||
self.cover(vt_symbol, order_price, cover_volume)
|
||
|
||
if buy_volume:
|
||
self.buy(vt_symbol, order_price, buy_volume)
|
||
# 空头
|
||
elif diff < 0:
|
||
# 计算空头委托价
|
||
order_price: float = price * (1 - percent_add)
|
||
|
||
# 计算卖平和卖开数量
|
||
sell_volume: int = 0
|
||
short_volume: int = 0
|
||
|
||
if pos > 0:
|
||
sell_volume = min(abs(diff), pos)
|
||
short_volume = abs(diff) - sell_volume
|
||
else:
|
||
short_volume = abs(diff)
|
||
|
||
# 发出对应委托
|
||
if sell_volume:
|
||
self.sell(vt_symbol, order_price, sell_volume)
|
||
|
||
if short_volume:
|
||
self.short(vt_symbol, order_price, short_volume)
|
||
|
||
def write_log(self, msg: str) -> None:
|
||
"""输出日志"""
|
||
self.engine.write_log(msg, self)
|
||
|
||
def get_portfolio(self, portfolio_name: str) -> PortfolioData:
|
||
"""获取期权组合"""
|
||
return self.portfolios[portfolio_name]
|
||
|
||
def load_bars(self, vt_symbol: str, days: int, interval: Interval) -> list[BarData]:
|
||
"""加载历史K线数据"""
|
||
return self.engine.load_bars(vt_symbol, days, interval)
|
||
|
||
def init_portfolio(self, portfolio_name: str) -> bool:
|
||
"""查询期权合约"""
|
||
# 避免重复订阅
|
||
if portfolio_name in self.portfolios:
|
||
return True
|
||
|
||
# 向引擎发起查询
|
||
contracts: list[ContractData] = self.engine.init_portfolio(self, portfolio_name)
|
||
|
||
# 如果有返回合约,则创建PortfolioData对象
|
||
if contracts:
|
||
portfolio: PortfolioData = PortfolioData(contracts[0])
|
||
self.portfolios[portfolio_name] = portfolio
|
||
|
||
for contract in contracts:
|
||
self.vt_symbols.add(contract.vt_symbol)
|
||
portfolio.add_contract(contract)
|
||
|
||
return True
|
||
else:
|
||
return False
|
||
|
||
def subscribe_options(self, portfolio_name: str) -> bool:
|
||
"""订阅期权行情"""
|
||
# 避免重复订阅
|
||
if portfolio_name in self.portfolios:
|
||
return True
|
||
|
||
# 向引擎发起订阅
|
||
contracts: list[ContractData] = self.engine.subscribe_options(self, portfolio_name)
|
||
|
||
# 如果有返回合约,则创建PortfolioData对象
|
||
if contracts:
|
||
portfolio: PortfolioData = PortfolioData(contracts[0])
|
||
self.portfolios[portfolio_name] = portfolio
|
||
|
||
for contract in contracts:
|
||
self.vt_symbols.add(contract.vt_symbol)
|
||
portfolio.add_contract(contract)
|
||
|
||
return True
|
||
else:
|
||
return False
|
||
|
||
def subscribe_data(self, vt_symbol: str) -> bool:
|
||
"""订阅标的行情"""
|
||
n: bool = self.engine.subscribe_data(self, vt_symbol)
|
||
if n:
|
||
self.vt_symbols.add(vt_symbol)
|
||
|
||
return n
|
||
|
||
def put_event(self) -> None:
|
||
"""推送策略数据更新事件"""
|
||
if self.inited:
|
||
self.engine.put_strategy_event(self)
|
||
|
||
def send_email(self, msg: str) -> None:
|
||
"""发送邮件信息"""
|
||
if self.inited:
|
||
self.engine.send_email(msg, self)
|
||
|
||
def sync_data(self):
|
||
"""同步策略状态数据到文件"""
|
||
if self.trading:
|
||
self.engine.sync_strategy_data(self)
|