1020 lines
48 KiB
Python
1020 lines
48 KiB
Python
"""WebSocket chat endpoint for real-time agent conversations."""
|
||
|
||
import json
|
||
import uuid
|
||
|
||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||
from loguru import logger
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import selectinload
|
||
|
||
from app.core.security import decode_access_token
|
||
from app.core.permissions import check_agent_access, is_agent_expired
|
||
from app.database import async_session
|
||
from app.models.agent import Agent
|
||
from app.models.audit import ChatMessage
|
||
from app.models.llm import LLMModel
|
||
from app.models.user import User
|
||
|
||
router = APIRouter(tags=["websocket"])
|
||
|
||
|
||
class ConnectionManager:
|
||
"""Manage WebSocket connections per agent."""
|
||
|
||
def __init__(self):
|
||
# agent_id_str -> list of (WebSocket, session_id_str | None)
|
||
self.active_connections: dict[str, list[tuple]] = {}
|
||
|
||
async def connect(self, agent_id: str, websocket: WebSocket, session_id: str = None):
|
||
await websocket.accept()
|
||
if agent_id not in self.active_connections:
|
||
self.active_connections[agent_id] = []
|
||
self.active_connections[agent_id].append((websocket, session_id))
|
||
|
||
def disconnect(self, agent_id: str, websocket: WebSocket):
|
||
if agent_id in self.active_connections:
|
||
self.active_connections[agent_id] = [
|
||
(ws, sid) for ws, sid in self.active_connections[agent_id] if ws != websocket
|
||
]
|
||
|
||
async def send_message(self, agent_id: str, message: dict):
|
||
if agent_id in self.active_connections:
|
||
for ws, _sid in self.active_connections[agent_id]:
|
||
try:
|
||
await ws.send_json(message)
|
||
except Exception:
|
||
pass
|
||
|
||
async def send_to_session(self, agent_id: str, session_id: str, message: dict):
|
||
"""Send message only to WebSocket connections matching the given session_id."""
|
||
if agent_id in self.active_connections:
|
||
for ws, sid in self.active_connections[agent_id]:
|
||
if sid == session_id:
|
||
try:
|
||
await ws.send_json(message)
|
||
except Exception:
|
||
pass
|
||
|
||
def get_active_session_ids(self, agent_id: str) -> list[str]:
|
||
"""Return distinct session IDs for all active WS connections of an agent."""
|
||
if agent_id not in self.active_connections:
|
||
return []
|
||
return list(set(sid for _ws, sid in self.active_connections[agent_id] if sid))
|
||
|
||
|
||
manager = ConnectionManager()
|
||
|
||
|
||
from fastapi import Depends
|
||
from app.core.security import get_current_user
|
||
from app.database import get_db
|
||
from app.models.user import User
|
||
|
||
|
||
@router.get("/api/chat/{agent_id}/history")
|
||
async def get_chat_history(
|
||
agent_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Return web chat message history for this user + agent."""
|
||
conv_id = f"web_{current_user.id}"
|
||
result = await db.execute(
|
||
select(ChatMessage)
|
||
.where(ChatMessage.agent_id == agent_id, ChatMessage.conversation_id == conv_id)
|
||
.order_by(ChatMessage.created_at.asc())
|
||
.limit(200)
|
||
)
|
||
messages = result.scalars().all()
|
||
out = []
|
||
for m in messages:
|
||
entry: dict = {"role": m.role, "content": m.content, "created_at": m.created_at.isoformat() if m.created_at else None}
|
||
if getattr(m, 'thinking', None):
|
||
entry["thinking"] = m.thinking
|
||
if m.role == "tool_call":
|
||
# Parse JSON-encoded tool call data
|
||
try:
|
||
import json
|
||
data = json.loads(m.content)
|
||
entry["content"] = ""
|
||
entry["toolName"] = data.get("name", "")
|
||
entry["toolArgs"] = data.get("args")
|
||
entry["toolStatus"] = data.get("status", "done")
|
||
entry["toolResult"] = data.get("result", "")
|
||
except Exception:
|
||
pass
|
||
out.append(entry)
|
||
return out
|
||
|
||
|
||
async def call_llm(
|
||
model: LLMModel,
|
||
messages: list[dict],
|
||
agent_name: str,
|
||
role_description: str,
|
||
agent_id=None,
|
||
user_id=None,
|
||
session_id: str = "",
|
||
on_chunk=None,
|
||
on_tool_call=None,
|
||
on_thinking=None,
|
||
supports_vision=False,
|
||
max_tool_rounds_override: int | None = None,
|
||
) -> str:
|
||
"""Call LLM via unified client with function-calling tool loop.
|
||
|
||
Args:
|
||
on_chunk: Optional async callback(text: str) for streaming chunks to client.
|
||
on_thinking: Optional async callback(text: str) for reasoning/thinking content.
|
||
on_tool_call: Optional async callback(dict) for tool call status updates.
|
||
"""
|
||
from app.services.agent_tools import AGENT_TOOLS, execute_tool, get_agent_tools_for_llm
|
||
from app.services.llm_utils import create_llm_client, get_max_tokens, LLMMessage, LLMError
|
||
|
||
# ── Token limit check & config ──
|
||
_max_tool_rounds = 50 # default
|
||
if agent_id:
|
||
try:
|
||
from app.models.agent import Agent as AgentModel
|
||
async with async_session() as _db:
|
||
_ar = await _db.execute(select(AgentModel).where(AgentModel.id == agent_id))
|
||
_agent = _ar.scalar_one_or_none()
|
||
if _agent:
|
||
_max_tool_rounds = _agent.max_tool_rounds or 50
|
||
if max_tool_rounds_override and max_tool_rounds_override < _max_tool_rounds:
|
||
_max_tool_rounds = max_tool_rounds_override
|
||
if _agent.max_tokens_per_day and _agent.tokens_used_today >= _agent.max_tokens_per_day:
|
||
return f"⚠️ Daily token usage has reached the limit ({_agent.tokens_used_today:,}/{_agent.max_tokens_per_day:,}). Please try again tomorrow or ask admin to increase the limit."
|
||
if _agent.max_tokens_per_month and _agent.tokens_used_month >= _agent.max_tokens_per_month:
|
||
return f"⚠️ Monthly token usage has reached the limit ({_agent.tokens_used_month:,}/{_agent.max_tokens_per_month:,}). Please ask admin to increase the limit."
|
||
except Exception:
|
||
pass
|
||
|
||
if max_tool_rounds_override and max_tool_rounds_override < _max_tool_rounds:
|
||
_max_tool_rounds = max_tool_rounds_override
|
||
|
||
# Build rich prompt with soul, memory, skills, relationships
|
||
from app.services.agent_context import build_agent_context
|
||
# Look up current user's display name so the agent knows who it's talking to
|
||
_current_user_name = None
|
||
if user_id:
|
||
try:
|
||
from app.models.user import User as _UserModel
|
||
async with async_session() as _udb:
|
||
_ur = await _udb.execute(select(_UserModel).where(_UserModel.id == user_id))
|
||
_u = _ur.scalar_one_or_none()
|
||
if _u:
|
||
_current_user_name = _u.display_name or _u.username
|
||
except Exception:
|
||
pass
|
||
static_prompt, dynamic_prompt = await build_agent_context(agent_id, agent_name, role_description, current_user_name=_current_user_name)
|
||
|
||
# Load tools dynamically from DB
|
||
tools_for_llm = await get_agent_tools_for_llm(agent_id) if agent_id else AGENT_TOOLS
|
||
|
||
# Convert messages to LLMMessage format
|
||
api_messages = [LLMMessage(role="system", content=static_prompt, dynamic_content=dynamic_prompt)]
|
||
for msg in messages:
|
||
api_messages.append(LLMMessage(
|
||
role=msg.get("role", "user"),
|
||
content=msg.get("content"),
|
||
tool_calls=msg.get("tool_calls"),
|
||
tool_call_id=msg.get("tool_call_id"),
|
||
))
|
||
|
||
# ── Vision format conversion ──
|
||
# If the model supports vision, convert image markers in user messages
|
||
# to OpenAI Vision API format: content becomes an array of parts.
|
||
if supports_vision:
|
||
import re as _re_v
|
||
for i, msg in enumerate(api_messages):
|
||
if msg.role != "user" or not msg.content or not isinstance(msg.content, str):
|
||
continue
|
||
content_str = msg.content
|
||
# Find [image_data:data:image/...;base64,...] markers
|
||
pattern = r'\[image_data:(data:image/[^;]+;base64,[A-Za-z0-9+/=]+)\]'
|
||
images = _re_v.findall(pattern, content_str)
|
||
if not images:
|
||
continue
|
||
# Build content array
|
||
text = _re_v.sub(pattern, '', content_str).strip()
|
||
parts = []
|
||
for img_url in images:
|
||
parts.append({"type": "image_url", "image_url": {"url": img_url}})
|
||
if text:
|
||
parts.append({"type": "text", "text": text})
|
||
# Replace the message content with the array format
|
||
api_messages[i] = LLMMessage(
|
||
role=msg.role,
|
||
content=parts, # type: ignore # This is valid for vision models
|
||
)
|
||
else:
|
||
# Strip base64 image markers for non-vision models to avoid wasting tokens
|
||
import re as _re_strip
|
||
_img_pattern = r'\[image_data:data:image/[^;]+;base64,[A-Za-z0-9+/=]+\]'
|
||
for i, msg in enumerate(api_messages):
|
||
if msg.role != "user" or not isinstance(msg.content, str):
|
||
continue
|
||
if "[image_data:" in msg.content:
|
||
_n_imgs = len(_re_strip.findall(_img_pattern, msg.content))
|
||
cleaned = _re_strip.sub(_img_pattern, '', msg.content).strip()
|
||
if _n_imgs > 0:
|
||
cleaned += f"\n[用户发送了 {_n_imgs} 张图片,但当前模型不支持视觉,无法查看图片内容]"
|
||
api_messages[i] = LLMMessage(
|
||
role=msg.role,
|
||
content=cleaned,
|
||
)
|
||
|
||
# Create the unified LLM client
|
||
try:
|
||
client = create_llm_client(
|
||
provider=model.provider,
|
||
api_key=model.api_key_encrypted,
|
||
model=model.model,
|
||
base_url=model.base_url,
|
||
timeout=float(getattr(model, 'request_timeout', None) or 120.0),
|
||
)
|
||
except Exception as e:
|
||
return f"[Error] Failed to create LLM client: {e}"
|
||
|
||
max_tokens = get_max_tokens(model.provider, model.model, getattr(model, 'max_output_tokens', None))
|
||
|
||
# ── Per-round token accumulator ──
|
||
from app.services.token_tracker import record_token_usage, extract_usage_tokens, estimate_tokens_from_chars
|
||
_accumulated_tokens = 0
|
||
|
||
# Tool-calling loop (configurable per agent, default 50)
|
||
for round_i in range(_max_tool_rounds):
|
||
# ── Dynamic tool-call limit warning (Aware engine) ──
|
||
# Don't tell the agent about limits at the start — only warn when approaching.
|
||
# This prevents models from rushing to complete tasks prematurely.
|
||
_warn_threshold_80 = int(_max_tool_rounds * 0.8)
|
||
_warn_threshold_96 = _max_tool_rounds - 2
|
||
if round_i == _warn_threshold_80:
|
||
api_messages.append(LLMMessage(
|
||
role="user",
|
||
content=(
|
||
f"⚠️ 你已使用 {round_i}/{_max_tool_rounds} 轮工具调用。"
|
||
"如果当前任务尚未完成,请尽快保存进度到 focus.md,"
|
||
"并使用 set_trigger 设置续接触发器,在剩余轮次中做好收尾。"
|
||
),
|
||
))
|
||
elif round_i == _warn_threshold_96:
|
||
api_messages.append(LLMMessage(
|
||
role="user",
|
||
content=f"🚨 仅剩 2 轮工具调用。请立即保存进度到 focus.md 并设置续接触发器。",
|
||
))
|
||
|
||
try:
|
||
# Use streaming API for real-time responses
|
||
response = await client.stream(
|
||
messages=api_messages,
|
||
tools=tools_for_llm if tools_for_llm else None,
|
||
temperature=model.temperature,
|
||
max_tokens=max_tokens,
|
||
on_chunk=on_chunk,
|
||
on_thinking=on_thinking,
|
||
)
|
||
except LLMError as e:
|
||
# Record accumulated tokens before returning error
|
||
logger.error(
|
||
f"[LLM] LLMError provider={getattr(model, 'provider', '?')} "
|
||
f"model={getattr(model, 'model', '?')} round={round_i + 1}: {e}"
|
||
)
|
||
if agent_id and _accumulated_tokens > 0:
|
||
await record_token_usage(agent_id, _accumulated_tokens)
|
||
return f"[LLM Error] {e}"
|
||
except Exception as e:
|
||
logger.error(
|
||
f"[LLM] Unexpected error provider={getattr(model, 'provider', '?')} "
|
||
f"model={getattr(model, 'model', '?')} round={round_i + 1}: "
|
||
f"{type(e).__name__}: {str(e)[:300]}"
|
||
)
|
||
if agent_id and _accumulated_tokens > 0:
|
||
await record_token_usage(agent_id, _accumulated_tokens)
|
||
return f"[LLM call error] {type(e).__name__}: {str(e)[:200]}"
|
||
|
||
# ── Track tokens for this round ──
|
||
logger.debug(f"[LLM] stream() returned: {len(response.content or '')} chars, finish={response.finish_reason}, tools={len(response.tool_calls or [])}")
|
||
real_tokens = extract_usage_tokens(response.usage)
|
||
if real_tokens:
|
||
_accumulated_tokens += real_tokens
|
||
else:
|
||
round_chars = sum(len(m.content or '') if isinstance(m.content, str) else 0 for m in api_messages) + len(response.content or '')
|
||
_accumulated_tokens += estimate_tokens_from_chars(round_chars)
|
||
|
||
# If no tool calls, return the final content
|
||
if not response.tool_calls:
|
||
if agent_id and _accumulated_tokens > 0:
|
||
await record_token_usage(agent_id, _accumulated_tokens)
|
||
await client.close()
|
||
return response.content or "[LLM returned empty content]"
|
||
|
||
# Execute tool calls
|
||
logger.info(f"[LLM] Round {round_i+1}: {len(response.tool_calls)} tool call(s), finish_reason={response.finish_reason}")
|
||
|
||
# Add assistant message with tool calls
|
||
api_messages.append(LLMMessage(
|
||
role="assistant",
|
||
content=response.content or None,
|
||
tool_calls=[{
|
||
"id": tc["id"],
|
||
"type": "function",
|
||
"function": tc["function"],
|
||
} for tc in response.tool_calls],
|
||
reasoning_content=response.reasoning_content,
|
||
))
|
||
|
||
full_reasoning_content = response.reasoning_content or ""
|
||
|
||
# Tools that require arguments — if LLM sends empty args, skip and ask to retry
|
||
_TOOLS_REQUIRING_ARGS = {"write_file", "read_file", "delete_file", "read_document", "send_message_to_agent", "send_feishu_message", "send_email"}
|
||
|
||
for tc in response.tool_calls:
|
||
fn = tc["function"]
|
||
tool_name = fn["name"]
|
||
raw_args = fn.get("arguments", "{}")
|
||
logger.info(f"[LLM] Raw arguments for {tool_name} (len={len(raw_args)}): {repr(raw_args[:300])}")
|
||
try:
|
||
args = json.loads(raw_args) if raw_args else {}
|
||
except json.JSONDecodeError:
|
||
args = {}
|
||
|
||
# Guard: if a tool that requires arguments received empty args,
|
||
# return an error to LLM instead of executing (Claude sometimes
|
||
# emits tool_use blocks with no input_json_delta events)
|
||
if not args and tool_name in _TOOLS_REQUIRING_ARGS:
|
||
logger.warning(f"[LLM] Empty arguments for {tool_name}, asking LLM to retry")
|
||
api_messages.append(LLMMessage(
|
||
role="tool",
|
||
content=f"Error: {tool_name} was called with empty arguments. You must provide the required parameters. Please retry with the correct arguments.",
|
||
tool_call_id=tc.get("id", ""),
|
||
))
|
||
continue
|
||
|
||
logger.info(f"[LLM] Calling tool: {tool_name}({args})")
|
||
# Notify client about tool call (in-progress)
|
||
if on_tool_call:
|
||
try:
|
||
await on_tool_call({
|
||
"name": tool_name,
|
||
"args": args,
|
||
"status": "running",
|
||
"reasoning_content": full_reasoning_content
|
||
})
|
||
except Exception:
|
||
pass
|
||
|
||
result = await execute_tool(
|
||
tool_name, args,
|
||
agent_id=agent_id,
|
||
user_id=user_id or agent_id,
|
||
session_id=session_id,
|
||
)
|
||
logger.debug(f"[LLM] Tool result: {result[:100]}")
|
||
|
||
# Notify client about tool call result
|
||
if on_tool_call:
|
||
try:
|
||
await on_tool_call({
|
||
"name": tool_name,
|
||
"args": args,
|
||
"status": "done",
|
||
"result": result,
|
||
"reasoning_content": full_reasoning_content
|
||
})
|
||
except Exception as _cb_err:
|
||
logger.warning(f"[LLM] on_tool_call callback error: {_cb_err}")
|
||
|
||
# ── Vision injection for screenshot tools ──
|
||
# If the model supports vision, try to inject the actual screenshot
|
||
# image into the tool result so the LLM can SEE what's on screen.
|
||
# Without this, the LLM only gets text like "Screenshot saved to ..."
|
||
# and blindly guesses the page content.
|
||
tool_content: str | list = str(result)
|
||
if supports_vision and agent_id:
|
||
try:
|
||
from app.services.vision_inject import try_inject_screenshot_vision
|
||
from app.services.agent_tools import WORKSPACE_ROOT
|
||
ws_path = WORKSPACE_ROOT / str(agent_id)
|
||
vision_content = try_inject_screenshot_vision(tool_name, str(result), ws_path)
|
||
if vision_content:
|
||
tool_content = vision_content
|
||
logger.info(f"[LLM] Injected screenshot vision for {tool_name}")
|
||
except Exception as e:
|
||
logger.warning(f"[LLM] Vision injection failed for {tool_name}: {e}")
|
||
|
||
api_messages.append(LLMMessage(
|
||
role="tool",
|
||
tool_call_id=tc["id"],
|
||
content=tool_content,
|
||
))
|
||
|
||
# Record tokens even on "too many rounds" exit
|
||
if agent_id and _accumulated_tokens > 0:
|
||
await record_token_usage(agent_id, _accumulated_tokens)
|
||
await client.close()
|
||
return "[Error] Too many tool call rounds"
|
||
|
||
|
||
@router.websocket("/ws/chat/{agent_id}")
|
||
async def websocket_chat(
|
||
websocket: WebSocket,
|
||
agent_id: uuid.UUID,
|
||
token: str = Query(...),
|
||
session_id: str = Query(None),
|
||
):
|
||
"""WebSocket endpoint for real-time chat with an agent.
|
||
|
||
Flow:
|
||
1. Client connects with JWT token + optional session_id as query params
|
||
2. Server accepts immediately so browser onopen fires quickly
|
||
3. Server authenticates and checks agent access
|
||
4. If session_id provided, uses it; otherwise finds/creates the user's latest session
|
||
5. Client sends messages as JSON: {"content": "..."}
|
||
6. Server calls the agent's configured LLM and sends response back
|
||
7. Messages are persisted to chat_messages table under the session
|
||
"""
|
||
# Accept immediately so browser sees onopen without waiting for DB setup
|
||
await websocket.accept()
|
||
|
||
# Authenticate
|
||
try:
|
||
payload = decode_access_token(token)
|
||
user_id = uuid.UUID(payload["sub"])
|
||
except Exception:
|
||
await websocket.send_json({"type": "error", "content": "Authentication failed"})
|
||
await websocket.close(code=4001)
|
||
return
|
||
|
||
# Verify access and load agent + model
|
||
agent_name = ""
|
||
agent_type = "" # Track agent type for OpenClaw routing
|
||
role_description = ""
|
||
welcome_message = ""
|
||
llm_model = None
|
||
fallback_llm_model = None
|
||
history_messages = []
|
||
|
||
try:
|
||
async with async_session() as db:
|
||
logger.info(f"[WS] Looking up user {user_id}")
|
||
result = await db.execute(select(User).where(User.id == user_id))
|
||
user = result.scalar_one_or_none()
|
||
if not user:
|
||
logger.info("[WS] User not found")
|
||
await websocket.send_json({"type": "error", "content": "User not found"})
|
||
await websocket.close(code=4001)
|
||
return
|
||
|
||
logger.info(f"[WS] Checking agent access for {agent_id}")
|
||
agent, _ = await check_agent_access(db, user, agent_id)
|
||
# Check agent expiry
|
||
if is_agent_expired(agent):
|
||
await websocket.send_json({"type": "error", "content": "This Agent has expired and is off duty. Please contact your admin to extend its service."})
|
||
await websocket.close(code=4003)
|
||
return
|
||
agent_name = agent.name
|
||
agent_type = agent.agent_type or ""
|
||
role_description = agent.role_description or ""
|
||
welcome_message = agent.welcome_message or ""
|
||
from app.models.agent import DEFAULT_CONTEXT_WINDOW_SIZE
|
||
ctx_size = agent.context_window_size or DEFAULT_CONTEXT_WINDOW_SIZE
|
||
logger.info(f"[WS] Agent: {agent_name}, type: {agent_type}, model_id: {agent.primary_model_id}, ctx: {ctx_size}")
|
||
|
||
# Load the agent's primary model
|
||
if agent.primary_model_id:
|
||
model_result = await db.execute(
|
||
select(LLMModel).where(LLMModel.id == agent.primary_model_id)
|
||
)
|
||
llm_model = model_result.scalar_one_or_none()
|
||
# Treat disabled models as unavailable at runtime
|
||
if llm_model and not llm_model.enabled:
|
||
logger.info(f"[WS] Primary model {llm_model.model} is disabled, skipping")
|
||
llm_model = None
|
||
else:
|
||
logger.info(f"[WS] Primary model loaded: {llm_model.model if llm_model else 'None'}")
|
||
|
||
# Load fallback model
|
||
if agent.fallback_model_id:
|
||
fb_result = await db.execute(
|
||
select(LLMModel).where(LLMModel.id == agent.fallback_model_id)
|
||
)
|
||
fallback_llm_model = fb_result.scalar_one_or_none()
|
||
# Treat disabled fallback models as unavailable
|
||
if fallback_llm_model and not fallback_llm_model.enabled:
|
||
logger.info(f"[WS] Fallback model {fallback_llm_model.model} is disabled, skipping")
|
||
fallback_llm_model = None
|
||
elif fallback_llm_model:
|
||
logger.info(f"[WS] Fallback model loaded: {fallback_llm_model.model}")
|
||
|
||
# Config-level fallback: primary missing -> use fallback
|
||
if not llm_model and fallback_llm_model:
|
||
llm_model = fallback_llm_model
|
||
fallback_llm_model = None # No further fallback available
|
||
logger.info(f"[WS] Primary model unavailable, using fallback: {llm_model.model}")
|
||
|
||
# Resolve or create chat session
|
||
from app.models.chat_session import ChatSession
|
||
from sqlalchemy import select as _sel
|
||
from datetime import datetime as _dt, timezone as _tz
|
||
conv_id = session_id
|
||
if conv_id:
|
||
# Validate the session belongs to this agent and to this user (no hijacking others' sessions).
|
||
try:
|
||
_sid = uuid.UUID(conv_id)
|
||
except (ValueError, TypeError):
|
||
conv_id = None
|
||
_existing = None
|
||
else:
|
||
_sr = await db.execute(
|
||
_sel(ChatSession).where(
|
||
ChatSession.id == _sid,
|
||
ChatSession.agent_id == agent_id,
|
||
)
|
||
)
|
||
_existing = _sr.scalar_one_or_none()
|
||
if not _existing:
|
||
conv_id = None
|
||
elif _existing.source_channel != "agent" and str(_existing.user_id) != str(user_id):
|
||
await websocket.send_json({"type": "error", "content": "Not authorized for this session"})
|
||
await websocket.close(code=4003)
|
||
return
|
||
if not conv_id:
|
||
# Find most recent session for this user+agent
|
||
_sr = await db.execute(
|
||
_sel(ChatSession)
|
||
.where(ChatSession.agent_id == agent_id, ChatSession.user_id == user_id)
|
||
.order_by(ChatSession.last_message_at.desc().nulls_last(), ChatSession.created_at.desc())
|
||
.limit(1)
|
||
)
|
||
_latest = _sr.scalar_one_or_none()
|
||
if _latest:
|
||
conv_id = str(_latest.id)
|
||
else:
|
||
# Create a default session
|
||
now = _dt.now(_tz.utc)
|
||
_new_session = ChatSession(
|
||
agent_id=agent_id, user_id=user_id,
|
||
title=f"Session {now.strftime('%m-%d %H:%M')}",
|
||
source_channel="web",
|
||
created_at=now,
|
||
)
|
||
db.add(_new_session)
|
||
await db.commit()
|
||
await db.refresh(_new_session)
|
||
conv_id = str(_new_session.id)
|
||
logger.info(f"[WS] Created default session {conv_id}")
|
||
|
||
try:
|
||
history_result = await db.execute(
|
||
select(ChatMessage)
|
||
.where(ChatMessage.agent_id == agent_id, ChatMessage.conversation_id == conv_id)
|
||
.order_by(ChatMessage.created_at.desc())
|
||
.limit(ctx_size)
|
||
)
|
||
history_messages = list(reversed(history_result.scalars().all()))
|
||
logger.info(f"[WS] Loaded {len(history_messages)} history messages for session {conv_id}")
|
||
except Exception as e:
|
||
logger.warning(f"[WS] History load failed (non-fatal): {e}")
|
||
except Exception as e:
|
||
logger.error(f"[WS] Setup error: {type(e).__name__}: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
await websocket.send_json({"type": "error", "content": "Setup failed"})
|
||
await websocket.close(code=4002) # Config error — client should NOT retry
|
||
return
|
||
|
||
agent_id_str = str(agent_id)
|
||
if agent_id_str not in manager.active_connections:
|
||
manager.active_connections[agent_id_str] = []
|
||
manager.active_connections[agent_id_str].append((websocket, conv_id))
|
||
logger.info(f"[WS] Ready! Agent={agent_name}")
|
||
|
||
# Send session_id to frontend so Take Control can reference the correct session
|
||
await websocket.send_json({"type": "connected", "session_id": conv_id})
|
||
|
||
# Build conversation context from history
|
||
# IMPORTANT: Include tool_call messages so the LLM maintains tool-calling behavior.
|
||
# Without them, Claude sees user→assistant-text patterns and learns to skip tools.
|
||
conversation: list[dict] = []
|
||
for msg in history_messages:
|
||
if msg.role == "tool_call":
|
||
# Convert stored tool_call JSON into OpenAI-format assistant+tool pair
|
||
try:
|
||
import json as _j_hist
|
||
tc_data = _j_hist.loads(msg.content)
|
||
tc_name = tc_data.get("name", "unknown")
|
||
tc_args = tc_data.get("args", {})
|
||
tc_result = tc_data.get("result", "")
|
||
tc_id = f"call_{msg.id}" # synthetic tool_call_id
|
||
# Assistant message with tool_calls array
|
||
asst_msg = {
|
||
"role": "assistant",
|
||
"content": None,
|
||
"tool_calls": [{
|
||
"id": tc_id,
|
||
"type": "function",
|
||
"function": {"name": tc_name, "arguments": _j_hist.dumps(tc_args, ensure_ascii=False)},
|
||
}],
|
||
}
|
||
if tc_data.get("reasoning_content"):
|
||
asst_msg["reasoning_content"] = tc_data["reasoning_content"]
|
||
conversation.append(asst_msg)
|
||
# Tool result message.
|
||
# Sanitize any stale [ImageID: ...] markers left by the ephemeral
|
||
# screenshot cache — those images are gone from memory and would
|
||
# confuse the LLM if sent as-is.
|
||
from app.services.vision_inject import sanitize_history_tool_result
|
||
sanitized_result = sanitize_history_tool_result(str(tc_result))
|
||
conversation.append({
|
||
"role": "tool",
|
||
"tool_call_id": tc_id,
|
||
"content": sanitized_result[:500],
|
||
})
|
||
except Exception:
|
||
continue # Skip malformed tool_call records
|
||
else:
|
||
entry = {"role": msg.role, "content": msg.content}
|
||
if hasattr(msg, 'thinking') and msg.thinking:
|
||
entry["thinking"] = msg.thinking
|
||
conversation.append(entry)
|
||
|
||
try:
|
||
# Send welcome message on new session (no history)
|
||
if welcome_message and not history_messages:
|
||
await websocket.send_json({"type": "done", "role": "assistant", "content": welcome_message})
|
||
|
||
while True:
|
||
logger.info(f"[WS] Waiting for message from {agent_name}...")
|
||
data = await websocket.receive_json()
|
||
|
||
# Set a unique trace ID for this specific message processing
|
||
from app.core.logging_config import set_trace_id
|
||
import uuid as _trace_uuid
|
||
trace_id = str(_trace_uuid.uuid4())[:12]
|
||
set_trace_id(trace_id)
|
||
|
||
content = data.get("content", "")
|
||
display_content = data.get("display_content", "") # User-facing display text
|
||
file_name = data.get("file_name", "") # Original file name for attachment display
|
||
logger.info(f"[WS] Received: {content[:50]}")
|
||
|
||
if not content:
|
||
continue
|
||
|
||
# ── Quota checks ──
|
||
try:
|
||
from app.services.quota_guard import (
|
||
check_conversation_quota, increment_conversation_usage,
|
||
check_agent_expired, check_agent_llm_quota, increment_agent_llm_usage,
|
||
QuotaExceeded, AgentExpired,
|
||
)
|
||
await check_conversation_quota(user_id)
|
||
await check_agent_expired(agent_id)
|
||
except QuotaExceeded as qe:
|
||
await websocket.send_json({"type": "done", "role": "assistant", "content": f"⚠️ {qe.message}"})
|
||
continue
|
||
except AgentExpired as ae:
|
||
await websocket.send_json({"type": "done", "role": "assistant", "content": f"⚠️ {ae.message}"})
|
||
continue
|
||
|
||
# Add user message to conversation (full LLM context)
|
||
conversation.append({"role": "user", "content": content})
|
||
|
||
# Save user message to DB.
|
||
#
|
||
# Strategy:
|
||
# - If the LLM content contains [image_data:...] markers (i.e. the user
|
||
# attached an image and the model supports vision), persist the FULL
|
||
# content including the base64 marker. This makes history self-contained
|
||
# so subsequent turns can forward the image to the LLM without any
|
||
# disk-based rehydration step.
|
||
# - For all other messages (text, non-image files) use display_content for
|
||
# cleaner history (avoids e.g. the raw file-text blob appearing in chat).
|
||
#
|
||
# The call_llm() path already strips [image_data:] for non-vision models
|
||
# (websocket.py ~line 210), so no extra handling is needed at read time.
|
||
HAS_IMAGE_MARKER = "[image_data:" in content
|
||
if HAS_IMAGE_MARKER:
|
||
# Preserve the full LLM content (includes base64) for multi-turn context.
|
||
# Prefix with [file:name] for the UI history parser if a file name exists.
|
||
saved_content = f"[file:{file_name}]\n{content}" if file_name else content
|
||
else:
|
||
saved_content = display_content if display_content else content
|
||
if file_name:
|
||
saved_content = f"[file:{file_name}]\n{saved_content}"
|
||
async with async_session() as db:
|
||
user_msg = ChatMessage(
|
||
agent_id=agent_id,
|
||
user_id=user_id,
|
||
role="user",
|
||
content=saved_content,
|
||
conversation_id=conv_id,
|
||
)
|
||
db.add(user_msg)
|
||
# Update session last_message_at + auto-title on first message
|
||
from app.models.chat_session import ChatSession as _CS
|
||
from datetime import datetime as _dt2, timezone as _tz2
|
||
_now = _dt2.now(_tz2.utc)
|
||
_sess_r = await db.execute(
|
||
select(_CS).where(_CS.id == uuid.UUID(conv_id))
|
||
)
|
||
_sess = _sess_r.scalar_one_or_none()
|
||
if _sess:
|
||
_sess.last_message_at = _now
|
||
if not history_messages and _sess.title.startswith("Session "):
|
||
# Always use display_content for title (never expose raw base64)
|
||
title_src = display_content if display_content else content
|
||
# Clean up common prefixes from image/file messages
|
||
clean_title = title_src.replace("[图片] ", "📷 ").replace("[image_data:", "").strip()
|
||
if file_name and not clean_title:
|
||
clean_title = f"📎 {file_name}"
|
||
_sess.title = clean_title[:40] if clean_title else content[:40]
|
||
await db.commit()
|
||
logger.info("[WS] User message saved")
|
||
|
||
|
||
# ── OpenClaw routing: insert into gateway_messages instead of LLM ──
|
||
if agent_type == "openclaw":
|
||
from app.models.gateway_message import GatewayMessage as GwMsg
|
||
async with async_session() as db:
|
||
gw_msg = GwMsg(
|
||
agent_id=agent_id,
|
||
sender_user_id=user_id,
|
||
conversation_id=conv_id,
|
||
content=content,
|
||
status="pending",
|
||
)
|
||
db.add(gw_msg)
|
||
await db.commit()
|
||
logger.info("[WS] OpenClaw: message queued for gateway poll")
|
||
await websocket.send_json({
|
||
"type": "done",
|
||
"role": "assistant",
|
||
"content": "Message forwarded to OpenClaw agent. Waiting for response..."
|
||
})
|
||
continue
|
||
|
||
# Detect task creation intent
|
||
import re
|
||
task_match = re.search(
|
||
r'(?:创建|新建|添加|建一个|帮我建|create|add)(?:一个|a )?(?:任务|待办|todo|task)[,,:::\\s]*(.+)',
|
||
content, re.IGNORECASE
|
||
)
|
||
|
||
# Track thinking content for storage (initialize before condition)
|
||
thinking_content = []
|
||
|
||
# Call LLM with streaming
|
||
if llm_model:
|
||
try:
|
||
logger.info(f"[WS] Calling LLM {llm_model.model} (streaming)...")
|
||
|
||
# Accumulate partial content for abort handling
|
||
partial_chunks: list[str] = []
|
||
|
||
async def stream_to_ws(text: str):
|
||
"""Send each chunk to client in real-time."""
|
||
partial_chunks.append(text)
|
||
await websocket.send_json({"type": "chunk", "content": text})
|
||
|
||
# Track which agentbay live URLs have been sent to avoid redundant pushes
|
||
_sent_live_envs: set[str] = set()
|
||
|
||
async def tool_call_to_ws(data: dict):
|
||
"""Send tool call info to client and persist completed ones."""
|
||
# ── AgentBay live preview: embed screenshot URL in tool_call message ──
|
||
# We embed live preview data directly in the tool_call payload
|
||
# because separate WebSocket messages get silently dropped by nginx.
|
||
if data.get("status") == "done":
|
||
try:
|
||
from app.services.agentbay_live import detect_agentbay_env, get_desktop_screenshot, get_browser_snapshot
|
||
import re as _re_live
|
||
tool_name = data.get("name", "")
|
||
env = detect_agentbay_env(tool_name)
|
||
if env:
|
||
tool_result = data.get("result", "") or ""
|
||
if env == "desktop":
|
||
b64_url = await get_desktop_screenshot(agent_id, session_id=conv_id)
|
||
if b64_url:
|
||
data["live_preview"] = {"env": env, "screenshot_url": b64_url}
|
||
logger.info(f"[WS][LivePreview] Embedded {env} base64 in tool_call")
|
||
elif env == "browser":
|
||
b64_url = await get_browser_snapshot(agent_id, session_id=conv_id)
|
||
if b64_url:
|
||
data["live_preview"] = {"env": env, "screenshot_url": b64_url}
|
||
logger.info(f"[WS][LivePreview] Embedded {env} base64 in tool_call")
|
||
elif env == "code":
|
||
data["live_preview"] = {"env": "code", "output": tool_result[:5000]}
|
||
except Exception as _lp_err:
|
||
logger.warning(f"[WS][LivePreview] Embed failed: {_lp_err}")
|
||
|
||
await websocket.send_json({"type": "tool_call", **data})
|
||
|
||
# Save completed tool calls to DB so they persist in chat history
|
||
if data.get("status") == "done":
|
||
try:
|
||
import json as _json_tc
|
||
async with async_session() as _tc_db:
|
||
tc_msg = ChatMessage(
|
||
agent_id=agent_id,
|
||
user_id=user_id,
|
||
role="tool_call",
|
||
content=_json_tc.dumps({
|
||
"name": data.get("name", ""),
|
||
"args": data.get("args"),
|
||
"status": "done",
|
||
"result": (data.get("result") or "")[:500],
|
||
"reasoning_content": data.get("reasoning_content"),
|
||
}),
|
||
conversation_id=conv_id,
|
||
)
|
||
_tc_db.add(tc_msg)
|
||
await _tc_db.commit()
|
||
except Exception as _tc_err:
|
||
logger.warning(f"[WS] Failed to save tool_call: {_tc_err}")
|
||
|
||
# Track thinking content for storage
|
||
thinking_content = []
|
||
|
||
async def thinking_to_ws(text: str):
|
||
"""Send thinking chunks to client for collapsible display."""
|
||
thinking_content.append(text)
|
||
await websocket.send_json({"type": "thinking", "content": text})
|
||
|
||
import asyncio as _aio
|
||
|
||
# Run call_llm as a cancellable task
|
||
llm_task = _aio.create_task(call_llm(
|
||
llm_model,
|
||
conversation[-ctx_size:],
|
||
agent_name,
|
||
role_description,
|
||
agent_id=agent_id,
|
||
user_id=user_id,
|
||
session_id=conv_id,
|
||
on_chunk=stream_to_ws,
|
||
on_tool_call=tool_call_to_ws,
|
||
on_thinking=thinking_to_ws,
|
||
supports_vision=getattr(llm_model, 'supports_vision', False),
|
||
))
|
||
|
||
# Listen for abort while LLM is running
|
||
aborted = False
|
||
queued_messages: list[dict] = []
|
||
while not llm_task.done():
|
||
try:
|
||
msg = await _aio.wait_for(
|
||
websocket.receive_json(), timeout=0.5
|
||
)
|
||
if msg.get("type") == "abort":
|
||
logger.info(f"[WS] Abort received, cancelling LLM task")
|
||
llm_task.cancel()
|
||
aborted = True
|
||
break
|
||
else:
|
||
# Queue non-abort messages for later
|
||
queued_messages.append(msg)
|
||
except _aio.TimeoutError:
|
||
continue
|
||
except WebSocketDisconnect:
|
||
llm_task.cancel()
|
||
raise
|
||
|
||
if aborted:
|
||
# Wait for task to finish cancelling
|
||
try:
|
||
await llm_task
|
||
except (_aio.CancelledError, Exception):
|
||
pass
|
||
partial_text = "".join(partial_chunks).strip()
|
||
if partial_text:
|
||
assistant_response = partial_text + "\n\n*[Generation stopped]*"
|
||
else:
|
||
assistant_response = "*[Generation stopped]*"
|
||
logger.info(f"[WS] LLM aborted, partial: {assistant_response[:80]}")
|
||
else:
|
||
assistant_response = await llm_task
|
||
logger.info(f"[WS] LLM response: {assistant_response[:80]}")
|
||
|
||
# Update last_active_at
|
||
from datetime import datetime, timezone as tz
|
||
async with async_session() as _db:
|
||
from app.models.agent import Agent as AgentModel
|
||
_ar = await _db.execute(select(AgentModel).where(AgentModel.id == agent_id))
|
||
_agent = _ar.scalar_one_or_none()
|
||
if _agent:
|
||
_agent.last_active_at = datetime.now(tz.utc)
|
||
await _db.commit()
|
||
|
||
# Increment quota usage
|
||
try:
|
||
await increment_conversation_usage(user_id)
|
||
await increment_agent_llm_usage(agent_id)
|
||
except Exception:
|
||
pass
|
||
|
||
# Log activity
|
||
from app.services.activity_logger import log_activity
|
||
await log_activity(agent_id, "chat_reply", f"Replied to web chat: {assistant_response[:80]}", detail={"channel": "web", "user_text": content[:200], "reply": assistant_response[:500]})
|
||
except WebSocketDisconnect:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"[WS] LLM error: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
# Runtime fallback: primary model failed -> retry with fallback model
|
||
if fallback_llm_model:
|
||
logger.info(f"[WS] Primary model failed, retrying with fallback: {fallback_llm_model.model}")
|
||
try:
|
||
await websocket.send_json({"type": "info", "content": f"Primary model error, switching to fallback model ({fallback_llm_model.model})..."})
|
||
assistant_response = await call_llm(
|
||
fallback_llm_model,
|
||
conversation[-ctx_size:],
|
||
agent_name,
|
||
role_description,
|
||
agent_id=agent_id,
|
||
user_id=user_id,
|
||
session_id=conv_id,
|
||
on_chunk=stream_to_ws,
|
||
on_tool_call=tool_call_to_ws,
|
||
on_thinking=thinking_to_ws,
|
||
supports_vision=getattr(fallback_llm_model, 'supports_vision', False),
|
||
)
|
||
logger.info(f"[WS] Fallback LLM response: {assistant_response[:80]}")
|
||
except Exception as e2:
|
||
logger.error(f"[WS] Fallback LLM also failed: {e2}")
|
||
traceback.print_exc()
|
||
assistant_response = f"[LLM call error] Primary: {str(e)[:100]} | Fallback: {str(e2)[:100]}"
|
||
else:
|
||
assistant_response = f"[LLM call error] {str(e)[:200]}"
|
||
else:
|
||
assistant_response = f"⚠️ {agent_name} has no LLM model configured. Please select a model in the agent's Settings tab."
|
||
|
||
# If task creation detected, create a real Task record
|
||
if task_match:
|
||
task_title = task_match.group(1).strip()
|
||
if task_title:
|
||
try:
|
||
from app.models.task import Task
|
||
from app.services.task_executor import execute_task
|
||
import asyncio as _asyncio
|
||
async with async_session() as db:
|
||
task = Task(
|
||
agent_id=agent_id,
|
||
title=task_title,
|
||
created_by=user_id,
|
||
status="pending",
|
||
priority="medium",
|
||
)
|
||
db.add(task)
|
||
await db.commit()
|
||
await db.refresh(task)
|
||
task_id = task.id
|
||
_asyncio.create_task(execute_task(task_id, agent_id))
|
||
assistant_response += f"\n\n📋 Task synced to task board: [{task_title}]"
|
||
logger.info(f"[WS] Created task: {task_title}")
|
||
except Exception as e:
|
||
logger.error(f"[WS] Failed to create task: {e}")
|
||
|
||
# Add assistant response to conversation
|
||
conversation.append({"role": "assistant", "content": assistant_response})
|
||
|
||
# Save assistant message
|
||
async with async_session() as db:
|
||
asst_msg = ChatMessage(
|
||
agent_id=agent_id,
|
||
user_id=user_id,
|
||
role="assistant",
|
||
content=assistant_response,
|
||
thinking=''.join(thinking_content) if thinking_content else None,
|
||
conversation_id=conv_id,
|
||
)
|
||
db.add(asst_msg)
|
||
await db.commit()
|
||
logger.info("[WS] Assistant message saved")
|
||
|
||
# Send done signal with final content (for non-streaming clients)
|
||
await websocket.send_json({
|
||
"type": "done",
|
||
"role": "assistant",
|
||
"content": assistant_response,
|
||
})
|
||
logger.info("[WS] Response done sent to client")
|
||
|
||
except WebSocketDisconnect:
|
||
logger.info(f"[WS] Client disconnected: {agent_name}")
|
||
manager.disconnect(agent_id_str, websocket)
|
||
except Exception as e:
|
||
logger.error(f"[WS] Error in message loop: {type(e).__name__}: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
manager.disconnect(agent_id_str, websocket)
|
||
try:
|
||
await websocket.close(code=1011)
|
||
except Exception:
|
||
pass
|