shuzhiren-comfyui/任务队列后端/worker_threads/wait/generatTask.js

219 lines
8.5 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { parentPort } from 'worker_threads';
import redis from '../../redis/index.js';
import initQueue from '../../redis/initQueue.js';
import { externalPostRequest } from '../../outside/generat.js';
import { platformData } from '../../config/Config.js';
// 批量转发待处理任务到各外部平台
async function generatTask(tasksData) {
// console.log('开始转发任务');
const generatTasks = []
for (const task of tasksData) {
// 2. 获取任务所属平台的生成接口地址
const generatTaskPromise = externalPostRequest(task) // { aigc, tasksData }
generatTasks.push(generatTaskPromise)
}
try {
const responseTasks = await Promise.all(generatTasks)
return responseTasks
} catch (error) {
console.error('Error:', error);
return []; // 确保总是返回数组
}
}
// 批量储存外部平台返回的任务数据到处理队列
async function storeGeneratTasks(tasks) {
// 确保tasks是数组
if (!tasks || !Array.isArray(tasks)) {
console.error('storeGeneratTasks函数接收到无效的tasks参数:', tasks);
return;
}
const multi = redis.multi();
let errorCount = 0;
const taskErrorCountMap = new Map();
const taskCountMap = new Map();
for (const task of tasks) {
// console.log('\n***************',task)
//错误任务
if(task.remoteTaskId?.type === 2){
console.log('储存在错误队列', task);
const aigc = task.AIGC || task.aigc;
const platform = task.platform || task.platformName;
// 存储错误信息到任务数据中
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'resultData', JSON.stringify(task.remoteTaskId.message));
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'status', 'failed');
// 推送任务ID到错误列表
multi.lPush(initQueue.errorList, task.taskId);
errorCount++;
// 存储相关平台信息
const key = `${aigc}:${platform}`;
if(taskErrorCountMap.has(key)){
taskErrorCountMap.set(key, taskErrorCountMap.get(key) + 1);
} else {
taskErrorCountMap.set(key, 1);
}
continue // 跳过错误任务
}
// 处理成功的任务
let externalTaskId;
if (task.remoteTaskId?.type === 1 && task.remoteTaskId?.data) {
// 使用解析后的响应数据提取外部平台任务ID
try {
const responseData = task.remoteTaskId.data;
// console.log('处理成功任务,响应数据:', responseData);
// 直接处理响应数据提取任务ID
const platform = task.platform || task.platformName;
if ((responseData.msg === 'success' || platform === 'coze') && responseData.code === 0) {
// Coze平台返回的是execute_id其他平台返回的是data.taskId
if (platform === 'coze') {
externalTaskId = responseData.execute_id;
} else {
externalTaskId = responseData.data?.taskId;
}
if (externalTaskId) {
console.log('成功提取外部平台任务ID:', externalTaskId);
} else {
console.error('无法从响应中提取外部平台任务ID:', responseData);
// 视为错误任务
const errorMessage = JSON.stringify({ message: '无法从响应中提取外部平台任务ID', response: responseData });
// 存储错误信息到任务数据中
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'resultData', errorMessage);
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'status', 'failed');
// 推送任务ID到错误列表
multi.lPush(initQueue.errorList, task.taskId);
errorCount++;
// 存储相关平台信息
const key = `${task.AIGC}:${task.platform}`;
if(taskErrorCountMap.has(key)){
taskErrorCountMap.set(key, taskErrorCountMap.get(key) + 1);
} else {
taskErrorCountMap.set(key, 1);
}
continue; // 跳过错误任务
}
} else {
console.error('外部平台返回错误:', responseData);
// 视为错误任务
const aigc = task.AIGC || task.aigc;
const platform = task.platform || task.platformName;
const errorMessage = JSON.stringify(responseData);
// 存储错误信息到任务数据中
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'resultData', errorMessage);
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'status', 'failed');
// 推送任务ID到错误列表
multi.lPush(initQueue.errorList, task.taskId);
errorCount++;
// 存储相关平台信息
const key = `${aigc}:${platform}`;
if(taskErrorCountMap.has(key)){
taskErrorCountMap.set(key, taskErrorCountMap.get(key) + 1);
} else {
taskErrorCountMap.set(key, 1);
}
continue; // 跳过错误任务
}
} catch (extractError) {
console.error('提取外部平台任务ID失败:', extractError);
// 视为错误任务
const aigc = task.AIGC || task.aigc;
const platform = task.platform || task.platformName;
const errorMessage = JSON.stringify({ message: '提取外部平台任务ID失败', error: extractError.message });
// 存储错误信息到任务数据中
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'resultData', errorMessage);
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'status', 'failed');
// 推送任务ID到错误列表
multi.lPush(initQueue.errorList, task.taskId);
errorCount++;
// 存储相关平台信息
const key = `${aigc}:${platform}`;
if(taskErrorCountMap.has(key)){
taskErrorCountMap.set(key, taskErrorCountMap.get(key) + 1);
} else {
taskErrorCountMap.set(key, 1);
}
continue; // 跳过错误任务
}
} else {
// 直接使用remoteTaskId作为外部平台任务ID
externalTaskId = task.remoteTaskId;
}
//回调任务
const aigc = task.AIGC || task.aigc;
const platform = task.platform || task.platformName;
if(platformData.callback.includes(platform)) {
console.log('储存在回调队列', externalTaskId, task.taskId);
multi.set(`${initQueue.callback}:${externalTaskId}`, task.taskId)
} else { // 轮询任务
// 按平台+AIGC类型存储轮询任务
const pollingKey = `${initQueue.prefix}:processPolling:${aigc}:${platform}`;
// 从task中提取workflow_id优先使用task.workflowId
let workflowId = task.workflowId || '';
try {
if (!workflowId && task.taskData) {
// taskData 已经是字符串,直接解析
const taskDataParsed = JSON.parse(task.taskData);
workflowId = taskDataParsed.workflow_id || '';
}
} catch (e) {
console.error('[generatTask] 解析taskData获取workflow_id失败:', e);
}
console.log(`[generatTask] 提取到的workflowId: ${workflowId}`);
// 确保workflowId被传递到轮询任务中
const pollingData = {
taskId: task.taskId,
platform: platform,
AIGC: aigc,
workflowId: workflowId // 包含workflowId为空则使用空字符串
};
console.log(`[generatTask] 添加轮询任务: pollingKey=${pollingKey}, externalTaskId=${externalTaskId}, pollingData=${JSON.stringify(pollingData)}`);
multi.hSet(pollingKey, externalTaskId, JSON.stringify(pollingData))
}
// 更新任务信息,添加 remoteTaskId 字段
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'remoteTaskId', externalTaskId);
// 确保任务有2小时的过期时间
multi.expire(`${initQueue.prefix}:task:${task.taskId}`, 7200);
// 记录相关队列处理的任务数
const key = `${aigc}:${platform}`;
if(taskCountMap.has(key)){
taskCountMap.set(key, taskCountMap.get(key) + 1);
} else {
taskCountMap.set(key, 1);
}
}
// 更新平台信息
if(errorCount > 0){
initQueue.addEQtaskALL(errorCount) // 添加错误队列任务数量
}
// 注意这里不再调用addPlatformsProcess因为PQtasks计数已经在updateTaskCounts函数中处理过了
// 避免同一个任务被两次增加PQtasks计数
await multi.exec();
}
parentPort.on('message', async (tasksData) => {
await generatTask(tasksData)
.then (tasks => storeGeneratTasks(tasks))
parentPort.postMessage({ status: 'completed' });
});