-
Notifications
You must be signed in to change notification settings - Fork 2
feat(closes OPEN-10341): add native async runner for testset batches #648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,9 +4,10 @@ | |
| import abc | ||
| import json | ||
| import time | ||
| import asyncio | ||
| import inspect | ||
| import argparse | ||
| from typing import Any, Dict, Tuple | ||
| from typing import Any, Dict, List, Tuple, Optional | ||
| from dataclasses import field, dataclass | ||
|
|
||
| import pandas as pd | ||
|
|
@@ -36,6 +37,12 @@ class OpenlayerModel(abc.ABC): | |
|
|
||
| It is more conventional to implement the `run` method. | ||
|
|
||
| ``run`` may be defined as either ``def run`` (called sequentially per row) | ||
| or ``async def run``. When ``run`` is async, ``run_batch_from_df`` will drive | ||
| rows concurrently with ``asyncio.gather``; pass ``max_workers > 1`` to enable | ||
| concurrent execution. Use async-native I/O (``httpx``, ``openai-async``, etc.) | ||
| inside an async ``run`` to actually benefit from concurrency. | ||
|
|
||
| Refer to Openlayer's templates for examples of how to implement this class. | ||
| """ | ||
|
|
||
|
|
@@ -59,6 +66,15 @@ def run_from_cli(self) -> None: | |
| required=False, | ||
| help="Custom arguments in format 'key1=value1,key2=value2'", | ||
| ) | ||
| parser.add_argument( | ||
| "--max-workers", | ||
| type=int, | ||
| default=None, | ||
| help=( | ||
| "Max concurrent rows when run() is async. " | ||
| "Defaults to 4 for async run, 1 for sync run." | ||
| ), | ||
| ) | ||
|
|
||
| # Parse the arguments | ||
| args = parser.parse_args() | ||
|
|
@@ -76,9 +92,12 @@ def run_from_cli(self) -> None: | |
| return self.batch( | ||
| dataset_path=args.dataset_path, | ||
| output_dir=args.output_dir, | ||
| max_workers=args.max_workers, | ||
| ) | ||
|
|
||
| def batch(self, dataset_path: str, output_dir: str) -> None: | ||
| def batch( | ||
| self, dataset_path: str, output_dir: str, max_workers: Optional[int] = None | ||
| ) -> None: | ||
| """Reads the dataset from a file and runs the model on it.""" | ||
| # Load the dataset into a pandas DataFrame | ||
| fmt = "csv" | ||
|
|
@@ -91,50 +110,125 @@ def batch(self, dataset_path: str, output_dir: str) -> None: | |
| raise ValueError(f"Unsupported dataset format: {dataset_path}") | ||
|
|
||
| # Call the model's run_batch method, passing in the DataFrame | ||
| output_df, config = self.run_batch_from_df(df) | ||
| output_df, config = self.run_batch_from_df(df, max_workers=max_workers) | ||
| self.write_output_to_directory(output_df, config, output_dir, fmt) | ||
|
|
||
| def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: | ||
| """Function that runs the model and returns the result.""" | ||
| # Ensure the 'output' column exists | ||
| if "output" not in df.columns: | ||
| df["output"] = None | ||
| def run_batch_from_df( | ||
| self, df: pd.DataFrame, max_workers: Optional[int] = None | ||
| ) -> Tuple[pd.DataFrame, dict]: | ||
| """Function that runs the model and returns the result. | ||
|
|
||
| # Get the signature of the 'run' method | ||
| If ``run`` is defined as ``async def run(...)``, rows are dispatched | ||
| concurrently with ``asyncio.gather`` gated by ``asyncio.Semaphore(max_workers)``. | ||
| ``max_workers`` defaults to 4 for an async ``run`` (writing `async def` | ||
| is the opt-in signal that interleaving is safe). For a synchronous | ||
| ``run``, rows are processed sequentially and ``max_workers`` must be 1. | ||
|
|
||
| A row's exception propagates and aborts the batch. For the async path, | ||
| ``asyncio.gather`` cancels in-flight siblings before re-raising. | ||
| """ | ||
| run_signature = inspect.signature(self.run) | ||
| valid_params = set(run_signature.parameters) | ||
| is_async = inspect.iscoroutinefunction(self.run) | ||
|
|
||
| if max_workers is None: | ||
| max_workers = 4 if is_async else 1 | ||
| elif max_workers < 1: | ||
| raise ValueError("max_workers must be >= 1") | ||
|
|
||
| if max_workers > 1 and not is_async: | ||
| raise ValueError( | ||
| "max_workers > 1 requires an async `run` method. " | ||
| "Define `run` as `async def run(self, ...)` to enable " | ||
| "concurrent execution." | ||
| ) | ||
|
|
||
| for col in ("output", "steps", "latency", "cost", "tokens", "context"): | ||
| if col not in df.columns: | ||
| df[col] = None | ||
|
|
||
| rows = [ | ||
| ( | ||
| idx, | ||
| {k: v for k, v in row.to_dict().items() if k in valid_params}, | ||
| ) | ||
| for idx, row in df.iterrows() | ||
| ] | ||
|
|
||
| if is_async: | ||
| try: | ||
| asyncio.get_running_loop() | ||
| except RuntimeError: | ||
| pass | ||
| else: | ||
| raise RuntimeError( | ||
| "run_batch_from_df was called from inside a running event " | ||
| "loop. Call `await self._run_rows_async(...)` directly " | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Error message points at an internal. The guidance to "Call |
||
| "from async code." | ||
| ) | ||
| results = asyncio.run(self._run_rows_async(rows, max_workers)) | ||
| else: | ||
| results = [ | ||
| (idx, self.run(**kwargs), tracer.get_current_trace()) | ||
| for idx, kwargs in rows | ||
| ] | ||
|
|
||
| for index, output, trace in results: | ||
| self._apply_row_result(df, index, output, trace) | ||
|
|
||
| for index, row in df.iterrows(): | ||
| # Filter row_dict to only include keys that are valid parameters | ||
| # for the 'run' method | ||
| row_dict = row.to_dict() | ||
| filtered_kwargs = { | ||
| k: v for k, v in row_dict.items() if k in run_signature.parameters | ||
| } | ||
|
|
||
| # Call the run method with filtered kwargs | ||
| output = self.run(**filtered_kwargs) | ||
|
|
||
| df.at[index, "output"] = output.output | ||
|
|
||
| for k, v in output.other_fields.items(): | ||
| if k not in df.columns: | ||
| df[k] = None | ||
| df.at[index, k] = v | ||
|
|
||
| trace = tracer.get_current_trace() | ||
| if trace: | ||
| processed_trace, _ = tracer.post_process_trace(trace_obj=trace) | ||
| df.at[index, "steps"] = trace.to_dict() | ||
| if "latency" in processed_trace: | ||
| df.at[index, "latency"] = processed_trace["latency"] | ||
| if "cost" in processed_trace: | ||
| df.at[index, "cost"] = processed_trace["cost"] | ||
| if "tokens" in processed_trace: | ||
| df.at[index, "tokens"] = processed_trace["tokens"] | ||
| if "context" in processed_trace: | ||
| df.at[index, "context"] = processed_trace["context"] | ||
|
|
||
| config = { | ||
| return df, self._build_config(run_signature, df) | ||
|
|
||
| async def _run_rows_async( | ||
| self, | ||
| rows: List[Tuple[Any, Dict[str, Any]]], | ||
| max_workers: int, | ||
| ) -> List[Tuple[Any, RunReturn, Optional[Any]]]: | ||
| """Drive an async ``run`` over all rows with bounded concurrency. | ||
|
|
||
| The first row to raise causes ``asyncio.gather`` to cancel in-flight | ||
| siblings and re-raise the original exception. | ||
| """ | ||
| sem = asyncio.Semaphore(max_workers) | ||
|
|
||
| async def _one(index: Any, kwargs: Dict[str, Any]): | ||
| async with sem: | ||
| output = await self.run(**kwargs) | ||
| return index, output, tracer.get_current_trace() | ||
|
|
||
| return await asyncio.gather(*(_one(i, k) for i, k in rows)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eager task creation won't scale to large testsets. For typical testsets this is fine, but since the motivating use case is large batches against slow APIs, consider a bounded worker pool (N workers pulling from an |
||
|
|
||
| def _apply_row_result( | ||
| self, | ||
| df: pd.DataFrame, | ||
| index: Any, | ||
| output: RunReturn, | ||
| trace: Optional[Any], | ||
| ) -> None: | ||
| """Write a single row's output and trace fields into ``df`` in place.""" | ||
| df.at[index, "output"] = output.output | ||
|
|
||
| for k, v in output.other_fields.items(): | ||
| if k not in df.columns: | ||
| df[k] = None | ||
| df.at[index, k] = v | ||
|
|
||
| if trace: | ||
| processed_trace, _ = tracer.post_process_trace(trace_obj=trace) | ||
| df.at[index, "steps"] = trace.to_dict() | ||
| if "latency" in processed_trace: | ||
| df.at[index, "latency"] = processed_trace["latency"] | ||
| if "cost" in processed_trace: | ||
| df.at[index, "cost"] = processed_trace["cost"] | ||
| if "tokens" in processed_trace: | ||
| df.at[index, "tokens"] = processed_trace["tokens"] | ||
| if "context" in processed_trace: | ||
| df.at[index, "context"] = processed_trace["context"] | ||
|
|
||
| def _build_config( | ||
| self, run_signature: inspect.Signature, df: pd.DataFrame | ||
| ) -> Dict[str, Any]: | ||
| """Build the config dict returned alongside the output DataFrame.""" | ||
| config: Dict[str, Any] = { | ||
| "outputColumnName": "output", | ||
| "inputVariableNames": list(run_signature.parameters.keys()), | ||
| "metadata": { | ||
|
|
@@ -154,7 +248,7 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: | |
| for k, v in self.custom_args.items(): | ||
| config["metadata"][k] = v | ||
|
|
||
| return df, config | ||
| return config | ||
|
|
||
| def write_output_to_directory( | ||
| self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defaulting async to 4 is opinionated and silent. The "writing
async defmeans interleaving is safe" contract is reasonable, but this jumps an asyncrunfrom sequential to 4 concurrent invocations with no explicit opt-in at the call site. That can surprise arunthat hits a rate-limited API or holds non-reentrant state.Two options: (a) default async to 1 and require
--max-workers Nto scale, or (b) keep 4 but call it out prominently in the changelog/user docs. Either is fine. Flagging so it's a deliberate choice rather than an accident.