当前位置: 首页 > 工具软件 > FinRL > 使用案例 >

11.15 分析finrl 的框架

董霖
2023-12-01

1、创建agent

agent = DRLAgent(env = env_train)
SAC_PARAMS = {
    "batch_size": 128,
    "buffer_size": 1000000,
    "learning_rate": 0.0001,
    "learning_starts": 100,
    "ent_coef": "auto_0.1",
    "device": "cpu", 
}

model_sac = agent.get_model("sac",model_kwargs = SAC_PARAMS)


if if_using_sac:
  # set up logger
  tmp_path = RESULTS_DIR + '/sac'
  new_logger_sac = configure(tmp_path, ["stdout", "csv", "tensorboard"])
  # Set new logger
  model_sac.set_logger(new_logger_sac)

agent 来自 DRLAgent类,SAC_PARAMS为参数,model_sac 从DRLAgent get_model 函数获得

2、DRLAgent 类从stable_baselines3 模块 加载 相应的SAC模型

from stable_baselines3 import SAC

DRLAgent(env) 需要把环境变量env传入进去

3、get_model函数

def get_model(
        self,
        model_name,
        policy="MlpPolicy",
        policy_kwargs=None,
        model_kwargs=None,
        verbose=1,
        seed=None,
        tensorboard_log=None,
    ):
        if model_name not in MODELS:
            raise NotImplementedError("NotImplementedError")

        if model_kwargs is None:
            model_kwargs = MODEL_KWARGS[model_name]

        if "action_noise" in model_kwargs:
            n_actions = self.env.action_space.shape[-1]
            model_kwargs["action_noise"] = NOISE[model_kwargs["action_noise"]](
                mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)
            )
        print(model_kwargs)
        return MODELS[model_name](
            policy=policy,
            env=self.env,
            tensorboard_log=tensorboard_log,
            verbose=verbose,
            policy_kwargs=policy_kwargs,
            seed=seed,
            **model_kwargs,
        )

env 被传入到 sac 模型中 

4、模型的训练

trained_sac = agent.train_model(model=model_sac, 
                             tb_log_name='sac',
                             total_timesteps=30000) if if_using_sac else None
def train_model(self, model, tb_log_name, total_timesteps=5000):
        model = model.learn(
            total_timesteps=total_timesteps,
            tb_log_name=tb_log_name,
            callback=TensorboardCallback(),
        )
        return model
class SAC(OffPolicyAlgorithm):
    def learn(
        self: SACSelf,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 4,
        tb_log_name: str = "SAC",
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ) -> SACSelf:

        return super().learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            tb_log_name=tb_log_name,
            reset_num_timesteps=reset_num_timesteps,
            progress_bar=progress_bar,
        )

5、环境env

from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
env_kwargs = {
    "hmax": 100, 
    "initial_amount": 1000000, 
    "buy_cost_pct": 0.001,
    "sell_cost_pct": 0.001,
    "state_space": state_space, #511 = 1 + 2 * 30 + 15 * 30
    "stock_dim": stock_dimension,  # 30
    "tech_indicator_list": ratio_list, # 15列表
    "action_space": stock_dimension, # 30
    "reward_scaling": 1e-4
    
}
e_train_gym = StockTradingEnv(df = train_data, **env_kwargs)
env_train, _ = e_train_gym.get_sb_env()

# 这个模型通过扫描30只股票 
# 前期处理数据 形成 15个技术指标类的因子 
# state space = 1 + 2 * 30 + 15 * 30 =511 这是状态空间
# action space = 30
# 实现 训练model 
# 通过action = anget.predict(state) 作出决策actions
# 通过api 同时交易 30只股票
class StockTradingEnv(gym.Env):
class StockTradingEnv(gym.Env):
def __init__()
    ''' 参数 '''
    # initalize state 初始化state
    self.state = self._initiate_state()

    # initialize reward 初始化 reward
    self.reward = 0
    self.turbulence = 0
    self.cost = 0
    self.trades = 0
    self.episode = 0
    # memorize all the total balance change 一些记忆
    self.asset_memory = [self.initial_amount]
    self.rewards_memory = []
    self.actions_memory = []
    self.date_memory = [self._get_date()]
    # self.reset() 重置
    self._seed()
def _sell_stock(self, index, action)
def _buy_stock(self, index, action)
# 重点
def step(self, actions):
    # if not terminal 先计算actions的值
    actions = actions * self.hmax  # actions initially is scaled between 0 to 1
    actions = actions.astype(int)
    # 更新 状态 state
    self.state = self._update_state()
    # 分别计算 begin_total_asset end_total_asset
    # 存储一些信息
    self.asset_memory.append(end_total_asset)
    self.date_memory.append(self._get_date()))
    self.reward = end_total_asset - begin_total_asset # 计算一下 当下step 的reward
    self.rewards_memory.append(self.reward)
    self.reward = self.reward * self.reward_scaling
    return self.state, self.reward, self.terminal, {}


# 1、重点关注 输入 学习的state 是什么模样 的 

# 2、sac模型的介绍

pandas sort_values

pandas DataFrame.sort_values(by, 
               axis=0, 
               ascending=True, 
               inplace=False, 
               kind='quicksort', 
               na_position='last', # last,first;默认是last
               ignore_index=False, 
               key=None)
by:表示根据什么字段或者索引进行排序,可以是一个或多个
axis:排序是在横轴还是纵轴,默认是纵轴axis=0
ascending:排序结果是升序还是降序,默认是升序
inplace:表示排序的结果是直接在原数据上的就地修改还是生成新的DatFrame
kind:表示使用排序的算法,快排quicksort,,归并mergesort, 堆排序heapsort,稳定排序stable ,默认是 :快排quicksort
na_position:缺失值的位置处理,默认是最后,另一个选择是首位
ignore_index:新生成的数据帧的索引是否重排,默认False(采用原数据的索引)
key:排序之前使用的函数

# 尝试一下 标普期货30分钟数据作为训练集 怎么样

# 下载talib https://www.lfd.uci.edu/~gohlke/pythonlibs/

 类似资料: