jupyter-ai
jupyter-ai copied to clipboard
Support SageMaker endpoints that communicate in plain text, not JSON
Problem
When I send a prompt to a SageMaker endpoint jumpstart-dft-hf-translation-t5-small, which expects plain text, not JSON, as input, I get an error (see below). This is because Jupyter AI is attempting to interpret the input as JSON.
In addition, the mandatory --response-path argument presumes that a SageMaker endpoint will always return text in JSON format.
Proposed Solution
Support models that accept input in non-JSON types, including plain text (and potentially other formats, like audio and video). If SageMaker does not expose a model's input or output type using an API, let users set the expectation using a parameter in the magic command.
If the expected output type is not JSON, do not accept any value for the --response-path argument.
Additional context
The magic command below should translate the text from English to German:
%%ai sagemaker-endpoint:jumpstart-dft-hf-translation-t5-small --region-name=us-east-1 --request-schema=<prompt> --response-path=.[0]
My dog, it has three corners. Three corners has my dog. If my dog did not have three corners, it would not be my dog.
It produces an error:
---------------------------------------------------------------------------
JSONDecodeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 get_ipython().run_cell_magic('ai', 'sagemaker-endpoint:jumpstart-dft-hf-translation-t5-small --region-name=us-east-1 --request-schema=<prompt> --response-path=.[0]', 'My dog, it has three corners. Three corners has my dog. If my dog did not have three corners, it would not be my dog.\n')
File /opt/miniconda3/envs/jupyter-ai/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2478, in InteractiveShell.run_cell_magic(self, magic_name, line, cell)
2476 with self.builtin_trap:
2477 args = (magic_arg_s, cell)
-> 2478 result = fn(*args, **kwargs)
2480 # The code below prevents the output from being displayed
2481 # when using magics with decodator @output_can_be_silenced
2482 # when the last Python token in the expression is a ';'.
2483 if getattr(fn, magic.MAGIC_OUTPUT_CAN_BE_SILENCED, False):
File ~/git/jupyter-ai/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py:565, in AiMagics.ai(self, line, cell)
562 ip = get_ipython()
563 prompt = prompt.format_map(FormatDict(ip.user_ns))
--> 565 return self.run_ai_cell(args, prompt)
File ~/git/jupyter-ai/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py:499, in AiMagics.run_ai_cell(self, args, prompt)
496 provider_params["request_schema"] = args.request_schema
497 provider_params["response_path"] = args.response_path
--> 499 provider = Provider(**provider_params)
501 # generate output from model via provider
502 result = provider.generate([prompt])
File ~/git/jupyter-ai/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py:373, in SmEndpointProvider.__init__(self, *args, **kwargs)
371 request_schema = kwargs.pop('request_schema')
372 response_path = kwargs.pop('response_path')
--> 373 content_handler = JsonContentHandler(request_schema=request_schema, response_path=response_path)
374 super().__init__(*args, **kwargs, content_handler=content_handler)
File ~/git/jupyter-ai/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py:322, in JsonContentHandler.__init__(self, request_schema, response_path)
321 def __init__(self, request_schema, response_path):
--> 322 self.request_schema = json.loads(request_schema)
323 self.response_path = response_path
324 self.response_parser = parse(response_path)
File /opt/miniconda3/envs/jupyter-ai/lib/python3.10/json/__init__.py:346, in loads(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)
341 s = s.decode(detect_encoding(s), 'surrogatepass')
343 if (cls is None and object_hook is None and
344 parse_int is None and parse_float is None and
345 parse_constant is None and object_pairs_hook is None and not kw):
--> 346 return _default_decoder.decode(s)
347 if cls is None:
348 cls = JSONDecoder
File /opt/miniconda3/envs/jupyter-ai/lib/python3.10/json/decoder.py:337, in JSONDecoder.decode(self, s, _w)
332 def decode(self, s, _w=WHITESPACE.match):
333 """Return the Python representation of ``s`` (a ``str`` instance
334 containing a JSON document).
335
336 """
--> 337 obj, end = self.raw_decode(s, idx=_w(s, 0).end())
338 end = _w(s, end).end()
339 if end != len(s):
File /opt/miniconda3/envs/jupyter-ai/lib/python3.10/json/decoder.py:355, in JSONDecoder.raw_decode(self, s, idx)
353 obj, end = self.scan_once(s, idx)
354 except StopIteration as err:
--> 355 raise JSONDecodeError("Expecting value", s, err.value) from None
356 return obj, end
JSONDecodeError: Expecting value: line 1 column 1 (char 0)