Clawith/backend/app/api/chat_sessions.py

428 lines
16 KiB
Python

"""Chat session management API endpoints."""
import uuid
from datetime import datetime, timezone as tz
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.permissions import check_agent_access
from app.core.security import get_current_user
from app.database import get_db
from app.models.audit import ChatMessage
from app.models.chat_session import ChatSession
from app.models.agent import Agent
from app.models.user import User
router = APIRouter(prefix="/api/agents", tags=["chat-sessions"])
def _is_admin_or_creator(user: User, agent: Agent) -> bool:
return (
user.role in ("platform_admin", "org_admin")
or str(agent.creator_id) == str(user.id)
)
def _can_view_all_agent_chat_sessions(user: User) -> bool:
"""Only admin roles may list/view/delete other users' chat sessions."""
return user.role in ("platform_admin", "org_admin", "agent_admin")
class SessionOut(BaseModel):
id: str
agent_id: str
user_id: str
username: Optional[str] = None # display_name ?? username
source_channel: str = "web" # web / feishu / discord / slack / agent
title: str
created_at: str
last_message_at: Optional[str] = None
message_count: int = 0
# Agent-to-agent session fields
peer_agent_id: Optional[str] = None
peer_agent_name: Optional[str] = None
participant_type: str = "user" # 'user' | 'agent'
# Group chat session fields
is_group: bool = False
group_name: Optional[str] = None
class Config:
from_attributes = True
class CreateSessionIn(BaseModel):
title: Optional[str] = None
class PatchSessionIn(BaseModel):
title: str
@router.get("/{agent_id}/sessions")
async def list_sessions(
agent_id: uuid.UUID,
scope: str = Query("mine", description="'mine' or 'all'"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""List chat sessions for an agent. scope=all for org/platform admins and agent_admin."""
# Verify agent exists
agent_result = await db.execute(select(Agent).where(Agent.id == agent_id))
agent = agent_result.scalar_one_or_none()
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")
await check_agent_access(db, current_user, agent_id)
if scope == "all":
if not _can_view_all_agent_chat_sessions(current_user):
raise HTTPException(status_code=403, detail="Not authorized to view all sessions")
# Fetch all sessions (including agent-to-agent where this agent is peer)
result = await db.execute(
select(ChatSession)
.where(
(ChatSession.agent_id == agent_id)
| ((ChatSession.peer_agent_id == agent_id) & (ChatSession.source_channel == "agent"))
)
.order_by(ChatSession.last_message_at.desc().nulls_last(), ChatSession.created_at.desc())
)
sessions = result.scalars().all()
out = []
for session in sessions:
count_result = await db.execute(
select(func.count(ChatMessage.id)).where(
ChatMessage.conversation_id == str(session.id),
)
)
count = count_result.scalar() or 0
if count == 0:
continue # hide empty sessions
# Determine display name based on session type
display = None
peer_agent_id = None
peer_agent_name = None
participant_type = "user"
if session.source_channel == "agent" and session.peer_agent_id:
# Agent-to-agent session
participant_type = "agent"
peer_agent_id = str(session.peer_agent_id)
# Get both agent names
a1_r = await db.execute(select(Agent.name).where(Agent.id == session.agent_id))
a2_r = await db.execute(select(Agent.name).where(Agent.id == session.peer_agent_id))
a1_name = a1_r.scalar_one_or_none() or "Agent"
a2_name = a2_r.scalar_one_or_none() or "Agent"
peer_agent_name = a2_name
display = f"Agent {a1_name} - {a2_name}"
elif session.is_group:
# Group chat session — display group name instead of username
display = session.group_name or session.title or "Group Chat"
else:
# Human session — resolve username
# Note: User.username is an association_proxy, so we need to join through Identity
from app.models.user import Identity
user_r = await db.execute(
select(func.coalesce(User.display_name, Identity.username))
.join(Identity, User.identity_id == Identity.id)
.where(User.id == session.user_id)
)
display = user_r.scalar_one_or_none() or "Unknown"
out.append(SessionOut(
id=str(session.id),
agent_id=str(session.agent_id),
user_id=str(session.user_id),
username=display,
source_channel=session.source_channel,
title=session.title,
created_at=session.created_at.isoformat(),
last_message_at=session.last_message_at.isoformat() if session.last_message_at else None,
message_count=count,
peer_agent_id=peer_agent_id,
peer_agent_name=peer_agent_name,
participant_type="group" if session.is_group else participant_type,
is_group=session.is_group,
group_name=session.group_name,
))
return out
else: # scope == "mine"
result = await db.execute(
select(ChatSession)
.where(
ChatSession.agent_id == agent_id,
ChatSession.user_id == current_user.id,
ChatSession.is_group == False, # Group sessions are not "mine"
ChatSession.source_channel.notin_(["agent", "trigger"]), # Exclude agent-to-agent and reflection sessions
)
.order_by(ChatSession.last_message_at.desc().nulls_last(), ChatSession.created_at.desc())
)
sessions = result.scalars().all()
out = []
for session in sessions:
# Count only — skip sessions with no user messages (orphan assistant-only records)
count_result = await db.execute(
select(func.count(ChatMessage.id)).where(
ChatMessage.conversation_id == str(session.id),
ChatMessage.agent_id == agent_id,
ChatMessage.role == "user",
)
)
user_msg_count = count_result.scalar() or 0
if user_msg_count == 0:
continue # hide empty or orphan sessions
# Total message count for display
total_result = await db.execute(
select(func.count(ChatMessage.id)).where(
ChatMessage.conversation_id == str(session.id),
ChatMessage.agent_id == agent_id,
)
)
count = total_result.scalar() or 0
out.append(SessionOut(
id=str(session.id),
agent_id=str(session.agent_id),
user_id=str(session.user_id),
source_channel=session.source_channel,
title=session.title,
created_at=session.created_at.isoformat(),
last_message_at=session.last_message_at.isoformat() if session.last_message_at else None,
message_count=count,
))
return out
@router.post("/{agent_id}/sessions", status_code=201)
async def create_session(
agent_id: uuid.UUID,
body: CreateSessionIn = CreateSessionIn(),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Create a new chat session for the current user."""
await check_agent_access(db, current_user, agent_id)
now = datetime.now(tz.utc)
new_id = uuid.uuid4()
session = ChatSession(
id=new_id,
agent_id=agent_id,
user_id=current_user.id,
title=body.title or f"Session {now.strftime('%m-%d %H:%M')}",
source_channel="web",
created_at=now,
)
db.add(session)
await db.commit()
await db.refresh(session)
return SessionOut(
id=str(session.id),
agent_id=str(session.agent_id),
user_id=str(session.user_id),
source_channel=session.source_channel,
title=session.title,
created_at=session.created_at.isoformat(),
last_message_at=None,
message_count=0,
participant_type="user",
is_group=False,
)
@router.patch("/{agent_id}/sessions/{session_id}")
async def rename_session(
agent_id: uuid.UUID,
session_id: uuid.UUID,
body: PatchSessionIn,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Rename a session. Owner, or org/platform admin (others' sessions)."""
await check_agent_access(db, current_user, agent_id)
result = await db.execute(
select(ChatSession).where(ChatSession.id == session_id, ChatSession.agent_id == agent_id)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if str(session.user_id) != str(current_user.id) and not _can_view_all_agent_chat_sessions(current_user):
raise HTTPException(status_code=403, detail="Not authorized")
session.title = body.title
await db.commit()
return {"id": str(session.id), "title": session.title}
@router.delete("/{agent_id}/sessions/{session_id}", status_code=204)
async def delete_session(
agent_id: uuid.UUID,
session_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Delete a chat session and its messages. Owner, or org/platform admin (others' sessions)."""
await check_agent_access(db, current_user, agent_id)
result = await db.execute(
select(ChatSession).where(ChatSession.id == session_id, ChatSession.agent_id == agent_id)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if str(session.user_id) != str(current_user.id) and not _can_view_all_agent_chat_sessions(current_user):
raise HTTPException(status_code=403, detail="Not authorized")
# Delete associated messages first
from sqlalchemy import delete as sql_delete
await db.execute(sql_delete(ChatMessage).where(ChatMessage.conversation_id == str(session_id)))
await db.delete(session)
await db.commit()
return None
@router.get("/{agent_id}/sessions/{session_id}/messages")
async def get_session_messages(
agent_id: uuid.UUID,
session_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get chat messages for a specific session."""
await check_agent_access(db, current_user, agent_id)
# Allow looking up sessions where agent_id OR peer_agent_id matches
result = await db.execute(
select(ChatSession).where(
ChatSession.id == session_id,
(ChatSession.agent_id == agent_id) | (ChatSession.peer_agent_id == agent_id),
)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Permission: session owner, or any user with manage access to the viewed agent.
if str(session.user_id) != str(current_user.id) and not _can_view_all_agent_chat_sessions(current_user):
raise HTTPException(status_code=403, detail="Not authorized to view this session")
# Query messages by conversation_id only (agent-to-agent uses session_agent_id)
msgs_result = await db.execute(
select(ChatMessage)
.where(ChatMessage.conversation_id == str(session_id))
.order_by(ChatMessage.created_at.asc())
.limit(500)
)
messages = msgs_result.scalars().all()
# Resolve sender names for agent sessions
sender_cache: dict = {}
if session.source_channel == "agent":
from app.models.participant import Participant
for m in messages:
if m.participant_id and str(m.participant_id) not in sender_cache:
p_r = await db.execute(select(Participant.display_name).where(Participant.id == m.participant_id))
sender_cache[str(m.participant_id)] = p_r.scalar_one_or_none() or "Unknown"
out = []
for m in messages:
sender_name = sender_cache.get(str(m.participant_id)) if m.participant_id else None
if m.role == "tool_call":
import json
entry: dict = {"role": m.role, "content": m.content, "created_at": m.created_at.isoformat() if m.created_at else None}
try:
data = json.loads(m.content)
entry["content"] = ""
entry["toolName"] = data.get("name", "")
entry["toolArgs"] = data.get("args")
entry["toolStatus"] = data.get("status", "done")
entry["toolResult"] = data.get("result", "")
except Exception:
pass
if sender_name:
entry["sender_name"] = sender_name
out.append(entry)
continue
# For agent sessions, parse inline tool_code blocks from assistant messages
if session.source_channel == "agent" and m.role == "assistant" and "```tool_code" in (m.content or ""):
parts = _split_inline_tools(m.content)
for part in parts:
if sender_name:
part["sender_name"] = sender_name
if m.participant_id:
part["participant_id"] = str(m.participant_id)
out.append(part)
else:
entry = {"role": m.role, "content": m.content, "created_at": m.created_at.isoformat() if m.created_at else None}
if hasattr(m, 'thinking') and m.thinking:
entry["thinking"] = m.thinking
if sender_name:
entry["sender_name"] = sender_name
if m.participant_id:
entry["participant_id"] = str(m.participant_id)
out.append(entry)
return out
import re
def _split_inline_tools(content: str) -> list[dict]:
"""Parse assistant content containing inline ```tool_code blocks.
Splits into alternating text segments and tool_call entries.
Format: ```tool_code\ntool_name\n``` ```json\n{args}\n```
"""
# Pattern: ```tool_code\n<name>\n``` optionally followed by ```json\n<args>\n```
pattern = re.compile(
r'```tool_code\s*\n\s*(\w+)\s*\n```' # tool name
r'(?:\s*```json\s*\n(.*?)\n```)?', # optional JSON args
re.DOTALL
)
parts: list[dict] = []
last_end = 0
for match in pattern.finditer(content):
# Text before this tool call
text_before = content[last_end:match.start()].strip()
if text_before:
parts.append({"role": "assistant", "content": text_before})
tool_name = match.group(1)
args_str = match.group(2)
tool_args = None
if args_str:
try:
import json
tool_args = json.loads(args_str.strip())
except Exception:
tool_args = {"raw": args_str.strip()}
parts.append({
"role": "tool_call",
"content": "",
"toolName": tool_name,
"toolArgs": tool_args,
"toolStatus": "done",
"toolResult": "",
})
last_end = match.end()
# Trailing text after last tool
trailing = content[last_end:].strip()
if trailing:
parts.append({"role": "assistant", "content": trailing})
# If no matches found, return the whole content as-is
if not parts:
parts.append({"role": "assistant", "content": content})
return parts