# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import annotations
import warnings
from os import PathLike
from pathlib import Path
from typing import Any, Callable, Collection, Iterable, Literal, Protocol
import gym
import numpy as np
import pandas as pd
import a2rl as wi
from ._io import Metadata, save_metadata
# TODO
def _get_constructor(klass, example: HasSarAttributes, *args, **kwargs) -> Callable:
def _constructor(*args, **kwargs):
return klass(
*args,
states=example.states,
actions=example.actions,
rewards=example.rewards,
**kwargs,
)
return _constructor
# TODO
def _sar_from(o: HasSarAttributes):
return o.states, o.actions, o.rewards
def _pre_to_csv(
o: WiDataFrame | WiSeries,
path_or_buf: str | PathLike[str],
forced_categories: None | list[str] = None,
compact: bool = False,
) -> Path:
"""Create output directory and metadata.
Args:
path_or_buf: The path name of the output dir.
Returns:
Path: path to output directory.
"""
p = path_or_buf if isinstance(path_or_buf, Path) else Path(path_or_buf)
if p.exists():
warnings.warn(f"Existing directory {p} will be overwritten.")
p.mkdir(parents=True, exist_ok=True)
save_metadata(
Metadata(
states=o.states,
actions=o.actions,
rewards=o.rewards,
forced_categories=forced_categories,
),
p / "metadata.yaml",
compact=compact,
)
return p
class HasSarAttributes(Protocol):
@property
def states(self) -> list[str]:
raise NotImplementedError
@property
def actions(self) -> list[str]:
raise NotImplementedError
@property
def rewards(self) -> list[str]:
raise NotImplementedError
class SarMixin:
# Declarations for mypy
_states: list[str] #: Expected state colum names (list[str])
_actions: list[str] #: Expected action column names (list[str])
_rewards: list[str] #: Expected rewards column names (list[str])
@property
def sar_d(self) -> dict[str, list[str]]:
"""The dictionary of 585 expected *sar* column names.
Returns:
``{'states': [str], 'actions': [str], 'rewards': [str]}``.
See Also
--------
sar : The list of the expected *sar* columns.
states : The expected state columns.
actions : The expected action columns.
rewards : The expected reward columns.
"""
return dict(
states=list(self._states),
actions=list(self._actions),
rewards=list(self._rewards),
)
@property
def sar(self) -> list[str]:
"""The list of the expected *sar* column names.
Returns:
The expected *sar* column names, in the order of states, actions, and reward.
See Also
--------
sar_d : The dictionary of the expected *sar* columns.
states : The expected state columns.
actions : The expected action columns.
rewards : The expected reward columns.
"""
return self.states + self.actions + self.rewards
def _set_sar(self, **kwargs) -> None:
"""Set ``self._{k} = {v}`` for each ``k:v`` in ``kwargs``."""
for k, v in kwargs.items():
if v is not None:
setattr(self, f"_{k}", list(v))
@property
def states(self) -> list[str]:
"""The list of the expected column names of states.
Returns:
The expected column names for states.
See Also
--------
sar : The list of the expected *sar* columns.
sar_d : The dictionary of the expected *sar* columns.
actions : The expected action columns.
rewards : The expected reward columns.
"""
return list(self._states)
@property
def actions(self) -> list[str]:
"""The list of the expected column names of actions.
Returns:
The expected column names for actions.
See Also
--------
sar : The list of the expected *sar* columns.
sar_d : The dictionary of the expected *sar* columns.
states : The expected state columns.
rewards : The expected reward columns.
"""
return list(self._actions)
@property
def rewards(self) -> list[str]:
"""The list of the expected column names of rewards.
Returns:
The expected column names for rewards.
See Also
--------
sar : The list of the expected *sar* columns.
sar_d : The dictionary of the expected *sar* columns.
states : The expected state columns.
actions : The expected action columns.
"""
return list(self._rewards)
[docs]class WiSeries(pd.Series, SarMixin): # type: ignore[misc]
_metadata = ["_states", "_actions", "_rewards"]
def __init__(
self,
data=None,
states: None | Collection[str] = None,
actions: None | Collection[str] = None,
rewards: None | Collection[str] = None,
**kwargs,
) -> None:
"""A ``WiSeries`` object is a :class:`pandas.Series` with additional metadata on the
expected column names for ``states``, ``actions``, and ``rewards`` (i.e., the *sar*
columns).
In addition to the standard :class:`pandas.Series` constructor arguments, a ``WiSeries``
also accepts the following keyword arguments:
Args:
states: The expected column names for states.
actions: The expected column names for actions.
rewards: The expected column names for rewards.
.. warning::
This class is mainly used internally by ``whatif``.
By design, the name of a ``WiSeries`` is equal to zero or one expected *sar* column
name.
See Also
--------
WiDataFrame
pandas.Series
Examples:
Create a new ``WiSeries``:
.. code-block:: python
>>> import a2rl as wi
>>> ser = wi.WiSeries(
... [11, 12, 13],
... name="s0",
... states=["s1", "s2"],
... actions=["a"],
... rewards=["r"],
... )
>>> ser
0 11
1 12
2 13
Name: s0, dtype: int64
>>> ser.sar
['s1', 's2', 'a', 'r']
Inherit *sar* columns from the source ``WiDataFrame``:
.. code-block:: python
>>> df = wi.WiDataFrame(
... {
... "s": [0, 1, 2],
... "a": ["x", "y", "z"],
... "r": [0.5, 1.5, 2.5],
... },
... states=["s"],
... actions=["a"],
... rewards=["r"],
... )
>>> ser = df['a']
>>> ser.sar
['s', 'a', 'r']
"""
if rewards and len(rewards) > 2:
raise ValueError(f"rewards can have at most two columns, but received {rewards}")
super().__init__(data=data, **kwargs)
if isinstance(data, SarMixin) and [states, actions, rewards] == [None] * 3:
states, actions, rewards = _sar_from(data)
self._set_sar(states=states, actions=actions, rewards=rewards)
@property
def _constructor(self):
return _get_constructor(WiSeries, self)
@property
def _constructor_expanddim(self):
_c_e = _get_constructor(WiDataFrame, self)
# See:
# https://github.com/geopandas/geopandas/blob/51864acf3dd0bcbc74b2a922c6e012d7e57e46b5/geopandas/geoseries.py#L66-L69
#
# pd.concat (pandas/core/reshape/concat.py) requires this for the
# concatenation of series since pandas 1.1
# (https://github.com/pandas-dev/pandas/commit/f9e4c8c84bcef987973f2624cc2932394c171c8c)
#
# E.g., required by df.groupby().agg({'a': 'min', 'b': 'max'})
_c_e._get_axis_number = WiDataFrame._get_axis_number
return _c_e
@property
def _values(self) -> np.ndarray:
# https://github.com/pandas-dev/pandas/issues/46554#issuecomment-1084305476
return super()._values
[docs] def to_csv_dataset(
self,
path_or_buf: str | PathLike[str],
*args,
forced_categories: None | Iterable[str] = None,
compact: bool = False,
**kwargs,
) -> None:
"""Save this series as a ``Whatif`` dataset.
This method has similar signatures to :meth:`pandas.Series.to_csv()`, however with some
changes.
Args:
path_or_buf: Unlike :meth:`pandas.Series.to_csv()`, this accepts only path name of
the output dir.
args: passed to :meth:`pandas.Series.to_csv()`.
compact: When True, do not write ``None`` entries to the output metadata YAML.
kwargs: passed to :meth:`pandas.Series.to_csv()`.
"""
if not (forced_categories is None or isinstance(forced_categories, list)):
forced_categories = list(forced_categories)
outdir = _pre_to_csv(
self,
path_or_buf,
forced_categories=forced_categories,
compact=compact,
)
self.to_csv(outdir / "data.csv", *args, **kwargs)
[docs]class WiDataFrame(pd.DataFrame, SarMixin):
_metadata = ["_states", "_actions", "_rewards"]
def __init__(
self,
data=None,
states: None | Collection[str] = None,
actions: None | Collection[str] = None,
rewards: None | Collection[str] = None,
**kwargs,
) -> None:
"""A ``WiDataFrame`` object is a :class:`pandas.DataFrame` with additional metadata on the
expected column names for ``states``, ``actions``, and ``rewards`` (i.e., the *sar*
columns).
In addition to the standard :class:`pandas.DataFrame` constructor arguments, a
``WiDataFrame`` also accepts the following keyword arguments:
Args:
states: The expected column names for states.
actions: The expected column names for actions.
rewards: The expected column names for rewards.
.. note::
By design, a ``WiDataFrame`` itself may miss one or more of the *sar* columns.
Downstream tasks should deal with missing *sar* columns.
Some downstream tasks such as slicing ignores the discrepancy, while RL-related tasks
may require all *sar* columns presented.
See Also
--------
WiSeries
pandas.DataFrame
Examples
--------
Create a new ``WiDataFrame``:
.. code-block:: python
>>> import a2rl as wi
>>> df = wi.WiDataFrame(
... {
... "s1": [1, 2, 3],
... "s2": [3, 4, 5],
... "sess": [0, 0, 0],
... "z": [6, 7, 8],
... "a": ["x", "y", "z"],
... "r": [0.5, 1.5, 2.5],
... },
... states=["s1", "s2"],
... actions=["a"],
... rewards=["r"],
... )
>>> df
s1 s2 sess z a r
0 1 3 0 6 x 0.5
1 2 4 0 7 y 1.5
2 3 5 0 8 z 2.5
Check the metadata:
.. code-block:: python
>>> df.sar
['s1', 's2', 'a', 'r']
>>> df.sar_d
{'states': ['s1', 's2'], 'actions': ['a'], 'rewards': ['r']}
>>> df.states
['s1', 's2']
>>> df.actions
['a']
>>> df.rewards
['r']
Slice the states. The resulted ``WiDataFrame`` or ``WiSeries`` inherits the expected
*sar* columns from the source ``DataFrame``.
.. code-block:: python
>>> df[df.states]
s1 s2
0 1 3
1 2 4
2 3 5
>>> df[df.states].sar
['s1', 's2', 'a', 'r']
Take just the *sar* columns:
.. code-block:: python
>>> df.trim()
s1 s2 a r
0 1 3 x 0.5
1 2 4 y 1.5
2 3 5 z 2.5
"""
if rewards and len(rewards) > 2:
raise ValueError(f"rewards can have at most two columns, but received {rewards}")
super().__init__(data=data, **kwargs)
if isinstance(data, SarMixin) and [states, actions, rewards] == [None] * 3:
states, actions, rewards = _sar_from(data)
self._set_sar(states=states, actions=actions, rewards=rewards)
[docs] def trim(self, copy: bool = False) -> WiDataFrame:
"""Get the *sar* columns of this data frame.
Raise an error when any of the expected *sar* column names is missing from this data frame.
Args:
copy: True to return a new copy of data frame, False to return a view to this data
frame.
Returns:
Data frame with only the *sar* columns. If ``copy=False``, the returned data frame is a
a view to this data frame, else a new data frame.
"""
view = self[self.states + self.actions + self.rewards]
return view if not copy else view.copy()
@property
def _constructor(self):
return _get_constructor(WiDataFrame, self)
@property
def _constructor_sliced(self):
return _get_constructor(WiSeries, self)
[docs] def to_csv_dataset(
self,
path_or_buf: str | PathLike[str],
*args,
forced_categories: None | Iterable[str] = None,
compact: bool = False,
**kwargs,
) -> None:
"""Save this data frame as a ``Whatif`` dataset.
This method has similar signatures to :meth:`pandas.DataFrame.to_csv()`, however with some
changes.
Args:
path_or_buf: Unlike :meth:`pandas.DataFrame.to_csv()`, this accepts only path name of
the output dir.
args: passed to :meth:`pandas.DataFrame.to_csv()`.
kwargs: passed to :meth:`pandas.DataFrame.to_csv()`.
See Also
--------
read_csv_dataset
Example:
Save a ``WiDataFrame`` to directory ``/tmp/my-dataset``.
.. code-block:: python
>>> from a2rl import WiDataFrame
>>> df = WiDataFrame(
... {
... "i": [3, 4, 5],
... "s": [1, 2, 3],
... "j": [4, 5, 6],
... "a": ["x", "y", "z"],
... "k": ["z", "x", "y"],
... "r": [0.5, 1.5, 2.5],
... },
... states=["s"],
... actions=["a"],
... rewards=["r"],
... )
>>> df
i s j a k r
0 3 1 4 x z 0.5
1 4 2 5 y x 1.5
2 5 3 6 z y 2.5
>>> df.to_csv_dataset("/tmp/my-dataset")
"""
if not (forced_categories is None or isinstance(forced_categories, list)):
forced_categories = list(forced_categories)
outdir = _pre_to_csv(
self,
path_or_buf,
forced_categories=forced_categories,
compact=compact,
)
self.to_csv(outdir / "data.csv", *args, **kwargs)
@property
def sequence(self) -> np.ndarray:
"""Return a 1D Numpy representation of the DataFrame, in row-major order.
Returns:
The sequence of the data frame.
Example:
.. code-block:: python
>>> from a2rl import WiDataFrame
>>> df = WiDataFrame(
... {
... "sess": [0, 0, 0],
... "s": [1, 2, 3],
... "a": ["x", "y", "z"],
... "r": [0.5, 1.5, 2.5]
... },
... states=["s"],
... actions=["a"],
... rewards=["r"],
... )
>>> df
sess s a r
0 0 1 x 0.5
1 0 2 y 1.5
2 0 3 z 2.5
>>> df.sequence
array([0, 1, 'x', 0.5, 0, 2, 'y', 1.5, 0, 3, 'z', 2.5], dtype=object)
"""
return self.values.ravel()
def _check_add_value_args(
self,
value_col: str,
override: Literal["replace", "warn", "error"],
alpha: float,
gamma: float,
) -> None:
if not getattr(self, "_rewards", None):
raise ValueError(f"Unspecified reward column: {getattr(self, '_rewards', None)}")
if value_col == self._rewards[0]:
raise ValueError(f"value_col={value_col} conflicts with reward")
if value_col in self.columns:
if override == "error":
raise ValueError(f"Column {value_col} already exists in this WiDataFrame")
elif override == "warn":
warnings.warn(f"Column {value_col} will be overwritten")
elif override != "replace":
raise ValueError(f"Unknown override: {override}")
if not (0.0 <= alpha <= 1.0):
raise ValueError(f"Learning rate alpha={alpha} not in 0 and 1.")
if not (0.0 <= gamma <= 1.0):
raise ValueError(f"Discount factor gamma={gamma} not in 0 and 1.")
[docs] def add_value(
self: WiDataFrame,
alpha: float = 0.1,
gamma: float = 0.6,
sarsa: bool = True,
value_col: str = "value",
override: Literal["replace", "warn", "error"] = "replace",
) -> WiDataFrame:
"""Append column ``value_col`` into this dataframe (restriction: ``df`` must NOT contain
column names ``_state``, ``_action``, ``_reward``, and the ``value_col``).
Args:
alpha: Learning rate in `Q-Learning <https://en.wikipedia.org/wiki/Q-learning>`_ and
`SARSA <https://en.wikipedia.org/wiki/State-action-reward-state-action>`_. Must be
be within 0 and 1.
gamma: Discount factor of future reward in Q-Learning and SARSA. Must be within 0 and 1.
sarsa: When ``True``, compute the value using the `SARSA Bellman equation
<https://en.wikipedia.org/wiki/State-action-reward-state-action>`_ which is a
conservative on-policy temporal difference update. When ``False``, use the
`Q-Learning Bellman equation <https://en.wikipedia.org/wiki/Q-learning>`_ which is
an off-policy temporal difference update.
value_col: The column name for the computed values.
override: What to do when this dataframe has had column ``value_col``. Valid values
are ``replace`` to silently override, ``warn`` to show a warning, and ``raise`` to
raise a :exc:`ValueError`.
Returns:
This dataframe, modified with an additional ``value_col`` column. This return value is
provided to facilitate chaining as-per the functional programming style.
"""
self._check_add_value_args(value_col, override, alpha, gamma)
if len(self._rewards) == 1:
df = self
else:
df = WiDataFrame(
self,
states=self.states,
actions=self.actions,
rewards=self.rewards[:1],
)
df_t = wi.DiscreteTokenizer(n_bins=50).fit_transform(df.trim())
# Temp df with only three columns: _state, _action, _reward
df_t = pd.concat( # type: ignore[assignment]
[
df_t[df_t.states].astype(str).apply("_".join, axis=1).astype("category").cat.codes,
df_t[df_t.actions].astype(str).apply("_".join, axis=1).astype("category").cat.codes,
df[df.rewards].reset_index(drop=True),
],
axis=1,
copy=False,
)
df_t.columns = ["_state", "_action", "_reward"] # type: ignore[assignment]
q_table = np.zeros([df_t["_state"].nunique(), df_t["_action"].nunique()])
iterations = 10
for n in range(iterations):
for i in range(0, len(df_t) - 1):
state = int(df_t.loc[i, "_state"])
next_state = int(df_t.loc[i + 1, "_state"])
action = int(df_t.loc[i, "_action"])
reward = df_t.loc[i + 1, "_reward"]
old_value = q_table[state, action]
if sarsa:
next_value = q_table[next_state, np.argmax(q_table[next_state])]
new_value = old_value + alpha * (reward + gamma * next_value - old_value)
else:
next_max = np.max(q_table[next_state])
new_value = old_value + alpha * (reward + gamma * next_max - old_value)
q_table[state, action] = new_value
self[value_col] = pd.Series(
q_table[df_t["_state"].astype(int), df_t["_action"].astype(int)],
index=self.index,
)
if len(self._rewards) < 2:
self._rewards.append(value_col)
else:
self._rewards[1] = value_col
return self
[docs] def add_value_for_multi_episode_process(
self,
sarsa: bool = True,
alpha: float = 0.1,
gamma: float = 0.6,
value_col: str = "value",
episode_identifier: str = "episode",
override: Literal["replace", "warn", "error"] = "replace",
) -> WiDataFrame:
"""Append column ``value_col`` into this dataframe (restriction: ``df`` must NOT contain
column names ``_state``, ``_action``, ``_reward``, and the ``value_col``).
Args:
sarsa: When ``True``, compute the value using the `SARSA Bellman equation
<https://en.wikipedia.org/wiki/State-action-reward-state-action>`_ which is a
conservative on-policy temporal difference update. When ``False``, use the
`Q-Learning Bellman equation <https://en.wikipedia.org/wiki/Q-learning>`_ which is
an off-policy temporal difference update.
alpha: Learning rate in `Q-Learning <https://en.wikipedia.org/wiki/Q-learning>`_ and
`SARSA <https://en.wikipedia.org/wiki/State-action-reward-state-action>`_. Must be
be within 0 and 1.
gamma: Discount factor of future reward in Q-Learning and SARSA. Must be within 0 and 1.
value_col: The column name for the computed values.
override: What to do when this dataframe has had column ``value_col``. Valid values
are ``replace`` to silently override, ``warn`` to show a warning, and ``raise`` to
raise a :exc:`ValueError`.
episode_identifier: group-by key in the this dataframe. Ensure that breaks BETWEEN
episodes are tagged with a ``0`` group name.
Returns:
This dataframe, modified with an additional ``value_col`` column. This return value is
provided to facilitate chaining as-per the functional programming style.
"""
self._check_add_value_args(value_col, override, alpha, gamma)
if len(self._rewards) == 1:
df = self
else:
df = WiDataFrame(
self,
states=self.states,
actions=self.actions,
rewards=self.rewards[:1],
)
df_t = wi.DiscreteTokenizer(n_bins=50).fit_transform(df.trim())
def calc_q_values(sub, df, x, alpha, gamma, sarsa=sarsa):
# print(x.name)
if sub.name == 0:
return pd.Series(0, index=sub.index)
df = df.loc[sub.index]
x = x.loc[sub.index]
# Temp df with only three columns: _state, _action, _reward
df_t = pd.concat( # type: ignore[assignment]
[
x[x.states].astype(str).apply("_".join, axis=1).astype("category").cat.codes,
x[x.actions].astype(str).apply("_".join, axis=1).astype("category").cat.codes,
df[df.rewards], # .reset_index(drop=False),
],
axis=1,
copy=False,
)
df_t.columns = ["_state", "_action", "_reward"] # type: ignore[assignment]
q_table = np.zeros([df_t["_state"].nunique(), df_t["_action"].nunique()])
iterations = 10
for n in range(iterations):
for i in range(df_t.index[0], df_t.index[-1]):
state = int(df_t.loc[i, "_state"])
next_state = int(df_t.loc[i + 1, "_state"])
action = int(df_t.loc[i, "_action"])
reward = df_t.loc[i + 1, "_reward"]
old_value = q_table[state, action]
if sarsa:
next_value = q_table[next_state, np.argmax(q_table[next_state])]
new_value = old_value + alpha * (reward + gamma * next_value - old_value)
else:
next_max = np.max(q_table[next_state])
new_value = old_value + alpha * (reward + gamma * next_max - old_value)
q_table[state, action] = new_value
return pd.Series(
q_table[df_t["_state"].astype(int), df_t["_action"].astype(int)],
index=df.index,
)
x = df.groupby(episode_identifier).apply(lambda x: calc_q_values(x, df, df_t, alpha, gamma))
x2 = x.reset_index(level=0)
x2.rename({0: value_col}, axis=1, inplace=True)
x2.drop([episode_identifier], axis=1, inplace=True)
x2 = x2.sort_index()
self[value_col] = x2
if len(self._rewards) < 2:
self._rewards.append(value_col)
else:
self._rewards[1] = value_col
return self
[docs]class TransitionRecorder(gym.Wrapper[Any, np.ndarray]):
"""Record the transitions in the OpenAI gym :class:`gym.Env` into a Whatif data frame.
Args:
env: a gym environment.
recording: When `True`, immediately start capturing steps. When `False`, callers need to
call :meth:`~TransitionRecorder.start()` to start capturing steps.
Examples
--------
.. code-block:: python
>>> import gym
>>> import a2rl as wi
>>> def do_steps(env):
... env.reset()
... for _ in range(5):
... env.step(0)
>>> env = wi.TransitionRecorder(env=gym.make("Taxi-v3"))
>>> do_steps(env)
>>> env.df.info() # doctest: +NORMALIZE_WHITESPACE
<class 'a2rl._dataframe.WiDataFrame'>
Int64Index: 5 entries, 0 to 0
Data columns (total 3 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 0 5 non-null float64
1 1 5 non-null float64
2 2 5 non-null float64
dtypes: float64(3)
memory usage: ...
>>> env.stop()
>>> do_steps(env)
>>> env.df.info() # doctest: +NORMALIZE_WHITESPACE
<class 'a2rl._dataframe.WiDataFrame'>
Int64Index: 5 entries, 0 to 0
Data columns (total 3 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 0 5 non-null float64
1 1 5 non-null float64
2 2 5 non-null float64
dtypes: float64(3)
memory usage: ...
>>> env.start();
>>> do_steps(env)
>>> env.df.info() # doctest: +NORMALIZE_WHITESPACE
<class 'a2rl._dataframe.WiDataFrame'>
Int64Index: 10 entries, 0 to 0
Data columns (total 3 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 0 10 non-null float64
1 1 10 non-null float64
2 2 10 non-null float64
dtypes: float64(3)
memory usage: ...
"""
def __init__(self, env: gym.Env, recording: bool = True):
super().__init__(env)
self.env = env
self.recording = recording
self.episode = 0
state = env.observation_space.sample()
state = np.array(state).ravel()
state_length = state.size
action = env.action_space.sample()
action = np.array(action).ravel()
action_length = action.size
self.sar_d = {
"states": np.arange(state.size),
"actions": np.arange(action.size) + state.size,
"rewards": [action.size + state.size],
}
self.df = WiDataFrame(
pd.DataFrame(columns=np.arange(state_length + action_length + 1), dtype="float"),
**self.sar_d,
)
self._state: Any
[docs] def start(self) -> None:
"""Start capturing subsequent steps."""
self.recording = True
[docs] def stop(self) -> None:
"""Stop capturing steps."""
self.recording = False
[docs] def step(self, action: np.ndarray) -> tuple[Any, float, bool, dict]:
"""Wrapper to :func:`gym.Wrapper.step()` which records one timestep of the environment's
dynamics.
Args:
action (object): an action provided by the agent
"""
# See: https://github.com/openai/gym/commit/907b1b20dd9ac0cba5803225059b9c6673702467
# - gym<0.24.0: step_results = (next_state, reward, done, info)
# - gym>=0.25.0: step_results = (next_state, reward, done, _, info)
step_results = self.env.step(action)
next_state, reward, done = step_results[:3]
info = step_results[-1]
if self.recording:
action = np.array(action).ravel()
stacked = np.hstack(
[
np.array(self._state).ravel(),
np.array(action).ravel(),
np.array(reward),
]
)
self.df = pd.concat( # type: ignore[assignment]
[
self.df,
WiDataFrame(stacked.reshape(1, -1), columns=list(self.df), **self.sar_d),
]
)
self._state = next_state
return next_state, reward, done, info
[docs] def reset(self, **kwargs) -> tuple[gym.core.ObsType, dict] | gym.core.ObsType:
"""Wrapper to :func:`gym.Wrapper.reset()` which resets the environment to an initial state
and returns an initial observation.
Returns:
observation: Observation of the initial state.
info (optional dictionary): returned only when ``return_info=True``.
"""
self.episode += 1
reset_result = self.env.reset(**kwargs)
if kwargs.get("return_info", False):
obs = reset_result[0]
else:
obs = reset_result # type: ignore[assignment]
self._state = obs
return reset_result