a2rl.Simulator.beam_search_n_steps#

Simulator.beam_search_n_steps(seq, n_steps, beam_width, randomness=False, overwrite_valid_tokens=None, start_col_idx=None, is_gpt_token=False, return_logprobs=False)[source]#

This function largely replaces A2RL Simulator.gpt_sample_n_steps(). It does not concern states/actions/rewards and only generates the next N tokens using beam search. This function is to be used by a planner.

Parameters:
  • seq (ndarray) – A sequence of tokens (1-dimensional only)

  • n_steps (int) – number of tokens to generate

  • beam_width (int) – number of beams used in beam search. Must be <= the vocab size in the starting column (determined by both valid tokens of that column & overwrite_valid_tokens, if used). Setting this to 1 is equivalent to behaviour cloning.

  • randomness (bool) – if True, will use multinomial sampling of the top-n tokens instead of deterministic beam search.

  • overwrite_valid_tokens (Optional[dict[str, list[int]]]) – dict[ col_name : list of GPT tokens ], overwrite the valid tokens in a column, useful if additional constriants need to be applied during inference.

  • start_col_index – Indicate the starting dataframe column index. Default to len(seq) % len(columns) if None

  • is_gpt_token (bool) – whether the tokens in seq are GPT tokens or DataFrame tokens

  • return_logprobs (bool) – if True, the return will be a tuple of tokens and the accumulated logprobs of each beam.