ai
ai copied to clipboard
useCompletions: expose reasoning to client-side
Description
Using onChunk in streamText I can see reasoning parts streaming in on server, but useCompletions does not expose those so we can render them in the UI.
Perhaps I've missed it, but this currently seems like it's only possible in useChat even though I'd rather avoid having a chat-based pattern for a one-shot generation.
AI SDK Version
v5.0.8
For others, I just patched this in a simple, additive way for now using pnpm patch
Example:
diff --git a/dist/index.d.mts b/dist/index.d.mts
index a6a0798eb77eb46880f42164afc87f5f28c25ff1..211e16dc4987707b59bd02b1c66860fef4afcfc7 100644
--- a/dist/index.d.mts
+++ b/dist/index.d.mts
@@ -43,6 +43,10 @@ declare function useChat<UI_MESSAGE extends UIMessage = UIMessage>({ experimenta
type UseCompletionHelpers = {
/** The current completion result */
completion: string;
+ /** The current reasoning result */
+ reasoning: string;
+ /** Whether reasoning is currently being generated */
+ isReasoning: boolean;
/**
* Send a new prompt to the API endpoint and update the completion state.
*/
diff --git a/dist/index.d.ts b/dist/index.d.ts
index a6a0798eb77eb46880f42164afc87f5f28c25ff1..211e16dc4987707b59bd02b1c66860fef4afcfc7 100644
--- a/dist/index.d.ts
+++ b/dist/index.d.ts
@@ -43,6 +43,10 @@ declare function useChat<UI_MESSAGE extends UIMessage = UIMessage>({ experimenta
type UseCompletionHelpers = {
/** The current completion result */
completion: string;
+ /** The current reasoning result */
+ reasoning: string;
+ /** Whether reasoning is currently being generated */
+ isReasoning: boolean;
/**
* Send a new prompt to the API endpoint and update the completion state.
*/
diff --git a/dist/index.js b/dist/index.js
index 918815dab92fac1ae229ed411f5afc06a3102afe..deaa0a1f2b6d728b24e7e87a0c0ef1b468df7b54 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -261,12 +261,23 @@ function useCompletion({
const { data, mutate } = (0, import_swr.default)([api, completionId], null, {
fallbackData: initialCompletion
});
+ const { data: reasoningData, mutate: mutateReasoning } = (0, import_swr.default)(
+ [api, completionId, "reasoning"],
+ null,
+ { fallbackData: "" }
+ );
+ const { data: isReasoningData = false, mutate: mutateIsReasoning } = (0, import_swr.default)(
+ [completionId, "isReasoning"],
+ null
+ );
const { data: isLoading = false, mutate: mutateLoading } = (0, import_swr.default)(
[completionId, "loading"],
null
);
const [error, setError] = (0, import_react2.useState)(void 0);
const completion = data;
+ const reasoning = reasoningData;
+ const isReasoning = isReasoningData;
const [abortController, setAbortController] = (0, import_react2.useState)(null);
const extraMetadataRef = (0, import_react2.useRef)({
credentials,
@@ -297,6 +308,11 @@ function useCompletion({
(completion2) => mutate(completion2, false),
throttleWaitMs
),
+ setReasoning: throttle(
+ (reasoning2) => mutateReasoning(reasoning2, false),
+ throttleWaitMs
+ ),
+ setIsReasoning: mutateIsReasoning,
setLoading: mutateLoading,
setError,
setAbortController,
@@ -305,6 +321,8 @@ function useCompletion({
}),
[
mutate,
+ mutateReasoning,
+ mutateIsReasoning,
mutateLoading,
api,
extraMetadataRef,
@@ -352,6 +370,8 @@ function useCompletion({
);
return {
completion,
+ reasoning,
+ isReasoning,
complete,
error,
setCompletion,
diff --git a/dist/index.mjs b/dist/index.mjs
index 56fb62b49f2f29459f2592022885274e81d33cbd..4183d42c8beba01cb04718073efc2d2d7ab8edca 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -225,12 +225,23 @@ function useCompletion({
const { data, mutate } = useSWR([api, completionId], null, {
fallbackData: initialCompletion
});
+ const { data: reasoningData, mutate: mutateReasoning } = useSWR(
+ [api, completionId, "reasoning"],
+ null,
+ { fallbackData: "" }
+ );
+ const { data: isReasoningData = false, mutate: mutateIsReasoning } = useSWR(
+ [completionId, "isReasoning"],
+ null
+ );
const { data: isLoading = false, mutate: mutateLoading } = useSWR(
[completionId, "loading"],
null
);
const [error, setError] = useState(void 0);
const completion = data;
+ const reasoning = reasoningData;
+ const isReasoning = isReasoningData;
const [abortController, setAbortController] = useState(null);
const extraMetadataRef = useRef2({
credentials,
@@ -261,6 +272,11 @@ function useCompletion({
(completion2) => mutate(completion2, false),
throttleWaitMs
),
+ setReasoning: throttle(
+ (reasoning2) => mutateReasoning(reasoning2, false),
+ throttleWaitMs
+ ),
+ setIsReasoning: mutateIsReasoning,
setLoading: mutateLoading,
setError,
setAbortController,
@@ -269,6 +285,8 @@ function useCompletion({
}),
[
mutate,
+ mutateReasoning,
+ mutateIsReasoning,
mutateLoading,
api,
extraMetadataRef,
@@ -316,6 +334,8 @@ function useCompletion({
);
return {
completion,
+ reasoning,
+ isReasoning,
complete,
error,
setCompletion,
diff --git a/dist/index.d.mts b/dist/index.d.mts
index 9d2ebc6d03c742aa975b853e8003cde61fc6b8f5..a234f388c04c9dec47db02317093a66efee12bcd 100644
--- a/dist/index.d.mts
+++ b/dist/index.d.mts
@@ -3870,7 +3870,7 @@ declare function transcribe({ model, audio, providerOptions, maxRetries: maxRetr
}): Promise<TranscriptionResult>;
declare const getOriginalFetch: () => typeof fetch;
-declare function callCompletionApi({ api, prompt, credentials, headers, body, streamProtocol, setCompletion, setLoading, setError, setAbortController, onFinish, onError, fetch, }: {
+declare function callCompletionApi({ api, prompt, credentials, headers, body, streamProtocol, setCompletion, setLoading, setError, setAbortController, onFinish, onError, fetch, setReasoning, setIsReasoning, }: {
api: string;
prompt: string;
credentials: RequestCredentials | undefined;
@@ -3884,6 +3884,8 @@ declare function callCompletionApi({ api, prompt, credentials, headers, body, st
onFinish: ((prompt: string, completion: string) => void) | undefined;
onError: ((error: Error) => void) | undefined;
fetch: ReturnType<typeof getOriginalFetch> | undefined;
+ setReasoning?: (reasoning: string) => void;
+ setIsReasoning?: (isReasoning: boolean) => void;
}): Promise<string | null | undefined>;
interface UIMessageStreamWriter<UI_MESSAGE extends UIMessage = UIMessage> {
diff --git a/dist/index.d.ts b/dist/index.d.ts
index 9d2ebc6d03c742aa975b853e8003cde61fc6b8f5..a234f388c04c9dec47db02317093a66efee12bcd 100644
--- a/dist/index.d.ts
+++ b/dist/index.d.ts
@@ -3870,7 +3870,7 @@ declare function transcribe({ model, audio, providerOptions, maxRetries: maxRetr
}): Promise<TranscriptionResult>;
declare const getOriginalFetch: () => typeof fetch;
-declare function callCompletionApi({ api, prompt, credentials, headers, body, streamProtocol, setCompletion, setLoading, setError, setAbortController, onFinish, onError, fetch, }: {
+declare function callCompletionApi({ api, prompt, credentials, headers, body, streamProtocol, setCompletion, setLoading, setError, setAbortController, onFinish, onError, fetch, setReasoning, setIsReasoning, }: {
api: string;
prompt: string;
credentials: RequestCredentials | undefined;
@@ -3884,6 +3884,8 @@ declare function callCompletionApi({ api, prompt, credentials, headers, body, st
onFinish: ((prompt: string, completion: string) => void) | undefined;
onError: ((error: Error) => void) | undefined;
fetch: ReturnType<typeof getOriginalFetch> | undefined;
+ setReasoning?: (reasoning: string) => void;
+ setIsReasoning?: (isReasoning: boolean) => void;
}): Promise<string | null | undefined>;
interface UIMessageStreamWriter<UI_MESSAGE extends UIMessage = UIMessage> {
diff --git a/dist/index.js b/dist/index.js
index 9a3668f7437d25c8f216f027c0713ad305943f11..95b5b75763e85bf54adad06e86b71dc933d4ea2d 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -8882,7 +8882,9 @@ async function callCompletionApi({
setAbortController,
onFinish,
onError,
- fetch: fetch2 = getOriginalFetch()
+ fetch: fetch2 = getOriginalFetch(),
+ setReasoning,
+ setIsReasoning
}) {
var _a16;
try {
@@ -8891,6 +8893,8 @@ async function callCompletionApi({
const abortController = new AbortController();
setAbortController(abortController);
setCompletion("");
+ setReasoning?.("");
+ setIsReasoning?.(false);
const response = await fetch2(api, {
method: "POST",
body: JSON.stringify({
@@ -8919,6 +8923,7 @@ async function callCompletionApi({
throw new Error("The response body is empty.");
}
let result = "";
+ let reasoning = "";
switch (streamProtocol) {
case "text": {
await processTextStream({
@@ -8945,6 +8950,13 @@ async function callCompletionApi({
if (streamPart.type === "text-delta") {
result += streamPart.delta;
setCompletion(result);
+ } else if (streamPart.type === "reasoning-start") {
+ setIsReasoning?.(true);
+ } else if (streamPart.type === "reasoning-delta") {
+ reasoning += streamPart.delta;
+ setReasoning?.(reasoning);
+ } else if (streamPart.type === "reasoning-end") {
+ setIsReasoning?.(false);
} else if (streamPart.type === "error") {
throw new Error(streamPart.errorText);
}
@@ -8965,6 +8977,7 @@ async function callCompletionApi({
if (onFinish) {
onFinish(prompt, result);
}
+ setIsReasoning?.(false);
setAbortController(null);
return result;
} catch (err) {
@@ -8977,6 +8990,7 @@ async function callCompletionApi({
onError(err);
}
}
+ setIsReasoning?.(false);
setError(err);
} finally {
setLoading(false);
diff --git a/dist/index.mjs b/dist/index.mjs
index bd1771ed44abe4045ffbff8f583bf3cfc8f1c444..9d2712aa26d3be0f204f5603d43e06a8177f9a8e 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -8840,7 +8840,9 @@ async function callCompletionApi({
setAbortController,
onFinish,
onError,
- fetch: fetch2 = getOriginalFetch()
+ fetch: fetch2 = getOriginalFetch(),
+ setReasoning,
+ setIsReasoning
}) {
var _a16;
try {
@@ -8849,6 +8851,8 @@ async function callCompletionApi({
const abortController = new AbortController();
setAbortController(abortController);
setCompletion("");
+ setReasoning?.("");
+ setIsReasoning?.(false);
const response = await fetch2(api, {
method: "POST",
body: JSON.stringify({
@@ -8877,6 +8881,7 @@ async function callCompletionApi({
throw new Error("The response body is empty.");
}
let result = "";
+ let reasoning = "";
switch (streamProtocol) {
case "text": {
await processTextStream({
@@ -8903,6 +8908,13 @@ async function callCompletionApi({
if (streamPart.type === "text-delta") {
result += streamPart.delta;
setCompletion(result);
+ } else if (streamPart.type === "reasoning-start") {
+ setIsReasoning?.(true);
+ } else if (streamPart.type === "reasoning-delta") {
+ reasoning += streamPart.delta;
+ setReasoning?.(reasoning);
+ } else if (streamPart.type === "reasoning-end") {
+ setIsReasoning?.(false);
} else if (streamPart.type === "error") {
throw new Error(streamPart.errorText);
}
@@ -8923,6 +8935,7 @@ async function callCompletionApi({
if (onFinish) {
onFinish(prompt, result);
}
+ setIsReasoning?.(false);
setAbortController(null);
return result;
} catch (err) {
@@ -8935,6 +8948,7 @@ async function callCompletionApi({
onError(err);
}
}
+ setIsReasoning?.(false);
setError(err);
} finally {
setLoading(false);