Clawith/backend/app/api/tasks.py

175 lines
5.9 KiB
Python

"""Task management API routes."""
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.permissions import check_agent_access
from app.core.security import get_current_user
from app.database import get_db
from app.models.task import Task, TaskLog
from app.models.user import User
from app.schemas.schemas import TaskCreate, TaskLogCreate, TaskLogOut, TaskOut, TaskUpdate
router = APIRouter(prefix="/agents/{agent_id}/tasks", tags=["tasks"])
async def _enrich_task_out(task: Task, db: AsyncSession) -> TaskOut:
"""Convert Task to TaskOut with creator_username populated."""
out = TaskOut.model_validate(task)
if task.created_by:
user_result = await db.execute(select(User).where(User.id == task.created_by))
user = user_result.scalar_one_or_none()
if user:
out.creator_username = user.username
return out
@router.get("/", response_model=list[TaskOut])
async def list_tasks(
agent_id: uuid.UUID,
status_filter: str | None = None,
type_filter: str | None = None,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""List tasks for an agent."""
await check_agent_access(db, current_user, agent_id)
query = select(Task).where(Task.agent_id == agent_id)
if status_filter:
query = query.where(Task.status == status_filter)
if type_filter:
query = query.where(Task.type == type_filter)
query = query.order_by(Task.created_at.desc())
result = await db.execute(query)
tasks_list = result.scalars().all()
# Batch-load creator usernames
creator_ids = {t.created_by for t in tasks_list if t.created_by}
creator_map = {}
if creator_ids:
users_result = await db.execute(select(User).where(User.id.in_(creator_ids)))
creator_map = {u.id: u.username for u in users_result.scalars().all()}
out_list = []
for t in tasks_list:
t_out = TaskOut.model_validate(t)
t_out.creator_username = creator_map.get(t.created_by)
out_list.append(t_out)
return out_list
@router.post("/", response_model=TaskOut, status_code=status.HTTP_201_CREATED)
async def create_task(
agent_id: uuid.UUID,
data: TaskCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Create a new task for an agent."""
await check_agent_access(db, current_user, agent_id)
task = Task(
agent_id=agent_id,
title=data.title,
description=data.description,
type=data.type,
priority=data.priority,
due_date=data.due_date,
created_by=current_user.id,
supervision_target_name=data.supervision_target_name,
supervision_channel=data.supervision_channel,
remind_schedule=data.remind_schedule,
)
db.add(task)
await db.flush()
task_out = await _enrich_task_out(task, db)
# Commit so the background executor can see the task in its own session
await db.commit()
# Fire background execution for todo tasks
if data.type == "todo":
import asyncio
from app.services.task_executor import execute_task
asyncio.create_task(execute_task(task.id, agent_id))
return task_out
@router.patch("/{task_id}", response_model=TaskOut)
async def update_task(
agent_id: uuid.UUID,
task_id: uuid.UUID,
data: TaskUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Update a task."""
await check_agent_access(db, current_user, agent_id)
result = await db.execute(select(Task).where(Task.id == task_id, Task.agent_id == agent_id))
task = result.scalar_one_or_none()
if not task:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Task not found")
for field, value in data.model_dump(exclude_unset=True).items():
setattr(task, field, value)
await db.flush()
return await _enrich_task_out(task, db)
@router.get("/{task_id}/logs", response_model=list[TaskLogOut])
async def get_task_logs(
agent_id: uuid.UUID,
task_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get progress logs for a task."""
await check_agent_access(db, current_user, agent_id)
result = await db.execute(
select(TaskLog).where(TaskLog.task_id == task_id).order_by(TaskLog.created_at.asc())
)
return [TaskLogOut.model_validate(l) for l in result.scalars().all()]
@router.post("/{task_id}/logs", response_model=TaskLogOut, status_code=status.HTTP_201_CREATED)
async def add_task_log(
agent_id: uuid.UUID,
task_id: uuid.UUID,
data: TaskLogCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Add a progress log entry to a task."""
await check_agent_access(db, current_user, agent_id)
log = TaskLog(task_id=task_id, content=data.content)
db.add(log)
await db.flush()
return TaskLogOut.model_validate(log)
@router.post("/{task_id}/trigger")
async def trigger_task(
agent_id: uuid.UUID,
task_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Manually trigger a supervision task execution (for testing)."""
from app.core.permissions import is_agent_expired
agent, _access = await check_agent_access(db, current_user, agent_id)
if is_agent_expired(agent):
raise HTTPException(status_code=403, detail="Agent has expired")
result = await db.execute(select(Task).where(Task.id == task_id, Task.agent_id == agent_id))
task = result.scalar_one_or_none()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
import asyncio
from app.services.task_executor import execute_task
asyncio.create_task(execute_task(task.id, agent_id))
return {"status": "triggered", "task_id": str(task_id)}