shuzhiren-comfyui/任务队列后端/worker_threads/wait/waiting.js

364 lines
11 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { parentPort, Worker } from 'worker_threads';
import redis from '../../redis/index.js';
import initQueue from '../../redis/initQueue.js';
const REDIS_KEYS = {
CAPACITY: `${process.env.PROJECT_PREFIX}:md:capacity`,
JWT: `${process.env.PROJECT_PREFIX}:md:jwt`
};
const logger = {
info: (message) => {
const timestamp = new Date().toISOString();
console.log(`[${timestamp}] INFO: ${message}`);
},
error: (message, error) => {
const timestamp = new Date().toISOString();
console.error(`[${timestamp}] ERROR: ${message}`, error || '');
},
warn: (message) => {
const timestamp = new Date().toISOString();
console.warn(`[${timestamp}] WARN: ${message}`);
}
};
let lastCapacityState = { capacity: null, hasJwt: null };
async function getInternalCapacityFromRedis() {
try {
const capacity = await redis.get(REDIS_KEYS.CAPACITY);
return capacity ? parseInt(capacity, 10) : 0;
} catch (error) {
logger.error('从 Redis 读取算力信息失败:', error);
return 0;
}
}
async function getJwtTokenFromRedis() {
try {
const token = await redis.get(REDIS_KEYS.JWT);
return token;
} catch (error) {
logger.error('从 Redis 读取 JWT Token 失败:', error);
return null;
}
}
class DispatchStateManager {
constructor() {
this.internalCapacity = 0;
this.assignedToInternal = 0;
this.hasJwtToken = false;
}
async init() {
this.internalCapacity = await getInternalCapacityFromRedis();
const jwtToken = await getJwtTokenFromRedis();
this.hasJwtToken = !!jwtToken;
this.assignedToInternal = 0;
}
async getDispatchType(platformName) {
if (platformName === 'comfyui') {
const internalCapacity = await getInternalCapacityFromRedis();
const jwtToken = await getJwtTokenFromRedis();
const hasJwtToken = !!jwtToken;
if (hasJwtToken && this.assignedToInternal < internalCapacity) {
this.assignedToInternal++;
return 'messageDispatcher';
}
if (!hasJwtToken) {
return 'error_no_jwt';
}
return 'error_no_capacity';
}
if (platformName === 'runninghub' || platformName === 'coze') {
return 'external';
}
return null;
}
}
const dispatchStateManager = new DispatchStateManager();
const generateWorker = new Worker(new URL('./GenerateWorkerManager.js', import.meta.url));
async function julgConcurrency() {
try {
const platforms = await initQueue.getPlatforms();
const wDeficiency = [];
const internalCapacity = await getInternalCapacityFromRedis();
const jwtToken = await getJwtTokenFromRedis();
const hasJwtToken = !!jwtToken;
const hasInternalCapacity = internalCapacity > 0 && hasJwtToken;
const currentState = { capacity: internalCapacity, hasJwt: hasJwtToken };
if (currentState.capacity !== lastCapacityState.capacity ||
currentState.hasJwt !== lastCapacityState.hasJwt) {
logger.info(`[waiting] 内部算力状态变更: 容量=${internalCapacity}, JWT=${jwtToken ? '存在' : '不存在'}`);
lastCapacityState = currentState;
}
for(const [aigcPfName, info] of Object.entries(platforms)) {
try {
const actualQueueLength = await redis.lLen(info.waitQueue);
if (info.platformName === 'comfyui') {
if (!hasJwtToken) {
if (actualQueueLength > 0) {
const count = Math.min(50, actualQueueLength);
logger.warn(`[waiting] messageDispatcher 未连接,限制取出 ${count} 个任务丢入 error`);
wDeficiency.push({ aigcPfName, info, count });
}
} else if (hasInternalCapacity) {
let totalCapacity = info.MAX_CONCURRENT + internalCapacity;
if (info.PQtasks < totalCapacity && actualQueueLength > 0) {
let count = totalCapacity - info.PQtasks;
if(count > actualQueueLength) {
count = actualQueueLength;
}
wDeficiency.push({ aigcPfName, info, count });
}
}
} else {
if (info.PQtasks < info.MAX_CONCURRENT && actualQueueLength > 0) {
let count = info.MAX_CONCURRENT - info.PQtasks;
if(count > actualQueueLength) {
count = actualQueueLength;
}
wDeficiency.push({ aigcPfName, info, count });
}
}
} catch (error) {
logger.error(`检查平台 ${aigcPfName} 队列长度失败:`, error);
}
}
return wDeficiency;
} catch (error) {
logger.error('判断并发数失败:', error);
return [];
}
}
async function getBatchWaitTasksID(platforms) {
try {
const multi = redis.multi();
for(const platform of platforms) {
multi.lRange(platform.info.waitQueue, 0, platform.count - 1);
}
const results = await multi.exec();
for(let i = 0; i < results.length; i++) {
const taskIDs = results[i] || [];
const platform = platforms[i];
platform.waitTaskID = taskIDs;
}
return platforms;
} catch (error) {
logger.error('批量获取等待队列任务ID失败:', error);
return platforms;
}
}
async function getBatchWaitTasks(aigcPfTasks) {
const tasksData = [];
try {
await dispatchStateManager.init();
const allTaskIds = [];
const taskIdMap = new Map();
for(const aigcPfTask of aigcPfTasks) {
for(const taskId of aigcPfTask.waitTaskID) {
if (taskId) {
allTaskIds.push(taskId);
taskIdMap.set(taskId, {
platformName: aigcPfTask.info.platformName,
aigc: aigcPfTask.info.AIGC,
aigcPfName: aigcPfTask.aigcPfName
});
}
}
}
if (allTaskIds.length === 0) {
return tasksData;
}
const multi = redis.multi();
for(const taskId of allTaskIds) {
multi.hGetAll(`${initQueue.prefix}:task:${taskId}`);
}
const results = await multi.exec();
for(let i = 0; i < results.length; i++) {
const taskInfo = results[i];
const taskId = allTaskIds[i];
const platformInfo = taskIdMap.get(taskId);
if (taskInfo) {
try {
const dispatchType = await dispatchStateManager.getDispatchType(platformInfo.platformName);
if (dispatchType === 'error_no_jwt') {
logger.warn(`[waiting] messageDispatcher 未连接,任务 ${taskId} 标记为待处理`);
tasksData.push({
backendId: taskInfo.backendId,
taskId: taskInfo.taskId,
platformName: platformInfo.platformName,
aigc: platformInfo.aigc,
aigcPfName: platformInfo.aigcPfName,
taskData: taskInfo.payload,
workflowId: taskInfo.workflowId || '',
dispatchType: dispatchType,
errorType: 'messageDispatcher 未连接'
});
} else {
tasksData.push({
backendId: taskInfo.backendId,
taskId: taskInfo.taskId,
platformName: platformInfo.platformName,
aigc: platformInfo.aigc,
aigcPfName: platformInfo.aigcPfName,
taskData: taskInfo.payload,
workflowId: taskInfo.workflowId || '',
dispatchType: dispatchType,
});
}
} catch (error) {
logger.error(`解析任务${taskId}数据失败:`, error);
}
} else {
logger.error(`任务 ${taskId} 数据不存在`);
}
}
return tasksData;
} catch (error) {
logger.error('批量获取任务数据失败:', error);
return tasksData;
}
}
async function updateTaskCounts(wDeficiency) {
try {
const taskCountMap = new Map();
const multi = redis.multi();
for(const aigcPfTask of wDeficiency) {
const { waitTaskID, info } = aigcPfTask;
const key = aigcPfTask.aigcPfName;
const count = waitTaskID.length;
if (count > 0) {
multi.lTrim(info.waitQueue, count, -1);
if(taskCountMap.has(key)) {
taskCountMap.set(key, taskCountMap.get(key) + count);
} else {
taskCountMap.set(key, count);
}
}
}
await multi.exec();
if (taskCountMap.size > 0) {
await initQueue.reducePlatformsWait(taskCountMap);
await initQueue.addPlatformsProcess(taskCountMap);
}
} catch (error) {
logger.error('更新任务计数失败:', error);
}
}
(async () => {
while(true) {
try {
const wDeficiency = await julgConcurrency();
if(wDeficiency.length > 0) {
const tasksWithIds = await getBatchWaitTasksID(wDeficiency);
const tasksData = await getBatchWaitTasks(tasksWithIds);
const tasksToProcess = [];
const errorTasks = [];
for (const task of tasksData) {
if (task.dispatchType === 'error_no_jwt') {
errorTasks.push(task);
} else if (task.dispatchType === 'messageDispatcher' || task.dispatchType === 'external') {
tasksToProcess.push(task);
}
}
if (errorTasks.length > 0) {
const errorTasksToProcess = errorTasks.slice(0, 50);
const remainingTasks = errorTasks.slice(50);
if (errorTasksToProcess.length > 0) {
logger.warn(`[waiting] messageDispatcher 未连接,将 ${errorTasksToProcess.length} 个任务直接丢入 error 队列`);
for (const task of errorTasksToProcess) {
tasksToProcess.push({
...task,
dispatchType: 'messageDispatcher'
});
}
}
if (remainingTasks.length > 0) {
const multi = redis.multi();
for (const task of remainingTasks) {
multi.rPush(task.aigcPfName.replace(/:.*/, `:${task.platformName}:wait`), task.taskId);
}
await multi.exec();
logger.warn(`[waiting] 超过50个任务限制${remainingTasks.length} 个任务重新放回等待队列`);
}
}
if (tasksToProcess.length === 0) {
await updateTaskCounts(tasksWithIds);
await new Promise(resolve => setTimeout(resolve, 15000));
continue;
}
await updateTaskCounts(tasksWithIds);
if (tasksToProcess.length > 0) {
logger.info(`[waiting] 开始处理 ${tasksToProcess.length} 个任务`);
generateWorker.postMessage(tasksToProcess);
}
} else {
await new Promise(resolve => setTimeout(resolve, 15000));
}
} catch (error) {
logger.error('批量处理任务失败:', error);
await new Promise(resolve => setTimeout(resolve, 5000));
}
}
})();
generateWorker.on('message', (message) => {
if (message.status === 'completed') {
logger.info('[waiting] 任务批次处理完成');
}
});
generateWorker.on('error', (error) => {
logger.error('生成 Worker 错误:', error);
});