# src/utils/context_manager.py import copy import json import logging from typing import List from langgraph.runtime import Runtime from langchain_core.messages import ( AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage, ) from src.config import load_yaml_config logger = logging.getLogger(__name__) def get_search_config(): config = load_yaml_config("conf.yaml") search_config = config.get("MODEL_TOKEN_LIMITS", {}) return search_config class ContextManager: """Context manager and compression class""" def __init__(self, token_limit: int, preserve_prefix_message_count: int = 0): """ Initialize ContextManager Args: token_limit: Maximum token limit preserve_prefix_message_count: Number of messages to preserve at the beginning of the context """ self.token_limit = token_limit self.preserve_prefix_message_count = preserve_prefix_message_count def count_tokens(self, messages: List[BaseMessage]) -> int: """ Count tokens in message list Args: messages: List of messages Returns: Number of tokens """ total_tokens = 0 for message in messages: total_tokens += self._count_message_tokens(message) return total_tokens def _count_message_tokens(self, message: BaseMessage) -> int: """ Count tokens in a single message Args: message: Message object Returns: Number of tokens """ # Estimate token count based on character length (different calculation for English and non-English) token_count = 0 # Count tokens in content field if hasattr(message, "content") and message.content: # Handle different content types if isinstance(message.content, str): token_count += self._count_text_tokens(message.content) # Count role-related tokens if hasattr(message, "type"): token_count += self._count_text_tokens(message.type) # Special handling for different message types if isinstance(message, SystemMessage): # System messages are usually short but important, slightly increase estimate token_count = int(token_count * 1.1) elif isinstance(message, HumanMessage): # Human messages use normal estimation pass elif isinstance(message, AIMessage): # AI messages may contain reasoning content, slightly increase estimate token_count = int(token_count * 1.2) elif isinstance(message, ToolMessage): # Tool messages may contain large amounts of structured data, increase estimate token_count = int(token_count * 1.3) # Process additional information in additional_kwargs if hasattr(message, "additional_kwargs") and message.additional_kwargs: # Simple estimation of extra field tokens extra_str = str(message.additional_kwargs) token_count += self._count_text_tokens(extra_str) # If there are tool_calls, add estimation if "tool_calls" in message.additional_kwargs: token_count += 50 # Add estimation for function call information # Ensure at least 1 token return max(1, token_count) def _count_text_tokens(self, text: str) -> int: """ Count tokens in text with different calculations for English and non-English characters. English characters: 4 characters ≈ 1 token Non-English characters (e.g., Chinese): 1 character ≈ 1 token Args: text: Text to count tokens for Returns: Number of tokens """ if not text: return 0 english_chars = 0 non_english_chars = 0 for char in text: # Check if character is ASCII (English letters, digits, punctuation) if ord(char) < 128: english_chars += 1 else: non_english_chars += 1 # Calculate tokens: English at 4 chars/token, others at 1 char/token english_tokens = english_chars // 4 non_english_tokens = non_english_chars return english_tokens + non_english_tokens def is_over_limit(self, messages: List[BaseMessage]) -> bool: """ Check if messages exceed token limit Args: messages: List of messages Returns: Whether limit is exceeded """ return self.count_tokens(messages) > self.token_limit def compress_messages(self, state: dict, runtime: Runtime | None = None) -> dict: """ Compress messages to fit within token limit Args: state: state with original messages runtime: Optional runtime parameter (not used but required for middleware compatibility) Returns: Compressed state with compressed messages """ # If not set token_limit, return original state if self.token_limit is None: logger.info("No token_limit set, the context management doesn't work.") return state if not isinstance(state, dict) or "messages" not in state: logger.warning("No messages found in state") return state messages = state["messages"] if not self.is_over_limit(messages): logger.debug(f"Messages within limit ({self.count_tokens(messages)} <= {self.token_limit} tokens)") return state # Compress messages original_token_count = self.count_tokens(messages) compressed_messages = self._compress_messages(messages) compressed_token_count = self.count_tokens(compressed_messages) logger.warning( f"Message compression executed (Issue #721): {original_token_count} -> {compressed_token_count} tokens " f"(limit: {self.token_limit}), {len(messages)} -> {len(compressed_messages)} messages" ) state["messages"] = compressed_messages return state def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]: """ Compress messages to fit within token limit through two strategies: 1. First, compress web_search ToolMessage raw_content by truncating to 1024 chars 2. If still over limit, drop oldest messages while preserving prefix messages and system messages Args: messages: List of messages to compress Returns: List of messages with compressed content and/or dropped messages """ # Create a deep copy to avoid mutating original messages compressed = copy.deepcopy(messages) # Step 1: Compress raw_content in web_search ToolMessages for msg in compressed: # Only compress ToolMessage with name 'web_search' if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "web_search": try: # Determine content type and check if compression is needed if isinstance(msg.content, str): # Early exit if content is small enough (avoid JSON parsing overhead) # A heuristic: if string is less than 2KB, raw_content likely doesn't need truncation if len(msg.content) < 2048: continue try: content_data = json.loads(msg.content) except json.JSONDecodeError as e: logger.error(f"Failed to parse JSON content in web_search ToolMessage: {e}. Content: {msg.content[:200]}") continue elif isinstance(msg.content, list): content_data = copy.deepcopy(msg.content) else: continue # Compress raw_content in the content (item by item processing) # Track if any modifications were made modified = False if isinstance(content_data, list): for item in content_data: if isinstance(item, dict) and "raw_content" in item: raw_content = item.get("raw_content") if raw_content and isinstance(raw_content, str) and len(raw_content) > 1024: item["raw_content"] = raw_content[:1024] modified = True # Update message content with modified data only if changes were made if modified: msg.content = json.dumps(content_data, ensure_ascii=False) except Exception as e: logger.error(f"Unexpected error during message compression: {e}") continue # Step 2: If still over limit after raw_content compression, drop oldest messages # while preserving prefix messages (e.g., system message) and recent messages if self.is_over_limit(compressed): # Identify messages to preserve at the beginning preserved_count = self.preserve_prefix_message_count preserved_messages = compressed[:preserved_count] remaining_messages = compressed[preserved_count:] # Drop messages from the middle, keeping the most recent ones result_messages = preserved_messages for msg in reversed(remaining_messages): result_messages.insert(len(preserved_messages), msg) if not self.is_over_limit(result_messages): break compressed = result_messages # Step 3: Verify that compression was successful and log warning if needed if self.is_over_limit(compressed): current_tokens = self.count_tokens(compressed) logger.warning( f"Message compression failed to bring tokens below limit: " f"{current_tokens} > {self.token_limit} tokens. " f"Total messages: {len(compressed)}. " f"Consider increasing token_limit or preserve_prefix_message_count." ) return compressed def _create_summary_message(self, messages: List[BaseMessage]) -> BaseMessage: """ Create summary for messages Args: messages: Messages to summarize Returns: Summary message """ # TODO: summary implementation pass def validate_message_content(messages: List[BaseMessage], max_content_length: int = 100000) -> List[BaseMessage]: """ Validate and fix all messages to ensure they have valid content before sending to LLM. This function ensures: 1. All messages have a content field 2. No message has None or empty string content (except for legitimate empty responses) 3. Complex objects (lists, dicts) are converted to JSON strings 4. Content is truncated if too long to prevent token overflow Args: messages: List of messages to validate max_content_length: Maximum allowed content length per message (default 100000) Returns: List of validated messages with fixed content """ validated = [] for i, msg in enumerate(messages): try: # Check if message has content attribute if not hasattr(msg, 'content'): logger.warning(f"Message {i} ({type(msg).__name__}) has no content attribute") msg.content = "" # Handle None content elif msg.content is None: logger.warning(f"Message {i} ({type(msg).__name__}) has None content, setting to empty string") msg.content = "" # Handle complex content types (convert to JSON) elif isinstance(msg.content, (list, dict)): logger.debug(f"Message {i} ({type(msg).__name__}) has complex content type {type(msg.content).__name__}, converting to JSON") msg.content = json.dumps(msg.content, ensure_ascii=False) # Handle other non-string types elif not isinstance(msg.content, str): logger.debug(f"Message {i} ({type(msg).__name__}) has non-string content type {type(msg.content).__name__}, converting to string") msg.content = str(msg.content) # Validate content length if isinstance(msg.content, str) and len(msg.content) > max_content_length: logger.warning(f"Message {i} content truncated from {len(msg.content)} to {max_content_length} chars") msg.content = msg.content[:max_content_length].rstrip() + "..." validated.append(msg) except Exception as e: logger.error(f"Error validating message {i}: {e}") # Create a safe fallback message if isinstance(msg, ToolMessage): msg.content = json.dumps({"error": str(e)}, ensure_ascii=False) else: msg.content = f"[Error processing message: {str(e)}]" validated.append(msg) return validated