diff --git a/src/openlayer/lib/integrations/langchain_callback.py b/src/openlayer/lib/integrations/langchain_callback.py index 41fd1e29..9d995570 100644 --- a/src/openlayer/lib/integrations/langchain_callback.py +++ b/src/openlayer/lib/integrations/langchain_callback.py @@ -7,15 +7,14 @@ from uuid import UUID try: - try: - from langchain_core import messages as langchain_schema - from langchain_core.callbacks.base import ( - AsyncCallbackHandler, - BaseCallbackHandler, - ) - except ImportError: - from langchain import schema as langchain_schema - from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler + # langchain-core is sufficient for this handler. The legacy + # ``langchain.schema`` / ``langchain.callbacks.base`` import paths were + # removed in LangChain v1, so we import from ``langchain_core`` directly. + from langchain_core import messages as langchain_schema + from langchain_core.callbacks.base import ( + AsyncCallbackHandler, + BaseCallbackHandler, + ) HAVE_LANGCHAIN = True except ImportError: @@ -33,6 +32,34 @@ "amazon_bedrock_converse_chat": "Bedrock", } +# LangChain v1 injects a standardized ``metadata["ls_provider"]`` (e.g. +# "openai", "anthropic", "google_genai", "groq"...). This maps those values to +# Openlayer's canonical provider names. Values without an explicit entry fall +# back to a title-cased version of the ls_provider string. +LS_PROVIDER_TO_OPENLAYER_MAP = { + "openai": "OpenAI", + "azure": "Azure", + "azure_openai": "Azure", + "anthropic": "Anthropic", + "google": "Google", + "google_genai": "Google", + "google_vertexai": "Google", + "vertexai": "Google", + "cohere": "Cohere", + "mistralai": "Mistral", + "mistral": "Mistral", + "bedrock": "Bedrock", + "amazon_bedrock": "Bedrock", + "ollama": "Ollama", + "huggingface": "Hugging Face", + "replicate": "Replicate", + "together": "Together AI", + "groq": "Groq", + "deepseek": "DeepSeek", + "fireworks": "Fireworks AI", + "perplexity": "Perplexity", +} + # LiteLLM model prefixes to provider names. # When models are accessed via a LiteLLM proxy (e.g. "gemini/gemini-2.5-flash"), # the LangChain _type is "openai-chat" which incorrectly maps to "OpenAI". @@ -71,10 +98,7 @@ class OpenlayerHandlerMixin: def __init__(self, **kwargs: Any) -> None: if not HAVE_LANGCHAIN: - raise ImportError( - "LangChain library is not installed. Please install it with: pip " - "install langchain" - ) + raise ImportError("LangChain library is not installed. Please install it with: pip install langchain-core") super().__init__() self.metadata: Dict[str, Any] = kwargs or {} self.steps: Dict[UUID, steps.Step] = {} @@ -89,6 +113,9 @@ def __init__(self, **kwargs: Any) -> None: self._inference_id = kwargs.get("inference_id") # Extract metadata_transformer from kwargs if provided self._metadata_transformer = kwargs.get("metadata_transformer") + # Opt-out flag: auto-map LangGraph metadata["thread_id"] to the trace's + # session_id when the user has not explicitly provided one. + self._map_thread_id_to_session = kwargs.get("map_thread_id_to_session", True) def _start_step( self, @@ -197,17 +224,11 @@ def _end_step( # Only upload if this is a standalone trace (not integrated with external trace) # If current_step is set, we're part of a larger trace and shouldn't upload - if ( - is_root_step - and run_id in self._traces_by_root - and tracer.get_current_step() is None - ): + if is_root_step and run_id in self._traces_by_root and tracer.get_current_step() is None: trace = self._traces_by_root.pop(run_id) if tracer._resolve("background_publish_enabled"): ctx = contextvars.copy_context() - tracer._get_background_executor().submit( - ctx.run, self._process_and_upload_trace, trace - ) + tracer._get_background_executor().submit(ctx.run, self._process_and_upload_trace, trace) else: self._process_and_upload_trace(trace) @@ -314,11 +335,7 @@ def _convert_langchain_objects(self, obj: Any) -> Any: return self._message_to_dict(obj) # Handle ChatPromptValue objects which contain messages - if ( - hasattr(obj, "messages") - and hasattr(obj, "__class__") - and "ChatPromptValue" in obj.__class__.__name__ - ): + if hasattr(obj, "messages") and hasattr(obj, "__class__") and "ChatPromptValue" in obj.__class__.__name__: return [self._convert_langchain_objects(msg) for msg in obj.messages] # Handle dictionaries @@ -353,9 +370,7 @@ def _convert_langchain_objects(self, obj: Any) -> Any: pass # Handle objects with content attribute - if hasattr(obj, "content") and not isinstance( - obj, langchain_schema.BaseMessage - ): + if hasattr(obj, "content") and not isinstance(obj, langchain_schema.BaseMessage): return obj.content # Handle objects with value attribute @@ -373,9 +388,7 @@ def _convert_langchain_objects(self, obj: Any) -> Any: # For everything else, convert to string return str(obj) - def _message_to_dict( - self, message: "langchain_schema.BaseMessage" - ) -> Dict[str, str]: + def _message_to_dict(self, message: "langchain_schema.BaseMessage") -> Dict[str, Any]: """Convert a LangChain message to a JSON-serializable dictionary.""" message_type = getattr(message, "type", "user") @@ -385,11 +398,61 @@ def _message_to_dict( elif message_type == "system": role = "system" - return {"role": role, "content": str(message.content)} + content_str, content_blocks = self._normalize_content(message.content) + result: Dict[str, Any] = {"role": role, "content": content_str} + + # With ``LC_OUTPUT_VERSION=v1`` the content can be a list of typed + # blocks. Preserve any non-text blocks (reasoning, tool_call, image...) + # structurally rather than stringifying the whole list. + if content_blocks: + result["content_blocks"] = content_blocks + + # Preserve tool calls on assistant messages (e.g. AIMessage.tool_calls). + # Without this, agent turns that only emit tool calls would drop them + # from the recorded inputs. Done defensively so it never raises on + # messages that lack the attribute. + tool_calls = getattr(message, "tool_calls", None) + if tool_calls: + result["tool_calls"] = tool_calls + + return result + + @staticmethod + def _normalize_content(content: Any) -> tuple: + """Normalize message content into a (text, non_text_blocks) pair. + + Handles the three shapes ``message.content`` can take: + * a plain string (returned unchanged, no blocks); + * a list of strings (joined into the text); + * a list of typed blocks (dicts with a ``type`` key) as produced + under ``LC_OUTPUT_VERSION=v1`` -- text is extracted from + ``text``-type blocks while non-text blocks are preserved + structurally. + + Defensive: never raises on unexpected shapes. + """ + if isinstance(content, str): + return content, [] + + if isinstance(content, list): + text_parts: List[str] = [] + non_text_blocks: List[Any] = [] + for block in content: + if isinstance(block, str): + text_parts.append(block) + elif isinstance(block, dict): + if block.get("type") == "text" and "text" in block: + text_parts.append(str(block.get("text", ""))) + else: + non_text_blocks.append(block) + else: + non_text_blocks.append(block) + return "".join(text_parts), non_text_blocks + + # Fallback for any other shape. + return str(content), [] - def _messages_to_prompt_format( - self, messages: List[List["langchain_schema.BaseMessage"]] - ) -> List[Dict[str, str]]: + def _messages_to_prompt_format(self, messages: List[List["langchain_schema.BaseMessage"]]) -> List[Dict[str, str]]: """Convert LangChain messages to Openlayer prompt format using unified conversion.""" prompt = [] @@ -410,9 +473,17 @@ def _extract_model_info( invocation_params = invocation_params or {} metadata = metadata or {} - provider = invocation_params.get("_type") - if provider in LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP: - provider = LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP[provider] + # Provider resolution order: + # 1. metadata["ls_provider"] (standardized by LangChain v1) + # 2. invocation_params["_type"] via the legacy _type map + # 3. LiteLLM model prefix (handled below) + ls_provider = metadata.get("ls_provider") + if ls_provider: + provider = LS_PROVIDER_TO_OPENLAYER_MAP.get(ls_provider, str(ls_provider).replace("_", " ").title()) + else: + provider = invocation_params.get("_type") + if provider in LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP: + provider = LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP[provider] model = ( invocation_params.get("model_name") @@ -431,9 +502,7 @@ def _extract_model_info( model = model_name # Clean invocation params (remove internal LangChain params) - clean_params = { - k: v for k, v in invocation_params.items() if not k.startswith("_") - } + clean_params = {k: v for k, v in invocation_params.items() if not k.startswith("_")} return { "provider": provider, @@ -441,18 +510,46 @@ def _extract_model_info( "model_parameters": clean_params, } - def _extract_token_info( - self, response: "langchain_schema.LLMResult" - ) -> Dict[str, Any]: - """Extract token information generically from LLM response.""" - llm_output = response.llm_output or {} + def _extract_token_info(self, response: "langchain_schema.LLMResult") -> Dict[str, Any]: + """Extract token information generically from LLM response. - # Try standard token_usage location first - token_usage = ( - llm_output.get("token_usage") or llm_output.get("estimatedTokens") or {} - ) + In LangChain v1, ``AIMessage.usage_metadata`` is the guaranteed, + standardized token source, so it is read FIRST. The provider-specific + ``llm_output`` / ``generation_info`` paths (Ollama ``prompt_eval_count``, + Google ``usage_metadata`` in generation_info) remain as fallbacks so + nothing regresses. - # Fallback to generation info for providers like Ollama/Google + When ``usage_metadata`` carries ``input_token_details`` / + ``output_token_details`` (e.g. ``cache_creation``, ``cache_read``, + ``reasoning``, ``audio``), those are surfaced under a ``token_details`` + key so cost is accurate for prompt-caching / reasoning models. + """ + token_usage: Dict[str, Any] = {} + token_details: Dict[str, Any] = {} + + # 1. usage_metadata on the message (standardized, v1-first). + if response.generations: + gen = response.generations[0][0] + usage = getattr(getattr(gen, "message", None), "usage_metadata", None) + if usage: + token_usage = { + "prompt_tokens": usage.get("input_tokens", 0), + "completion_tokens": usage.get("output_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + } + input_details = usage.get("input_token_details") + if input_details: + token_details["input_token_details"] = dict(input_details) + output_details = usage.get("output_token_details") + if output_details: + token_details["output_token_details"] = dict(output_details) + + # 2. Fall back to llm_output for providers that surface it there. + if not token_usage: + llm_output = response.llm_output or {} + token_usage = llm_output.get("token_usage") or llm_output.get("estimatedTokens") or {} + + # 3. Fall back to generation_info for providers like Ollama/Google. if not token_usage and response.generations: gen = response.generations[0][0] generation_info = gen.generation_info or {} @@ -474,30 +571,56 @@ def _extract_token_info( "completion_tokens": usage.get("candidates_token_count", 0), "total_tokens": usage.get("total_token_count", 0), } - # AWS Bedrock / newer LangChain style - usage_metadata on the message - elif hasattr(gen, "message") and hasattr(gen.message, "usage_metadata"): - usage = gen.message.usage_metadata - if usage: - token_usage = { - "prompt_tokens": usage.get("input_tokens", 0), - "completion_tokens": usage.get("output_tokens", 0), - "total_tokens": usage.get("total_tokens", 0), - } - return { + result = { "prompt_tokens": token_usage.get("prompt_tokens", 0), "completion_tokens": token_usage.get("completion_tokens", 0), "tokens": token_usage.get("total_tokens", 0), } + if token_details: + result["token_details"] = token_details + return result def _extract_output(self, response: "langchain_schema.LLMResult") -> str: - """Extract output text from LLM response.""" + """Extract output text from LLM response. + + When a generation carries no text (e.g. an agent turn whose model + response is *only* tool calls, as in LangGraph / ``create_agent``), + fall back to serializing the message's tool calls so the step output + is not empty. + """ output = "" for generations in response.generations: for generation in generations: - output += generation.text.replace("\n", " ") + text = generation.text or "" + if text: + output += text.replace("\n", " ") + else: + output += self._tool_calls_to_str(generation) return output + def _tool_calls_to_str(self, generation: Any) -> str: + """Serialize tool calls from a generation's message, if any. + + Done defensively (getattr) so it never raises on generations or + messages that lack a ``message`` / ``tool_calls`` attribute, and works + whether or not langchain uses the v1 message shape. + """ + message = getattr(generation, "message", None) + if message is None: + return "" + + tool_calls = getattr(message, "tool_calls", None) + if not tool_calls: + return "" + + import json + + try: + return json.dumps(tool_calls, default=str) + except (TypeError, ValueError): + return str(tool_calls) + def _safe_parse_json(self, input_str: str) -> Any: """Safely parse JSON string, returning the string if parsing fails.""" try: @@ -523,9 +646,7 @@ def _handle_llm_start( ) -> Any: """Common logic for LLM start.""" invocation_params = kwargs.get("invocation_params", {}) - model_info = self._extract_model_info( - serialized, invocation_params, metadata or {} - ) + model_info = self._extract_model_info(serialized, invocation_params, metadata or {}) step_name = f"{model_info['provider'] or 'LLM'} Chat Completion" prompt = [{"role": "user", "content": text} for text in prompts] @@ -539,6 +660,7 @@ def _handle_llm_start( metadata={"tags": tags} if tags else None, **model_info, ) + self._apply_thread_id_session(run_id, metadata) def _handle_chat_model_start( self, @@ -554,9 +676,7 @@ def _handle_chat_model_start( ) -> Any: """Common logic for chat model start.""" invocation_params = kwargs.get("invocation_params", {}) - model_info = self._extract_model_info( - serialized, invocation_params, metadata or {} - ) + model_info = self._extract_model_info(serialized, invocation_params, metadata or {}) # Always use provider-based name for chat completions (e.g. "Google Chat Completion") # rather than the run_name from the caller (e.g. "Language Model") which is generic. @@ -572,6 +692,7 @@ def _handle_chat_model_start( metadata={"tags": tags} if tags else None, **model_info, ) + self._apply_thread_id_session(run_id, metadata) def _handle_llm_end( self, @@ -591,13 +712,16 @@ def _handle_llm_end( # Only extract token info if it hasn't been set during streaming step = self.steps[run_id] token_info = {} - if not ( - hasattr(step, "prompt_tokens") - and step.prompt_tokens is not None - and step.prompt_tokens > 0 - ): + if not (hasattr(step, "prompt_tokens") and step.prompt_tokens is not None and step.prompt_tokens > 0): token_info = self._extract_token_info(response) + # ChatCompletionStep has no dedicated field for fine-grained token + # details (cache_read/creation, reasoning, audio...), so surface them + # under step metadata where cost-relevant breakdowns can be read. + token_details = token_info.pop("token_details", None) + if token_details: + step.metadata = {**step.metadata, "token_details": token_details} + self._end_step( run_id=run_id, parent_run_id=parent_run_id, @@ -632,10 +756,22 @@ def _handle_chain_start( # Extract chain name from serialized data or use provided name # Handle case where serialized can be None serialized = serialized or {} + metadata = metadata or {} + + # Resolve the chain's display name. Prefer the runnable's own ``name`` + # (e.g. a node's name like "callAgent", or "RunnableSequence"/"Prompt" + # for nested LCEL runs), then fall back to LangGraph's injected + # ``metadata["langgraph_node"]`` only when no name is available, then the + # serialized id. This mirrors the TS handler's ``name ?? langgraph_node + # ?? id`` precedence. ``langgraph_node`` is inherited by *every* run + # nested inside a node, so preferring it over ``name`` would relabel all + # of a node's internal LCEL runs (RunnableSequence/Prompt/...) — and even + # nested sub-graphs — with the parent node's name. LangGraph already sets + # ``name`` to the node name at node boundaries, so this keeps nodes + # identifiable while preserving the inner structure's real names. + langgraph_node = metadata.get("langgraph_node") chain_name = ( - name - or (serialized.get("id", [])[-1] if serialized.get("id") else None) - or "Chain" + name or langgraph_node or (serialized.get("id", [])[-1] if serialized.get("id") else None) or "Chain" ) # Skip chains marked as hidden (e.g., internal LangGraph chains) @@ -656,10 +792,11 @@ def _handle_chain_start( metadata={ "tags": tags, "serialized": serialized, - **(metadata or {}), + **metadata, **kwargs, }, ) + self._apply_thread_id_session(run_id, metadata) def _handle_chain_end( self, @@ -695,9 +832,7 @@ def _handle_chain_end( step = self.steps[run_id] if step.name == "LangGraph" and outputs.get("messages"): if isinstance(outputs.get("messages"), list): - if isinstance( - outputs.get("messages")[-1], langchain_schema.BaseMessage - ): + if isinstance(outputs.get("messages")[-1], langchain_schema.BaseMessage): outputs = outputs.get("messages")[-1].content self._end_step( @@ -734,10 +869,7 @@ def _handle_tool_start( # Handle case where serialized can be None serialized = serialized or {} tool_name = ( - name - or serialized.get("name") - or (serialized.get("id", [])[-1] if serialized.get("id") else None) - or "Tool" + name or serialized.get("name") or (serialized.get("id", [])[-1] if serialized.get("id") else None) or "Tool" ) # Parse input - prefer structured inputs over string @@ -890,9 +1022,7 @@ def _handle_retriever_start( """Common logic for retriever start.""" # Handle case where serialized can be None serialized = serialized or {} - retriever_name = ( - serialized.get("id", [])[-1] if serialized.get("id") else "Retriever" - ) + retriever_name = serialized.get("id", [])[-1] if serialized.get("id") else "Retriever" self._start_step( run_id=run_id, @@ -929,6 +1059,11 @@ def _handle_retriever_end( else: doc_contents.append(str(doc)) + # Populate the trace-level `context` so the reserved column is set. + # Prefer an active (external) trace context; otherwise fall back to the + # standalone trace this handler owns for the run. Without the fallback, + # synchronous RAG pipelines (where the retriever is the root step and no + # external @trace context exists) would silently drop the context. current_trace = self._find_trace(run_id) if current_trace: current_trace.update_metadata(context=doc_contents) @@ -960,11 +1095,7 @@ def _handle_llm_new_token(self, token: str, **kwargs: Any) -> Any: """Common logic for LLM new token.""" # Safely check for chunk and usage_metadata chunk = kwargs.get("chunk") - if ( - chunk - and hasattr(chunk, "message") - and hasattr(chunk.message, "usage_metadata") - ): + if chunk and hasattr(chunk, "message") and hasattr(chunk.message, "usage_metadata"): usage = chunk.message.usage_metadata # Only proceed if usage is not None @@ -985,6 +1116,42 @@ def _handle_llm_new_token(self, token: str, **kwargs: Any) -> Any: step.log(**token_info) return + def _apply_thread_id_session(self, run_id: UUID, metadata: Optional[Dict[str, Any]]) -> None: + """Map LangGraph ``metadata["thread_id"]`` to the trace's session_id. + + Any LangGraph app with a checkpointer injects ``thread_id`` into + callback metadata. When ``map_thread_id_to_session`` is enabled (the + default) and the user has not explicitly set a session via + ``UserSessionContext``, the thread_id is written to the trace's + ``session_id`` metadata so it is promoted to the ``session_id`` column + by ``post_process_trace``. + + An explicitly-set session_id is never clobbered: ``post_process_trace`` + re-applies the ``UserSessionContext`` value after merging trace + metadata, and this method additionally skips writing when a context + session_id is present. + """ + if not self._map_thread_id_to_session or not metadata: + return + + thread_id = metadata.get("thread_id") + if not thread_id: + return + + # Do not clobber an explicitly-provided session_id. + if tracer.UserSessionContext.get_session_id() is not None: + return + + target_trace = self._find_trace(run_id) + if target_trace is None: + return + + existing = (target_trace.metadata or {}).get("session_id") + if existing: + return + + target_trace.update_metadata(session_id=str(thread_id)) + class OpenlayerHandler(OpenlayerHandlerMixin, BaseCallbackHandlerClass): # type: ignore[misc] """LangChain callback handler that logs to Openlayer.""" @@ -997,9 +1164,8 @@ def __init__( ignore_retriever=False, ignore_agent=False, inference_id: Optional[Any] = None, - metadata_transformer: Optional[ - Callable[[Dict[str, Any]], Dict[str, Any]] - ] = None, + metadata_transformer: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + map_thread_id_to_session: bool = True, **kwargs: Any, ) -> None: # Add both inference_id and metadata_transformer to kwargs so they get passed to mixin @@ -1007,6 +1173,7 @@ def __init__( kwargs["inference_id"] = inference_id if metadata_transformer is not None: kwargs["metadata_transformer"] = metadata_transformer + kwargs["map_thread_id_to_session"] = map_thread_id_to_session super().__init__(**kwargs) # Store the ignore flags as instance variables self._ignore_llm = ignore_llm @@ -1035,9 +1202,7 @@ def ignore_retriever(self) -> bool: def ignore_agent(self) -> bool: return self._ignore_agent - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> Any: + def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any: """Run when LLM starts running.""" if self.ignore_llm: return @@ -1060,9 +1225,7 @@ def on_llm_end(self, response: "langchain_schema.LLMResult", **kwargs: Any) -> A return return self._handle_llm_end(response, **kwargs) - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: + def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any: """Run when LLM errors.""" if self.ignore_llm: return @@ -1072,9 +1235,7 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> Any: """Run on new LLM token. Only available when streaming is enabled.""" return self._handle_llm_new_token(token, **kwargs) - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> Any: + def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any: """Run when chain starts running.""" if self.ignore_chain: return @@ -1086,17 +1247,13 @@ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: return return self._handle_chain_end(outputs, **kwargs) - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: + def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any: """Run when chain errors.""" if self.ignore_chain: return return self._handle_chain_error(error, **kwargs) - def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> Any: + def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any: """Run when tool starts running.""" if self.ignore_retriever: return @@ -1108,29 +1265,41 @@ def on_tool_end(self, output: str, **kwargs: Any) -> Any: return return self._handle_tool_end(output, **kwargs) - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: + def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any: """Run when tool errors.""" if self.ignore_retriever: return return self._handle_tool_error(error, **kwargs) + def on_retriever_start(self, serialized: Dict[str, Any], query: str, **kwargs: Any) -> Any: + """Run when retriever starts running.""" + if self.ignore_retriever: + return + return self._handle_retriever_start(serialized, query, **kwargs) + + def on_retriever_end(self, documents: List[Any], **kwargs: Any) -> Any: + """Run when retriever ends running.""" + if self.ignore_retriever: + return + return self._handle_retriever_end(documents, **kwargs) + + def on_retriever_error(self, error: Exception, **kwargs: Any) -> Any: + """Run when retriever errors.""" + if self.ignore_retriever: + return + return self._handle_retriever_error(error, **kwargs) + def on_text(self, text: str, **kwargs: Any) -> Any: """Run on arbitrary text.""" pass - def on_agent_action( - self, action: "langchain_schema.AgentAction", **kwargs: Any - ) -> Any: + def on_agent_action(self, action: "langchain_schema.AgentAction", **kwargs: Any) -> Any: """Run on agent action.""" if self.ignore_agent: return return self._handle_agent_action(action, **kwargs) - def on_agent_finish( - self, finish: "langchain_schema.AgentFinish", **kwargs: Any - ) -> Any: + def on_agent_finish(self, finish: "langchain_schema.AgentFinish", **kwargs: Any) -> Any: """Run on agent end.""" if self.ignore_agent: return @@ -1148,9 +1317,8 @@ def __init__( ignore_retriever=False, ignore_agent=False, inference_id: Optional[Any] = None, - metadata_transformer: Optional[ - Callable[[Dict[str, Any]], Dict[str, Any]] - ] = None, + metadata_transformer: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + map_thread_id_to_session: bool = True, **kwargs: Any, ) -> None: # Add both inference_id and metadata_transformer to kwargs so they get passed to mixin @@ -1158,6 +1326,7 @@ def __init__( kwargs["inference_id"] = inference_id if metadata_transformer is not None: kwargs["metadata_transformer"] = metadata_transformer + kwargs["map_thread_id_to_session"] = map_thread_id_to_session super().__init__(**kwargs) # Store the ignore flags as instance variables self._ignore_llm = ignore_llm @@ -1313,9 +1482,7 @@ def _end_step( trace = self._traces_by_root.pop(run_id) if tracer._resolve("background_publish_enabled"): ctx = contextvars.copy_context() - tracer._get_background_executor().submit( - ctx.run, self._process_and_upload_async_trace, trace - ) + tracer._get_background_executor().submit(ctx.run, self._process_and_upload_async_trace, trace) else: self._process_and_upload_async_trace(trace) @@ -1378,9 +1545,7 @@ def _process_and_upload_async_trace(self, trace: traces.Trace) -> None: tracer.logger.error("Could not stream data to Openlayer %s", err) # All callback methods remain the same - just delegate to mixin - async def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> Any: + async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any: if self.ignore_llm: return return self._handle_llm_start(serialized, prompts, **kwargs) @@ -1395,16 +1560,12 @@ async def on_chat_model_start( return return self._handle_chat_model_start(serialized, messages, **kwargs) - async def on_llm_end( - self, response: "langchain_schema.LLMResult", **kwargs: Any - ) -> Any: + async def on_llm_end(self, response: "langchain_schema.LLMResult", **kwargs: Any) -> Any: if self.ignore_llm: return return self._handle_llm_end(response, **kwargs) - async def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: + async def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any: if self.ignore_llm: return return self._handle_llm_error(error, **kwargs) @@ -1412,9 +1573,7 @@ async def on_llm_error( async def on_llm_new_token(self, token: str, **kwargs: Any) -> Any: return self._handle_llm_new_token(token, **kwargs) - async def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> Any: + async def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any: if self.ignore_chain: return return self._handle_chain_start(serialized, inputs, **kwargs) @@ -1424,16 +1583,12 @@ async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: return return self._handle_chain_end(outputs, **kwargs) - async def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: + async def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any: if self.ignore_chain: return return self._handle_chain_error(error, **kwargs) - async def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> Any: + async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any: if self.ignore_retriever: # Note: tool events use ignore_retriever flag return return self._handle_tool_start(serialized, input_str, **kwargs) @@ -1443,9 +1598,7 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any: return return self._handle_tool_end(output, **kwargs) - async def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: + async def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any: if self.ignore_retriever: return return self._handle_tool_error(error, **kwargs) @@ -1453,23 +1606,17 @@ async def on_tool_error( async def on_text(self, text: str, **kwargs: Any) -> Any: pass - async def on_agent_action( - self, action: "langchain_schema.AgentAction", **kwargs: Any - ) -> Any: + async def on_agent_action(self, action: "langchain_schema.AgentAction", **kwargs: Any) -> Any: if self.ignore_agent: return return self._handle_agent_action(action, **kwargs) - async def on_agent_finish( - self, finish: "langchain_schema.AgentFinish", **kwargs: Any - ) -> Any: + async def on_agent_finish(self, finish: "langchain_schema.AgentFinish", **kwargs: Any) -> Any: if self.ignore_agent: return return self._handle_agent_finish(finish, **kwargs) - async def on_retriever_start( - self, serialized: Dict[str, Any], query: str, **kwargs: Any - ) -> Any: + async def on_retriever_start(self, serialized: Dict[str, Any], query: str, **kwargs: Any) -> Any: if self.ignore_retriever: return return self._handle_retriever_start(serialized, query, **kwargs) diff --git a/tests/lib/__init__.py b/tests/lib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/lib/integrations/__init__.py b/tests/lib/integrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/lib/integrations/test_langchain_callback.py b/tests/lib/integrations/test_langchain_callback.py new file mode 100644 index 00000000..765c444a --- /dev/null +++ b/tests/lib/integrations/test_langchain_callback.py @@ -0,0 +1,585 @@ +"""Tests for the Openlayer LangChain callback handler. + +Regression coverage for OPEN-11315: + +1. The synchronous ``OpenlayerHandler`` must wire up the retriever callbacks + (``on_retriever_start`` / ``on_retriever_end`` / ``on_retriever_error``) so + synchronous RAG pipelines produce a RETRIEVER step and the auto-populated + ``context`` column on the trace. + +2. A chat-completion turn whose response contains ONLY tool calls (every agent + iteration in LangGraph / ``create_agent``) must still produce a non-empty + step output: ``_extract_output`` falls back to serializing the message's + tool calls, and ``_message_to_dict`` preserves ``tool_calls`` on the + assistant message. +""" + +# The handler under test lives in ``src/openlayer/lib`` (excluded from pyright +# via ``[tool.pyright] ignore``) and depends on LangChain, an OPTIONAL +# integration that is not installed in the lint/type-check environment. So its +# imports don't resolve there and its dynamic types surface as ``Unknown``. Use +# basic mode and disable the missing-import diagnostic for this test module; +# the runtime is guarded by ``pytest.importorskip`` below. +# pyright: basic, reportMissingImports=false + +from __future__ import annotations + +import uuid + +import pytest + +pytest.importorskip("langchain_core") + +from langchain_core.outputs import LLMResult, ChatGeneration +from langchain_core.messages import AIMessage, ToolMessage, HumanMessage +from langchain_core.documents import Document + +from openlayer.lib.tracing import enums, steps +from openlayer.lib.integrations.langchain_callback import OpenlayerHandler + + +@pytest.fixture(autouse=True) +def _disable_publish(monkeypatch: pytest.MonkeyPatch) -> None: + """Keep all tracer publish paths off during the test (fully offline).""" + monkeypatch.setenv("OPENLAYER_DISABLE_PUBLISH", "true") + monkeypatch.setenv("OPENLAYER_API_KEY", "fake") + + from openlayer.lib.tracing import tracer as _tracer + + monkeypatch.setattr(_tracer, "_publish", False, raising=False) + + +def _tool_call(name: str, args: dict, call_id: str) -> dict: + """Build a langchain_core tool-call dict.""" + return {"name": name, "args": args, "id": call_id, "type": "tool_call"} + + +# --------------------------------------------------------------------------- # +# Fix 2 (b): _message_to_dict preserves tool_calls +# --------------------------------------------------------------------------- # +class TestMessageToDict: + def test_preserves_tool_calls_on_ai_message(self) -> None: + handler = OpenlayerHandler() + tool_calls = [_tool_call("search", {"query": "openlayer"}, "call_1")] + message = AIMessage(content="", tool_calls=tool_calls) + + result = handler._message_to_dict(message) + + assert result["role"] == "assistant" + assert "tool_calls" in result, "tool_calls should be preserved" + assert result["tool_calls"] == tool_calls + + def test_human_message_without_tool_calls_is_backwards_compatible(self) -> None: + handler = OpenlayerHandler() + message = HumanMessage(content="hello there") + + result = handler._message_to_dict(message) + + assert result == {"role": "user", "content": "hello there"} + assert "tool_calls" not in result + + def test_ai_message_without_tool_calls_omits_key(self) -> None: + handler = OpenlayerHandler() + message = AIMessage(content="plain answer") + + result = handler._message_to_dict(message) + + assert result == {"role": "assistant", "content": "plain answer"} + assert "tool_calls" not in result + + def test_tool_message_does_not_raise(self) -> None: + handler = OpenlayerHandler() + message = ToolMessage(content="42", tool_call_id="call_1") + + result = handler._message_to_dict(message) + + assert result["content"] == "42" + assert "tool_calls" not in result + + +# --------------------------------------------------------------------------- # +# Fix 2 (a): _extract_output falls back to tool calls +# --------------------------------------------------------------------------- # +class TestExtractOutput: + def test_text_generation_returns_text(self) -> None: + handler = OpenlayerHandler() + gen = ChatGeneration(message=AIMessage(content="Hello world")) + response = LLMResult(generations=[[gen]]) + + assert handler._extract_output(response) == "Hello world" + + def test_tool_only_generation_returns_tool_calls(self) -> None: + handler = OpenlayerHandler() + tool_calls = [_tool_call("get_weather", {"city": "Rio"}, "call_42")] + message = AIMessage(content="", tool_calls=tool_calls) + gen = ChatGeneration(message=message) + response = LLMResult(generations=[[gen]]) + + output = handler._extract_output(response) + + assert output, "tool-only output must not be empty" + assert "get_weather" in output + + def test_tool_only_generation_via_callbacks(self) -> None: + """Drive on_chat_model_start / on_llm_end and assert step output is set.""" + handler = OpenlayerHandler() + run_id = uuid.uuid4() + + handler.on_chat_model_start( + serialized={"name": "gpt-4o"}, + messages=[[HumanMessage(content="What is the weather in Rio?")]], + run_id=run_id, + invocation_params={"model_name": "gpt-4o", "_type": "openai-chat"}, + ) + + # Capture the standalone trace before the root step is ended/popped. + trace = handler._traces_by_root[run_id] + step = handler.steps[run_id] + assert isinstance(step, steps.ChatCompletionStep) + + tool_calls = [_tool_call("get_weather", {"city": "Rio"}, "call_42")] + message = AIMessage(content="", tool_calls=tool_calls) + response = LLMResult(generations=[[ChatGeneration(message=message)]]) + + handler.on_llm_end(response, run_id=run_id) + + assert step.step_type == enums.StepType.CHAT_COMPLETION + assert step.output, "tool-only chat completion output must not be empty" + assert "get_weather" in step.output + assert trace.steps[0] is step + + +# --------------------------------------------------------------------------- # +# Fix 1: sync handler wires retriever callbacks +# --------------------------------------------------------------------------- # +class TestSyncRetrieverCallbacks: + def test_handler_exposes_retriever_callbacks(self) -> None: + handler = OpenlayerHandler() + assert hasattr(handler, "on_retriever_start") + assert hasattr(handler, "on_retriever_end") + assert hasattr(handler, "on_retriever_error") + + def test_sync_retriever_run_produces_step_and_context(self) -> None: + handler = OpenlayerHandler() + run_id = uuid.uuid4() + + handler.on_retriever_start( + serialized={"id": ["langchain", "retrievers", "VectorStoreRetriever"]}, + query="what is openlayer?", + run_id=run_id, + ) + + trace = handler._traces_by_root[run_id] + step = handler.steps[run_id] + assert isinstance(step, steps.RetrieverStep) + assert step.step_type == enums.StepType.RETRIEVER + assert step.inputs == {"query": "what is openlayer?"} + + documents = [ + Document(page_content="Openlayer is an evaluation platform."), + Document(page_content="It supports LangChain tracing."), + ] + handler.on_retriever_end(documents, run_id=run_id) + + # The retriever step captured the documents... + assert step.documents == [ + "Openlayer is an evaluation platform.", + "It supports LangChain tracing.", + ] + # ...and the trace gained the auto-populated `context` metadata. + assert trace.metadata is not None + assert trace.metadata.get("context") == [ + "Openlayer is an evaluation platform.", + "It supports LangChain tracing.", + ] + + def test_sync_retriever_respects_ignore_flag(self) -> None: + handler = OpenlayerHandler(ignore_retriever=True) + run_id = uuid.uuid4() + + handler.on_retriever_start( + serialized={"id": ["VectorStoreRetriever"]}, + query="ignored", + run_id=run_id, + ) + + assert run_id not in handler.steps + + def test_sync_retriever_error(self) -> None: + handler = OpenlayerHandler() + run_id = uuid.uuid4() + + handler.on_retriever_start( + serialized={"id": ["VectorStoreRetriever"]}, + query="boom", + run_id=run_id, + ) + step = handler.steps[run_id] + + handler.on_retriever_error(ValueError("retrieval failed"), run_id=run_id) + + assert step.metadata.get("error") == "retrieval failed" + assert run_id not in handler.steps + + +# --------------------------------------------------------------------------- # +# OPEN-11315 (medium/low items) +# --------------------------------------------------------------------------- # + + +def _ai_message_with_usage( + *, + input_tokens: int, + output_tokens: int, + total_tokens: int, + input_token_details: dict | None = None, + output_token_details: dict | None = None, + content: str = "done", +) -> AIMessage: + """Build an AIMessage carrying standardized usage_metadata (v1 shape).""" + usage: dict = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens, + } + if input_token_details is not None: + usage["input_token_details"] = input_token_details + if output_token_details is not None: + usage["output_token_details"] = output_token_details + return AIMessage(content=content, usage_metadata=usage) + + +# --------------------------------------------------------------------------- # +# Item 1: usage_metadata-first token extraction + token details +# --------------------------------------------------------------------------- # +class TestTokenExtraction: + def test_usage_metadata_is_read_first(self) -> None: + """usage_metadata on the message wins over llm_output token_usage.""" + handler = OpenlayerHandler() + message = _ai_message_with_usage(input_tokens=11, output_tokens=7, total_tokens=18) + response = LLMResult( + generations=[[ChatGeneration(message=message)]], + # Divergent llm_output that must NOT be used when usage_metadata exists. + llm_output={ + "token_usage": { + "prompt_tokens": 999, + "completion_tokens": 999, + "total_tokens": 999, + } + }, + ) + + info = handler._extract_token_info(response) + + assert info["prompt_tokens"] == 11 + assert info["completion_tokens"] == 7 + assert info["tokens"] == 18 + + def test_captures_input_and_output_token_details(self) -> None: + handler = OpenlayerHandler() + message = _ai_message_with_usage( + input_tokens=100, + output_tokens=40, + total_tokens=140, + input_token_details={"cache_read": 80, "cache_creation": 10}, + output_token_details={"reasoning": 25, "audio": 0}, + ) + response = LLMResult(generations=[[ChatGeneration(message=message)]]) + + info = handler._extract_token_info(response) + + assert info["tokens"] == 140 + details = info["token_details"] + assert details["input_token_details"] == { + "cache_read": 80, + "cache_creation": 10, + } + assert details["output_token_details"] == {"reasoning": 25, "audio": 0} + + def test_no_token_details_when_absent(self) -> None: + handler = OpenlayerHandler() + message = _ai_message_with_usage(input_tokens=1, output_tokens=1, total_tokens=2) + response = LLMResult(generations=[[ChatGeneration(message=message)]]) + + info = handler._extract_token_info(response) + + assert "token_details" not in info + + def test_falls_back_to_llm_output_token_usage(self) -> None: + """When no usage_metadata, the legacy llm_output path still works.""" + handler = OpenlayerHandler() + gen = ChatGeneration(message=AIMessage(content="hi")) + response = LLMResult( + generations=[[gen]], + llm_output={ + "token_usage": { + "prompt_tokens": 5, + "completion_tokens": 3, + "total_tokens": 8, + } + }, + ) + + info = handler._extract_token_info(response) + + assert info == {"prompt_tokens": 5, "completion_tokens": 3, "tokens": 8} + + def test_falls_back_to_ollama_generation_info(self) -> None: + handler = OpenlayerHandler() + gen = ChatGeneration( + message=AIMessage(content="hi"), + generation_info={"prompt_eval_count": 12, "eval_count": 4}, + ) + response = LLMResult(generations=[[gen]]) + + info = handler._extract_token_info(response) + + assert info == {"prompt_tokens": 12, "completion_tokens": 4, "tokens": 16} + + def test_falls_back_to_google_generation_info(self) -> None: + handler = OpenlayerHandler() + gen = ChatGeneration( + message=AIMessage(content="hi"), + generation_info={ + "usage_metadata": { + "prompt_token_count": 9, + "candidates_token_count": 6, + "total_token_count": 15, + } + }, + ) + response = LLMResult(generations=[[gen]]) + + info = handler._extract_token_info(response) + + assert info == {"prompt_tokens": 9, "completion_tokens": 6, "tokens": 15} + + def test_token_details_surface_on_step_metadata_via_callbacks(self) -> None: + handler = OpenlayerHandler() + run_id = uuid.uuid4() + + handler.on_chat_model_start( + serialized={"name": "claude"}, + messages=[[HumanMessage(content="hi")]], + run_id=run_id, + invocation_params={"model_name": "claude", "_type": "openai-chat"}, + metadata={"ls_provider": "anthropic"}, + ) + step = handler.steps[run_id] + assert isinstance(step, steps.ChatCompletionStep) + + message = _ai_message_with_usage( + input_tokens=100, + output_tokens=40, + total_tokens=140, + input_token_details={"cache_read": 80}, + ) + response = LLMResult(generations=[[ChatGeneration(message=message)]]) + + handler.on_llm_end(response, run_id=run_id) + + assert step.prompt_tokens == 100 + assert step.completion_tokens == 40 + assert step.tokens == 140 + assert step.metadata["token_details"]["input_token_details"] == {"cache_read": 80} + + +# --------------------------------------------------------------------------- # +# Item 2: provider detection via ls_provider +# --------------------------------------------------------------------------- # +class TestLsProviderDetection: + def test_ls_provider_is_primary_source(self) -> None: + handler = OpenlayerHandler() + info = handler._extract_model_info( + serialized={}, + # _type would map to OpenAI; ls_provider must win. + invocation_params={"_type": "openai-chat", "model_name": "claude-x"}, + metadata={"ls_provider": "anthropic"}, + ) + assert info["provider"] == "Anthropic" + assert info["model"] == "claude-x" + + def test_ls_provider_title_cases_unknown_values(self) -> None: + handler = OpenlayerHandler() + info = handler._extract_model_info( + serialized={}, + invocation_params={"model": "some-model"}, + metadata={"ls_provider": "some_new_provider"}, + ) + assert info["provider"] == "Some New Provider" + + def test_falls_back_to_type_map_without_ls_provider(self) -> None: + handler = OpenlayerHandler() + info = handler._extract_model_info( + serialized={}, + invocation_params={"_type": "chat-ollama", "model": "llama3"}, + metadata={}, + ) + assert info["provider"] == "Ollama" + + def test_litellm_prefix_overrides_ls_provider(self) -> None: + """A LiteLLM proxy reports ls_provider=openai but routes elsewhere.""" + handler = OpenlayerHandler() + info = handler._extract_model_info( + serialized={}, + invocation_params={"model": "gemini/gemini-2.5-flash"}, + metadata={"ls_provider": "openai"}, + ) + assert info["provider"] == "Google" + assert info["model"] == "gemini-2.5-flash" + + def test_chat_model_step_named_from_ls_provider(self) -> None: + handler = OpenlayerHandler() + run_id = uuid.uuid4() + handler.on_chat_model_start( + serialized={"name": "model"}, + messages=[[HumanMessage(content="hi")]], + run_id=run_id, + invocation_params={"_type": "openai-chat"}, + metadata={"ls_provider": "groq"}, + ) + step = handler.steps[run_id] + assert isinstance(step, steps.ChatCompletionStep) + assert step.name == "Groq Chat Completion" + assert step.provider == "Groq" + + +# --------------------------------------------------------------------------- # +# Item 3: LangGraph metadata (langgraph_node + thread_id -> session_id) +# --------------------------------------------------------------------------- # +class TestLangGraphMetadata: + def test_langgraph_node_names_chain_step_when_no_explicit_name(self) -> None: + # When the runnable carries no name, fall back to langgraph_node so graph + # nodes stay identifiable. + handler = OpenlayerHandler() + run_id = uuid.uuid4() + handler.on_chain_start( + serialized={"id": ["langgraph", "utils", "RunnableCallable"]}, + inputs={"messages": []}, + run_id=run_id, + metadata={"langgraph_node": "agent"}, + ) + step = handler.steps[run_id] + assert step.name == "agent" + + def test_explicit_name_wins_over_langgraph_node(self) -> None: + # An explicit runnable name takes precedence over langgraph_node (matches + # the TS handler's `name ?? langgraph_node ?? id`). langgraph_node is + # inherited by every run nested inside a node, so preferring it would + # relabel a node's internal LCEL runs with the node name. + handler = OpenlayerHandler() + run_id = uuid.uuid4() + handler.on_chain_start( + serialized={"id": ["langchain_core", "runnables", "RunnableSequence"]}, + inputs={"messages": []}, + run_id=run_id, + name="RunnableSequence", + metadata={"langgraph_node": "agent"}, + ) + step = handler.steps[run_id] + assert step.name == "RunnableSequence" + + def test_chain_name_unchanged_without_langgraph_node(self) -> None: + handler = OpenlayerHandler() + run_id = uuid.uuid4() + handler.on_chain_start( + serialized={"id": ["langchain", "chains", "LLMChain"]}, + inputs={}, + run_id=run_id, + ) + step = handler.steps[run_id] + assert step.name == "LLMChain" + + def test_thread_id_maps_to_session_id(self) -> None: + handler = OpenlayerHandler() + run_id = uuid.uuid4() + handler.on_chain_start( + serialized={"id": ["graph"]}, + inputs={}, + run_id=run_id, + metadata={"langgraph_node": "agent", "thread_id": "thread-123"}, + ) + trace = handler._traces_by_root[run_id] + assert trace.metadata is not None + assert trace.metadata.get("session_id") == "thread-123" + + def test_thread_id_does_not_clobber_explicit_session(self) -> None: + from openlayer.lib.tracing.context import UserSessionContext + + UserSessionContext.set_session_id("explicit-session") + try: + handler = OpenlayerHandler() + run_id = uuid.uuid4() + handler.on_chain_start( + serialized={"id": ["graph"]}, + inputs={}, + run_id=run_id, + metadata={"thread_id": "thread-999"}, + ) + trace = handler._traces_by_root[run_id] + session_in_metadata = (trace.metadata or {}).get("session_id") + assert session_in_metadata != "thread-999" + finally: + UserSessionContext.clear_context() + + def test_thread_id_mapping_opt_out(self) -> None: + handler = OpenlayerHandler(map_thread_id_to_session=False) + run_id = uuid.uuid4() + handler.on_chain_start( + serialized={"id": ["graph"]}, + inputs={}, + run_id=run_id, + metadata={"thread_id": "thread-123"}, + ) + trace = handler._traces_by_root[run_id] + assert (trace.metadata or {}).get("session_id") is None + + +# --------------------------------------------------------------------------- # +# Item 4: v1 content blocks in _message_to_dict +# --------------------------------------------------------------------------- # +class TestV1ContentBlocks: + def test_plain_string_content_unchanged(self) -> None: + handler = OpenlayerHandler() + result = handler._message_to_dict(HumanMessage(content="hello")) + assert result == {"role": "user", "content": "hello"} + assert "content_blocks" not in result + + def test_list_of_text_blocks_joined(self) -> None: + handler = OpenlayerHandler() + message = AIMessage( + content=[ + {"type": "text", "text": "hello "}, + {"type": "text", "text": "world"}, + ] + ) + result = handler._message_to_dict(message) + assert result["content"] == "hello world" + assert "content_blocks" not in result + + def test_non_text_blocks_preserved_structurally(self) -> None: + handler = OpenlayerHandler() + reasoning_block = {"type": "reasoning", "reasoning": "thinking..."} + message = AIMessage(content=[{"type": "text", "text": "answer"}, reasoning_block]) + result = handler._message_to_dict(message) + assert result["content"] == "answer" + assert result["content_blocks"] == [reasoning_block] + + def test_list_of_plain_strings(self) -> None: + handler = OpenlayerHandler() + message = AIMessage(content=["foo", "bar"]) + result = handler._message_to_dict(message) + assert result["content"] == "foobar" + assert "content_blocks" not in result + + +# --------------------------------------------------------------------------- # +# Item 5: import simplification keeps HAVE_LANGCHAIN working +# --------------------------------------------------------------------------- # +class TestImportSimplification: + def test_have_langchain_true_and_schema_from_core(self) -> None: + from openlayer.lib.integrations import langchain_callback as lc + + assert lc.HAVE_LANGCHAIN is True + # The schema alias must resolve to langchain_core, not the legacy path. + assert lc.langchain_schema.__name__.startswith("langchain_core")