jupyter-ai icon indicating copy to clipboard operation
jupyter-ai copied to clipboard

Support SageMaker endpoints that communicate in plain text, not JSON

Open JasonWeill opened this issue 2 years ago • 0 comments
trafficstars

Problem

When I send a prompt to a SageMaker endpoint jumpstart-dft-hf-translation-t5-small, which expects plain text, not JSON, as input, I get an error (see below). This is because Jupyter AI is attempting to interpret the input as JSON.

In addition, the mandatory --response-path argument presumes that a SageMaker endpoint will always return text in JSON format.

Proposed Solution

Support models that accept input in non-JSON types, including plain text (and potentially other formats, like audio and video). If SageMaker does not expose a model's input or output type using an API, let users set the expectation using a parameter in the magic command.

If the expected output type is not JSON, do not accept any value for the --response-path argument.

Additional context

The magic command below should translate the text from English to German:

%%ai sagemaker-endpoint:jumpstart-dft-hf-translation-t5-small --region-name=us-east-1 --request-schema=<prompt> --response-path=.[0]
My dog, it has three corners. Three corners has my dog. If my dog did not have three corners, it would not be my dog.

It produces an error:

---------------------------------------------------------------------------
JSONDecodeError                           Traceback (most recent call last)
Cell In[6], line 1
----> 1 get_ipython().run_cell_magic('ai', 'sagemaker-endpoint:jumpstart-dft-hf-translation-t5-small --region-name=us-east-1 --request-schema=<prompt> --response-path=.[0]', 'My dog, it has three corners. Three corners has my dog. If my dog did not have three corners, it would not be my dog.\n')

File /opt/miniconda3/envs/jupyter-ai/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2478, in InteractiveShell.run_cell_magic(self, magic_name, line, cell)
   2476 with self.builtin_trap:
   2477     args = (magic_arg_s, cell)
-> 2478     result = fn(*args, **kwargs)
   2480 # The code below prevents the output from being displayed
   2481 # when using magics with decodator @output_can_be_silenced
   2482 # when the last Python token in the expression is a ';'.
   2483 if getattr(fn, magic.MAGIC_OUTPUT_CAN_BE_SILENCED, False):

File ~/git/jupyter-ai/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py:565, in AiMagics.ai(self, line, cell)
    562 ip = get_ipython()
    563 prompt = prompt.format_map(FormatDict(ip.user_ns))
--> 565 return self.run_ai_cell(args, prompt)

File ~/git/jupyter-ai/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py:499, in AiMagics.run_ai_cell(self, args, prompt)
    496     provider_params["request_schema"] = args.request_schema
    497     provider_params["response_path"] = args.response_path
--> 499 provider = Provider(**provider_params)
    501 # generate output from model via provider
    502 result = provider.generate([prompt])

File ~/git/jupyter-ai/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py:373, in SmEndpointProvider.__init__(self, *args, **kwargs)
    371 request_schema = kwargs.pop('request_schema')
    372 response_path = kwargs.pop('response_path')
--> 373 content_handler = JsonContentHandler(request_schema=request_schema, response_path=response_path)
    374 super().__init__(*args, **kwargs, content_handler=content_handler)

File ~/git/jupyter-ai/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py:322, in JsonContentHandler.__init__(self, request_schema, response_path)
    321 def __init__(self, request_schema, response_path):
--> 322     self.request_schema = json.loads(request_schema)
    323     self.response_path = response_path
    324     self.response_parser = parse(response_path)

File /opt/miniconda3/envs/jupyter-ai/lib/python3.10/json/__init__.py:346, in loads(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)
    341     s = s.decode(detect_encoding(s), 'surrogatepass')
    343 if (cls is None and object_hook is None and
    344         parse_int is None and parse_float is None and
    345         parse_constant is None and object_pairs_hook is None and not kw):
--> 346     return _default_decoder.decode(s)
    347 if cls is None:
    348     cls = JSONDecoder

File /opt/miniconda3/envs/jupyter-ai/lib/python3.10/json/decoder.py:337, in JSONDecoder.decode(self, s, _w)
    332 def decode(self, s, _w=WHITESPACE.match):
    333     """Return the Python representation of ``s`` (a ``str`` instance
    334     containing a JSON document).
    335 
    336     """
--> 337     obj, end = self.raw_decode(s, idx=_w(s, 0).end())
    338     end = _w(s, end).end()
    339     if end != len(s):

File /opt/miniconda3/envs/jupyter-ai/lib/python3.10/json/decoder.py:355, in JSONDecoder.raw_decode(self, s, idx)
    353     obj, end = self.scan_once(s, idx)
    354 except StopIteration as err:
--> 355     raise JSONDecodeError("Expecting value", s, err.value) from None
    356 return obj, end

JSONDecodeError: Expecting value: line 1 column 1 (char 0)

JasonWeill avatar Jun 22 '23 22:06 JasonWeill