deerflow2/frontend/src/core/threads/hooks.ts

785 lines
23 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import type { AIMessage, Message } from "@langchain/langgraph-sdk";
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
import { useStream, type UseStream } from "@langchain/langgraph-sdk/react";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { useCallback, useEffect, useRef, useState } from "react";
import { toast } from "sonner";
import type { PromptInputMessage } from "@/components/ai-elements/prompt-input";
import { getAPIClient } from "../api";
import { getBackendBaseURL } from "../config";
import { useI18n } from "../i18n/hooks";
import type { FileInMessage } from "../messages/utils";
import type { LocalSettings } from "../settings";
import { useUpdateSubtask } from "../tasks/context";
import type { UploadedFileInfo } from "../uploads";
import { uploadFiles } from "../uploads";
import type { UploadTarget } from "../uploads/api";
import type {
AgentThread,
AgentThreadContext,
AgentThreadState,
} from "./types";
export type ToolEndEvent = {
name: string;
data: unknown;
};
export type ThreadStreamOptions = {
threadId: string | null | undefined;
context: LocalSettings["context"];
createNewSession?: boolean;
isMock?: boolean;
onStart?: (threadId: string) => void;
onFinish?: (state: AgentThreadState) => void;
onToolEnd?: (event: ToolEndEvent) => void;
};
export type LegacyThreadStreamOptions = {
isNewThread: boolean;
threadId: string | null | undefined;
fetchStateHistory?: boolean;
onFinish?: (state: AgentThreadState) => void;
useSubmitThread?: boolean;
};
function getStreamErrorMessage(error: unknown): string {
if (typeof error === "string" && error.trim()) {
return error;
}
if (error instanceof Error && error.message.trim()) {
return error.message;
}
if (typeof error === "object" && error !== null) {
const message = Reflect.get(error, "message");
if (typeof message === "string" && message.trim()) {
return message;
}
const nestedError = Reflect.get(error, "error");
if (nestedError instanceof Error && nestedError.message.trim()) {
return nestedError.message;
}
if (typeof nestedError === "string" && nestedError.trim()) {
return nestedError;
}
}
return "Request failed.";
}
export function useThreadStreamLegacy({
threadId,
isNewThread,
fetchStateHistory = true,
onFinish,
}: LegacyThreadStreamOptions): UseStream<AgentThreadState> {
const queryClient = useQueryClient();
const updateSubtask = useUpdateSubtask();
const thread = useStream<AgentThreadState>({
client: getAPIClient(),
assistantId: "lead_agent",
threadId: isNewThread ? undefined : threadId,
reconnectOnMount: true,
fetchStateHistory,
onCustomEvent(event: unknown) {
console.info(event);
if (
typeof event === "object" &&
event !== null &&
"type" in event &&
event.type === "task_running"
) {
const e = event as {
type: "task_running";
task_id: string;
message: AIMessage;
};
updateSubtask({ id: e.task_id, latestMessage: e.message });
}
},
onFinish(state) {
onFinish?.(state.values);
queryClient.setQueriesData(
{
queryKey: ["threads", "search"],
exact: false,
},
(oldData: Array<AgentThread>) => {
return oldData.map((t) => {
if (t.thread_id === threadId) {
return {
...t,
values: {
...t.values,
title: state.values.title,
},
};
}
return t;
});
},
);
},
});
return thread as UseStream<AgentThreadState>;
}
export function useThreadStream({
threadId,
context,
createNewSession = false,
isMock,
onStart,
onFinish,
onToolEnd,
}: ThreadStreamOptions) {
const { t } = useI18n();
// Track the thread ID that is currently streaming to handle thread changes during streaming
const [onStreamThreadId, setOnStreamThreadId] = useState(() => threadId);
// Ref to track current thread ID across async callbacks without causing re-renders,
// and to allow access to the current thread id in onUpdateEvent
const threadIdRef = useRef<string | null>(threadId ?? null);
const startedRef = useRef(false);
const listeners = useRef({
onStart,
onFinish,
onToolEnd,
});
// Keep listeners ref updated with latest callbacks
useEffect(() => {
listeners.current = { onStart, onFinish, onToolEnd };
}, [onStart, onFinish, onToolEnd]);
useEffect(() => {
const normalizedThreadId = threadId ?? null;
if (!normalizedThreadId) {
// Just reset for new thread creation when threadId becomes null/undefined
startedRef.current = false;
setOnStreamThreadId(normalizedThreadId);
}
threadIdRef.current = normalizedThreadId;
}, [threadId]);
const _handleOnStart = useCallback((id: string) => {
if (!startedRef.current) {
listeners.current.onStart?.(id);
startedRef.current = true;
}
}, []);
const handleStreamStart = useCallback(
(_threadId: string) => {
threadIdRef.current = _threadId;
_handleOnStart(_threadId);
},
[_handleOnStart],
);
const queryClient = useQueryClient();
const updateSubtask = useUpdateSubtask();
const apiClient = getAPIClient(isMock);
const thread = useStream<AgentThreadState>({
client: apiClient,
assistantId: "lead_agent",
threadId: onStreamThreadId,
reconnectOnMount: true,
fetchStateHistory: { limit: 1 },
onCreated(meta) {
handleStreamStart(meta.thread_id);
setOnStreamThreadId(meta.thread_id);
},
onLangChainEvent:
onToolEnd == null
? undefined
: (event) => {
if (event.event === "on_tool_end") {
listeners.current.onToolEnd?.({
name: event.name,
data: event.data,
});
}
},
onUpdateEvent(data) {
const updates: Array<Partial<AgentThreadState> | null> = Object.values(
data || {},
);
for (const update of updates) {
if (update && "title" in update && update.title) {
void queryClient.setQueriesData(
{
queryKey: ["threads", "search"],
exact: false,
},
(oldData: Array<AgentThread> | undefined) => {
return oldData?.map((t) => {
if (t.thread_id === threadIdRef.current) {
return {
...t,
values: {
...t.values,
title: update.title,
},
};
}
return t;
});
},
);
}
}
},
onCustomEvent(event: unknown) {
if (
typeof event === "object" &&
event !== null &&
"type" in event &&
event.type === "task_running"
) {
const e = event as {
type: "task_running";
task_id: string;
message: AIMessage;
};
updateSubtask({ id: e.task_id, latestMessage: e.message });
}
},
onError(error) {
setOptimisticMessages([]);
toast.error(getStreamErrorMessage(error));
},
onFinish(state) {
listeners.current.onFinish?.(state.values);
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
},
});
// Optimistic messages shown before the server stream responds
const [optimisticMessages, setOptimisticMessages] = useState<Message[]>([]);
const [isUploading, setIsUploading] = useState(false);
const sendInFlightRef = useRef(false);
// Track message count before sending so we know when server has responded
const prevMsgCountRef = useRef(thread.messages.length);
// Clear optimistic when server messages arrive (count increases)
useEffect(() => {
if (
optimisticMessages.length > 0 &&
thread.messages.length > prevMsgCountRef.current
) {
setOptimisticMessages([]);
}
}, [thread.messages.length, optimisticMessages.length]);
const sendMessage = useCallback(
async (
threadId: string | undefined,
message: PromptInputMessage,
extraContext?: Record<string, unknown>,
) => {
if (sendInFlightRef.current) {
return;
}
sendInFlightRef.current = true;
const text = message.text.trim();
const resolvedThreadId =
threadId ?? threadIdRef.current ?? undefined;
if (resolvedThreadId === "new") {
toast.error("Invalid thread id 'new'. Please refresh and retry.");
sendInFlightRef.current = false;
return;
}
// Capture current count before showing optimistic messages
prevMsgCountRef.current = thread.messages.length;
// Build optimistic files list with uploading status
const optimisticFiles: FileInMessage[] = (message.files ?? []).map(
(f) => ({
filename: f.filename ?? "",
size: 0,
status: "uploading" as const,
}),
);
// Create optimistic human message (shown immediately)
const optimisticHumanMsg: Message = {
type: "human",
id: `opt-human-${Date.now()}`,
content: text ? [{ type: "text", text }] : "",
additional_kwargs:
optimisticFiles.length > 0 ? { files: optimisticFiles } : {},
};
const newOptimistic: Message[] = [optimisticHumanMsg];
if (optimisticFiles.length > 0) {
// Mock AI message while files are being uploaded
newOptimistic.push({
type: "ai",
id: `opt-ai-${Date.now()}`,
content: t.uploads.uploadingFiles,
additional_kwargs: { element: "task" },
});
}
setOptimisticMessages(newOptimistic);
// For "new chat with prefilled thread_id" flows, calling onStart before
// submit can trigger route switch too early, which causes the new page to
// fetch history before the thread/run is actually created.
// Let useStream.onCreated -> handleStreamStart drive onStart instead.
if (resolvedThreadId && !createNewSession) {
_handleOnStart(resolvedThreadId);
}
let uploadedFileInfo: UploadedFileInfo[] = [];
try {
// 新会话模式下,仅在本地已有历史消息时才重置旧线程。
// 对于全新 thread_id避免多发一次 DELETE /threads/{id}(通常会 404
if (createNewSession && resolvedThreadId && thread.messages.length > 0) {
await apiClient.threads.delete(resolvedThreadId).catch(() => undefined);
}
// Upload files first if any
if (message.files && message.files.length > 0) {
setIsUploading(true);
try {
// Convert FileUIPart to File objects by fetching blob URLs
const filePromises = message.files.map(async (fileUIPart) => {
if (fileUIPart.url && fileUIPart.filename) {
try {
// Fetch the blob URL to get the file data
const response = await fetch(fileUIPart.url);
const blob = await response.blob();
// Create a File object from the blob
return new File([blob], fileUIPart.filename, {
type: fileUIPart.mediaType || blob.type,
});
} catch (error) {
console.error(
`Failed to fetch file ${fileUIPart.filename}:`,
error,
);
return null;
}
}
return null;
});
const conversionResults = await Promise.all(filePromises);
const files = conversionResults.filter(
(file): file is File => file !== null,
);
const failedConversions = conversionResults.length - files.length;
if (failedConversions > 0) {
throw new Error(
`Failed to prepare ${failedConversions} attachment(s) for upload. Please retry.`,
);
}
if (!resolvedThreadId) {
throw new Error("Thread is not ready for file upload.");
}
if (files.length > 0) {
const uploadResponse = await uploadFiles(resolvedThreadId, files);
uploadedFileInfo = uploadResponse.files;
// Update optimistic human message with uploaded status + paths
const uploadedFiles: FileInMessage[] = uploadedFileInfo.map(
(info) => ({
filename: info.filename,
size: info.size,
path: info.virtual_path,
status: "uploaded" as const,
}),
);
setOptimisticMessages((messages) => {
if (messages.length > 1 && messages[0]) {
const humanMessage: Message = messages[0];
return [
{
...humanMessage,
additional_kwargs: { files: uploadedFiles },
},
...messages.slice(1),
];
}
return messages;
});
}
} catch (error) {
console.error("Failed to upload files:", error);
const errorMessage =
error instanceof Error
? error.message
: "Failed to upload files.";
toast.error(errorMessage);
setOptimisticMessages([]);
throw error;
} finally {
setIsUploading(false);
}
}
// Build files metadata for submission (included in additional_kwargs)
const filesForSubmit: FileInMessage[] = uploadedFileInfo.map(
(info) => ({
filename: info.filename,
size: info.size,
path: info.virtual_path,
status: "uploaded" as const,
}),
);
await thread.submit(
{
messages: [
{
type: "human",
content: [
{
type: "text",
text,
},
],
additional_kwargs:
filesForSubmit.length > 0 ? { files: filesForSubmit } : {},
},
],
},
{
threadId: resolvedThreadId,
streamSubgraphs: true,
streamResumable: true,
config: {
recursion_limit: 1000,
},
context: {
...extraContext,
...context,
thinking_enabled: context.mode !== "flash",
is_plan_mode: context.mode === "pro" || context.mode === "ultra",
subagent_enabled: context.mode === "ultra",
reasoning_effort:
context.reasoning_effort ??
(context.mode === "ultra"
? "high"
: context.mode === "pro"
? "medium"
: context.mode === "thinking"
? "low"
: undefined),
...(resolvedThreadId ? { thread_id: resolvedThreadId } : {}),
},
},
);
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
} catch (error) {
setOptimisticMessages([]);
setIsUploading(false);
throw error;
} finally {
sendInFlightRef.current = false;
}
},
[
thread,
_handleOnStart,
t.uploads.uploadingFiles,
context,
queryClient,
apiClient,
createNewSession,
],
);
// Merge thread with optimistic messages for display
const mergedThread =
optimisticMessages.length > 0
? ({
...thread,
messages: [...thread.messages, ...optimisticMessages],
} as typeof thread)
: thread;
return [
mergedThread as UseStream<AgentThreadState>,
sendMessage,
isUploading,
] as const;
}
export function useSubmitThread({
threadId,
thread,
threadContext,
createNewSession,
uploadTarget,
afterSubmit,
}: {
createNewSession: boolean;
threadId: string | null | undefined;
thread: UseStream<AgentThreadState>;
threadContext: Omit<AgentThreadContext, "thread_id">;
uploadTarget?: UploadTarget;
afterSubmit?: () => void;
}) {
const queryClient = useQueryClient();
const apiClient = getAPIClient();
const callback = useCallback(
async (message: PromptInputMessage) => {
if (threadId === "new") {
toast.error("Invalid thread id 'new'. Please refresh and retry.");
return;
}
const text = message.text.trim();
const hasFiles = !!(message.files && message.files.length > 0);
if (!text && !hasFiles) {
return;
}
if (createNewSession && threadId) {
await apiClient.threads.delete(threadId).catch(() => undefined);
await apiClient.threads.create({
threadId,
ifExists: "do_nothing",
});
}
if (message.files && message.files.length > 0) {
try {
const filePromises = message.files.map(async (fileUIPart) => {
if (fileUIPart.url && fileUIPart.filename) {
try {
const response = await fetch(fileUIPart.url);
const blob = await response.blob();
return new File([blob], fileUIPart.filename, {
type: fileUIPart.mediaType || blob.type,
});
} catch (error) {
console.error(
`Failed to fetch file ${fileUIPart.filename}:`,
error,
);
return null;
}
}
return null;
});
const files = (await Promise.all(filePromises)).filter(
(file): file is File => file !== null,
);
if (files.length > 0 && threadId) {
await uploadFiles(threadId, files, { target: uploadTarget });
}
} catch (error) {
console.error("Failed to upload files:", error);
}
}
await thread.submit(
{
messages: [
{
type: "human",
content: [
{
type: "text",
text,
},
],
},
] as Message[],
},
{
threadId: createNewSession ? threadId! : undefined,
streamSubgraphs: true,
streamResumable: true,
streamMode: ["values", "messages-tuple", "custom"],
config: {
recursion_limit: 1000,
},
context: {
...threadContext,
...(threadId ? { thread_id: threadId } : {}),
},
},
);
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
afterSubmit?.();
},
[
thread,
createNewSession,
threadId,
threadContext,
uploadTarget,
queryClient,
apiClient,
afterSubmit,
],
);
return callback;
}
export function useThreads(
params: Parameters<ThreadsClient["search"]>[0] = {
limit: 50,
sortBy: "updated_at",
sortOrder: "desc",
select: ["thread_id", "updated_at", "values"],
},
) {
const apiClient = getAPIClient();
return useQuery<AgentThread[]>({
queryKey: ["threads", "search", params],
queryFn: async () => {
const maxResults = params.limit;
const initialOffset = params.offset ?? 0;
const DEFAULT_PAGE_SIZE = 50;
// Preserve prior semantics: if a non-positive limit is explicitly provided,
// delegate to a single search call with the original parameters.
if (maxResults !== undefined && maxResults <= 0) {
const response =
await apiClient.threads.search<AgentThreadState>(params);
return response as AgentThread[];
}
const pageSize =
typeof maxResults === "number" && maxResults > 0
? Math.min(DEFAULT_PAGE_SIZE, maxResults)
: DEFAULT_PAGE_SIZE;
const threads: AgentThread[] = [];
let offset = initialOffset;
while (true) {
if (typeof maxResults === "number" && threads.length >= maxResults) {
break;
}
const currentLimit =
typeof maxResults === "number"
? Math.min(pageSize, maxResults - threads.length)
: pageSize;
if (typeof maxResults === "number" && currentLimit <= 0) {
break;
}
const response = (await apiClient.threads.search<AgentThreadState>({
...params,
limit: currentLimit,
offset,
})) as AgentThread[];
threads.push(...response);
if (response.length < currentLimit) {
break;
}
offset += response.length;
}
return threads;
},
refetchOnWindowFocus: false,
});
}
export function useDeleteThread() {
const queryClient = useQueryClient();
const apiClient = getAPIClient();
return useMutation({
mutationFn: async ({ threadId }: { threadId: string }) => {
await apiClient.threads.delete(threadId);
const response = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}`,
{
method: "DELETE",
},
);
if (!response.ok) {
const error = await response
.json()
.catch(() => ({ detail: "Failed to delete local thread data." }));
throw new Error(error.detail ?? "Failed to delete local thread data.");
}
},
onSuccess(_, { threadId }) {
queryClient.setQueriesData(
{
queryKey: ["threads", "search"],
exact: false,
},
(oldData: Array<AgentThread> | undefined) => {
if (oldData == null) {
return oldData;
}
return oldData.filter((t) => t.thread_id !== threadId);
},
);
},
onSettled() {
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
},
});
}
export function useRenameThread() {
const queryClient = useQueryClient();
const apiClient = getAPIClient();
return useMutation({
mutationFn: async ({
threadId,
title,
}: {
threadId: string;
title: string;
}) => {
await apiClient.threads.updateState(threadId, {
values: { title },
});
},
onSuccess(_, { threadId, title }) {
queryClient.setQueriesData(
{
queryKey: ["threads", "search"],
exact: false,
},
(oldData: Array<AgentThread>) => {
return oldData.map((t) => {
if (t.thread_id === threadId) {
return {
...t,
values: {
...t.values,
title,
},
};
}
return t;
});
},
);
},
});
}