Source code for a2rl._dataframe

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import annotations

import warnings
from os import PathLike
from pathlib import Path
from typing import Any, Callable, Collection, Iterable, Literal, Protocol

import gym
import numpy as np
import pandas as pd

import a2rl as wi

from ._io import Metadata, save_metadata


# TODO
def _get_constructor(klass, example: HasSarAttributes, *args, **kwargs) -> Callable:
    def _constructor(*args, **kwargs):
        return klass(
            *args,
            states=example.states,
            actions=example.actions,
            rewards=example.rewards,
            **kwargs,
        )

    return _constructor


# TODO
def _sar_from(o: HasSarAttributes):
    return o.states, o.actions, o.rewards


def _pre_to_csv(
    o: WiDataFrame | WiSeries,
    path_or_buf: str | PathLike[str],
    forced_categories: None | list[str] = None,
    compact: bool = False,
) -> Path:
    """Create output directory and metadata.

    Args:
        path_or_buf: The path name of the output dir.

    Returns:
        Path: path to output directory.
    """
    p = path_or_buf if isinstance(path_or_buf, Path) else Path(path_or_buf)
    if p.exists():
        warnings.warn(f"Existing directory {p} will be overwritten.")
    p.mkdir(parents=True, exist_ok=True)
    save_metadata(
        Metadata(
            states=o.states,
            actions=o.actions,
            rewards=o.rewards,
            forced_categories=forced_categories,
        ),
        p / "metadata.yaml",
        compact=compact,
    )
    return p


class HasSarAttributes(Protocol):
    @property
    def states(self) -> list[str]:
        raise NotImplementedError

    @property
    def actions(self) -> list[str]:
        raise NotImplementedError

    @property
    def rewards(self) -> list[str]:
        raise NotImplementedError


class SarMixin:
    # Declarations for mypy
    _states: list[str]  #: Expected state colum names (list[str])
    _actions: list[str]  #: Expected action column names (list[str])
    _rewards: list[str]  #: Expected rewards column names (list[str])

    @property
    def sar_d(self) -> dict[str, list[str]]:
        """The dictionary of 585 expected *sar* column names.

        Returns:
            ``{'states': [str], 'actions': [str], 'rewards': [str]}``.

        See Also
        --------
        sar : The list of the expected *sar* columns.
        states : The expected state columns.
        actions : The expected action columns.
        rewards : The expected reward columns.
        """
        return dict(
            states=list(self._states),
            actions=list(self._actions),
            rewards=list(self._rewards),
        )

    @property
    def sar(self) -> list[str]:
        """The list of the expected *sar* column names.

        Returns:
            The expected *sar* column names, in the order of states, actions, and reward.

        See Also
        --------
        sar_d : The dictionary of the expected *sar* columns.
        states : The expected state columns.
        actions : The expected action columns.
        rewards : The expected reward columns.
        """
        return self.states + self.actions + self.rewards

    def _set_sar(self, **kwargs) -> None:
        """Set ``self._{k} = {v}`` for each ``k:v`` in ``kwargs``."""
        for k, v in kwargs.items():
            if v is not None:
                setattr(self, f"_{k}", list(v))

    @property
    def states(self) -> list[str]:
        """The list of the expected column names of states.

        Returns:
            The expected column names for states.

        See Also
        --------
        sar : The list of the expected *sar* columns.
        sar_d : The dictionary of the expected *sar* columns.
        actions : The expected action columns.
        rewards : The expected reward columns.
        """
        return list(self._states)

    @property
    def actions(self) -> list[str]:
        """The list of the expected column names of actions.

        Returns:
            The expected column names for actions.

        See Also
        --------
        sar : The list of the expected *sar* columns.
        sar_d : The dictionary of the expected *sar* columns.
        states : The expected state columns.
        rewards : The expected reward columns.
        """
        return list(self._actions)

    @property
    def rewards(self) -> list[str]:
        """The list of the expected column names of rewards.

        Returns:
            The expected column names for rewards.

        See Also
        --------
        sar : The list of the expected *sar* columns.
        sar_d : The dictionary of the expected *sar* columns.
        states : The expected state columns.
        actions : The expected action columns.
        """
        return list(self._rewards)


[docs]class WiSeries(pd.Series, SarMixin): # type: ignore[misc] _metadata = ["_states", "_actions", "_rewards"] def __init__( self, data=None, states: None | Collection[str] = None, actions: None | Collection[str] = None, rewards: None | Collection[str] = None, **kwargs, ) -> None: """A ``WiSeries`` object is a :class:`pandas.Series` with additional metadata on the expected column names for ``states``, ``actions``, and ``rewards`` (i.e., the *sar* columns). In addition to the standard :class:`pandas.Series` constructor arguments, a ``WiSeries`` also accepts the following keyword arguments: Args: states: The expected column names for states. actions: The expected column names for actions. rewards: The expected column names for rewards. .. warning:: This class is mainly used internally by ``whatif``. By design, the name of a ``WiSeries`` is equal to zero or one expected *sar* column name. See Also -------- WiDataFrame pandas.Series Examples: Create a new ``WiSeries``: .. code-block:: python >>> import a2rl as wi >>> ser = wi.WiSeries( ... [11, 12, 13], ... name="s0", ... states=["s1", "s2"], ... actions=["a"], ... rewards=["r"], ... ) >>> ser 0 11 1 12 2 13 Name: s0, dtype: int64 >>> ser.sar ['s1', 's2', 'a', 'r'] Inherit *sar* columns from the source ``WiDataFrame``: .. code-block:: python >>> df = wi.WiDataFrame( ... { ... "s": [0, 1, 2], ... "a": ["x", "y", "z"], ... "r": [0.5, 1.5, 2.5], ... }, ... states=["s"], ... actions=["a"], ... rewards=["r"], ... ) >>> ser = df['a'] >>> ser.sar ['s', 'a', 'r'] """ if rewards and len(rewards) > 2: raise ValueError(f"rewards can have at most two columns, but received {rewards}") super().__init__(data=data, **kwargs) if isinstance(data, SarMixin) and [states, actions, rewards] == [None] * 3: states, actions, rewards = _sar_from(data) self._set_sar(states=states, actions=actions, rewards=rewards) @property def _constructor(self): return _get_constructor(WiSeries, self) @property def _constructor_expanddim(self): _c_e = _get_constructor(WiDataFrame, self) # See: # https://github.com/geopandas/geopandas/blob/51864acf3dd0bcbc74b2a922c6e012d7e57e46b5/geopandas/geoseries.py#L66-L69 # # pd.concat (pandas/core/reshape/concat.py) requires this for the # concatenation of series since pandas 1.1 # (https://github.com/pandas-dev/pandas/commit/f9e4c8c84bcef987973f2624cc2932394c171c8c) # # E.g., required by df.groupby().agg({'a': 'min', 'b': 'max'}) _c_e._get_axis_number = WiDataFrame._get_axis_number return _c_e @property def _values(self) -> np.ndarray: # https://github.com/pandas-dev/pandas/issues/46554#issuecomment-1084305476 return super()._values
[docs] def to_csv_dataset( self, path_or_buf: str | PathLike[str], *args, forced_categories: None | Iterable[str] = None, compact: bool = False, **kwargs, ) -> None: """Save this series as a ``Whatif`` dataset. This method has similar signatures to :meth:`pandas.Series.to_csv()`, however with some changes. Args: path_or_buf: Unlike :meth:`pandas.Series.to_csv()`, this accepts only path name of the output dir. args: passed to :meth:`pandas.Series.to_csv()`. compact: When True, do not write ``None`` entries to the output metadata YAML. kwargs: passed to :meth:`pandas.Series.to_csv()`. """ if not (forced_categories is None or isinstance(forced_categories, list)): forced_categories = list(forced_categories) outdir = _pre_to_csv( self, path_or_buf, forced_categories=forced_categories, compact=compact, ) self.to_csv(outdir / "data.csv", *args, **kwargs)
[docs]class WiDataFrame(pd.DataFrame, SarMixin): _metadata = ["_states", "_actions", "_rewards"] def __init__( self, data=None, states: None | Collection[str] = None, actions: None | Collection[str] = None, rewards: None | Collection[str] = None, **kwargs, ) -> None: """A ``WiDataFrame`` object is a :class:`pandas.DataFrame` with additional metadata on the expected column names for ``states``, ``actions``, and ``rewards`` (i.e., the *sar* columns). In addition to the standard :class:`pandas.DataFrame` constructor arguments, a ``WiDataFrame`` also accepts the following keyword arguments: Args: states: The expected column names for states. actions: The expected column names for actions. rewards: The expected column names for rewards. .. note:: By design, a ``WiDataFrame`` itself may miss one or more of the *sar* columns. Downstream tasks should deal with missing *sar* columns. Some downstream tasks such as slicing ignores the discrepancy, while RL-related tasks may require all *sar* columns presented. See Also -------- WiSeries pandas.DataFrame Examples -------- Create a new ``WiDataFrame``: .. code-block:: python >>> import a2rl as wi >>> df = wi.WiDataFrame( ... { ... "s1": [1, 2, 3], ... "s2": [3, 4, 5], ... "sess": [0, 0, 0], ... "z": [6, 7, 8], ... "a": ["x", "y", "z"], ... "r": [0.5, 1.5, 2.5], ... }, ... states=["s1", "s2"], ... actions=["a"], ... rewards=["r"], ... ) >>> df s1 s2 sess z a r 0 1 3 0 6 x 0.5 1 2 4 0 7 y 1.5 2 3 5 0 8 z 2.5 Check the metadata: .. code-block:: python >>> df.sar ['s1', 's2', 'a', 'r'] >>> df.sar_d {'states': ['s1', 's2'], 'actions': ['a'], 'rewards': ['r']} >>> df.states ['s1', 's2'] >>> df.actions ['a'] >>> df.rewards ['r'] Slice the states. The resulted ``WiDataFrame`` or ``WiSeries`` inherits the expected *sar* columns from the source ``DataFrame``. .. code-block:: python >>> df[df.states] s1 s2 0 1 3 1 2 4 2 3 5 >>> df[df.states].sar ['s1', 's2', 'a', 'r'] Take just the *sar* columns: .. code-block:: python >>> df.trim() s1 s2 a r 0 1 3 x 0.5 1 2 4 y 1.5 2 3 5 z 2.5 """ if rewards and len(rewards) > 2: raise ValueError(f"rewards can have at most two columns, but received {rewards}") super().__init__(data=data, **kwargs) if isinstance(data, SarMixin) and [states, actions, rewards] == [None] * 3: states, actions, rewards = _sar_from(data) self._set_sar(states=states, actions=actions, rewards=rewards)
[docs] def trim(self, copy: bool = False) -> WiDataFrame: """Get the *sar* columns of this data frame. Raise an error when any of the expected *sar* column names is missing from this data frame. Args: copy: True to return a new copy of data frame, False to return a view to this data frame. Returns: Data frame with only the *sar* columns. If ``copy=False``, the returned data frame is a a view to this data frame, else a new data frame. """ view = self[self.states + self.actions + self.rewards] return view if not copy else view.copy()
@property def _constructor(self): return _get_constructor(WiDataFrame, self) @property def _constructor_sliced(self): return _get_constructor(WiSeries, self)
[docs] def to_csv_dataset( self, path_or_buf: str | PathLike[str], *args, forced_categories: None | Iterable[str] = None, compact: bool = False, **kwargs, ) -> None: """Save this data frame as a ``Whatif`` dataset. This method has similar signatures to :meth:`pandas.DataFrame.to_csv()`, however with some changes. Args: path_or_buf: Unlike :meth:`pandas.DataFrame.to_csv()`, this accepts only path name of the output dir. args: passed to :meth:`pandas.DataFrame.to_csv()`. kwargs: passed to :meth:`pandas.DataFrame.to_csv()`. See Also -------- read_csv_dataset Example: Save a ``WiDataFrame`` to directory ``/tmp/my-dataset``. .. code-block:: python >>> from a2rl import WiDataFrame >>> df = WiDataFrame( ... { ... "i": [3, 4, 5], ... "s": [1, 2, 3], ... "j": [4, 5, 6], ... "a": ["x", "y", "z"], ... "k": ["z", "x", "y"], ... "r": [0.5, 1.5, 2.5], ... }, ... states=["s"], ... actions=["a"], ... rewards=["r"], ... ) >>> df i s j a k r 0 3 1 4 x z 0.5 1 4 2 5 y x 1.5 2 5 3 6 z y 2.5 >>> df.to_csv_dataset("/tmp/my-dataset") """ if not (forced_categories is None or isinstance(forced_categories, list)): forced_categories = list(forced_categories) outdir = _pre_to_csv( self, path_or_buf, forced_categories=forced_categories, compact=compact, ) self.to_csv(outdir / "data.csv", *args, **kwargs)
@property def sequence(self) -> np.ndarray: """Return a 1D Numpy representation of the DataFrame, in row-major order. Returns: The sequence of the data frame. Example: .. code-block:: python >>> from a2rl import WiDataFrame >>> df = WiDataFrame( ... { ... "sess": [0, 0, 0], ... "s": [1, 2, 3], ... "a": ["x", "y", "z"], ... "r": [0.5, 1.5, 2.5] ... }, ... states=["s"], ... actions=["a"], ... rewards=["r"], ... ) >>> df sess s a r 0 0 1 x 0.5 1 0 2 y 1.5 2 0 3 z 2.5 >>> df.sequence array([0, 1, 'x', 0.5, 0, 2, 'y', 1.5, 0, 3, 'z', 2.5], dtype=object) """ return self.values.ravel() def _check_add_value_args( self, value_col: str, override: Literal["replace", "warn", "error"], alpha: float, gamma: float, ) -> None: if not getattr(self, "_rewards", None): raise ValueError(f"Unspecified reward column: {getattr(self, '_rewards', None)}") if value_col == self._rewards[0]: raise ValueError(f"value_col={value_col} conflicts with reward") if value_col in self.columns: if override == "error": raise ValueError(f"Column {value_col} already exists in this WiDataFrame") elif override == "warn": warnings.warn(f"Column {value_col} will be overwritten") elif override != "replace": raise ValueError(f"Unknown override: {override}") if not (0.0 <= alpha <= 1.0): raise ValueError(f"Learning rate alpha={alpha} not in 0 and 1.") if not (0.0 <= gamma <= 1.0): raise ValueError(f"Discount factor gamma={gamma} not in 0 and 1.")
[docs] def add_value( self: WiDataFrame, alpha: float = 0.1, gamma: float = 0.6, sarsa: bool = True, value_col: str = "value", override: Literal["replace", "warn", "error"] = "replace", ) -> WiDataFrame: """Append column ``value_col`` into this dataframe (restriction: ``df`` must NOT contain column names ``_state``, ``_action``, ``_reward``, and the ``value_col``). Args: alpha: Learning rate in `Q-Learning <https://en.wikipedia.org/wiki/Q-learning>`_ and `SARSA <https://en.wikipedia.org/wiki/State-action-reward-state-action>`_. Must be be within 0 and 1. gamma: Discount factor of future reward in Q-Learning and SARSA. Must be within 0 and 1. sarsa: When ``True``, compute the value using the `SARSA Bellman equation <https://en.wikipedia.org/wiki/State-action-reward-state-action>`_ which is a conservative on-policy temporal difference update. When ``False``, use the `Q-Learning Bellman equation <https://en.wikipedia.org/wiki/Q-learning>`_ which is an off-policy temporal difference update. value_col: The column name for the computed values. override: What to do when this dataframe has had column ``value_col``. Valid values are ``replace`` to silently override, ``warn`` to show a warning, and ``raise`` to raise a :exc:`ValueError`. Returns: This dataframe, modified with an additional ``value_col`` column. This return value is provided to facilitate chaining as-per the functional programming style. """ self._check_add_value_args(value_col, override, alpha, gamma) if len(self._rewards) == 1: df = self else: df = WiDataFrame( self, states=self.states, actions=self.actions, rewards=self.rewards[:1], ) df_t = wi.DiscreteTokenizer(n_bins=50).fit_transform(df.trim()) # Temp df with only three columns: _state, _action, _reward df_t = pd.concat( # type: ignore[assignment] [ df_t[df_t.states].astype(str).apply("_".join, axis=1).astype("category").cat.codes, df_t[df_t.actions].astype(str).apply("_".join, axis=1).astype("category").cat.codes, df[df.rewards].reset_index(drop=True), ], axis=1, copy=False, ) df_t.columns = ["_state", "_action", "_reward"] # type: ignore[assignment] q_table = np.zeros([df_t["_state"].nunique(), df_t["_action"].nunique()]) iterations = 10 for n in range(iterations): for i in range(0, len(df_t) - 1): state = int(df_t.loc[i, "_state"]) next_state = int(df_t.loc[i + 1, "_state"]) action = int(df_t.loc[i, "_action"]) reward = df_t.loc[i + 1, "_reward"] old_value = q_table[state, action] if sarsa: next_value = q_table[next_state, np.argmax(q_table[next_state])] new_value = old_value + alpha * (reward + gamma * next_value - old_value) else: next_max = np.max(q_table[next_state]) new_value = old_value + alpha * (reward + gamma * next_max - old_value) q_table[state, action] = new_value self[value_col] = pd.Series( q_table[df_t["_state"].astype(int), df_t["_action"].astype(int)], index=self.index, ) if len(self._rewards) < 2: self._rewards.append(value_col) else: self._rewards[1] = value_col return self
[docs] def add_value_for_multi_episode_process( self, sarsa: bool = True, alpha: float = 0.1, gamma: float = 0.6, value_col: str = "value", episode_identifier: str = "episode", override: Literal["replace", "warn", "error"] = "replace", ) -> WiDataFrame: """Append column ``value_col`` into this dataframe (restriction: ``df`` must NOT contain column names ``_state``, ``_action``, ``_reward``, and the ``value_col``). Args: sarsa: When ``True``, compute the value using the `SARSA Bellman equation <https://en.wikipedia.org/wiki/State-action-reward-state-action>`_ which is a conservative on-policy temporal difference update. When ``False``, use the `Q-Learning Bellman equation <https://en.wikipedia.org/wiki/Q-learning>`_ which is an off-policy temporal difference update. alpha: Learning rate in `Q-Learning <https://en.wikipedia.org/wiki/Q-learning>`_ and `SARSA <https://en.wikipedia.org/wiki/State-action-reward-state-action>`_. Must be be within 0 and 1. gamma: Discount factor of future reward in Q-Learning and SARSA. Must be within 0 and 1. value_col: The column name for the computed values. override: What to do when this dataframe has had column ``value_col``. Valid values are ``replace`` to silently override, ``warn`` to show a warning, and ``raise`` to raise a :exc:`ValueError`. episode_identifier: group-by key in the this dataframe. Ensure that breaks BETWEEN episodes are tagged with a ``0`` group name. Returns: This dataframe, modified with an additional ``value_col`` column. This return value is provided to facilitate chaining as-per the functional programming style. """ self._check_add_value_args(value_col, override, alpha, gamma) if len(self._rewards) == 1: df = self else: df = WiDataFrame( self, states=self.states, actions=self.actions, rewards=self.rewards[:1], ) df_t = wi.DiscreteTokenizer(n_bins=50).fit_transform(df.trim()) def calc_q_values(sub, df, x, alpha, gamma, sarsa=sarsa): # print(x.name) if sub.name == 0: return pd.Series(0, index=sub.index) df = df.loc[sub.index] x = x.loc[sub.index] # Temp df with only three columns: _state, _action, _reward df_t = pd.concat( # type: ignore[assignment] [ x[x.states].astype(str).apply("_".join, axis=1).astype("category").cat.codes, x[x.actions].astype(str).apply("_".join, axis=1).astype("category").cat.codes, df[df.rewards], # .reset_index(drop=False), ], axis=1, copy=False, ) df_t.columns = ["_state", "_action", "_reward"] # type: ignore[assignment] q_table = np.zeros([df_t["_state"].nunique(), df_t["_action"].nunique()]) iterations = 10 for n in range(iterations): for i in range(df_t.index[0], df_t.index[-1]): state = int(df_t.loc[i, "_state"]) next_state = int(df_t.loc[i + 1, "_state"]) action = int(df_t.loc[i, "_action"]) reward = df_t.loc[i + 1, "_reward"] old_value = q_table[state, action] if sarsa: next_value = q_table[next_state, np.argmax(q_table[next_state])] new_value = old_value + alpha * (reward + gamma * next_value - old_value) else: next_max = np.max(q_table[next_state]) new_value = old_value + alpha * (reward + gamma * next_max - old_value) q_table[state, action] = new_value return pd.Series( q_table[df_t["_state"].astype(int), df_t["_action"].astype(int)], index=df.index, ) x = df.groupby(episode_identifier).apply(lambda x: calc_q_values(x, df, df_t, alpha, gamma)) x2 = x.reset_index(level=0) x2.rename({0: value_col}, axis=1, inplace=True) x2.drop([episode_identifier], axis=1, inplace=True) x2 = x2.sort_index() self[value_col] = x2 if len(self._rewards) < 2: self._rewards.append(value_col) else: self._rewards[1] = value_col return self
[docs]class TransitionRecorder(gym.Wrapper[Any, np.ndarray]): """Record the transitions in the OpenAI gym :class:`gym.Env` into a Whatif data frame. Args: env: a gym environment. recording: When `True`, immediately start capturing steps. When `False`, callers need to call :meth:`~TransitionRecorder.start()` to start capturing steps. Examples -------- .. code-block:: python >>> import gym >>> import a2rl as wi >>> def do_steps(env): ... env.reset() ... for _ in range(5): ... env.step(0) >>> env = wi.TransitionRecorder(env=gym.make("Taxi-v3")) >>> do_steps(env) >>> env.df.info() # doctest: +NORMALIZE_WHITESPACE <class 'a2rl._dataframe.WiDataFrame'> Int64Index: 5 entries, 0 to 0 Data columns (total 3 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 0 5 non-null float64 1 1 5 non-null float64 2 2 5 non-null float64 dtypes: float64(3) memory usage: ... >>> env.stop() >>> do_steps(env) >>> env.df.info() # doctest: +NORMALIZE_WHITESPACE <class 'a2rl._dataframe.WiDataFrame'> Int64Index: 5 entries, 0 to 0 Data columns (total 3 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 0 5 non-null float64 1 1 5 non-null float64 2 2 5 non-null float64 dtypes: float64(3) memory usage: ... >>> env.start(); >>> do_steps(env) >>> env.df.info() # doctest: +NORMALIZE_WHITESPACE <class 'a2rl._dataframe.WiDataFrame'> Int64Index: 10 entries, 0 to 0 Data columns (total 3 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 0 10 non-null float64 1 1 10 non-null float64 2 2 10 non-null float64 dtypes: float64(3) memory usage: ... """ def __init__(self, env: gym.Env, recording: bool = True): super().__init__(env) self.env = env self.recording = recording self.episode = 0 state = env.observation_space.sample() state = np.array(state).ravel() state_length = state.size action = env.action_space.sample() action = np.array(action).ravel() action_length = action.size self.sar_d = { "states": np.arange(state.size), "actions": np.arange(action.size) + state.size, "rewards": [action.size + state.size], } self.df = WiDataFrame( pd.DataFrame(columns=np.arange(state_length + action_length + 1), dtype="float"), **self.sar_d, ) self._state: Any
[docs] def start(self) -> None: """Start capturing subsequent steps.""" self.recording = True
[docs] def stop(self) -> None: """Stop capturing steps.""" self.recording = False
[docs] def step(self, action: np.ndarray) -> tuple[Any, float, bool, dict]: """Wrapper to :func:`gym.Wrapper.step()` which records one timestep of the environment's dynamics. Args: action (object): an action provided by the agent """ # See: https://github.com/openai/gym/commit/907b1b20dd9ac0cba5803225059b9c6673702467 # - gym<0.24.0: step_results = (next_state, reward, done, info) # - gym>=0.25.0: step_results = (next_state, reward, done, _, info) step_results = self.env.step(action) next_state, reward, done = step_results[:3] info = step_results[-1] if self.recording: action = np.array(action).ravel() stacked = np.hstack( [ np.array(self._state).ravel(), np.array(action).ravel(), np.array(reward), ] ) self.df = pd.concat( # type: ignore[assignment] [ self.df, WiDataFrame(stacked.reshape(1, -1), columns=list(self.df), **self.sar_d), ] ) self._state = next_state return next_state, reward, done, info
[docs] def reset(self, **kwargs) -> tuple[gym.core.ObsType, dict] | gym.core.ObsType: """Wrapper to :func:`gym.Wrapper.reset()` which resets the environment to an initial state and returns an initial observation. Returns: observation: Observation of the initial state. info (optional dictionary): returned only when ``return_info=True``. """ self.episode += 1 reset_result = self.env.reset(**kwargs) if kwargs.get("return_info", False): obs = reset_result[0] else: obs = reset_result # type: ignore[assignment] self._state = obs return reset_result