esm
esm copied to clipboard
Unable to use ESM6B through Forge - RuntimeError: Could not infer dtype of dict
When running the following snippet from the README to use ESMC 6B:
from esm.sdk.forge import ESM3ForgeInferenceClient
from esm.sdk.api import ESMProtein, LogitsConfig
# Apply for forge access and get an access token
forge_client = ESM3ForgeInferenceClient(model="esmc-6b-2024-12", url="https://forge.evolutionaryscale.ai", token="<your forge token>")
protein = ESMProtein(sequence="AAAAA")
protein_tensor = forge_client.encode(protein)
logits_output = forge_client.logits(
protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
print(logits_output.logits, logits_output.embeddings)
I get the following runtime error:
RuntimeError Traceback (most recent call last)
Cell In[46], line 1
----> 1 logits_output = forge_client.logits(
2 protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
3 )
File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/esm/sdk/forge.py:194, in ESM3ForgeInferenceClient.retry_decorator.<locals>.wrapper(instance, *args, **kwargs)
183 retry_decorator = retry(
184 retry=retry_if_result(retry_if_specific_error),
185 wait=wait_exponential(
(...)
191 before_sleep=log_retry_attempt,
192 )
193 # Apply the retry decorator to the function
--> 194 return retry_decorator(func)(instance, *args, **kwargs)
File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/tenacity/__init__.py:336, in BaseRetrying.wraps.<locals>.wrapped_f(*args, **kw)
334 copy = self.copy()
335 wrapped_f.statistics = copy.statistics # type: ignore[attr-defined]
--> 336 return copy(f, *args, **kw)
File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/tenacity/__init__.py:475, in Retrying.__call__(self, fn, *args, **kwargs)
473 retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
474 while True:
--> 475 do = self.iter(retry_state=retry_state)
476 if isinstance(do, DoAttempt):
477 try:
File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/tenacity/__init__.py:376, in BaseRetrying.iter(self, retry_state)
374 result = None
375 for action in self.iter_state.actions:
--> 376 result = action(retry_state)
377 return result
File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/tenacity/__init__.py:398, in BaseRetrying._post_retry_check_actions.<locals>.<lambda>(rs)
396 def _post_retry_check_actions(self, retry_state: "RetryCallState") -> None:
397 if not (self.iter_state.is_explicit_retry or self.iter_state.retry_run_result):
--> 398 self._add_action_func(lambda rs: rs.outcome.result())
399 return
401 if self.after is not None:
File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/concurrent/futures/_base.py:449, in Future.result(self, timeout)
447 raise CancelledError()
448 elif self._state == FINISHED:
--> 449 return self.__get_result()
451 self._condition.wait(timeout)
453 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/concurrent/futures/_base.py:401, in Future.__get_result(self)
399 if self._exception:
400 try:
--> 401 raise self._exception
402 finally:
403 # Break a reference cycle with the exception in self._exception
404 self = None
File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/tenacity/__init__.py:478, in Retrying.__call__(self, fn, *args, **kwargs)
476 if isinstance(do, DoAttempt):
477 try:
--> 478 result = fn(*args, **kwargs)
479 except BaseException: # noqa: B902
480 retry_state.set_exception(sys.exc_info()) # type: ignore[arg-type]
File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/esm/sdk/forge.py:546, in ESM3ForgeInferenceClient.logits(self, input, config)
535 return maybe_tensor(data["logits"][track])
536 return None
538 output = LogitsOutput(
539 logits=ForwardTrackData(
540 sequence=_maybe_logits("sequence"),
541 structure=_maybe_logits("structure"),
542 secondary_structure=_maybe_logits("secondary_structure"),
543 sasa=_maybe_logits("sasa"),
544 function=_maybe_logits("function"),
545 ),
--> 546 embeddings=maybe_tensor(data["embeddings"]),
547 residue_annotation_logits=_maybe_logits("residue_annotation"),
548 )
550 return output
File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/esm/utils/misc.py:263, in maybe_tensor(x, convert_none_to_nan)
261 x = np.array(x, copy=False, dtype=np.float32)
262 x = np.where(x is None, np.nan, x)
--> 263 return torch.tensor(x)
RuntimeError: Could not infer dtype of dict