Skip to content

Sagemaker

sagemaker

SageMakerEndpoint

SageMakerEndpoint(endpoint_name, model_id, generated_text_jmespath='generated_text', input_text_jmespath='inputs', token_count_jmespath='details.generated_tokens', region=None, boto3_session=None, **kwargs)

Bases: SageMakerBase[InvokeEndpointOutputTypeDef]

A class for handling invocations to a SageMaker endpoint.

This class extends SageMakerBase to provide functionality for invoking a SageMaker endpoint and parsing its response.

Source code in llmeter/endpoints/sagemaker.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self,
    endpoint_name: str,
    model_id: str,
    # TODO: generated & token count jmespaths not actually used by streaming yet
    generated_text_jmespath: str = "generated_text",
    input_text_jmespath: str = "inputs",
    token_count_jmespath: str | None = "details.generated_tokens",
    region: str | None = None,
    boto3_session: boto3.Session | None = None,
    **kwargs,
):
    super().__init__(
        endpoint_name=endpoint_name, model_id=model_id, provider="sagemaker"
    )
    self.generated_text_jmespath = generated_text_jmespath
    self.input_text_jmespath = input_text_jmespath
    self.token_count_jmespath = token_count_jmespath
    self.kwargs = kwargs

    # Get the current AWS region if not provided
    _session = boto3_session or boto3.session.Session()
    self.region = region or _session.region_name
    logger.info(f"Using AWS region: {self.region}")

    self._sagemaker_runtime = _session.client(
        "sagemaker-runtime", region_name=self.region
    )

create_payload staticmethod

create_payload(input_text, max_tokens=256, inference_parameters={}, **kwargs)

Create a payload for the SageMaker API request.

Parameters:

Name Type Description Default
input_text str | list[ContentItem]

A single text string, or an ordered list mixing strings and :class:~llmeter.prompt_utils.MediaContent objects.

required
max_tokens int

Maximum tokens to generate. Defaults to 256.

256
inference_parameters dict

Additional inference parameters.

{}
**kwargs

Additional keyword arguments to include in the payload.

{}

Returns:

Name Type Description
dict

The formatted payload for the SageMaker API request.

Source code in llmeter/endpoints/sagemaker.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
@staticmethod
def create_payload(
    input_text: str | list[ContentItem],
    max_tokens: int = 256,
    inference_parameters: dict = {},
    **kwargs,
):
    """Create a payload for the SageMaker API request.

    Args:
        input_text: A single text string, or an ordered list mixing strings
            and :class:`~llmeter.prompt_utils.MediaContent` objects.
        max_tokens: Maximum tokens to generate. Defaults to 256.
        inference_parameters: Additional inference parameters.
        **kwargs: Additional keyword arguments to include in the payload.

    Returns:
        dict: The formatted payload for the SageMaker API request.
    """
    if not isinstance(max_tokens, int) or max_tokens <= 0:
        raise ValueError("max_tokens must be a positive integer")

    if isinstance(input_text, str):
        items: list[ContentItem] = [input_text]
    elif isinstance(input_text, list):
        items = input_text
    else:
        raise TypeError(
            "input_text must be a str or list of str/MediaContent, "
            f"got {type(input_text).__name__}"
        )

    if not items:
        raise ValueError("input_text must not be empty")

    # Text-only shortcut
    if len(items) == 1 and isinstance(items[0], str):
        payload = {
            "inputs": items[0],
            "parameters": {"max_new_tokens": max_tokens, "details": True},
        }
        if inference_parameters:
            payload["parameters"].update(inference_parameters)
        payload.update(kwargs)
        return payload

    content_blocks = _build_content_blocks_sagemaker(items)
    payload = {
        "inputs": content_blocks,
        "parameters": {"max_new_tokens": max_tokens, "details": True},
    }
    if inference_parameters:
        payload["parameters"].update(inference_parameters)
    payload.update(kwargs)
    return payload

invoke

invoke(payload)

Invoke the SageMaker endpoint with the given payload.

Source code in llmeter/endpoints/sagemaker.py
195
196
197
198
199
200
201
202
203
204
205
@SageMakerBase.llmeter_invoke
def invoke(self, payload: dict) -> InvokeEndpointOutputTypeDef:
    """Invoke the SageMaker endpoint with the given payload."""
    json_payload = json.dumps(payload)

    client_response = self._sagemaker_runtime.invoke_endpoint(
        EndpointName=self.endpoint_name,
        ContentType="application/json",
        Body=bytes(json_payload, "utf-8"),
    )
    return client_response

SageMakerStreamEndpoint

SageMakerStreamEndpoint(endpoint_name, model_id, generated_text_jmespath='generated_text', input_text_jmespath='inputs', token_count_jmespath='details.generated_tokens', region=None, boto3_session=None, **kwargs)

Bases: SageMakerBase[InvokeEndpointWithResponseStreamOutputTypeDef]

A class for handling streaming invocations to a SageMaker endpoint.

This class extends SageMakerBase to provide functionality specific to streaming responses from a SageMaker endpoint.

Source code in llmeter/endpoints/sagemaker.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self,
    endpoint_name: str,
    model_id: str,
    # TODO: generated & token count jmespaths not actually used by streaming yet
    generated_text_jmespath: str = "generated_text",
    input_text_jmespath: str = "inputs",
    token_count_jmespath: str | None = "details.generated_tokens",
    region: str | None = None,
    boto3_session: boto3.Session | None = None,
    **kwargs,
):
    super().__init__(
        endpoint_name=endpoint_name, model_id=model_id, provider="sagemaker"
    )
    self.generated_text_jmespath = generated_text_jmespath
    self.input_text_jmespath = input_text_jmespath
    self.token_count_jmespath = token_count_jmespath
    self.kwargs = kwargs

    # Get the current AWS region if not provided
    _session = boto3_session or boto3.session.Session()
    self.region = region or _session.region_name
    logger.info(f"Using AWS region: {self.region}")

    self._sagemaker_runtime = _session.client(
        "sagemaker-runtime", region_name=self.region
    )

create_payload staticmethod

create_payload(input_text, max_tokens=256, inference_parameters={}, **kwargs)

Create a payload for the SageMaker streaming API request.

Parameters:

Name Type Description Default
input_text str | list[ContentItem]

A single text string, or an ordered list mixing strings and :class:~llmeter.prompt_utils.MediaContent objects.

required
max_tokens int

Maximum tokens to generate. Defaults to 256.

256
inference_parameters dict

Additional inference parameters.

{}
**kwargs

Additional keyword arguments to include in the payload.

{}

Returns:

Name Type Description
dict

The formatted payload for the SageMaker streaming API request.

Source code in llmeter/endpoints/sagemaker.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
@staticmethod
def create_payload(
    input_text: str | list[ContentItem],
    max_tokens: int = 256,
    inference_parameters: dict = {},
    **kwargs,
):
    """Create a payload for the SageMaker streaming API request.

    Args:
        input_text: A single text string, or an ordered list mixing strings
            and :class:`~llmeter.prompt_utils.MediaContent` objects.
        max_tokens: Maximum tokens to generate. Defaults to 256.
        inference_parameters: Additional inference parameters.
        **kwargs: Additional keyword arguments to include in the payload.

    Returns:
        dict: The formatted payload for the SageMaker streaming API request.
    """
    if not isinstance(max_tokens, int) or max_tokens <= 0:
        raise ValueError("max_tokens must be a positive integer")

    if isinstance(input_text, str):
        items: list[ContentItem] = [input_text]
    elif isinstance(input_text, list):
        items = input_text
    else:
        raise TypeError(
            "input_text must be a str or list of str/MediaContent, "
            f"got {type(input_text).__name__}"
        )

    if not items:
        raise ValueError("input_text must not be empty")

    # Text-only shortcut
    if len(items) == 1 and isinstance(items[0], str):
        payload = {
            "inputs": items[0],
            "parameters": {"max_new_tokens": max_tokens, "details": True},
            "stream": True,
        }
        if inference_parameters:
            payload["parameters"].update(inference_parameters)
        payload.update(kwargs)
        return payload

    content_blocks = _build_content_blocks_sagemaker(items)
    payload = {
        "inputs": content_blocks,
        "parameters": {"max_new_tokens": max_tokens, "details": True},
        "stream": True,
    }
    if inference_parameters:
        payload["parameters"].update(inference_parameters)
    payload.update(kwargs)
    return payload

invoke

invoke(payload)

Invoke a SageMaker streaming endpoint with the given payload.

Source code in llmeter/endpoints/sagemaker.py
244
245
246
247
248
249
250
251
252
253
254
@SageMakerBase.llmeter_invoke
def invoke(self, payload: dict) -> InvokeEndpointWithResponseStreamOutputTypeDef:
    """Invoke a SageMaker streaming endpoint with the given payload."""
    json_payload = json.dumps(payload)

    client_response = self._sagemaker_runtime.invoke_endpoint_with_response_stream(
        EndpointName=self.endpoint_name,
        Body=json_payload,
        ContentType="application/json",
    )
    return client_response