364 lines
11 KiB
JavaScript
364 lines
11 KiB
JavaScript
import { parentPort, Worker } from 'worker_threads';
|
||
import redis from '../../redis/index.js';
|
||
import initQueue from '../../redis/initQueue.js';
|
||
|
||
const REDIS_KEYS = {
|
||
CAPACITY: `${process.env.PROJECT_PREFIX}:md:capacity`,
|
||
JWT: `${process.env.PROJECT_PREFIX}:md:jwt`
|
||
};
|
||
|
||
const logger = {
|
||
info: (message) => {
|
||
const timestamp = new Date().toISOString();
|
||
console.log(`[${timestamp}] INFO: ${message}`);
|
||
},
|
||
error: (message, error) => {
|
||
const timestamp = new Date().toISOString();
|
||
console.error(`[${timestamp}] ERROR: ${message}`, error || '');
|
||
},
|
||
warn: (message) => {
|
||
const timestamp = new Date().toISOString();
|
||
console.warn(`[${timestamp}] WARN: ${message}`);
|
||
}
|
||
};
|
||
|
||
let lastCapacityState = { capacity: null, hasJwt: null };
|
||
|
||
async function getInternalCapacityFromRedis() {
|
||
try {
|
||
const capacity = await redis.get(REDIS_KEYS.CAPACITY);
|
||
return capacity ? parseInt(capacity, 10) : 0;
|
||
} catch (error) {
|
||
logger.error('从 Redis 读取算力信息失败:', error);
|
||
return 0;
|
||
}
|
||
}
|
||
|
||
async function getJwtTokenFromRedis() {
|
||
try {
|
||
const token = await redis.get(REDIS_KEYS.JWT);
|
||
return token;
|
||
} catch (error) {
|
||
logger.error('从 Redis 读取 JWT Token 失败:', error);
|
||
return null;
|
||
}
|
||
}
|
||
|
||
class DispatchStateManager {
|
||
constructor() {
|
||
this.internalCapacity = 0;
|
||
this.assignedToInternal = 0;
|
||
this.hasJwtToken = false;
|
||
}
|
||
|
||
async init() {
|
||
this.internalCapacity = await getInternalCapacityFromRedis();
|
||
const jwtToken = await getJwtTokenFromRedis();
|
||
this.hasJwtToken = !!jwtToken;
|
||
this.assignedToInternal = 0;
|
||
}
|
||
|
||
async getDispatchType(platformName) {
|
||
if (platformName === 'comfyui') {
|
||
const internalCapacity = await getInternalCapacityFromRedis();
|
||
const jwtToken = await getJwtTokenFromRedis();
|
||
const hasJwtToken = !!jwtToken;
|
||
|
||
if (hasJwtToken && this.assignedToInternal < internalCapacity) {
|
||
this.assignedToInternal++;
|
||
return 'messageDispatcher';
|
||
}
|
||
|
||
if (!hasJwtToken) {
|
||
return 'error_no_jwt';
|
||
}
|
||
|
||
return 'error_no_capacity';
|
||
}
|
||
|
||
if (platformName === 'runninghub' || platformName === 'coze') {
|
||
return 'external';
|
||
}
|
||
|
||
return null;
|
||
}
|
||
}
|
||
|
||
const dispatchStateManager = new DispatchStateManager();
|
||
const generateWorker = new Worker(new URL('./GenerateWorkerManager.js', import.meta.url));
|
||
|
||
async function julgConcurrency() {
|
||
try {
|
||
const platforms = await initQueue.getPlatforms();
|
||
const wDeficiency = [];
|
||
|
||
const internalCapacity = await getInternalCapacityFromRedis();
|
||
const jwtToken = await getJwtTokenFromRedis();
|
||
const hasJwtToken = !!jwtToken;
|
||
const hasInternalCapacity = internalCapacity > 0 && hasJwtToken;
|
||
|
||
const currentState = { capacity: internalCapacity, hasJwt: hasJwtToken };
|
||
if (currentState.capacity !== lastCapacityState.capacity ||
|
||
currentState.hasJwt !== lastCapacityState.hasJwt) {
|
||
logger.info(`[waiting] 内部算力状态变更: 容量=${internalCapacity}, JWT=${jwtToken ? '存在' : '不存在'}`);
|
||
lastCapacityState = currentState;
|
||
}
|
||
|
||
for(const [aigcPfName, info] of Object.entries(platforms)) {
|
||
try {
|
||
const actualQueueLength = await redis.lLen(info.waitQueue);
|
||
|
||
if (info.platformName === 'comfyui') {
|
||
if (!hasJwtToken) {
|
||
if (actualQueueLength > 0) {
|
||
const count = Math.min(50, actualQueueLength);
|
||
logger.warn(`[waiting] messageDispatcher 未连接,限制取出 ${count} 个任务丢入 error`);
|
||
wDeficiency.push({ aigcPfName, info, count });
|
||
}
|
||
} else if (hasInternalCapacity) {
|
||
let totalCapacity = info.MAX_CONCURRENT + internalCapacity;
|
||
|
||
if (info.PQtasks < totalCapacity && actualQueueLength > 0) {
|
||
let count = totalCapacity - info.PQtasks;
|
||
|
||
if(count > actualQueueLength) {
|
||
count = actualQueueLength;
|
||
}
|
||
wDeficiency.push({ aigcPfName, info, count });
|
||
}
|
||
}
|
||
} else {
|
||
if (info.PQtasks < info.MAX_CONCURRENT && actualQueueLength > 0) {
|
||
let count = info.MAX_CONCURRENT - info.PQtasks;
|
||
|
||
if(count > actualQueueLength) {
|
||
count = actualQueueLength;
|
||
}
|
||
wDeficiency.push({ aigcPfName, info, count });
|
||
}
|
||
}
|
||
} catch (error) {
|
||
logger.error(`检查平台 ${aigcPfName} 队列长度失败:`, error);
|
||
}
|
||
}
|
||
|
||
return wDeficiency;
|
||
} catch (error) {
|
||
logger.error('判断并发数失败:', error);
|
||
return [];
|
||
}
|
||
}
|
||
|
||
async function getBatchWaitTasksID(platforms) {
|
||
try {
|
||
const multi = redis.multi();
|
||
for(const platform of platforms) {
|
||
multi.lRange(platform.info.waitQueue, 0, platform.count - 1);
|
||
}
|
||
|
||
const results = await multi.exec();
|
||
|
||
for(let i = 0; i < results.length; i++) {
|
||
const taskIDs = results[i] || [];
|
||
const platform = platforms[i];
|
||
platform.waitTaskID = taskIDs;
|
||
}
|
||
|
||
return platforms;
|
||
} catch (error) {
|
||
logger.error('批量获取等待队列任务ID失败:', error);
|
||
return platforms;
|
||
}
|
||
}
|
||
|
||
async function getBatchWaitTasks(aigcPfTasks) {
|
||
const tasksData = [];
|
||
|
||
try {
|
||
await dispatchStateManager.init();
|
||
|
||
const allTaskIds = [];
|
||
const taskIdMap = new Map();
|
||
|
||
for(const aigcPfTask of aigcPfTasks) {
|
||
for(const taskId of aigcPfTask.waitTaskID) {
|
||
if (taskId) {
|
||
allTaskIds.push(taskId);
|
||
taskIdMap.set(taskId, {
|
||
platformName: aigcPfTask.info.platformName,
|
||
aigc: aigcPfTask.info.AIGC,
|
||
aigcPfName: aigcPfTask.aigcPfName
|
||
});
|
||
}
|
||
}
|
||
}
|
||
|
||
if (allTaskIds.length === 0) {
|
||
return tasksData;
|
||
}
|
||
|
||
const multi = redis.multi();
|
||
for(const taskId of allTaskIds) {
|
||
multi.hGetAll(`${initQueue.prefix}:task:${taskId}`);
|
||
}
|
||
|
||
const results = await multi.exec();
|
||
|
||
for(let i = 0; i < results.length; i++) {
|
||
const taskInfo = results[i];
|
||
const taskId = allTaskIds[i];
|
||
const platformInfo = taskIdMap.get(taskId);
|
||
|
||
if (taskInfo) {
|
||
try {
|
||
const dispatchType = await dispatchStateManager.getDispatchType(platformInfo.platformName);
|
||
|
||
if (dispatchType === 'error_no_jwt') {
|
||
logger.warn(`[waiting] messageDispatcher 未连接,任务 ${taskId} 标记为待处理`);
|
||
tasksData.push({
|
||
backendId: taskInfo.backendId,
|
||
taskId: taskInfo.taskId,
|
||
platformName: platformInfo.platformName,
|
||
aigc: platformInfo.aigc,
|
||
aigcPfName: platformInfo.aigcPfName,
|
||
taskData: taskInfo.payload,
|
||
workflowId: taskInfo.workflowId || '',
|
||
dispatchType: dispatchType,
|
||
errorType: 'messageDispatcher 未连接'
|
||
});
|
||
} else {
|
||
tasksData.push({
|
||
backendId: taskInfo.backendId,
|
||
taskId: taskInfo.taskId,
|
||
platformName: platformInfo.platformName,
|
||
aigc: platformInfo.aigc,
|
||
aigcPfName: platformInfo.aigcPfName,
|
||
taskData: taskInfo.payload,
|
||
workflowId: taskInfo.workflowId || '',
|
||
dispatchType: dispatchType,
|
||
});
|
||
}
|
||
} catch (error) {
|
||
logger.error(`解析任务${taskId}数据失败:`, error);
|
||
}
|
||
} else {
|
||
logger.error(`任务 ${taskId} 数据不存在`);
|
||
}
|
||
}
|
||
|
||
return tasksData;
|
||
} catch (error) {
|
||
logger.error('批量获取任务数据失败:', error);
|
||
return tasksData;
|
||
}
|
||
}
|
||
|
||
async function updateTaskCounts(wDeficiency) {
|
||
try {
|
||
const taskCountMap = new Map();
|
||
const multi = redis.multi();
|
||
|
||
for(const aigcPfTask of wDeficiency) {
|
||
const { waitTaskID, info } = aigcPfTask;
|
||
const key = aigcPfTask.aigcPfName;
|
||
const count = waitTaskID.length;
|
||
|
||
if (count > 0) {
|
||
multi.lTrim(info.waitQueue, count, -1);
|
||
|
||
if(taskCountMap.has(key)) {
|
||
taskCountMap.set(key, taskCountMap.get(key) + count);
|
||
} else {
|
||
taskCountMap.set(key, count);
|
||
}
|
||
}
|
||
}
|
||
|
||
await multi.exec();
|
||
|
||
if (taskCountMap.size > 0) {
|
||
await initQueue.reducePlatformsWait(taskCountMap);
|
||
await initQueue.addPlatformsProcess(taskCountMap);
|
||
}
|
||
} catch (error) {
|
||
logger.error('更新任务计数失败:', error);
|
||
}
|
||
}
|
||
|
||
(async () => {
|
||
while(true) {
|
||
try {
|
||
const wDeficiency = await julgConcurrency();
|
||
|
||
if(wDeficiency.length > 0) {
|
||
const tasksWithIds = await getBatchWaitTasksID(wDeficiency);
|
||
const tasksData = await getBatchWaitTasks(tasksWithIds);
|
||
|
||
const tasksToProcess = [];
|
||
const errorTasks = [];
|
||
|
||
for (const task of tasksData) {
|
||
if (task.dispatchType === 'error_no_jwt') {
|
||
errorTasks.push(task);
|
||
} else if (task.dispatchType === 'messageDispatcher' || task.dispatchType === 'external') {
|
||
tasksToProcess.push(task);
|
||
}
|
||
}
|
||
|
||
if (errorTasks.length > 0) {
|
||
const errorTasksToProcess = errorTasks.slice(0, 50);
|
||
const remainingTasks = errorTasks.slice(50);
|
||
|
||
if (errorTasksToProcess.length > 0) {
|
||
logger.warn(`[waiting] messageDispatcher 未连接,将 ${errorTasksToProcess.length} 个任务直接丢入 error 队列`);
|
||
|
||
for (const task of errorTasksToProcess) {
|
||
tasksToProcess.push({
|
||
...task,
|
||
dispatchType: 'messageDispatcher'
|
||
});
|
||
}
|
||
}
|
||
|
||
if (remainingTasks.length > 0) {
|
||
const multi = redis.multi();
|
||
for (const task of remainingTasks) {
|
||
multi.rPush(task.aigcPfName.replace(/:.*/, `:${task.platformName}:wait`), task.taskId);
|
||
}
|
||
await multi.exec();
|
||
logger.warn(`[waiting] 超过50个任务限制,${remainingTasks.length} 个任务重新放回等待队列`);
|
||
}
|
||
}
|
||
|
||
if (tasksToProcess.length === 0) {
|
||
await updateTaskCounts(tasksWithIds);
|
||
await new Promise(resolve => setTimeout(resolve, 15000));
|
||
continue;
|
||
}
|
||
|
||
await updateTaskCounts(tasksWithIds);
|
||
|
||
if (tasksToProcess.length > 0) {
|
||
logger.info(`[waiting] 开始处理 ${tasksToProcess.length} 个任务`);
|
||
generateWorker.postMessage(tasksToProcess);
|
||
}
|
||
} else {
|
||
await new Promise(resolve => setTimeout(resolve, 15000));
|
||
}
|
||
|
||
} catch (error) {
|
||
logger.error('批量处理任务失败:', error);
|
||
await new Promise(resolve => setTimeout(resolve, 5000));
|
||
}
|
||
}
|
||
})();
|
||
|
||
generateWorker.on('message', (message) => {
|
||
if (message.status === 'completed') {
|
||
logger.info('[waiting] 任务批次处理完成');
|
||
}
|
||
});
|
||
|
||
generateWorker.on('error', (error) => {
|
||
logger.error('生成 Worker 错误:', error);
|
||
}); |