Skip to content

Sagemaker

sagemaker

SageMakerEndpoint

SageMakerEndpoint(endpoint_name, model_id, generated_text_jmespath='generated_text', input_text_jmespath='inputs', token_count_jmespath='details.generated_tokens', region=None, boto3_session=None, **kwargs)

Bases: SageMakerBase

A class for handling invocations to a SageMaker endpoint.

This class extends SageMakerBase to provide functionality for invoking a SageMaker endpoint and parsing its response.

Source code in llmeter/endpoints/sagemaker.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __init__(
    self,
    endpoint_name: str,
    model_id: str,
    generated_text_jmespath: str = "generated_text",
    input_text_jmespath: str = "inputs",
    token_count_jmespath: str | None = "details.generated_tokens",
    region: str | None = None,
    boto3_session: boto3.Session | None = None,
    **kwargs,
):
    super().__init__(
        endpoint_name=endpoint_name, model_id=model_id, provider="sagemaker"
    )
    self.generated_text_jmespath = generated_text_jmespath
    self.input_text_jmespath = input_text_jmespath
    self.token_count_jmespath = token_count_jmespath
    self.kwargs = kwargs

    # Get the current AWS region if not provided
    _session = boto3_session or boto3.session.Session()
    self.region = region or _session.region_name
    logger.info(f"Using AWS region: {self.region}")

    self._sagemaker_runtime = _session.client(
        "sagemaker-runtime", region_name=self.region
    )

invoke

invoke(payload)

Invoke the SageMaker endpoint with the given payload.

This method sends a request to the SageMaker endpoint, processes the response, and returns an InvocationResponse object with the results.

Parameters:

Name Type Description Default
payload Dict

The input payload for the model.

required

Returns:

Name Type Description
InvocationResponse InvocationResponse

An object containing the model's response and associated metrics.

Raises:

Type Description
ClientError

If there's an error during the invocation of the SageMaker endpoint.

Exception

If there's any other error during the invocation or parsing of the response.

Source code in llmeter/endpoints/sagemaker.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def invoke(self, payload: dict) -> InvocationResponse:
    """
    Invoke the SageMaker endpoint with the given payload.

    This method sends a request to the SageMaker endpoint, processes the response,
    and returns an InvocationResponse object with the results.

    Args:
        payload (Dict): The input payload for the model.

    Returns:
        InvocationResponse: An object containing the model's response and associated metrics.

    Raises:
        ClientError: If there's an error during the invocation of the SageMaker endpoint.
        Exception: If there's any other error during the invocation or parsing of the response.
    """

    json_payload = json.dumps(payload)
    input_prompt = self._parse_input(payload)

    start_t = time.perf_counter()
    try:
        client_response = self._sagemaker_runtime.invoke_endpoint(
            EndpointName=self.endpoint_name,
            ContentType="application/json",
            Body=bytes(json_payload, "utf-8"),
        )
    except (ClientError, Exception) as e:
        logger.error(e)
        return InvocationResponse.error_output(
            input_payload=payload,
            id=uuid4().hex,
            error=str(e),
        )

    time_to_last_token = time.perf_counter() - start_t
    parsed_response = self._parse_client_response(client_response)
    if parsed_response:
        response_text = parsed_response.get("output_text", "")
        num_tokens_output = parsed_response.get("num_tokens_output", None)

    return InvocationResponse(
        input_payload=payload,
        id=uuid4().hex,
        response_text=response_text,
        time_to_last_token=time_to_last_token,
        input_prompt=input_prompt,
        num_tokens_output=num_tokens_output if num_tokens_output else None,
    )

SageMakerStreamEndpoint

SageMakerStreamEndpoint(endpoint_name, model_id, generated_text_jmespath='generated_text', input_text_jmespath='inputs', token_count_jmespath='details.generated_tokens', region=None, boto3_session=None, **kwargs)

Bases: SageMakerBase

A class for handling streaming invocations to a SageMaker endpoint.

This class extends SageMakerBase to provide functionality specific to streaming responses from a SageMaker endpoint.

Source code in llmeter/endpoints/sagemaker.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __init__(
    self,
    endpoint_name: str,
    model_id: str,
    generated_text_jmespath: str = "generated_text",
    input_text_jmespath: str = "inputs",
    token_count_jmespath: str | None = "details.generated_tokens",
    region: str | None = None,
    boto3_session: boto3.Session | None = None,
    **kwargs,
):
    super().__init__(
        endpoint_name=endpoint_name, model_id=model_id, provider="sagemaker"
    )
    self.generated_text_jmespath = generated_text_jmespath
    self.input_text_jmespath = input_text_jmespath
    self.token_count_jmespath = token_count_jmespath
    self.kwargs = kwargs

    # Get the current AWS region if not provided
    _session = boto3_session or boto3.session.Session()
    self.region = region or _session.region_name
    logger.info(f"Using AWS region: {self.region}")

    self._sagemaker_runtime = _session.client(
        "sagemaker-runtime", region_name=self.region
    )

invoke

invoke(payload)

Invoke a SageMaker endpoint with the given payload.

This method sends a request to the SageMaker endpoint and handles the streaming response.

Parameters:

Name Type Description Default
payload Dict

The input payload for the model.

required

Returns:

Name Type Description
InvocationResponse InvocationResponse

An object containing the model's response and metrics.

Raises:

Type Description
Exception

If there's an error during the invocation or parsing of the response.

Source code in llmeter/endpoints/sagemaker.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def invoke(self, payload: dict) -> InvocationResponse:
    """
    Invoke a SageMaker endpoint with the given payload.

    This method sends a request to the SageMaker endpoint and handles
    the streaming response.

    Args:
        payload (Dict): The input payload for the model.

    Returns:
        InvocationResponse: An object containing the model's response and metrics.

    Raises:
        Exception: If there's an error during the invocation or parsing of the response.
    """

    _payload = payload
    if "parameters" in _payload:
        _payload["parameters"].pop("decoder_input_details", None)
    if "stream" not in _payload:
        warnings.warn("stream not specified in payload, defaulting to True")
        _payload["stream"] = True

    json_payload = json.dumps(_payload)
    input_prompt = self._parse_input(_payload)

    start_t = time.perf_counter()
    try:
        client_response = (
            self._sagemaker_runtime.invoke_endpoint_with_response_stream(
                EndpointName=self.endpoint_name,
                Body=json_payload,
                ContentType="application/json",
            )
        )
    except Exception as e:
        logger.error(e)
        return InvocationResponse.error_output(input_payload=payload, error=str(e))

    try:
        response = self._parse_client_response(client_response, start_t)
        response.input_payload = payload
        response.input_prompt = input_prompt
        return response
    except Exception as e:
        return InvocationResponse.error_output(input_payload=payload, error=str(e))