"""Task management API routes.""" import uuid from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.permissions import check_agent_access from app.core.security import get_current_user from app.database import get_db from app.models.task import Task, TaskLog from app.models.user import User from app.schemas.schemas import TaskCreate, TaskLogCreate, TaskLogOut, TaskOut, TaskUpdate router = APIRouter(prefix="/agents/{agent_id}/tasks", tags=["tasks"]) async def _enrich_task_out(task: Task, db: AsyncSession) -> TaskOut: """Convert Task to TaskOut with creator_username populated.""" out = TaskOut.model_validate(task) if task.created_by: user_result = await db.execute(select(User).where(User.id == task.created_by)) user = user_result.scalar_one_or_none() if user: out.creator_username = user.username return out @router.get("/", response_model=list[TaskOut]) async def list_tasks( agent_id: uuid.UUID, status_filter: str | None = None, type_filter: str | None = None, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """List tasks for an agent.""" await check_agent_access(db, current_user, agent_id) query = select(Task).where(Task.agent_id == agent_id) if status_filter: query = query.where(Task.status == status_filter) if type_filter: query = query.where(Task.type == type_filter) query = query.order_by(Task.created_at.desc()) result = await db.execute(query) tasks_list = result.scalars().all() # Batch-load creator usernames creator_ids = {t.created_by for t in tasks_list if t.created_by} creator_map = {} if creator_ids: users_result = await db.execute(select(User).where(User.id.in_(creator_ids))) creator_map = {u.id: u.username for u in users_result.scalars().all()} out_list = [] for t in tasks_list: t_out = TaskOut.model_validate(t) t_out.creator_username = creator_map.get(t.created_by) out_list.append(t_out) return out_list @router.post("/", response_model=TaskOut, status_code=status.HTTP_201_CREATED) async def create_task( agent_id: uuid.UUID, data: TaskCreate, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """Create a new task for an agent.""" await check_agent_access(db, current_user, agent_id) task = Task( agent_id=agent_id, title=data.title, description=data.description, type=data.type, priority=data.priority, due_date=data.due_date, created_by=current_user.id, supervision_target_name=data.supervision_target_name, supervision_channel=data.supervision_channel, remind_schedule=data.remind_schedule, ) db.add(task) await db.flush() task_out = await _enrich_task_out(task, db) # Commit so the background executor can see the task in its own session await db.commit() # Fire background execution for todo tasks if data.type == "todo": import asyncio from app.services.task_executor import execute_task asyncio.create_task(execute_task(task.id, agent_id)) return task_out @router.patch("/{task_id}", response_model=TaskOut) async def update_task( agent_id: uuid.UUID, task_id: uuid.UUID, data: TaskUpdate, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """Update a task.""" await check_agent_access(db, current_user, agent_id) result = await db.execute(select(Task).where(Task.id == task_id, Task.agent_id == agent_id)) task = result.scalar_one_or_none() if not task: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Task not found") for field, value in data.model_dump(exclude_unset=True).items(): setattr(task, field, value) await db.flush() return await _enrich_task_out(task, db) @router.get("/{task_id}/logs", response_model=list[TaskLogOut]) async def get_task_logs( agent_id: uuid.UUID, task_id: uuid.UUID, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """Get progress logs for a task.""" await check_agent_access(db, current_user, agent_id) result = await db.execute( select(TaskLog).where(TaskLog.task_id == task_id).order_by(TaskLog.created_at.asc()) ) return [TaskLogOut.model_validate(l) for l in result.scalars().all()] @router.post("/{task_id}/logs", response_model=TaskLogOut, status_code=status.HTTP_201_CREATED) async def add_task_log( agent_id: uuid.UUID, task_id: uuid.UUID, data: TaskLogCreate, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """Add a progress log entry to a task.""" await check_agent_access(db, current_user, agent_id) log = TaskLog(task_id=task_id, content=data.content) db.add(log) await db.flush() return TaskLogOut.model_validate(log) @router.post("/{task_id}/trigger") async def trigger_task( agent_id: uuid.UUID, task_id: uuid.UUID, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """Manually trigger a supervision task execution (for testing).""" from app.core.permissions import is_agent_expired agent, _access = await check_agent_access(db, current_user, agent_id) if is_agent_expired(agent): raise HTTPException(status_code=403, detail="Agent has expired") result = await db.execute(select(Task).where(Task.id == task_id, Task.agent_id == agent_id)) task = result.scalar_one_or_none() if not task: raise HTTPException(status_code=404, detail="Task not found") import asyncio from app.services.task_executor import execute_task asyncio.create_task(execute_task(task.id, agent_id)) return {"status": "triggered", "task_id": str(task_id)}