esm icon indicating copy to clipboard operation
esm copied to clipboard

Unable to use ESM6B through Forge - RuntimeError: Could not infer dtype of dict

Open atarashansky opened this issue 2 months ago • 0 comments

When running the following snippet from the README to use ESMC 6B:

from esm.sdk.forge import ESM3ForgeInferenceClient
from esm.sdk.api import ESMProtein, LogitsConfig

# Apply for forge access and get an access token
forge_client = ESM3ForgeInferenceClient(model="esmc-6b-2024-12", url="https://forge.evolutionaryscale.ai", token="<your forge token>")
protein = ESMProtein(sequence="AAAAA")
protein_tensor = forge_client.encode(protein)
logits_output = forge_client.logits(
   protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
print(logits_output.logits, logits_output.embeddings)

I get the following runtime error:

RuntimeError                              Traceback (most recent call last)
Cell In[46], line 1
----> 1 logits_output = forge_client.logits(
      2    protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
      3 )

File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/esm/sdk/forge.py:194, in ESM3ForgeInferenceClient.retry_decorator.<locals>.wrapper(instance, *args, **kwargs)
    183 retry_decorator = retry(
    184     retry=retry_if_result(retry_if_specific_error),
    185     wait=wait_exponential(
   (...)
    191     before_sleep=log_retry_attempt,
    192 )
    193 # Apply the retry decorator to the function
--> 194 return retry_decorator(func)(instance, *args, **kwargs)

File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/tenacity/__init__.py:336, in BaseRetrying.wraps.<locals>.wrapped_f(*args, **kw)
    334 copy = self.copy()
    335 wrapped_f.statistics = copy.statistics  # type: ignore[attr-defined]
--> 336 return copy(f, *args, **kw)

File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/tenacity/__init__.py:475, in Retrying.__call__(self, fn, *args, **kwargs)
    473 retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
    474 while True:
--> 475     do = self.iter(retry_state=retry_state)
    476     if isinstance(do, DoAttempt):
    477         try:

File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/tenacity/__init__.py:376, in BaseRetrying.iter(self, retry_state)
    374 result = None
    375 for action in self.iter_state.actions:
--> 376     result = action(retry_state)
    377 return result

File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/tenacity/__init__.py:398, in BaseRetrying._post_retry_check_actions.<locals>.<lambda>(rs)
    396 def _post_retry_check_actions(self, retry_state: "RetryCallState") -> None:
    397     if not (self.iter_state.is_explicit_retry or self.iter_state.retry_run_result):
--> 398         self._add_action_func(lambda rs: rs.outcome.result())
    399         return
    401     if self.after is not None:

File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/concurrent/futures/_base.py:449, in Future.result(self, timeout)
    447     raise CancelledError()
    448 elif self._state == FINISHED:
--> 449     return self.__get_result()
    451 self._condition.wait(timeout)
    453 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:

File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/concurrent/futures/_base.py:401, in Future.__get_result(self)
    399 if self._exception:
    400     try:
--> 401         raise self._exception
    402     finally:
    403         # Break a reference cycle with the exception in self._exception
    404         self = None

File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/tenacity/__init__.py:478, in Retrying.__call__(self, fn, *args, **kwargs)
    476 if isinstance(do, DoAttempt):
    477     try:
--> 478         result = fn(*args, **kwargs)
    479     except BaseException:  # noqa: B902
    480         retry_state.set_exception(sys.exc_info())  # type: ignore[arg-type]

File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/esm/sdk/forge.py:546, in ESM3ForgeInferenceClient.logits(self, input, config)
    535         return maybe_tensor(data["logits"][track])
    536     return None
    538 output = LogitsOutput(
    539     logits=ForwardTrackData(
    540         sequence=_maybe_logits("sequence"),
    541         structure=_maybe_logits("structure"),
    542         secondary_structure=_maybe_logits("secondary_structure"),
    543         sasa=_maybe_logits("sasa"),
    544         function=_maybe_logits("function"),
    545     ),
--> 546     embeddings=maybe_tensor(data["embeddings"]),
    547     residue_annotation_logits=_maybe_logits("residue_annotation"),
    548 )
    550 return output

File /opt/jupyter-envs/generate-cross-species/atar-vscode/conda/lib/python3.11/site-packages/esm/utils/misc.py:263, in maybe_tensor(x, convert_none_to_nan)
    261     x = np.array(x, copy=False, dtype=np.float32)
    262     x = np.where(x is None, np.nan, x)
--> 263 return torch.tensor(x)

RuntimeError: Could not infer dtype of dict

atarashansky avatar Dec 06 '24 02:12 atarashansky