diff --git a/frontend/src/core/threads/hooks.test.ts b/frontend/src/core/threads/hooks.test.ts index 646de7a7..ddd7d8f4 100644 --- a/frontend/src/core/threads/hooks.test.ts +++ b/frontend/src/core/threads/hooks.test.ts @@ -1,7 +1,7 @@ import assert from "node:assert/strict"; import test from "node:test"; -const { buildFilesForSubmit } = await import( +const { buildFilesForSubmit, materializeArtifactReferences } = await import( new URL("./submit-files.ts", import.meta.url).href ); @@ -48,3 +48,62 @@ void test("buildFilesForSubmit drops stale references without blocking submit", assert.equal(result.staleCount, 1); assert.equal(result.files.length, 0); }); + +void test("materializeArtifactReferences converts artifact references to upload paths", async () => { + const references = await materializeArtifactReferences( + [ + { + filename: "artifact.md", + path: "/mnt/user-data/outputs/artifact.md", + ref_kind: "mention", + ref_source: "artifact", + }, + { + filename: "uploaded.md", + path: "/mnt/user-data/uploads/uploaded.md", + ref_kind: "mention", + ref_source: "upload", + }, + ], + { + fetchArtifactBlob: async () => + new Blob(["artifact"], { type: "text/plain" }), + uploadArtifact: async () => ({ + filename: "artifact.md", + size: 8, + path: "/host/path/artifact.md", + virtual_path: "/mnt/user-data/uploads/artifact.md", + artifact_url: "/api/threads/t1/artifacts/mnt/user-data/uploads/artifact.md", + }), + }, + ); + + assert.equal(references.length, 2); + assert.equal(references[0]?.ref_source, "upload"); + assert.equal(references[0]?.path, "/mnt/user-data/uploads/artifact.md"); + assert.equal(references[1]?.path, "/mnt/user-data/uploads/uploaded.md"); +}); + +void test("materializeArtifactReferences marks artifact as stale on upload failure", async () => { + const references = await materializeArtifactReferences( + [ + { + filename: "broken.md", + path: "/mnt/user-data/outputs/broken.md", + ref_kind: "mention", + ref_source: "artifact", + }, + ], + { + fetchArtifactBlob: async () => new Blob(["artifact"]), + uploadArtifact: async () => null, + }, + ); + + assert.equal(references.length, 1); + assert.equal(references[0]?.stale, true); + + const result = buildFilesForSubmit([], references); + assert.equal(result.staleCount, 1); + assert.equal(result.files.length, 0); +}); diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index 6c036da4..b068cd52 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -10,6 +10,7 @@ import type { } from "@/components/ai-elements/prompt-input"; import { getAPIClient } from "../api"; +import { urlOfArtifact } from "../artifacts/utils"; import { getBackendBaseURL } from "../config"; import { useI18n } from "../i18n/hooks"; import type { FileInMessage } from "../messages/utils"; @@ -19,7 +20,7 @@ import type { UploadedFileInfo } from "../uploads"; import { uploadFiles } from "../uploads"; import type { UploadTarget } from "../uploads/api"; -import { buildFilesForSubmit } from "./submit-files"; +import { buildFilesForSubmit, materializeArtifactReferences } from "./submit-files"; import type { AgentThread, AgentThreadContext, @@ -59,6 +60,34 @@ const STREAM_CANCEL_PATTERNS = [ /\babort(?:ed|error)?\b/i, ]; +async function convertArtifactReferencesToUploads( + threadId: string, + references: PromptInputMessage["references"], +) { + return materializeArtifactReferences(references, { + fetchArtifactBlob: async (reference) => { + const filepath = reference.path; + if (!filepath) { + throw new Error("Missing artifact path"); + } + const response = await fetch( + urlOfArtifact({ + filepath, + threadId, + }), + ); + if (!response.ok) { + throw new Error("Failed to read artifact"); + } + return response.blob(); + }, + uploadArtifact: async (file) => { + const response = await uploadFiles(threadId, [file]); + return response.files[0]; + }, + }); +} + function readMessageCandidate(value: unknown): string | null { if (typeof value === "string" && value.trim()) { return value.trim(); @@ -553,9 +582,15 @@ export function useThreadStream({ } // Build files metadata for submission (single envelope for uploads + references) + const normalizedReferences = resolvedThreadId + ? await convertArtifactReferencesToUploads( + resolvedThreadId, + message.references, + ) + : (message.references ?? []); const { files: filesForSubmit, staleCount } = buildFilesForSubmit( uploadedFileInfo, - message.references, + normalizedReferences, ); if (staleCount > 0) { toast.error("部分引用已失效,已自动移除"); @@ -714,9 +749,12 @@ export function useSubmitThread({ } } + const normalizedReferences = threadId + ? await convertArtifactReferencesToUploads(threadId, message.references) + : (message.references ?? []); const { files: filesForSubmit, staleCount } = buildFilesForSubmit( [], - message.references, + normalizedReferences, ); if (staleCount > 0) { toast.error("部分引用已失效,已自动移除"); diff --git a/frontend/src/core/threads/submit-files.ts b/frontend/src/core/threads/submit-files.ts index a82d3091..02d1ab91 100644 --- a/frontend/src/core/threads/submit-files.ts +++ b/frontend/src/core/threads/submit-files.ts @@ -10,6 +10,51 @@ export type MentionReference = { stale?: boolean; }; +type ArtifactMaterializer = ( + file: File, +) => Promise; +type ArtifactBlobFetcher = (reference: MentionReference) => Promise; + +export async function materializeArtifactReferences( + references: MentionReference[] = [], + options: { + fetchArtifactBlob: ArtifactBlobFetcher; + uploadArtifact: ArtifactMaterializer; + }, +): Promise { + const result: MentionReference[] = []; + for (const reference of references) { + if ( + reference.ref_source !== "artifact" || + !reference.path || + reference.stale + ) { + result.push(reference); + continue; + } + + try { + const blob = await options.fetchArtifactBlob(reference); + const file = new File([blob], reference.filename, { + type: blob.type || "application/octet-stream", + }); + const uploaded = await options.uploadArtifact(file); + if (!uploaded?.virtual_path) { + result.push({ ...reference, stale: true }); + continue; + } + result.push({ + ...reference, + ref_source: "upload", + path: uploaded.virtual_path, + }); + } catch { + result.push({ ...reference, stale: true }); + } + } + return result; +} + export function buildFilesForSubmit( uploadedFileInfo: UploadedFileInfo[], references: MentionReference[] = [], diff --git a/frontend/tests/e2e/input-and-compose.spec.ts b/frontend/tests/e2e/input-and-compose.spec.ts index 6de5b93c..03862d0e 100644 --- a/frontend/tests/e2e/input-and-compose.spec.ts +++ b/frontend/tests/e2e/input-and-compose.spec.ts @@ -128,8 +128,11 @@ test.describe("聊天工作台 / 输入区与发送", () => { skipIfMissingThread(testInfo, THREAD_FOR_WELCOME, "FRONTEND_E2E_THREAD_ID"); await openChat(page, reuseThreadChatEntry(THREAD_FOR_WELCOME!)); + const expander = page.locator("div.absolute.inset-0.z-1.cursor-text"); + if ((await expander.count()) > 0) { + await expander.first().click(); + } const textarea = page.locator("textarea[name='message']"); - await textarea.click(); await textarea.fill("请基于这个文件回答 @"); const panel = page.getByTestId("mention-candidate-panel").first(); @@ -138,7 +141,7 @@ test.describe("聊天工作台 / 输入区与发送", () => { const itemCount = await items.count(); testInfo.skip(itemCount === 0, "当前线程没有可引用文件候选。"); - await items.first().click(); + await textarea.press("Enter"); await expect(textarea).toBeFocused(); await expect(textarea).toHaveValue(/请基于这个文件回答/); await expect(page.getByTestId("reference-inline-preview")).toBeVisible(); @@ -163,8 +166,11 @@ test.describe("聊天工作台 / 输入区与发送", () => { skipIfMissingThread(testInfo, THREAD_FOR_WELCOME, "FRONTEND_E2E_THREAD_ID"); await openChat(page, reuseThreadChatEntry(THREAD_FOR_WELCOME!)); + const expander = page.locator("div.absolute.inset-0.z-1.cursor-text"); + if ((await expander.count()) > 0) { + await expander.first().click(); + } const textarea = page.locator("textarea[name='message']"); - await textarea.click(); await textarea.fill("请参考这些文件 "); await textarea.type("@"); @@ -178,14 +184,20 @@ test.describe("聊天工作台 / 输入区与发送", () => { await textarea.type("@"); const currentPanel = page.getByTestId("mention-candidate-panel").first(); await expect(currentPanel).toBeVisible(); - await currentPanel.locator("button").nth(i).click(); + for (let step = 0; step < i; step += 1) { + await textarea.press("ArrowDown"); + } + await textarea.press("Enter"); } await expect(page.getByLabel("移除引用")).toHaveCount(6); await textarea.type("@"); await expect(panel).toBeVisible(); - await panel.locator("button").nth(6).click(); + for (let step = 0; step < 6; step += 1) { + await textarea.press("ArrowDown"); + } + await textarea.press("Enter"); await expect(page.getByLabel("移除引用")).toHaveCount(6); await expect(page.getByText("单条消息最多引用 6 个文件")).toBeVisible();