77 lines
2.5 KiB
Python
77 lines
2.5 KiB
Python
"""Debounced queue for per-thread memory updates."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import threading
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
from deerflow.config.thread_memory_config import get_thread_memory_config
|
|
|
|
|
|
@dataclass
|
|
class ThreadConversationContext:
|
|
thread_id: str
|
|
messages: list[Any]
|
|
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
|
|
|
|
|
class ThreadMemoryUpdateQueue:
|
|
def __init__(self):
|
|
self._queue_by_thread: dict[str, ThreadConversationContext] = {}
|
|
self._lock = threading.Lock()
|
|
self._timers: dict[str, threading.Timer] = {}
|
|
self._processing_threads: set[str] = set()
|
|
|
|
def add(self, thread_id: str, messages: list[Any]) -> None:
|
|
config = get_thread_memory_config()
|
|
if not config.enabled:
|
|
return
|
|
with self._lock:
|
|
self._queue_by_thread[thread_id] = ThreadConversationContext(thread_id=thread_id, messages=messages)
|
|
self._reset_timer(thread_id)
|
|
|
|
def _reset_timer(self, thread_id: str) -> None:
|
|
config = get_thread_memory_config()
|
|
timer = self._timers.get(thread_id)
|
|
if timer is not None:
|
|
timer.cancel()
|
|
timer = threading.Timer(config.debounce_seconds, self._process_thread, args=(thread_id,))
|
|
timer.daemon = True
|
|
self._timers[thread_id] = timer
|
|
timer.start()
|
|
|
|
def _process_thread(self, thread_id: str) -> None:
|
|
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
|
|
|
|
with self._lock:
|
|
if thread_id in self._processing_threads:
|
|
self._reset_timer(thread_id)
|
|
return
|
|
context = self._queue_by_thread.pop(thread_id, None)
|
|
if context is None:
|
|
self._timers.pop(thread_id, None)
|
|
return
|
|
self._processing_threads.add(thread_id)
|
|
self._timers.pop(thread_id, None)
|
|
|
|
try:
|
|
updater = ThreadMemoryUpdater()
|
|
updater.update_memory(context.messages, context.thread_id)
|
|
finally:
|
|
with self._lock:
|
|
self._processing_threads.discard(thread_id)
|
|
|
|
|
|
_thread_queue: ThreadMemoryUpdateQueue | None = None
|
|
_lock = threading.Lock()
|
|
|
|
|
|
def get_thread_memory_queue() -> ThreadMemoryUpdateQueue:
|
|
global _thread_queue
|
|
with _lock:
|
|
if _thread_queue is None:
|
|
_thread_queue = ThreadMemoryUpdateQueue()
|
|
return _thread_queue
|