175 lines
5.9 KiB
Python
175 lines
5.9 KiB
Python
"""Task management API routes."""
|
|
|
|
import uuid
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.permissions import check_agent_access
|
|
from app.core.security import get_current_user
|
|
from app.database import get_db
|
|
from app.models.task import Task, TaskLog
|
|
from app.models.user import User
|
|
from app.schemas.schemas import TaskCreate, TaskLogCreate, TaskLogOut, TaskOut, TaskUpdate
|
|
|
|
router = APIRouter(prefix="/agents/{agent_id}/tasks", tags=["tasks"])
|
|
|
|
|
|
async def _enrich_task_out(task: Task, db: AsyncSession) -> TaskOut:
|
|
"""Convert Task to TaskOut with creator_username populated."""
|
|
out = TaskOut.model_validate(task)
|
|
if task.created_by:
|
|
user_result = await db.execute(select(User).where(User.id == task.created_by))
|
|
user = user_result.scalar_one_or_none()
|
|
if user:
|
|
out.creator_username = user.username
|
|
return out
|
|
|
|
|
|
@router.get("/", response_model=list[TaskOut])
|
|
async def list_tasks(
|
|
agent_id: uuid.UUID,
|
|
status_filter: str | None = None,
|
|
type_filter: str | None = None,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""List tasks for an agent."""
|
|
await check_agent_access(db, current_user, agent_id)
|
|
query = select(Task).where(Task.agent_id == agent_id)
|
|
if status_filter:
|
|
query = query.where(Task.status == status_filter)
|
|
if type_filter:
|
|
query = query.where(Task.type == type_filter)
|
|
query = query.order_by(Task.created_at.desc())
|
|
result = await db.execute(query)
|
|
tasks_list = result.scalars().all()
|
|
# Batch-load creator usernames
|
|
creator_ids = {t.created_by for t in tasks_list if t.created_by}
|
|
creator_map = {}
|
|
if creator_ids:
|
|
users_result = await db.execute(select(User).where(User.id.in_(creator_ids)))
|
|
creator_map = {u.id: u.username for u in users_result.scalars().all()}
|
|
out_list = []
|
|
for t in tasks_list:
|
|
t_out = TaskOut.model_validate(t)
|
|
t_out.creator_username = creator_map.get(t.created_by)
|
|
out_list.append(t_out)
|
|
return out_list
|
|
|
|
|
|
@router.post("/", response_model=TaskOut, status_code=status.HTTP_201_CREATED)
|
|
async def create_task(
|
|
agent_id: uuid.UUID,
|
|
data: TaskCreate,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Create a new task for an agent."""
|
|
await check_agent_access(db, current_user, agent_id)
|
|
task = Task(
|
|
agent_id=agent_id,
|
|
title=data.title,
|
|
description=data.description,
|
|
type=data.type,
|
|
priority=data.priority,
|
|
due_date=data.due_date,
|
|
created_by=current_user.id,
|
|
supervision_target_name=data.supervision_target_name,
|
|
supervision_channel=data.supervision_channel,
|
|
remind_schedule=data.remind_schedule,
|
|
)
|
|
db.add(task)
|
|
await db.flush()
|
|
|
|
task_out = await _enrich_task_out(task, db)
|
|
|
|
# Commit so the background executor can see the task in its own session
|
|
await db.commit()
|
|
|
|
# Fire background execution for todo tasks
|
|
if data.type == "todo":
|
|
import asyncio
|
|
from app.services.task_executor import execute_task
|
|
asyncio.create_task(execute_task(task.id, agent_id))
|
|
|
|
return task_out
|
|
|
|
|
|
@router.patch("/{task_id}", response_model=TaskOut)
|
|
async def update_task(
|
|
agent_id: uuid.UUID,
|
|
task_id: uuid.UUID,
|
|
data: TaskUpdate,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Update a task."""
|
|
await check_agent_access(db, current_user, agent_id)
|
|
result = await db.execute(select(Task).where(Task.id == task_id, Task.agent_id == agent_id))
|
|
task = result.scalar_one_or_none()
|
|
if not task:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Task not found")
|
|
|
|
for field, value in data.model_dump(exclude_unset=True).items():
|
|
setattr(task, field, value)
|
|
await db.flush()
|
|
return await _enrich_task_out(task, db)
|
|
|
|
|
|
@router.get("/{task_id}/logs", response_model=list[TaskLogOut])
|
|
async def get_task_logs(
|
|
agent_id: uuid.UUID,
|
|
task_id: uuid.UUID,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Get progress logs for a task."""
|
|
await check_agent_access(db, current_user, agent_id)
|
|
result = await db.execute(
|
|
select(TaskLog).where(TaskLog.task_id == task_id).order_by(TaskLog.created_at.asc())
|
|
)
|
|
return [TaskLogOut.model_validate(l) for l in result.scalars().all()]
|
|
|
|
|
|
@router.post("/{task_id}/logs", response_model=TaskLogOut, status_code=status.HTTP_201_CREATED)
|
|
async def add_task_log(
|
|
agent_id: uuid.UUID,
|
|
task_id: uuid.UUID,
|
|
data: TaskLogCreate,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Add a progress log entry to a task."""
|
|
await check_agent_access(db, current_user, agent_id)
|
|
log = TaskLog(task_id=task_id, content=data.content)
|
|
db.add(log)
|
|
await db.flush()
|
|
return TaskLogOut.model_validate(log)
|
|
|
|
|
|
@router.post("/{task_id}/trigger")
|
|
async def trigger_task(
|
|
agent_id: uuid.UUID,
|
|
task_id: uuid.UUID,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Manually trigger a supervision task execution (for testing)."""
|
|
from app.core.permissions import is_agent_expired
|
|
agent, _access = await check_agent_access(db, current_user, agent_id)
|
|
if is_agent_expired(agent):
|
|
raise HTTPException(status_code=403, detail="Agent has expired")
|
|
|
|
result = await db.execute(select(Task).where(Task.id == task_id, Task.agent_id == agent_id))
|
|
task = result.scalar_one_or_none()
|
|
if not task:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
import asyncio
|
|
from app.services.task_executor import execute_task
|
|
asyncio.create_task(execute_task(task.id, agent_id))
|
|
|
|
return {"status": "triggered", "task_id": str(task_id)}
|