336 lines
14 KiB
Python
336 lines
14 KiB
Python
"""Feishu WebSocket Long Connection Manager."""
|
|
|
|
import asyncio
|
|
import json
|
|
import threading
|
|
from typing import Any, Dict
|
|
import uuid
|
|
|
|
from loguru import logger
|
|
try:
|
|
import lark_oapi as lark
|
|
import lark_oapi.ws as ws
|
|
_HAS_LARK = True
|
|
except ImportError:
|
|
lark = None # type: ignore
|
|
ws = None # type: ignore
|
|
_HAS_LARK = False
|
|
|
|
if _HAS_LARK:
|
|
try:
|
|
import websockets as _websockets
|
|
# Keep a reference to the original connect so we can restore it if needed.
|
|
_orig_websockets_connect = _websockets.connect
|
|
_PROXY_PATCH_AVAILABLE = True
|
|
except ImportError:
|
|
_PROXY_PATCH_AVAILABLE = False
|
|
else:
|
|
_PROXY_PATCH_AVAILABLE = False
|
|
|
|
|
|
def _make_no_proxy_connect(orig_connect):
|
|
"""Return a drop-in replacement for websockets.connect that forces proxy=None.
|
|
|
|
This is intentionally NOT applied at module import time to avoid polluting
|
|
the global websockets namespace for other modules in the process. Instead
|
|
it is applied as a scoped context manager around lark-oapi's _connect() call.
|
|
"""
|
|
import contextlib
|
|
|
|
class _NoProxyConnect:
|
|
"""Wraps websockets.connect to inject proxy=None, preventing macOS
|
|
system-proxy interference with long-lived SSE / WebSocket connections."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs.setdefault("proxy", None)
|
|
self._coro = orig_connect(*args, **kwargs)
|
|
self._ws = None
|
|
|
|
def __await__(self):
|
|
return self._coro.__await__()
|
|
|
|
async def __aenter__(self):
|
|
self._ws = await self._coro
|
|
return self._ws
|
|
|
|
async def __aexit__(self, *exc):
|
|
if self._ws:
|
|
await self._ws.close()
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def _scoped_no_proxy():
|
|
"""Context manager that temporarily replaces websockets.connect for
|
|
the duration of the lark-oapi connection handshake only."""
|
|
if not _PROXY_PATCH_AVAILABLE:
|
|
yield
|
|
return
|
|
old = _websockets.connect
|
|
_websockets.connect = _NoProxyConnect
|
|
logger.debug("[Feishu WS] Scoped websockets proxy bypass: active")
|
|
try:
|
|
yield
|
|
finally:
|
|
_websockets.connect = old
|
|
logger.debug("[Feishu WS] Scoped websockets proxy bypass: restored")
|
|
|
|
return _scoped_no_proxy
|
|
|
|
from app.database import async_session
|
|
from app.models.channel_config import ChannelConfig
|
|
from sqlalchemy import select
|
|
|
|
|
|
if not _HAS_LARK:
|
|
logger.warning(
|
|
"[Feishu WS] lark-oapi package not installed. "
|
|
"Feishu WebSocket features will be disabled. "
|
|
"Install with: pip install lark-oapi"
|
|
)
|
|
|
|
|
|
class FeishuWSManager:
|
|
"""Manages Feishu WebSocket clients for all agents."""
|
|
|
|
def __init__(self):
|
|
self._clients: Dict[uuid.UUID, ws.Client] = {}
|
|
# Tasks for reconnection or ping loops if we want to cancel them later
|
|
self._tasks: Dict[uuid.UUID, asyncio.Task] = {}
|
|
|
|
def _create_event_handler(self, agent_id: uuid.UUID) -> lark.EventDispatcherHandler:
|
|
"""Create an event dispatcher for a specific agent."""
|
|
|
|
def handle_message(data: Any) -> None:
|
|
"""Handle im.message.receive_v1 events from Feishu WebSocket."""
|
|
try:
|
|
# The data object carries the raw event body
|
|
raw_body = getattr(data, "raw_body", None)
|
|
logger.info(f"[Feishu WS] Received event: {data}")
|
|
if not raw_body:
|
|
# Some SDK versions pass the dict directly
|
|
if isinstance(data, dict):
|
|
body_dict = data
|
|
else:
|
|
# Handle lark_oapi.event.custom.CustomizedEvent
|
|
body_dict = {}
|
|
if hasattr(data, "header"):
|
|
header_obj = data.header
|
|
body_dict["header"] = vars(header_obj) if hasattr(header_obj, "__dict__") else {
|
|
"event_type": getattr(header_obj, "event_type", "im.message.receive_v1"),
|
|
"event_id": getattr(header_obj, "event_id", ""),
|
|
"create_time": getattr(header_obj, "create_time", "")
|
|
}
|
|
# Ensure event_type is present as it's required downstream
|
|
if "event_type" not in body_dict["header"]:
|
|
body_dict["header"]["event_type"] = getattr(header_obj, "event_type", "im.message.receive_v1")
|
|
else:
|
|
body_dict["header"] = {"event_type": "im.message.receive_v1"}
|
|
|
|
if hasattr(data, "event"):
|
|
body_dict["event"] = data.event
|
|
elif hasattr(data, "content") and isinstance(getattr(data, "content"), str):
|
|
import json
|
|
try:
|
|
body_dict["event"] = json.loads(data.content)
|
|
except json.JSONDecodeError:
|
|
body_dict["event"] = {"content": data.content}
|
|
|
|
if not hasattr(data, "header") and not hasattr(data, "event"):
|
|
logger.warning(f"[Feishu WS] Unexpected event data type with no recognizable fields: {type(data)}")
|
|
return
|
|
else:
|
|
body_dict = json.loads(raw_body.decode("utf-8"))
|
|
|
|
loop = asyncio.get_running_loop()
|
|
loop.create_task(self._async_handle_message(agent_id, data))
|
|
except RuntimeError:
|
|
try:
|
|
# If no running loop in this thread, try to find the main event loop
|
|
# This is a heuristic and might need adjustment depending on the exact async framework setup
|
|
main_loop = [t for t in asyncio.all_tasks() if t.get_name() != "feishu-ws"][0].get_loop()
|
|
asyncio.run_coroutine_threadsafe(self._async_handle_message(agent_id, data), main_loop)
|
|
except Exception as e:
|
|
logger.exception(f"[Feishu WS] Could not dispatch event to main loop: {e}")
|
|
|
|
dispatcher = (
|
|
lark.EventDispatcherHandler.builder("", "")
|
|
.register_p2_customized_event("im.message.receive_v1", handle_message)
|
|
.build()
|
|
)
|
|
return dispatcher
|
|
|
|
async def _async_handle_message(self, agent_id: uuid.UUID, data: Dict[str, Any]) -> None:
|
|
"""Handle im.message.receive_v1 events from Feishu WebSocket asynchronously."""
|
|
try:
|
|
# The data object carries the raw event body
|
|
raw_body = getattr(data, "raw_body", None)
|
|
if not raw_body:
|
|
# Some SDK versions pass the dict directly
|
|
if isinstance(data, dict):
|
|
body_dict = data
|
|
else:
|
|
# Handle lark_oapi.event.custom.CustomizedEvent
|
|
body_dict = {}
|
|
if hasattr(data, "header"):
|
|
header_obj = data.header
|
|
body_dict["header"] = vars(header_obj) if hasattr(header_obj, "__dict__") else {
|
|
"event_type": getattr(header_obj, "event_type", "im.message.receive_v1"),
|
|
"event_id": getattr(header_obj, "event_id", ""),
|
|
"create_time": getattr(header_obj, "create_time", "")
|
|
}
|
|
if "event_type" not in body_dict["header"]:
|
|
body_dict["header"]["event_type"] = getattr(header_obj, "event_type", "im.message.receive_v1")
|
|
else:
|
|
body_dict["header"] = {"event_type": "im.message.receive_v1"}
|
|
|
|
if hasattr(data, "event"):
|
|
body_dict["event"] = data.event
|
|
elif hasattr(data, "content") and isinstance(getattr(data, "content"), str):
|
|
import json
|
|
try:
|
|
body_dict["event"] = json.loads(data.content)
|
|
except json.JSONDecodeError:
|
|
body_dict["event"] = {"content": data.content}
|
|
|
|
if not hasattr(data, "header") and not hasattr(data, "event"):
|
|
logger.warning(f"[Feishu WS] Unexpected event data type with no recognizable fields: {type(data)}")
|
|
return
|
|
else:
|
|
body_dict = json.loads(raw_body.decode("utf-8"))
|
|
|
|
event_type = body_dict.get("header", {}).get("event_type", "unknown")
|
|
logger.info(f"[Feishu WS] Event received for agent {agent_id}: {event_type}")
|
|
|
|
# Import here to avoid circular dependencies
|
|
from app.api.feishu import process_feishu_event
|
|
|
|
async with async_session() as db:
|
|
await process_feishu_event(agent_id, body_dict, db)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"[Feishu WS] Error processing event for {agent_id}: {e}")
|
|
|
|
async def start_client(
|
|
self,
|
|
agent_id: uuid.UUID,
|
|
app_id: str,
|
|
app_secret: str,
|
|
stop_existing: bool = True,
|
|
):
|
|
"""Spawns a WebSocket client fully asynchronously inside FastAPI's loop."""
|
|
if not _HAS_LARK:
|
|
logger.warning("[Feishu WS] lark-oapi not installed, cannot start client")
|
|
return
|
|
if not app_id or not app_secret:
|
|
logger.warning(f"[Feishu WS] Missing app_id or app_secret for {agent_id}, skipping")
|
|
return
|
|
|
|
logger.info(f"[Feishu WS] Starting async WS client for agent {agent_id} (App ID: {app_id})")
|
|
|
|
# Stop existing client task if any
|
|
if stop_existing and agent_id in self._tasks:
|
|
old_task = self._tasks.pop(agent_id, None)
|
|
if old_task and not old_task.done():
|
|
old_task.cancel()
|
|
logger.info(f"[Feishu WS] Cancelled old WS task for {agent_id}")
|
|
|
|
try:
|
|
event_handler = self._create_event_handler(agent_id)
|
|
except Exception as e:
|
|
logger.exception(f"[Feishu WS] Failed to create event handler for {agent_id}: {e}")
|
|
return
|
|
|
|
# Instantiate Client
|
|
client = ws.Client(
|
|
app_id,
|
|
app_secret,
|
|
event_handler=event_handler,
|
|
log_level=lark.LogLevel.INFO,
|
|
)
|
|
self._clients[agent_id] = client
|
|
|
|
# Build scoped proxy bypass: active only during _connect() to avoid
|
|
# permanently replacing websockets.connect for the whole process.
|
|
_no_proxy_ctx = (
|
|
_make_no_proxy_connect(_orig_websockets_connect)
|
|
if _PROXY_PATCH_AVAILABLE
|
|
else None
|
|
)
|
|
|
|
# Direct Async runner bypassing the faulty client.start()
|
|
async def _run_async_client():
|
|
try:
|
|
# Wrap _connect() in the scoped proxy bypass so macOS system proxy
|
|
# settings cannot interfere with the WebSocket handshake.
|
|
if _no_proxy_ctx:
|
|
async with _no_proxy_ctx():
|
|
await client._connect()
|
|
else:
|
|
await client._connect()
|
|
# Start ping loop natively after connection is established
|
|
ping_task = asyncio.create_task(client._ping_loop())
|
|
|
|
# Keep this task alive so it doesn't get canceled, and handle reconnections
|
|
while True:
|
|
await asyncio.sleep(3600) # Keep-alive
|
|
except asyncio.CancelledError:
|
|
logger.info(f"[Feishu WS] Async client task cancelled for {agent_id}")
|
|
await client._disconnect()
|
|
except Exception as e:
|
|
logger.exception(f"[Feishu WS] Async client exception for {agent_id}: {e}")
|
|
await client._disconnect()
|
|
self._clients.pop(agent_id, None)
|
|
|
|
task = asyncio.create_task(_run_async_client(), name=f"feishu-ws-async-{str(agent_id)[:8]}")
|
|
self._tasks[agent_id] = task
|
|
logger.info(f"[Feishu WS] Async WS task scheduled for agent {agent_id}")
|
|
|
|
async def stop_client(self, agent_id: uuid.UUID):
|
|
"""Stops an actively running WebSocket client for an agent."""
|
|
if agent_id in self._tasks:
|
|
task = self._tasks.pop(agent_id)
|
|
if not task.done():
|
|
task.cancel()
|
|
logger.info(f"[Feishu WS] Stopped client task for agent {agent_id}")
|
|
if agent_id in self._clients:
|
|
client = self._clients.pop(agent_id)
|
|
try:
|
|
await client._disconnect()
|
|
except Exception as e:
|
|
logger.error(f"[Feishu WS] Error disconnecting client for {agent_id}: {e}")
|
|
|
|
async def start_all(self):
|
|
"""Start WS clients for all configured Feishu agents."""
|
|
if not _HAS_LARK:
|
|
logger.info("[Feishu WS] lark-oapi not installed, skipping Feishu WS initialization")
|
|
return
|
|
logger.info("[Feishu WS] Initializing all active Feishu channels...")
|
|
async with async_session() as db:
|
|
result = await db.execute(
|
|
select(ChannelConfig).where(
|
|
ChannelConfig.is_configured == True,
|
|
ChannelConfig.channel_type == "feishu",
|
|
)
|
|
)
|
|
configs = result.scalars().all()
|
|
|
|
for config in configs:
|
|
extra = config.extra_config or {}
|
|
mode = extra.get("connection_mode", "webhook")
|
|
if mode == "websocket":
|
|
if config.app_id and config.app_secret:
|
|
await self.start_client(
|
|
config.agent_id, config.app_id, config.app_secret, stop_existing=False
|
|
)
|
|
else:
|
|
logger.warning(f"[Feishu WS] Skipping agent {config.agent_id}: missing credentials")
|
|
|
|
def status(self) -> dict:
|
|
"""Return status of all active WS tasks."""
|
|
return {
|
|
str(aid): not self._tasks[aid].done()
|
|
for aid in self._tasks
|
|
}
|
|
|
|
|
|
feishu_ws_manager = FeishuWSManager()
|