428 lines
16 KiB
Python
428 lines
16 KiB
Python
"""Chat session management API endpoints."""
|
|
|
|
import uuid
|
|
from datetime import datetime, timezone as tz
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.permissions import check_agent_access
|
|
from app.core.security import get_current_user
|
|
from app.database import get_db
|
|
from app.models.audit import ChatMessage
|
|
from app.models.chat_session import ChatSession
|
|
from app.models.agent import Agent
|
|
from app.models.user import User
|
|
|
|
router = APIRouter(prefix="/api/agents", tags=["chat-sessions"])
|
|
|
|
|
|
def _is_admin_or_creator(user: User, agent: Agent) -> bool:
|
|
return (
|
|
user.role in ("platform_admin", "org_admin")
|
|
or str(agent.creator_id) == str(user.id)
|
|
)
|
|
|
|
|
|
def _can_view_all_agent_chat_sessions(user: User) -> bool:
|
|
"""Only admin roles may list/view/delete other users' chat sessions."""
|
|
return user.role in ("platform_admin", "org_admin", "agent_admin")
|
|
|
|
|
|
class SessionOut(BaseModel):
|
|
id: str
|
|
agent_id: str
|
|
user_id: str
|
|
username: Optional[str] = None # display_name ?? username
|
|
source_channel: str = "web" # web / feishu / discord / slack / agent
|
|
title: str
|
|
created_at: str
|
|
last_message_at: Optional[str] = None
|
|
message_count: int = 0
|
|
# Agent-to-agent session fields
|
|
peer_agent_id: Optional[str] = None
|
|
peer_agent_name: Optional[str] = None
|
|
participant_type: str = "user" # 'user' | 'agent'
|
|
# Group chat session fields
|
|
is_group: bool = False
|
|
group_name: Optional[str] = None
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
class CreateSessionIn(BaseModel):
|
|
title: Optional[str] = None
|
|
|
|
|
|
class PatchSessionIn(BaseModel):
|
|
title: str
|
|
|
|
|
|
@router.get("/{agent_id}/sessions")
|
|
async def list_sessions(
|
|
agent_id: uuid.UUID,
|
|
scope: str = Query("mine", description="'mine' or 'all'"),
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""List chat sessions for an agent. scope=all for org/platform admins and agent_admin."""
|
|
# Verify agent exists
|
|
agent_result = await db.execute(select(Agent).where(Agent.id == agent_id))
|
|
agent = agent_result.scalar_one_or_none()
|
|
if not agent:
|
|
raise HTTPException(status_code=404, detail="Agent not found")
|
|
await check_agent_access(db, current_user, agent_id)
|
|
|
|
if scope == "all":
|
|
if not _can_view_all_agent_chat_sessions(current_user):
|
|
raise HTTPException(status_code=403, detail="Not authorized to view all sessions")
|
|
|
|
# Fetch all sessions (including agent-to-agent where this agent is peer)
|
|
result = await db.execute(
|
|
select(ChatSession)
|
|
.where(
|
|
(ChatSession.agent_id == agent_id)
|
|
| ((ChatSession.peer_agent_id == agent_id) & (ChatSession.source_channel == "agent"))
|
|
)
|
|
.order_by(ChatSession.last_message_at.desc().nulls_last(), ChatSession.created_at.desc())
|
|
)
|
|
sessions = result.scalars().all()
|
|
out = []
|
|
for session in sessions:
|
|
count_result = await db.execute(
|
|
select(func.count(ChatMessage.id)).where(
|
|
ChatMessage.conversation_id == str(session.id),
|
|
)
|
|
)
|
|
count = count_result.scalar() or 0
|
|
if count == 0:
|
|
continue # hide empty sessions
|
|
|
|
# Determine display name based on session type
|
|
display = None
|
|
peer_agent_id = None
|
|
peer_agent_name = None
|
|
participant_type = "user"
|
|
|
|
if session.source_channel == "agent" and session.peer_agent_id:
|
|
# Agent-to-agent session
|
|
participant_type = "agent"
|
|
peer_agent_id = str(session.peer_agent_id)
|
|
# Get both agent names
|
|
a1_r = await db.execute(select(Agent.name).where(Agent.id == session.agent_id))
|
|
a2_r = await db.execute(select(Agent.name).where(Agent.id == session.peer_agent_id))
|
|
a1_name = a1_r.scalar_one_or_none() or "Agent"
|
|
a2_name = a2_r.scalar_one_or_none() or "Agent"
|
|
peer_agent_name = a2_name
|
|
display = f"Agent {a1_name} - {a2_name}"
|
|
elif session.is_group:
|
|
# Group chat session — display group name instead of username
|
|
display = session.group_name or session.title or "Group Chat"
|
|
else:
|
|
# Human session — resolve username
|
|
# Note: User.username is an association_proxy, so we need to join through Identity
|
|
from app.models.user import Identity
|
|
user_r = await db.execute(
|
|
select(func.coalesce(User.display_name, Identity.username))
|
|
.join(Identity, User.identity_id == Identity.id)
|
|
.where(User.id == session.user_id)
|
|
)
|
|
display = user_r.scalar_one_or_none() or "Unknown"
|
|
|
|
out.append(SessionOut(
|
|
id=str(session.id),
|
|
agent_id=str(session.agent_id),
|
|
user_id=str(session.user_id),
|
|
username=display,
|
|
source_channel=session.source_channel,
|
|
title=session.title,
|
|
created_at=session.created_at.isoformat(),
|
|
last_message_at=session.last_message_at.isoformat() if session.last_message_at else None,
|
|
message_count=count,
|
|
peer_agent_id=peer_agent_id,
|
|
peer_agent_name=peer_agent_name,
|
|
participant_type="group" if session.is_group else participant_type,
|
|
is_group=session.is_group,
|
|
group_name=session.group_name,
|
|
))
|
|
return out
|
|
|
|
else: # scope == "mine"
|
|
result = await db.execute(
|
|
select(ChatSession)
|
|
.where(
|
|
ChatSession.agent_id == agent_id,
|
|
ChatSession.user_id == current_user.id,
|
|
ChatSession.is_group == False, # Group sessions are not "mine"
|
|
ChatSession.source_channel.notin_(["agent", "trigger"]), # Exclude agent-to-agent and reflection sessions
|
|
)
|
|
.order_by(ChatSession.last_message_at.desc().nulls_last(), ChatSession.created_at.desc())
|
|
)
|
|
sessions = result.scalars().all()
|
|
out = []
|
|
for session in sessions:
|
|
# Count only — skip sessions with no user messages (orphan assistant-only records)
|
|
count_result = await db.execute(
|
|
select(func.count(ChatMessage.id)).where(
|
|
ChatMessage.conversation_id == str(session.id),
|
|
ChatMessage.agent_id == agent_id,
|
|
ChatMessage.role == "user",
|
|
)
|
|
)
|
|
user_msg_count = count_result.scalar() or 0
|
|
if user_msg_count == 0:
|
|
continue # hide empty or orphan sessions
|
|
# Total message count for display
|
|
total_result = await db.execute(
|
|
select(func.count(ChatMessage.id)).where(
|
|
ChatMessage.conversation_id == str(session.id),
|
|
ChatMessage.agent_id == agent_id,
|
|
)
|
|
)
|
|
count = total_result.scalar() or 0
|
|
out.append(SessionOut(
|
|
id=str(session.id),
|
|
agent_id=str(session.agent_id),
|
|
user_id=str(session.user_id),
|
|
source_channel=session.source_channel,
|
|
title=session.title,
|
|
created_at=session.created_at.isoformat(),
|
|
last_message_at=session.last_message_at.isoformat() if session.last_message_at else None,
|
|
message_count=count,
|
|
))
|
|
return out
|
|
|
|
|
|
@router.post("/{agent_id}/sessions", status_code=201)
|
|
async def create_session(
|
|
agent_id: uuid.UUID,
|
|
body: CreateSessionIn = CreateSessionIn(),
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Create a new chat session for the current user."""
|
|
await check_agent_access(db, current_user, agent_id)
|
|
|
|
now = datetime.now(tz.utc)
|
|
new_id = uuid.uuid4()
|
|
session = ChatSession(
|
|
id=new_id,
|
|
agent_id=agent_id,
|
|
user_id=current_user.id,
|
|
title=body.title or f"Session {now.strftime('%m-%d %H:%M')}",
|
|
source_channel="web",
|
|
created_at=now,
|
|
)
|
|
db.add(session)
|
|
await db.commit()
|
|
await db.refresh(session)
|
|
return SessionOut(
|
|
id=str(session.id),
|
|
agent_id=str(session.agent_id),
|
|
user_id=str(session.user_id),
|
|
source_channel=session.source_channel,
|
|
title=session.title,
|
|
created_at=session.created_at.isoformat(),
|
|
last_message_at=None,
|
|
message_count=0,
|
|
participant_type="user",
|
|
is_group=False,
|
|
)
|
|
|
|
|
|
@router.patch("/{agent_id}/sessions/{session_id}")
|
|
async def rename_session(
|
|
agent_id: uuid.UUID,
|
|
session_id: uuid.UUID,
|
|
body: PatchSessionIn,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Rename a session. Owner, or org/platform admin (others' sessions)."""
|
|
await check_agent_access(db, current_user, agent_id)
|
|
result = await db.execute(
|
|
select(ChatSession).where(ChatSession.id == session_id, ChatSession.agent_id == agent_id)
|
|
)
|
|
session = result.scalar_one_or_none()
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
if str(session.user_id) != str(current_user.id) and not _can_view_all_agent_chat_sessions(current_user):
|
|
raise HTTPException(status_code=403, detail="Not authorized")
|
|
|
|
session.title = body.title
|
|
await db.commit()
|
|
return {"id": str(session.id), "title": session.title}
|
|
|
|
|
|
@router.delete("/{agent_id}/sessions/{session_id}", status_code=204)
|
|
async def delete_session(
|
|
agent_id: uuid.UUID,
|
|
session_id: uuid.UUID,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Delete a chat session and its messages. Owner, or org/platform admin (others' sessions)."""
|
|
await check_agent_access(db, current_user, agent_id)
|
|
result = await db.execute(
|
|
select(ChatSession).where(ChatSession.id == session_id, ChatSession.agent_id == agent_id)
|
|
)
|
|
session = result.scalar_one_or_none()
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
if str(session.user_id) != str(current_user.id) and not _can_view_all_agent_chat_sessions(current_user):
|
|
raise HTTPException(status_code=403, detail="Not authorized")
|
|
|
|
# Delete associated messages first
|
|
from sqlalchemy import delete as sql_delete
|
|
await db.execute(sql_delete(ChatMessage).where(ChatMessage.conversation_id == str(session_id)))
|
|
await db.delete(session)
|
|
await db.commit()
|
|
return None
|
|
|
|
|
|
@router.get("/{agent_id}/sessions/{session_id}/messages")
|
|
async def get_session_messages(
|
|
agent_id: uuid.UUID,
|
|
session_id: uuid.UUID,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Get chat messages for a specific session."""
|
|
await check_agent_access(db, current_user, agent_id)
|
|
# Allow looking up sessions where agent_id OR peer_agent_id matches
|
|
result = await db.execute(
|
|
select(ChatSession).where(
|
|
ChatSession.id == session_id,
|
|
(ChatSession.agent_id == agent_id) | (ChatSession.peer_agent_id == agent_id),
|
|
)
|
|
)
|
|
session = result.scalar_one_or_none()
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
# Permission: session owner, or any user with manage access to the viewed agent.
|
|
if str(session.user_id) != str(current_user.id) and not _can_view_all_agent_chat_sessions(current_user):
|
|
raise HTTPException(status_code=403, detail="Not authorized to view this session")
|
|
|
|
# Query messages by conversation_id only (agent-to-agent uses session_agent_id)
|
|
msgs_result = await db.execute(
|
|
select(ChatMessage)
|
|
.where(ChatMessage.conversation_id == str(session_id))
|
|
.order_by(ChatMessage.created_at.asc())
|
|
.limit(500)
|
|
)
|
|
messages = msgs_result.scalars().all()
|
|
|
|
# Resolve sender names for agent sessions
|
|
sender_cache: dict = {}
|
|
if session.source_channel == "agent":
|
|
from app.models.participant import Participant
|
|
for m in messages:
|
|
if m.participant_id and str(m.participant_id) not in sender_cache:
|
|
p_r = await db.execute(select(Participant.display_name).where(Participant.id == m.participant_id))
|
|
sender_cache[str(m.participant_id)] = p_r.scalar_one_or_none() or "Unknown"
|
|
|
|
out = []
|
|
for m in messages:
|
|
sender_name = sender_cache.get(str(m.participant_id)) if m.participant_id else None
|
|
|
|
if m.role == "tool_call":
|
|
import json
|
|
entry: dict = {"role": m.role, "content": m.content, "created_at": m.created_at.isoformat() if m.created_at else None}
|
|
try:
|
|
data = json.loads(m.content)
|
|
entry["content"] = ""
|
|
entry["toolName"] = data.get("name", "")
|
|
entry["toolArgs"] = data.get("args")
|
|
entry["toolStatus"] = data.get("status", "done")
|
|
entry["toolResult"] = data.get("result", "")
|
|
except Exception:
|
|
pass
|
|
if sender_name:
|
|
entry["sender_name"] = sender_name
|
|
out.append(entry)
|
|
continue
|
|
|
|
# For agent sessions, parse inline tool_code blocks from assistant messages
|
|
if session.source_channel == "agent" and m.role == "assistant" and "```tool_code" in (m.content or ""):
|
|
parts = _split_inline_tools(m.content)
|
|
for part in parts:
|
|
if sender_name:
|
|
part["sender_name"] = sender_name
|
|
if m.participant_id:
|
|
part["participant_id"] = str(m.participant_id)
|
|
out.append(part)
|
|
else:
|
|
entry = {"role": m.role, "content": m.content, "created_at": m.created_at.isoformat() if m.created_at else None}
|
|
if hasattr(m, 'thinking') and m.thinking:
|
|
entry["thinking"] = m.thinking
|
|
if sender_name:
|
|
entry["sender_name"] = sender_name
|
|
if m.participant_id:
|
|
entry["participant_id"] = str(m.participant_id)
|
|
out.append(entry)
|
|
|
|
return out
|
|
|
|
|
|
import re
|
|
|
|
def _split_inline_tools(content: str) -> list[dict]:
|
|
"""Parse assistant content containing inline ```tool_code blocks.
|
|
|
|
Splits into alternating text segments and tool_call entries.
|
|
Format: ```tool_code\ntool_name\n``` ```json\n{args}\n```
|
|
"""
|
|
# Pattern: ```tool_code\n<name>\n``` optionally followed by ```json\n<args>\n```
|
|
pattern = re.compile(
|
|
r'```tool_code\s*\n\s*(\w+)\s*\n```' # tool name
|
|
r'(?:\s*```json\s*\n(.*?)\n```)?', # optional JSON args
|
|
re.DOTALL
|
|
)
|
|
|
|
parts: list[dict] = []
|
|
last_end = 0
|
|
|
|
for match in pattern.finditer(content):
|
|
# Text before this tool call
|
|
text_before = content[last_end:match.start()].strip()
|
|
if text_before:
|
|
parts.append({"role": "assistant", "content": text_before})
|
|
|
|
tool_name = match.group(1)
|
|
args_str = match.group(2)
|
|
tool_args = None
|
|
if args_str:
|
|
try:
|
|
import json
|
|
tool_args = json.loads(args_str.strip())
|
|
except Exception:
|
|
tool_args = {"raw": args_str.strip()}
|
|
|
|
parts.append({
|
|
"role": "tool_call",
|
|
"content": "",
|
|
"toolName": tool_name,
|
|
"toolArgs": tool_args,
|
|
"toolStatus": "done",
|
|
"toolResult": "",
|
|
})
|
|
last_end = match.end()
|
|
|
|
# Trailing text after last tool
|
|
trailing = content[last_end:].strip()
|
|
if trailing:
|
|
parts.append({"role": "assistant", "content": trailing})
|
|
|
|
# If no matches found, return the whole content as-is
|
|
if not parts:
|
|
parts.append({"role": "assistant", "content": content})
|
|
|
|
return parts
|