219 lines
8.5 KiB
JavaScript
219 lines
8.5 KiB
JavaScript
import { parentPort } from 'worker_threads';
|
||
import redis from '../../redis/index.js';
|
||
import initQueue from '../../redis/initQueue.js';
|
||
import { externalPostRequest } from '../../outside/generat.js';
|
||
import { platformData } from '../../config/Config.js';
|
||
|
||
// 批量转发待处理任务到各外部平台
|
||
async function generatTask(tasksData) {
|
||
// console.log('开始转发任务');
|
||
const generatTasks = []
|
||
for (const task of tasksData) {
|
||
// 2. 获取任务所属平台的生成接口地址
|
||
const generatTaskPromise = externalPostRequest(task) // { aigc, tasksData }
|
||
generatTasks.push(generatTaskPromise)
|
||
}
|
||
|
||
try {
|
||
const responseTasks = await Promise.all(generatTasks)
|
||
return responseTasks
|
||
} catch (error) {
|
||
console.error('Error:', error);
|
||
return []; // 确保总是返回数组
|
||
}
|
||
}
|
||
|
||
// 批量储存外部平台返回的任务数据到处理队列
|
||
async function storeGeneratTasks(tasks) {
|
||
// 确保tasks是数组
|
||
if (!tasks || !Array.isArray(tasks)) {
|
||
console.error('storeGeneratTasks函数接收到无效的tasks参数:', tasks);
|
||
return;
|
||
}
|
||
|
||
const multi = redis.multi();
|
||
let errorCount = 0;
|
||
const taskErrorCountMap = new Map();
|
||
const taskCountMap = new Map();
|
||
for (const task of tasks) {
|
||
// console.log('\n***************',task)
|
||
//错误任务
|
||
if(task.remoteTaskId?.type === 2){
|
||
console.log('储存在错误队列', task);
|
||
const aigc = task.AIGC || task.aigc;
|
||
const platform = task.platform || task.platformName;
|
||
|
||
// 存储错误信息到任务数据中
|
||
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'resultData', JSON.stringify(task.remoteTaskId.message));
|
||
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'status', 'failed');
|
||
|
||
// 推送任务ID到错误列表
|
||
multi.lPush(initQueue.errorList, task.taskId);
|
||
|
||
errorCount++;
|
||
// 存储相关平台信息
|
||
const key = `${aigc}:${platform}`;
|
||
if(taskErrorCountMap.has(key)){
|
||
taskErrorCountMap.set(key, taskErrorCountMap.get(key) + 1);
|
||
} else {
|
||
taskErrorCountMap.set(key, 1);
|
||
}
|
||
continue // 跳过错误任务
|
||
}
|
||
|
||
// 处理成功的任务
|
||
let externalTaskId;
|
||
if (task.remoteTaskId?.type === 1 && task.remoteTaskId?.data) {
|
||
// 使用解析后的响应数据提取外部平台任务ID
|
||
try {
|
||
const responseData = task.remoteTaskId.data;
|
||
// console.log('处理成功任务,响应数据:', responseData);
|
||
|
||
// 直接处理响应数据,提取任务ID
|
||
const platform = task.platform || task.platformName;
|
||
if ((responseData.msg === 'success' || platform === 'coze') && responseData.code === 0) {
|
||
// Coze平台返回的是execute_id,其他平台返回的是data.taskId
|
||
if (platform === 'coze') {
|
||
externalTaskId = responseData.execute_id;
|
||
} else {
|
||
externalTaskId = responseData.data?.taskId;
|
||
}
|
||
|
||
if (externalTaskId) {
|
||
console.log('成功提取外部平台任务ID:', externalTaskId);
|
||
} else {
|
||
console.error('无法从响应中提取外部平台任务ID:', responseData);
|
||
// 视为错误任务
|
||
const errorMessage = JSON.stringify({ message: '无法从响应中提取外部平台任务ID', response: responseData });
|
||
|
||
// 存储错误信息到任务数据中
|
||
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'resultData', errorMessage);
|
||
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'status', 'failed');
|
||
|
||
// 推送任务ID到错误列表
|
||
multi.lPush(initQueue.errorList, task.taskId);
|
||
|
||
errorCount++;
|
||
// 存储相关平台信息
|
||
const key = `${task.AIGC}:${task.platform}`;
|
||
if(taskErrorCountMap.has(key)){
|
||
taskErrorCountMap.set(key, taskErrorCountMap.get(key) + 1);
|
||
} else {
|
||
taskErrorCountMap.set(key, 1);
|
||
}
|
||
continue; // 跳过错误任务
|
||
}
|
||
} else {
|
||
console.error('外部平台返回错误:', responseData);
|
||
// 视为错误任务
|
||
const aigc = task.AIGC || task.aigc;
|
||
const platform = task.platform || task.platformName;
|
||
const errorMessage = JSON.stringify(responseData);
|
||
|
||
// 存储错误信息到任务数据中
|
||
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'resultData', errorMessage);
|
||
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'status', 'failed');
|
||
|
||
// 推送任务ID到错误列表
|
||
multi.lPush(initQueue.errorList, task.taskId);
|
||
|
||
errorCount++;
|
||
// 存储相关平台信息
|
||
const key = `${aigc}:${platform}`;
|
||
if(taskErrorCountMap.has(key)){
|
||
taskErrorCountMap.set(key, taskErrorCountMap.get(key) + 1);
|
||
} else {
|
||
taskErrorCountMap.set(key, 1);
|
||
}
|
||
continue; // 跳过错误任务
|
||
}
|
||
} catch (extractError) {
|
||
console.error('提取外部平台任务ID失败:', extractError);
|
||
// 视为错误任务
|
||
const aigc = task.AIGC || task.aigc;
|
||
const platform = task.platform || task.platformName;
|
||
const errorMessage = JSON.stringify({ message: '提取外部平台任务ID失败', error: extractError.message });
|
||
|
||
// 存储错误信息到任务数据中
|
||
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'resultData', errorMessage);
|
||
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'status', 'failed');
|
||
|
||
// 推送任务ID到错误列表
|
||
multi.lPush(initQueue.errorList, task.taskId);
|
||
|
||
errorCount++;
|
||
// 存储相关平台信息
|
||
const key = `${aigc}:${platform}`;
|
||
if(taskErrorCountMap.has(key)){
|
||
taskErrorCountMap.set(key, taskErrorCountMap.get(key) + 1);
|
||
} else {
|
||
taskErrorCountMap.set(key, 1);
|
||
}
|
||
continue; // 跳过错误任务
|
||
}
|
||
} else {
|
||
// 直接使用remoteTaskId作为外部平台任务ID
|
||
externalTaskId = task.remoteTaskId;
|
||
}
|
||
|
||
//回调任务
|
||
const aigc = task.AIGC || task.aigc;
|
||
const platform = task.platform || task.platformName;
|
||
if(platformData.callback.includes(platform)) {
|
||
console.log('储存在回调队列', externalTaskId, task.taskId);
|
||
multi.set(`${initQueue.callback}:${externalTaskId}`, task.taskId)
|
||
} else { // 轮询任务
|
||
// 按平台+AIGC类型存储轮询任务
|
||
const pollingKey = `${initQueue.prefix}:processPolling:${aigc}:${platform}`;
|
||
// 从task中提取workflow_id,优先使用task.workflowId
|
||
let workflowId = task.workflowId || '';
|
||
try {
|
||
if (!workflowId && task.taskData) {
|
||
// taskData 已经是字符串,直接解析
|
||
const taskDataParsed = JSON.parse(task.taskData);
|
||
workflowId = taskDataParsed.workflow_id || '';
|
||
}
|
||
} catch (e) {
|
||
console.error('[generatTask] 解析taskData获取workflow_id失败:', e);
|
||
}
|
||
console.log(`[generatTask] 提取到的workflowId: ${workflowId}`);
|
||
// 确保workflowId被传递到轮询任务中
|
||
const pollingData = {
|
||
taskId: task.taskId,
|
||
platform: platform,
|
||
AIGC: aigc,
|
||
workflowId: workflowId // 包含workflowId,为空则使用空字符串
|
||
};
|
||
console.log(`[generatTask] 添加轮询任务: pollingKey=${pollingKey}, externalTaskId=${externalTaskId}, pollingData=${JSON.stringify(pollingData)}`);
|
||
multi.hSet(pollingKey, externalTaskId, JSON.stringify(pollingData))
|
||
}
|
||
|
||
// 更新任务信息,添加 remoteTaskId 字段
|
||
multi.hSet(`${initQueue.prefix}:task:${task.taskId}`, 'remoteTaskId', externalTaskId);
|
||
// 确保任务有2小时的过期时间
|
||
multi.expire(`${initQueue.prefix}:task:${task.taskId}`, 7200);
|
||
|
||
// 记录相关队列处理的任务数
|
||
const key = `${aigc}:${platform}`;
|
||
if(taskCountMap.has(key)){
|
||
taskCountMap.set(key, taskCountMap.get(key) + 1);
|
||
} else {
|
||
taskCountMap.set(key, 1);
|
||
}
|
||
}
|
||
|
||
// 更新平台信息
|
||
if(errorCount > 0){
|
||
initQueue.addEQtaskALL(errorCount) // 添加错误队列任务数量
|
||
}
|
||
// 注意:这里不再调用addPlatformsProcess,因为PQtasks计数已经在updateTaskCounts函数中处理过了
|
||
// 避免同一个任务被两次增加PQtasks计数
|
||
await multi.exec();
|
||
}
|
||
|
||
parentPort.on('message', async (tasksData) => {
|
||
await generatTask(tasksData)
|
||
.then (tasks => storeGeneratTasks(tasks))
|
||
parentPort.postMessage({ status: 'completed' });
|
||
});
|