deerflow2/frontend/src/core/threads/hooks.ts

922 lines
26 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import type { AIMessage, Message } from "@langchain/langgraph-sdk";
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
import { useStream, type UseStream } from "@langchain/langgraph-sdk/react";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { useCallback, useEffect, useRef, useState } from "react";
import { toast } from "sonner";
import type {
PromptInputMessage,
} from "@/components/ai-elements/prompt-input";
import { getAPIClient } from "../api";
import { getBackendBaseURL } from "../config";
import { useI18n } from "../i18n/hooks";
import type { FileInMessage } from "../messages/utils";
import type { LocalSettings } from "../settings";
import { useUpdateSubtask } from "../tasks/context";
import type { UploadedFileInfo } from "../uploads";
import { uploadFiles } from "../uploads";
import type { UploadTarget } from "../uploads/api";
import { buildFilesForSubmit } from "./submit-files";
import type {
AgentThread,
AgentThreadContext,
AgentThreadState,
} from "./types";
export type ToolEndEvent = {
name: string;
data: unknown;
};
export type ThreadStreamOptions = {
threadId: string | null | undefined;
context: LocalSettings["context"];
createNewSession?: boolean;
isMock?: boolean;
onStart?: (threadId: string) => void;
onFinish?: (state: AgentThreadState) => void;
onToolEnd?: (event: ToolEndEvent) => void;
};
export type LegacyThreadStreamOptions = {
isNewThread: boolean;
threadId: string | null | undefined;
fetchStateHistory?: boolean;
onFinish?: (state: AgentThreadState) => void;
useSubmitThread?: boolean;
};
const STREAM_ERROR_FALLBACK_MESSAGE = "Request failed.";
const STREAM_ERROR_TOAST_MESSAGE = "出现了某些错误。";
const STREAM_ERROR_TOAST_DEDUPE_WINDOW_MS = 2000;
const STREAM_CANCEL_PATTERNS = [
/\bcancellederror\b/i,
/\bcancelled\b/i,
/\bcanceled\b/i,
/\babort(?:ed|error)?\b/i,
];
function readMessageCandidate(value: unknown): string | null {
if (typeof value === "string" && value.trim()) {
return value.trim();
}
if (value instanceof Error && value.message.trim()) {
return value.message.trim();
}
return null;
}
function getStreamErrorMessage(error: unknown): string {
const directMessage = readMessageCandidate(error);
if (directMessage) {
return directMessage;
}
const visited = new Set<object>();
const queue: unknown[] = [error];
const preferredKeys = ["message", "detail", "error"];
while (queue.length > 0) {
const current = queue.shift();
if (current == null) {
continue;
}
const message = readMessageCandidate(current);
if (message) {
return message;
}
if (typeof current !== "object") {
continue;
}
if (visited.has(current)) {
continue;
}
visited.add(current);
for (const key of preferredKeys) {
const candidate = Reflect.get(current, key);
const parsed = readMessageCandidate(candidate);
if (parsed) {
return parsed;
}
if (candidate && typeof candidate === "object") {
queue.push(candidate);
}
}
if (Array.isArray(current)) {
queue.push(...current);
continue;
}
for (const value of Object.values(current)) {
if (value && typeof value === "object") {
queue.push(value);
}
}
}
return STREAM_ERROR_FALLBACK_MESSAGE;
}
function isStreamCancellation(error: unknown, message: string): boolean {
const direct =
typeof error === "object" &&
error !== null &&
"name" in error &&
typeof Reflect.get(error, "name") === "string"
? String(Reflect.get(error, "name"))
: "";
const candidates = [message, direct];
return candidates.some((value) =>
STREAM_CANCEL_PATTERNS.some((pattern) => pattern.test(value)),
);
}
function normalizeThreadId(
value: string | null | undefined,
): string | undefined {
if (!value) return undefined;
const normalized = value.trim();
if (!normalized || normalized === "new") return undefined;
return normalized;
}
export function useThreadStreamLegacy({
threadId,
isNewThread,
fetchStateHistory = true,
onFinish,
}: LegacyThreadStreamOptions): UseStream<AgentThreadState> {
const queryClient = useQueryClient();
const updateSubtask = useUpdateSubtask();
const thread = useStream<AgentThreadState>({
client: getAPIClient(),
assistantId: "lead_agent",
threadId: isNewThread ? undefined : threadId,
reconnectOnMount: true,
fetchStateHistory,
onCustomEvent(event: unknown) {
console.info(event);
if (
typeof event === "object" &&
event !== null &&
"type" in event &&
event.type === "task_running"
) {
const e = event as {
type: "task_running";
task_id: string;
message: AIMessage;
};
updateSubtask({ id: e.task_id, latestMessage: e.message });
}
},
onFinish(state) {
onFinish?.(state.values);
queryClient.setQueriesData(
{
queryKey: ["threads", "search"],
exact: false,
},
(oldData: Array<AgentThread>) => {
return oldData.map((t) => {
if (t.thread_id === threadId) {
return {
...t,
values: {
...t.values,
title: state.values.title,
},
};
}
return t;
});
},
);
},
});
return thread as UseStream<AgentThreadState>;
}
export function useThreadStream({
threadId,
context,
createNewSession = false,
isMock,
onStart,
onFinish,
onToolEnd,
}: ThreadStreamOptions) {
const { t } = useI18n();
// Track the thread ID that is currently streaming to handle thread changes during streaming
const [onStreamThreadId, setOnStreamThreadId] = useState(() => threadId);
// Ref to track current thread ID across async callbacks without causing re-renders,
// and to allow access to the current thread id in onUpdateEvent
const threadIdRef = useRef<string | null>(threadId ?? null);
const startedRef = useRef(false);
const lastErrorToastRef = useRef<{
message: string;
timestamp: number;
} | null>(null);
const listeners = useRef({
onStart,
onFinish,
onToolEnd,
});
// Keep listeners ref updated with latest callbacks
useEffect(() => {
listeners.current = { onStart, onFinish, onToolEnd };
}, [onStart, onFinish, onToolEnd]);
useEffect(() => {
const normalizedThreadId = normalizeThreadId(threadId) ?? null;
if (!normalizedThreadId) {
// Just reset for new thread creation when threadId becomes null/undefined
startedRef.current = false;
}
setOnStreamThreadId((prev) =>
prev === normalizedThreadId ? prev : normalizedThreadId,
);
threadIdRef.current = normalizedThreadId;
}, [threadId]);
const _handleOnStart = useCallback((id: string) => {
if (!startedRef.current) {
listeners.current.onStart?.(id);
startedRef.current = true;
}
}, []);
const showStreamErrorToast = useCallback((error: unknown) => {
const message = getStreamErrorMessage(error);
if (isStreamCancellation(error, message)) {
// Cancellation is expected when user presses "Stop" or stream disconnects.
console.info("[useThreadStream] stream cancelled:", message);
return;
}
const now = Date.now();
const lastToast = lastErrorToastRef.current;
if (
lastToast &&
lastToast.message === message &&
now - lastToast.timestamp < STREAM_ERROR_TOAST_DEDUPE_WINDOW_MS
) {
return;
}
lastErrorToastRef.current = { message, timestamp: now };
console.error("[useThreadStream] conversation stream error:", error);
console.error("[useThreadStream] parsed error message:", message);
toast.error(STREAM_ERROR_TOAST_MESSAGE);
}, []);
const handleStreamStart = useCallback(
(_threadId: string) => {
threadIdRef.current = _threadId;
_handleOnStart(_threadId);
},
[_handleOnStart],
);
const queryClient = useQueryClient();
const updateSubtask = useUpdateSubtask();
const apiClient = getAPIClient(isMock);
const thread = useStream<AgentThreadState>({
client: apiClient,
assistantId: "lead_agent",
threadId: onStreamThreadId,
reconnectOnMount: true,
fetchStateHistory: { limit: 1 },
onCreated(meta) {
handleStreamStart(meta.thread_id);
setOnStreamThreadId(meta.thread_id);
},
onLangChainEvent:
onToolEnd == null
? undefined
: (event) => {
if (event.event === "on_tool_end") {
listeners.current.onToolEnd?.({
name: event.name,
data: event.data,
});
}
},
onUpdateEvent(data) {
const updates: Array<Partial<AgentThreadState> | null> = Object.values(
data || {},
);
for (const update of updates) {
if (update && "title" in update && update.title) {
void queryClient.setQueriesData(
{
queryKey: ["threads", "search"],
exact: false,
},
(oldData: Array<AgentThread> | undefined) => {
return oldData?.map((t) => {
if (t.thread_id === threadIdRef.current) {
return {
...t,
values: {
...t.values,
title: update.title,
},
};
}
return t;
});
},
);
}
}
},
onCustomEvent(event: unknown) {
if (
typeof event === "object" &&
event !== null &&
"type" in event &&
event.type === "task_running"
) {
const e = event as {
type: "task_running";
task_id: string;
message: AIMessage;
};
updateSubtask({ id: e.task_id, latestMessage: e.message });
}
},
onError(error) {
setOptimisticMessages([]);
showStreamErrorToast(error);
},
onFinish(state) {
listeners.current.onFinish?.(state.values);
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
},
});
// Optimistic messages shown before the server stream responds
const [optimisticMessages, setOptimisticMessages] = useState<Message[]>([]);
const [isUploading, setIsUploading] = useState(false);
const sendInFlightRef = useRef(false);
// Track message count before sending so we know when server has responded
const prevMsgCountRef = useRef(thread.messages.length);
// Clear optimistic when server messages arrive (count increases)
useEffect(() => {
if (
optimisticMessages.length > 0 &&
thread.messages.length > prevMsgCountRef.current
) {
setOptimisticMessages([]);
}
}, [thread.messages.length, optimisticMessages.length]);
useEffect(() => {
if (!thread.error) {
return;
}
showStreamErrorToast(thread.error);
}, [thread.error, showStreamErrorToast]);
const sendMessage = useCallback(
async (
threadId: string | undefined,
message: PromptInputMessage,
extraContext?: Record<string, unknown>,
) => {
if (sendInFlightRef.current) {
return;
}
sendInFlightRef.current = true;
const text = message.text.trim();
const resolvedThreadId =
normalizeThreadId(threadId) ??
normalizeThreadId(threadIdRef.current) ??
undefined;
if (resolvedThreadId === "new") {
toast.error("Invalid thread id 'new'. Please refresh and retry.");
sendInFlightRef.current = false;
return;
}
// Capture current count before showing optimistic messages
prevMsgCountRef.current = thread.messages.length;
// Build optimistic files list with uploading status
const optimisticFiles: FileInMessage[] = (message.files ?? []).map(
(f) => ({
filename: f.filename ?? "",
size: 0,
status: "uploading" as const,
}),
);
// Create optimistic human message (shown immediately)
const optimisticHumanMsg: Message = {
type: "human",
id: `opt-human-${Date.now()}`,
content: text ? [{ type: "text", text }] : "",
additional_kwargs:
optimisticFiles.length > 0 ? { files: optimisticFiles } : {},
};
const newOptimistic: Message[] = [optimisticHumanMsg];
if (optimisticFiles.length > 0) {
// Mock AI message while files are being uploaded
newOptimistic.push({
type: "ai",
id: `opt-ai-${Date.now()}`,
content: t.uploads.uploadingFiles,
additional_kwargs: { element: "task" },
});
}
setOptimisticMessages(newOptimistic);
// For "new chat with prefilled thread_id" flows, calling onStart before
// submit can trigger route switch too early, which causes the new page to
// fetch history before the thread/run is actually created.
// Let useStream.onCreated -> handleStreamStart drive onStart instead.
if (resolvedThreadId && !createNewSession) {
_handleOnStart(resolvedThreadId);
}
let uploadedFileInfo: UploadedFileInfo[] = [];
try {
// 新会话模式下,仅在本地已有历史消息时才重置旧线程。
// 对于全新 thread_id避免多发一次 DELETE /threads/{id}(通常会 404
if (
createNewSession &&
resolvedThreadId &&
thread.messages.length > 0
) {
await apiClient.threads
.delete(resolvedThreadId)
.catch(() => undefined);
}
// Upload files first if any
if (message.files && message.files.length > 0) {
setIsUploading(true);
try {
// Convert FileUIPart to File objects by fetching blob URLs
const filePromises = message.files.map(async (fileUIPart) => {
if (fileUIPart.url && fileUIPart.filename) {
try {
// Fetch the blob URL to get the file data
const response = await fetch(fileUIPart.url);
const blob = await response.blob();
// Create a File object from the blob
return new File([blob], fileUIPart.filename, {
type: fileUIPart.mediaType || blob.type,
});
} catch (error) {
console.error(
`Failed to fetch file ${fileUIPart.filename}:`,
error,
);
return null;
}
}
return null;
});
const conversionResults = await Promise.all(filePromises);
const files = conversionResults.filter(
(file): file is File => file !== null,
);
const failedConversions = conversionResults.length - files.length;
if (failedConversions > 0) {
throw new Error(
`Failed to prepare ${failedConversions} attachment(s) for upload. Please retry.`,
);
}
if (!resolvedThreadId) {
throw new Error("Thread is not ready for file upload.");
}
if (files.length > 0) {
const uploadResponse = await uploadFiles(resolvedThreadId, files);
uploadedFileInfo = uploadResponse.files;
// Update optimistic human message with uploaded status + paths
const uploadedFiles: FileInMessage[] = uploadedFileInfo.map(
(info) => ({
filename: info.filename,
size: info.size,
path: info.virtual_path,
status: "uploaded" as const,
}),
);
setOptimisticMessages((messages) => {
if (messages.length > 1 && messages[0]) {
const humanMessage: Message = messages[0];
return [
{
...humanMessage,
additional_kwargs: { files: uploadedFiles },
},
...messages.slice(1),
];
}
return messages;
});
}
} catch (error) {
console.error("Failed to upload files:", error);
const errorMessage =
error instanceof Error
? error.message
: "Failed to upload files.";
toast.error(errorMessage);
setOptimisticMessages([]);
throw error;
} finally {
setIsUploading(false);
}
}
// Build files metadata for submission (single envelope for uploads + references)
const normalizedReferences = message.references ?? [];
const { files: filesForSubmit, staleCount } = buildFilesForSubmit(
uploadedFileInfo,
normalizedReferences,
);
if (staleCount > 0) {
toast.error("部分引用文件已失效,已自动移除并继续发送。");
}
await thread.submit(
{
messages: [
{
type: "human",
content: [
{
type: "text",
text,
},
],
additional_kwargs:
filesForSubmit.length > 0 ? { files: filesForSubmit } : {},
},
],
},
{
threadId: resolvedThreadId,
streamSubgraphs: true,
streamResumable: true,
config: {
recursion_limit: 1000,
},
context: {
...extraContext,
...context,
thinking_enabled: context.mode !== "flash",
is_plan_mode: context.mode === "pro" || context.mode === "ultra",
subagent_enabled: context.mode === "ultra",
reasoning_effort:
context.reasoning_effort ??
(context.mode === "ultra"
? "high"
: context.mode === "pro"
? "medium"
: context.mode === "thinking"
? "low"
: undefined),
...(resolvedThreadId ? { thread_id: resolvedThreadId } : {}),
},
},
);
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
} catch (error) {
setOptimisticMessages([]);
setIsUploading(false);
throw error;
} finally {
sendInFlightRef.current = false;
}
},
[
thread,
_handleOnStart,
t.uploads.uploadingFiles,
context,
queryClient,
apiClient,
createNewSession,
],
);
// Merge thread with optimistic messages for display
const mergedThread =
optimisticMessages.length > 0
? ({
...thread,
messages: [...thread.messages, ...optimisticMessages],
} as typeof thread)
: thread;
return [
mergedThread as UseStream<AgentThreadState>,
sendMessage,
isUploading,
] as const;
}
export function useSubmitThread({
threadId,
thread,
threadContext,
createNewSession,
uploadTarget,
afterSubmit,
}: {
createNewSession: boolean;
threadId: string | null | undefined;
thread: UseStream<AgentThreadState>;
threadContext: Omit<AgentThreadContext, "thread_id">;
uploadTarget?: UploadTarget;
afterSubmit?: () => void;
}) {
const queryClient = useQueryClient();
const apiClient = getAPIClient();
const callback = useCallback(
async (message: PromptInputMessage) => {
if (threadId === "new") {
toast.error("Invalid thread id 'new'. Please refresh and retry.");
return;
}
const text = message.text.trim();
const hasFiles = !!(message.files && message.files.length > 0);
const hasReferences = !!(
message.references && message.references.length > 0
);
if (!text && !hasFiles && !hasReferences) {
return;
}
if (createNewSession && threadId) {
await apiClient.threads.delete(threadId).catch(() => undefined);
await apiClient.threads.create({
threadId,
ifExists: "do_nothing",
});
}
if (message.files && message.files.length > 0) {
try {
const filePromises = message.files.map(async (fileUIPart) => {
if (fileUIPart.url && fileUIPart.filename) {
try {
const response = await fetch(fileUIPart.url);
const blob = await response.blob();
return new File([blob], fileUIPart.filename, {
type: fileUIPart.mediaType || blob.type,
});
} catch (error) {
console.error(
`Failed to fetch file ${fileUIPart.filename}:`,
error,
);
return null;
}
}
return null;
});
const files = (await Promise.all(filePromises)).filter(
(file): file is File => file !== null,
);
if (files.length > 0 && threadId) {
await uploadFiles(threadId, files, { target: uploadTarget });
}
} catch (error) {
console.error("Failed to upload files:", error);
}
}
const normalizedReferences = message.references ?? [];
const { files: filesForSubmit, staleCount } = buildFilesForSubmit(
[],
normalizedReferences,
);
if (staleCount > 0) {
toast.error("部分引用文件已失效,已自动移除并继续发送。");
}
await thread.submit(
{
messages: [
{
type: "human",
content: [
{
type: "text",
text,
},
],
additional_kwargs:
filesForSubmit.length > 0 ? { files: filesForSubmit } : {},
},
] as Message[],
},
{
threadId: createNewSession ? threadId! : undefined,
streamSubgraphs: true,
streamResumable: true,
streamMode: ["values", "messages-tuple", "custom"],
config: {
recursion_limit: 1000,
},
context: {
...threadContext,
...(threadId ? { thread_id: threadId } : {}),
},
},
);
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
afterSubmit?.();
},
[
thread,
createNewSession,
threadId,
threadContext,
uploadTarget,
queryClient,
apiClient,
afterSubmit,
],
);
return callback;
}
export function useThreads(
params: Parameters<ThreadsClient["search"]>[0] = {
limit: 50,
sortBy: "updated_at",
sortOrder: "desc",
select: ["thread_id", "updated_at", "values"],
},
) {
const apiClient = getAPIClient();
return useQuery<AgentThread[]>({
queryKey: ["threads", "search", params],
queryFn: async () => {
const maxResults = params.limit;
const initialOffset = params.offset ?? 0;
const DEFAULT_PAGE_SIZE = 50;
// Preserve prior semantics: if a non-positive limit is explicitly provided,
// delegate to a single search call with the original parameters.
if (maxResults !== undefined && maxResults <= 0) {
const response =
await apiClient.threads.search<AgentThreadState>(params);
return response as AgentThread[];
}
const pageSize =
typeof maxResults === "number" && maxResults > 0
? Math.min(DEFAULT_PAGE_SIZE, maxResults)
: DEFAULT_PAGE_SIZE;
const threads: AgentThread[] = [];
let offset = initialOffset;
while (true) {
if (typeof maxResults === "number" && threads.length >= maxResults) {
break;
}
const currentLimit =
typeof maxResults === "number"
? Math.min(pageSize, maxResults - threads.length)
: pageSize;
if (typeof maxResults === "number" && currentLimit <= 0) {
break;
}
const response = (await apiClient.threads.search<AgentThreadState>({
...params,
limit: currentLimit,
offset,
})) as AgentThread[];
threads.push(...response);
if (response.length < currentLimit) {
break;
}
offset += response.length;
}
return threads;
},
refetchOnWindowFocus: false,
});
}
export function useDeleteThread() {
const queryClient = useQueryClient();
const apiClient = getAPIClient();
return useMutation({
mutationFn: async ({ threadId }: { threadId: string }) => {
await apiClient.threads.delete(threadId);
const response = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}`,
{
method: "DELETE",
},
);
if (!response.ok) {
const error = await response
.json()
.catch(() => ({ detail: "Failed to delete local thread data." }));
throw new Error(error.detail ?? "Failed to delete local thread data.");
}
},
onSuccess(_, { threadId }) {
queryClient.setQueriesData(
{
queryKey: ["threads", "search"],
exact: false,
},
(oldData: Array<AgentThread> | undefined) => {
if (oldData == null) {
return oldData;
}
return oldData.filter((t) => t.thread_id !== threadId);
},
);
},
onSettled() {
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
},
});
}
export function useRenameThread() {
const queryClient = useQueryClient();
const apiClient = getAPIClient();
return useMutation({
mutationFn: async ({
threadId,
title,
}: {
threadId: string;
title: string;
}) => {
await apiClient.threads.updateState(threadId, {
values: { title },
});
},
onSuccess(_, { threadId, title }) {
queryClient.setQueriesData(
{
queryKey: ["threads", "search"],
exact: false,
},
(oldData: Array<AgentThread>) => {
return oldData.map((t) => {
if (t.thread_id === threadId) {
return {
...t,
values: {
...t.values,
title,
},
};
}
return t;
});
},
);
},
});
}