shuzhiren-comfyui/message-dispatcher/src/websocket-server/index.js

384 lines
11 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { WebSocketServer as WSServer } from 'ws';
import logger from '../logger/index.js';
import bridgeManager from '../bridge-manager/index.js';
import taskScheduler from '../task-scheduler/index.js';
import mdWebSocketClient from '../md-websocket-client/index.js';
import { v4 as uuidv4 } from 'uuid';
class WebSocketServer {
constructor() {
this.wss = null;
this.pendingRequests = new Map();
this.instanceTaskMap = new Map();
this.TASK_TIMEOUT = 5 * 60 * 1000;
}
start(server) {
this.wss = new WSServer({
server,
keepalive: true
});
logger.info('WebSocket 服务器已启动');
this.wss.on('connection', (ws) => {
this.handleConnection(ws);
});
}
handleConnection(ws) {
let bridgeId = null;
let pingInterval = null;
let pongTimeout = null;
logger.info('新的 WebSocket 连接已建立');
const PING_INTERVAL = 30000;
const PONG_TIMEOUT = 10000;
const sendPing = () => {
if (ws.readyState !== WSServer.OPEN) {
clearInterval(pingInterval);
return;
}
ws.ping();
pongTimeout = setTimeout(() => {
logger.warn('PONG 响应超时,关闭连接');
ws.terminate();
}, PONG_TIMEOUT);
};
pingInterval = setInterval(sendPing, PING_INTERVAL);
ws.on('message', (data) => {
this.handleMessage(ws, data, (id) => { bridgeId = id; });
});
ws.on('pong', () => {
if (pongTimeout) {
clearTimeout(pongTimeout);
pongTimeout = null;
}
});
ws.on('close', (code, reason) => {
clearInterval(pingInterval);
if (pongTimeout) {
clearTimeout(pongTimeout);
}
if (bridgeId) {
this.handleBridgeDisconnect(bridgeId);
bridgeManager.unregisterBridge(bridgeId);
this.cleanupPendingRequests(bridgeId);
}
logger.info(`WebSocket 连接已关闭 (code: ${code})`);
});
ws.on('error', (error) => {
clearInterval(pingInterval);
if (pongTimeout) {
clearTimeout(pongTimeout);
}
logger.error('WebSocket 连接错误:', error);
});
}
async handleBridgeDisconnect(bridgeId) {
const bridge = bridgeManager.getBridge(bridgeId);
if (bridge?.info?.instances) {
for (const instance of bridge.info.instances) {
const affectedTaskId = bridgeManager.handleInstanceOffline(instance.id);
if (affectedTaskId) {
taskScheduler.recoverTask(affectedTaskId, 'bridge_disconnect');
}
}
}
const capacity = bridgeManager.getAvailableCapacity();
await taskScheduler.setCurrentCapacity(capacity.online);
}
async handleMessage(ws, data, setBridgeId) {
try {
const message = JSON.parse(data.toString());
logger.debug(`收到消息:${message.type}`);
switch (message.type) {
case 'REGISTER':
this.handleRegister(ws, message, setBridgeId);
break;
case 'HEARTBEAT':
this.handleHeartbeat(message);
break;
case 'TASK_ACK':
this.handleTaskAck(message);
break;
case 'TASK_END':
this.handleTaskEnd(message);
break;
case 'INSTANCE_CHECK_ACK':
this.handleBridgeResponse(message);
break;
case 'INSTANCE_STATUS_UPDATE':
this.handleInstanceStatusUpdate(message);
break;
case 'PONG':
break;
default:
logger.debug('未知消息类型:', message.type);
}
} catch (error) {
logger.error('解析消息失败:', error);
}
}
async handleRegister(ws, message, setBridgeId) {
const bridgeId = message.data?.bridgeId || uuidv4();
setBridgeId(bridgeId);
bridgeManager.registerBridge(bridgeId, ws, message.data);
const capacity = bridgeManager.getAvailableCapacity();
await taskScheduler.setCurrentCapacity(capacity.online);
const response = {
type: 'REGISTER_ACK',
data: {
bridgeId,
timestamp: new Date().toISOString()
}
};
ws.send(JSON.stringify(response));
}
handleHeartbeat(message) {
const bridgeId = message.data?.bridgeId;
if (bridgeId) {
bridgeManager.updateHeartbeat(bridgeId);
}
}
handleTaskAck(message) {
const requestId = message.data?.requestId;
const instanceId = message.data?.instanceId;
const bridgeId = message.data?.bridgeId;
if (requestId) {
taskScheduler.handleTaskAck(requestId, instanceId, bridgeId);
if (instanceId) {
bridgeManager.confirmInstanceLock(instanceId);
}
if (this.pendingRequests.has(requestId)) {
const pending = this.pendingRequests.get(requestId);
pending.ackReceived = true;
pending.ackAt = new Date().toISOString();
}
this.instanceTaskMap.set(instanceId, requestId);
}
this.handleBridgeResponse(message);
}
async handleTaskEnd(message) {
console.log(`[WebSocketServer] 收到 TASK_END 消息:`, JSON.stringify(message.data, null, 2));
const requestId = message.data?.requestId;
const instanceId = message.data?.instanceId;
const result = message.data?.result;
const error = message.data?.error;
console.log(`[WebSocketServer] 解析参数requestId=${requestId}, instanceId=${instanceId}`);
if (instanceId) {
const existsInMap = this.instanceTaskMap.has(instanceId);
console.log(`[WebSocketServer] 实例 ${instanceId} 在 instanceTaskMap 中:${existsInMap}`);
this.instanceTaskMap.delete(instanceId);
const released = bridgeManager.releaseInstanceLock(instanceId);
console.log(`[WebSocketServer] 实例锁释放 ${released ? '成功' : '失败'}: ${instanceId}`);
if (!released) {
console.warn(`[WebSocketServer] 实例锁释放失败可能原因1) 锁已超时自动释放 2) 重复收到 TASK_END 消息`);
}
} else {
console.warn(`[WebSocketServer] TASK_END 消息中缺少 instanceId无法释放锁`);
}
if (requestId) {
if (error) {
await taskScheduler.handleTaskFailure(requestId, error);
} else {
await taskScheduler.handleTaskComplete(requestId, result);
}
}
this.handleBridgeResponse(message);
}
async handleInstanceStatusUpdate(message) {
const bridgeId = message.data?.bridgeId;
const instances = message.data?.instances;
if (instances && Array.isArray(instances)) {
for (const instance of instances) {
if (instance.status === 'offline') {
const affectedTaskId = bridgeManager.handleInstanceOffline(instance.id);
if (affectedTaskId) {
taskScheduler.recoverTask(affectedTaskId, 'instance_offline');
}
}
}
const capacity = bridgeManager.getAvailableCapacity();
await taskScheduler.setCurrentCapacity(capacity.online);
}
}
handleBridgeResponse(message) {
const requestId = message.data?.requestId;
if (requestId && this.pendingRequests.has(requestId)) {
const pending = this.pendingRequests.get(requestId);
pending.resolve(message);
this.pendingRequests.delete(requestId);
}
}
sendTaskToBridge(bridgeId, taskData, requestId) {
return new Promise((resolve, reject) => {
const message = {
type: 'TASK_ASSIGN',
data: {
...taskData,
requestId
}
};
console.log('[分发] WebSocketServer 准备发送消息:', JSON.stringify(message, null, 2));
const success = bridgeManager.sendToBridge(bridgeId, message);
if (!success) {
bridgeManager.releaseInstanceLock(taskData.instanceId);
reject(new Error('发送任务失败'));
return;
}
const timeout = setTimeout(() => {
if (this.pendingRequests.has(requestId)) {
this.pendingRequests.delete(requestId);
bridgeManager.releaseInstanceLock(taskData.instanceId);
reject(new Error('任务执行超时'));
}
}, this.TASK_TIMEOUT);
this.pendingRequests.set(requestId, {
resolve,
reject,
timeout,
bridgeId,
instanceId: taskData.instanceId,
sentAt: new Date().toISOString(),
ackReceived: false
});
});
}
sendTaskToInstance(bridgeId, instanceId, taskData, requestId) {
return new Promise((resolve, reject) => {
const message = {
type: 'TASK_ASSIGN',
data: {
...taskData,
requestId,
instanceId
}
};
// console.log(`[分发] WebSocketServer 发送任务到实例bridgeId=${bridgeId}, instanceId=${instanceId}, requestId=${requestId}`);
const success = bridgeManager.sendToBridge(bridgeId, message);
if (!success) {
bridgeManager.releaseInstanceLock(instanceId);
reject(new Error('发送任务到实例失败'));
return;
}
const timeout = setTimeout(() => {
if (this.pendingRequests.has(requestId)) {
this.pendingRequests.delete(requestId);
bridgeManager.releaseInstanceLock(instanceId);
taskScheduler.handleTaskFailure(requestId, '任务发送超时');
reject(new Error('任务执行超时'));
}
}, this.TASK_TIMEOUT);
this.pendingRequests.set(requestId, {
resolve,
reject,
timeout,
bridgeId,
instanceId,
sentAt: new Date().toISOString(),
ackReceived: false
});
this.instanceTaskMap.set(instanceId, requestId);
});
}
sendInstanceCheckToBridge(bridgeId, checkType, instanceId, requestId) {
return new Promise((resolve, reject) => {
const message = {
type: 'INSTANCE_CHECK',
data: {
checkType,
instanceId,
requestId
}
};
console.log('[分发] WebSocketServer 准备发送实例检查消息:', JSON.stringify(message, null, 2));
const success = bridgeManager.sendToBridge(bridgeId, message);
if (!success) {
reject(new Error('发送实例检查请求失败'));
return;
}
const timeout = setTimeout(() => {
if (this.pendingRequests.has(requestId)) {
this.pendingRequests.delete(requestId);
reject(new Error('实例检查超时'));
}
}, 30000);
this.pendingRequests.set(requestId, { resolve, reject, timeout, bridgeId });
});
}
cleanupPendingRequests(bridgeId) {
for (const [requestId, pending] of this.pendingRequests) {
if (pending.bridgeId === bridgeId) {
clearTimeout(pending.timeout);
if (pending.instanceId) {
bridgeManager.releaseInstanceLock(pending.instanceId);
this.instanceTaskMap.delete(pending.instanceId);
taskScheduler.recoverTask(requestId, 'bridge_disconnect');
}
pending.reject(new Error('桥接器连接已断开'));
this.pendingRequests.delete(requestId);
}
}
}
getStats() {
return {
pendingRequests: this.pendingRequests.size,
instanceTaskMap: this.instanceTaskMap.size,
pendingRequestIds: Array.from(this.pendingRequests.keys())
};
}
}
export default new WebSocketServer();