a2rl.Simulator.beam_search_n_steps#
- Simulator.beam_search_n_steps(seq, n_steps, beam_width, randomness=False, overwrite_valid_tokens=None, start_col_idx=None, is_gpt_token=False, return_logprobs=False)[source]#
This function largely replaces A2RL
Simulator.gpt_sample_n_steps()
. It does not concern states/actions/rewards and only generates the nextN
tokens using beam search. This function is to be used by a planner.- Parameters:
seq (
ndarray
) – A sequence of tokens (1-dimensional only)n_steps (
int
) – number of tokens to generatebeam_width (
int
) – number of beams used in beam search. Must be <= the vocab size in the starting column (determined by both valid tokens of that column &overwrite_valid_tokens
, if used). Setting this to 1 is equivalent to behaviour cloning.randomness (
bool
) – if True, will use multinomial sampling of the top-n tokens instead of deterministic beam search.overwrite_valid_tokens (
Optional
[dict
[str
,list
[int
]]]) –dict[ col_name : list of GPT tokens ]
, overwrite the valid tokens in a column, useful if additional constriants need to be applied during inference.start_col_index – Indicate the starting dataframe column index. Default to
len(seq) % len(columns)
if Noneis_gpt_token (
bool
) – whether the tokens inseq
are GPT tokens or DataFrame tokensreturn_logprobs (
bool
) – if True, the return will be a tuple of tokens and the accumulated logprobs of each beam.