Files
Quant_Code/5.课程代码/2.Option_spread_strategy/使用文档/20/template.py

319 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from abc import ABC
from typing import Union, Type, TYPE_CHECKING
from collections import defaultdict
from vnpy.trader.constant import Interval, Direction, Offset
from vnpy.trader.object import BarData, TickData, OrderData, TradeData, ContractData
from vnpy.trader.utility import virtual
from .object import PortfolioData
if TYPE_CHECKING:
from .engine import StrategyEngine
class StrategyTemplate(ABC):
"""期权策略模板"""
author: str = ""
def __init__(self, engine: "StrategyEngine", name: str) -> None:
"""构造函数"""
self.engine: "StrategyEngine" = engine
self.name: str = name
self.vt_symbols: set[str] = set()
self.inited: bool = False
self.trading: bool = False
# 初始化变量和参数列表
if not hasattr(self, "parameters"):
self.parameters: list[str] = []
if not hasattr(self, "variables"):
self.variables: list[str] = []
self.variables = ["inited", "trading", "pos_data", "target_data"] + self.variables
# 委托缓存数据
self.orders: dict[str, OrderData] = {}
self.active_orderids: set[str] = set()
# 持仓和目标数据
self.pos_data: dict[str, int] = defaultdict(int) # 实际持仓
self.target_data: dict[str, int] = defaultdict(int) # 目标持仓
# 期权组合
self.portfolios: dict[str, PortfolioData] = {}
def update_setting(self, setting: dict) -> None:
"""更新策略参数"""
for name in self.parameters:
if name in setting:
setattr(self, name, setting[name])
def get_parameters(self) -> dict:
"""查询策略参数"""
strategy_parameters: dict = {}
for name in self.parameters:
strategy_parameters[name] = getattr(self, name)
return strategy_parameters
def get_variables(self) -> dict:
"""查询策略变量"""
strategy_variables: dict = {}
for name in self.variables:
strategy_variables[name] = getattr(self, name)
return strategy_variables
def get_data(self) -> dict:
"""查询策略状态数据"""
strategy_data: dict = {
"strategy_name": self.name,
"vt_symbols": self.vt_symbols,
"class_name": self.__class__.__name__,
"author": self.author,
"parameters": self.get_parameters(),
"variables": self.get_variables(),
}
return strategy_data
@virtual
def on_init(self) -> None:
"""策略初始化"""
pass
@virtual
def on_start(self) -> None:
"""策略启动"""
pass
@virtual
def on_stop(self) -> None:
"""策略停止"""
pass
@virtual
def on_tick(self, tick: TickData) -> None:
"""Tick推送"""
pass
@virtual
def on_bars(self, bars: dict[str, BarData]) -> None:
"""K线推送"""
pass
def update_trade(self, trade: TradeData) -> None:
"""更新成交数据"""
if trade.direction == Direction.LONG:
self.pos_data[trade.vt_symbol] += trade.volume
else:
self.pos_data[trade.vt_symbol] -= trade.volume
def update_order(self, order: OrderData) -> None:
"""更新委托数据"""
self.orders[order.vt_orderid] = order
if not order.is_active() and order.vt_orderid in self.active_orderids:
self.active_orderids.remove(order.vt_orderid)
def buy(self, vt_symbol: str, price: float, volume: float) -> list[str]:
"""买入开仓"""
return self.send_order(vt_symbol, Direction.LONG, Offset.OPEN, price, volume)
def sell(self, vt_symbol: str, price: float, volume: float) -> list[str]:
"""卖出平仓"""
return self.send_order(vt_symbol, Direction.SHORT, Offset.CLOSE, price, volume)
def short(self, vt_symbol: str, price: float, volume: float) -> list[str]:
"""卖出开仓"""
return self.send_order(vt_symbol, Direction.SHORT, Offset.OPEN, price, volume)
def cover(self, vt_symbol: str, price: float, volume: float) -> list[str]:
"""买入平仓"""
return self.send_order(vt_symbol, Direction.LONG, Offset.CLOSE, price, volume)
def send_order(
self,
vt_symbol: str,
direction: Direction,
offset: Offset,
price: float,
volume: float,
) -> list[str]:
"""委托下单"""
if not self.trading:
return []
vt_orderids = self.engine.send_order(
self,
vt_symbol,
direction,
offset,
price,
volume
)
self.active_orderids.update(vt_orderids)
return vt_orderids
def cancel_order(self, vt_orderid: str) -> None:
"""委托撤单"""
if not self.trading:
return
self.engine.cancel_order(self, vt_orderid)
def cancel_all(self) -> None:
"""全撤委托"""
for vt_orderid in list(self.active_orderids):
self.cancel_order(vt_orderid)
def get_pos(self, vt_symbol: str) -> int:
"""查询当前持仓"""
return self.pos_data.get(vt_symbol, 0)
def get_target(self, vt_symbol: str) -> int:
"""查询目标仓位"""
return self.target_data[vt_symbol]
def set_target(self, vt_symbol: str, target: int) -> None:
"""设置目标仓位"""
self.target_data[vt_symbol] = target
def clear_targets(self) -> None:
"""清空目标仓位"""
self.target_data.clear()
def execute_trading(self, price_data: dict[str, float], percent_add: float) -> None:
"""基于目标执行调仓交易"""
self.cancel_all()
# 只发出当前K线切片有行情的合约的委托
for vt_symbol, price in price_data.items():
# 计算仓差
target: int = self.get_target(vt_symbol)
pos: int = self.get_pos(vt_symbol)
diff: int = target - pos
# 多头
if diff > 0:
# 计算多头委托价
order_price: float = price * (1 + percent_add)
# 计算买平和买开数量
cover_volume: int = 0
buy_volume: int = 0
if pos < 0:
cover_volume = min(diff, abs(pos))
buy_volume = diff - cover_volume
else:
buy_volume = diff
# 发出对应委托
if cover_volume:
self.cover(vt_symbol, order_price, cover_volume)
if buy_volume:
self.buy(vt_symbol, order_price, buy_volume)
# 空头
elif diff < 0:
# 计算空头委托价
order_price: float = price * (1 - percent_add)
# 计算卖平和卖开数量
sell_volume: int = 0
short_volume: int = 0
if pos > 0:
sell_volume = min(abs(diff), pos)
short_volume = abs(diff) - sell_volume
else:
short_volume = abs(diff)
# 发出对应委托
if sell_volume:
self.sell(vt_symbol, order_price, sell_volume)
if short_volume:
self.short(vt_symbol, order_price, short_volume)
def write_log(self, msg: str) -> None:
"""输出日志"""
self.engine.write_log(msg, self)
def get_portfolio(self, portfolio_name: str) -> PortfolioData:
"""获取期权组合"""
return self.portfolios[portfolio_name]
def load_bars(self, vt_symbol: str, days: int, interval: Interval) -> list[BarData]:
"""加载历史K线数据"""
return self.engine.load_bars(vt_symbol, days, interval)
def init_portfolio(self, portfolio_name: str) -> bool:
"""查询期权合约"""
# 避免重复订阅
if portfolio_name in self.portfolios:
return True
# 向引擎发起查询
contracts: list[ContractData] = self.engine.init_portfolio(self, portfolio_name)
# 如果有返回合约则创建PortfolioData对象
if contracts:
portfolio: PortfolioData = PortfolioData(contracts[0])
self.portfolios[portfolio_name] = portfolio
for contract in contracts:
self.vt_symbols.add(contract.vt_symbol)
portfolio.add_contract(contract)
return True
else:
return False
def subscribe_options(self, portfolio_name: str) -> bool:
"""订阅期权行情"""
# 避免重复订阅
if portfolio_name in self.portfolios:
return True
# 向引擎发起订阅
contracts: list[ContractData] = self.engine.subscribe_options(self, portfolio_name)
# 如果有返回合约则创建PortfolioData对象
if contracts:
portfolio: PortfolioData = PortfolioData(contracts[0])
self.portfolios[portfolio_name] = portfolio
for contract in contracts:
self.vt_symbols.add(contract.vt_symbol)
portfolio.add_contract(contract)
return True
else:
return False
def subscribe_data(self, vt_symbol: str) -> bool:
"""订阅标的行情"""
n: bool = self.engine.subscribe_data(self, vt_symbol)
if n:
self.vt_symbols.add(vt_symbol)
return n
def put_event(self) -> None:
"""推送策略数据更新事件"""
if self.inited:
self.engine.put_strategy_event(self)
def send_email(self, msg: str) -> None:
"""发送邮件信息"""
if self.inited:
self.engine.send_email(msg, self)
def sync_data(self):
"""同步策略状态数据到文件"""
if self.trading:
self.engine.sync_strategy_data(self)