diff --git a/README.md b/README.md
index b148540..d8c6d91 100644
--- a/README.md
+++ b/README.md
@@ -785,6 +785,70 @@ finally:
+
+ Dual-channel: mic + system audio in one session
+
+For note-taker apps that capture two live sources (microphone **and** system/speaker output) but want them handled as **one** streaming session — while still knowing which source each word came from — wrap the client in a `ChannelStreamer`.
+
+You declare named channels and feed each channel's PCM separately. The SDK runs per-channel energy VAD, mixes the channels into a single mono stream over one websocket, and — for handlers registered on the coordinator — delivers an enriched `DualChannelTurnEvent` whose words/turn carry their originating channel (`turn.channel` and per-word `word.channel`). The base `Word` / `TurnEvent` stay unchanged, so single-stream payloads aren't affected. Attribution is fully client-side and model-agnostic, so it composes with `speaker_labels`, multilingual, and `u3-rt-pro`. It is a **separate dimension from diarization** — `word.channel` (physical source) is independent of `word.speaker` (voice): two people on the same `system` channel get distinct speaker labels, while one person heard on two channels keeps a single speaker label.
+
+Unlike a browser sample, the SDK does not capture audio — you supply 16-bit PCM for each channel (from `sounddevice`, `pyaudio`, a loopback device, files, …).
+
+```python
+from assemblyai.streaming.v3 import (
+ ChannelStreamer, StreamingClient, StreamingClientOptions,
+ StreamingEvents, StreamingParameters,
+)
+
+def on_turn(client, event): # event is a DualChannelTurnEvent
+ print(f"[{event.channel}] {event.transcript}")
+ for w in event.words:
+ print(f" {w.text!r} -> channel={w.channel} speaker={w.speaker}")
+
+client = StreamingClient(StreamingClientOptions(api_key=""))
+
+# Declare the channels and the session sample rate (must be pcm_s16le).
+mixer = ChannelStreamer(client, channels=["mic", "system"], sample_rate=16000)
+# Register handlers on the mixer: Turn handlers receive the enriched event,
+# other events (Begin/Error/…) are forwarded to the client.
+mixer.on(StreamingEvents.Turn, on_turn)
+client.connect(StreamingParameters(
+ sample_rate=16000, speech_model="u3-rt-pro", speaker_labels=True,
+))
+
+# Feed each source separately — e.g. from two capture callbacks. Send
+# continuous PCM for every channel (silence as zeros), at the same rate.
+mixer.stream("mic", mic_pcm)
+mixer.stream("system", system_pcm)
+
+mixer.flush() # push trailing buffered audio
+client.disconnect(terminate=True)
+```
+
+`AsyncChannelStreamer` is the asyncio-native equivalent (`await mixer.stream(...)` / `await mixer.close_channel(...)` / `await mixer.flush()`); register handlers the same way with `mixer.on(...)`.
+
+**Sources that end mid-session.** Mixing keeps channels aligned by consuming the shortest buffer, so it assumes every channel keeps delivering PCM (send silence as zeros, don't omit it). When a source genuinely ends (file EOF, screen share stopped, device removed), call `mixer.close_channel(name)` so the session degrades to the surviving channel(s) instead of stalling — the ended channel is then padded with silence.
+
+**Swappable VAD.** The default detector is the built-in energy-based `EnergyVad`. Supply your own (e.g. a DNN VAD such as Silero) via `ChannelAttributionOptions.create_vad`, which is called once per channel with the channel name; subclass `VadDetector` (`process(frame) -> VadResult`, `reset()`). Pass `on_vad=callback` to observe raw per-frame activity (e.g. a live "who's talking" meter). Tune the default with `EnergyVad(threshold_ratio=3.0, noise_floor_alpha=0.05, hangover_frames=10)` — `threshold_ratio` below ~2 is too sensitive, above ~6 misses quiet onsets/offsets.
+
+**Resolving unknown channels.** A word is `"unknown"` when no channel was clearly dominant in its window — silence, or two channels too close to call (the top must beat the runner-up by `dominance_ratio`, default 4). `ChannelAttributionOptions.resolve_unknown_channels_method` back-fills these:
+
+- `"window"` (default) — from the dominant non-`"unknown"` channel among ±`resolution_window_words` neighbor words.
+- `"speaker-history"` — from the speaker's session-wide channel evidence (requires `speaker_labels`).
+- `"none"` — leave `"unknown"` as-is.
+
+Back-filled words are flagged `word.channel_resolved = True`; confident per-word decisions are never overwritten. The method is validated at construction, so a typo raises immediately rather than silently disabling resolution.
+
+**Caveats.**
+
+- Requires 16-bit PCM (`pcm_s16le`, the default) — linear mixing is invalid for `pcm_mulaw`.
+- Capturing the system/speaker output is platform-specific: macOS needs a loopback driver (e.g. BlackHole); Windows uses WASAPI loopback; Linux a PulseAudio/PipeWire monitor source.
+- If the mic physically picks up the speakers, that bleed can pull attribution toward `mic`. Apply acoustic echo cancellation at capture (`getUserMedia({ audio: { echoCancellation: true } })` in browser front-ends, or an AEC-capable native path) — the SDK only receives already-captured PCM, so it can't apply AEC itself. Transcription quality is unaffected; only the `channel` field.
+
+See [`examples/streaming_dual_channel.py`](./examples/streaming_dual_channel.py) for a complete runnable demo.
+
+
+
Stream a local file (async)
diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py
index 0baa0e8..bffaa07 100644
--- a/assemblyai/__version__.py
+++ b/assemblyai/__version__.py
@@ -1 +1 @@
-__version__ = "0.64.20"
+__version__ = "0.64.21"
diff --git a/assemblyai/streaming/v3/__init__.py b/assemblyai/streaming/v3/__init__.py
index bd9ca87..f72aa0f 100644
--- a/assemblyai/streaming/v3/__init__.py
+++ b/assemblyai/streaming/v3/__init__.py
@@ -1,5 +1,17 @@
from .async_client import AsyncStreamingClient
from .client import StreamingClient
+from .extras import (
+ AsyncChannelStreamer,
+ ChannelAttributionOptions,
+ ChannelStreamer,
+ DualChannelTurnEvent,
+ DualChannelWord,
+ EnergyVad,
+ VadDetector,
+ VadFrame,
+ VadResult,
+ attribute_turn,
+)
from .models import (
BeginEvent,
Encoding,
@@ -27,8 +39,14 @@
)
__all__ = [
+ "AsyncChannelStreamer",
"AsyncStreamingClient",
"BeginEvent",
+ "ChannelAttributionOptions",
+ "ChannelStreamer",
+ "DualChannelTurnEvent",
+ "DualChannelWord",
+ "EnergyVad",
"Encoding",
"EventMessage",
"LLMGatewayResponseEvent",
@@ -50,6 +68,10 @@
"StreamingSessionParameters",
"TerminationEvent",
"TurnEvent",
+ "VadDetector",
+ "VadFrame",
+ "VadResult",
"WarningEvent",
"Word",
+ "attribute_turn",
]
diff --git a/assemblyai/streaming/v3/async_client.py b/assemblyai/streaming/v3/async_client.py
index f04c9b6..03f3ef0 100644
--- a/assemblyai/streaming/v3/async_client.py
+++ b/assemblyai/streaming/v3/async_client.py
@@ -37,6 +37,7 @@
ErrorEvent,
EventMessage,
ForceEndpoint,
+ KeepAlive,
OperationMessage,
StreamingClientOptions,
StreamingError,
@@ -80,10 +81,10 @@ class AsyncStreamingClient(_BaseStreamingClient):
Behavioral notes vs. the sync ``StreamingClient``:
- - ``stream`` / ``set_params`` / ``force_endpoint`` raise ``RuntimeError``
- when called before ``connect()`` — silent drop would diverge from the
- sync client (which buffers pre-connect data) in a way that's easy to
- miss. After the connection has closed, the same calls are silent
+ - ``stream`` / ``set_params`` / ``force_endpoint`` / ``keep_alive`` raise
+ ``RuntimeError`` when called before ``connect()`` — silent drop would
+ diverge from the sync client (which buffers pre-connect data) in a way
+ that's easy to miss. After the connection has closed, the same calls are silent
no-ops so cleanup paths don't need defensive try/except.
- ``disconnect(terminate=True)`` waits at most 2.0s for the write task to
drain the ``TerminateSession`` frame before forcing teardown. The sync
@@ -288,6 +289,12 @@ async def force_endpoint(self) -> None:
return
await write_queue.put(ForceEndpoint())
+ async def keep_alive(self) -> None:
+ write_queue, stop_event = self._ensure_connected("keep_alive")
+ if stop_event.is_set():
+ return
+ await write_queue.put(KeepAlive())
+
def _ensure_connected(
self, method: str
) -> "tuple[asyncio.Queue[OperationMessage], asyncio.Event]":
diff --git a/assemblyai/streaming/v3/client.py b/assemblyai/streaming/v3/client.py
index 8dc5084..64740c3 100644
--- a/assemblyai/streaming/v3/client.py
+++ b/assemblyai/streaming/v3/client.py
@@ -24,6 +24,7 @@
ErrorEvent,
EventMessage,
ForceEndpoint,
+ KeepAlive,
OperationMessage,
StreamingClientOptions,
StreamingError,
@@ -194,6 +195,10 @@ def force_endpoint(self):
message = ForceEndpoint()
self._write_queue.put(message)
+ def keep_alive(self):
+ message = KeepAlive()
+ self._write_queue.put(message)
+
def _write_message(self) -> None:
while True:
if not self._websocket:
diff --git a/assemblyai/streaming/v3/extras.py b/assemblyai/streaming/v3/extras.py
new file mode 100644
index 0000000..f29006c
--- /dev/null
+++ b/assemblyai/streaming/v3/extras.py
@@ -0,0 +1,742 @@
+"""Client-side dual / multi-channel support for streaming v3.
+
+Note-taker use case: capture two live sources (microphone + system audio) as
+one streaming session while still knowing which physical source each word came
+from.
+
+- Each named channel's PCM is fed in separately via ``ChannelStreamer.stream``.
+- Per-channel energy VAD records which channel was acoustically active when.
+- The channels are summed into one mono stream sent over the existing single
+ websocket session.
+- On every ``Turn``, words are attributed back to a channel by matching their
+ server timestamps against the per-channel VAD timeline; the enriched
+ ``TurnEvent`` (``turn.channel`` + per-word ``word.channel``) is delivered to
+ the handler registered on the coordinator.
+
+Attribution is purely client-side, so any ``speech_model`` works unchanged and
+channel (physical source) stays independent of diarization ``speaker``.
+"""
+
+import inspect
+import logging
+import math
+import sys
+import threading
+from array import array
+from dataclasses import dataclass
+from typing import (
+ TYPE_CHECKING,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Union,
+)
+
+from .models import StreamingEvents, TurnEvent, Word
+
+if TYPE_CHECKING: # avoid an import cycle; only used for type hints
+ from .async_client import AsyncStreamingClient
+ from .client import StreamingClient
+
+logger = logging.getLogger(__name__)
+
+UNKNOWN_CHANNEL = "unknown"
+
+# 20 ms VAD frames, and the server's per-message audio duration limits.
+_VAD_FRAME_MS = 20
+_MIN_CHUNK_MS = 50
+_MAX_CHUNK_MS = 200
+
+
+# Enriched event types subclass the base streaming models and add the
+# client-side channel fields, so the base ``Word`` / ``TurnEvent`` payloads stay
+# clean for single-stream users. A coordinator builds these from the base events
+# and delivers them to handlers registered on the coordinator.
+class DualChannelWord(Word):
+ """A ``Word`` enriched with channel attribution (independent of ``speaker``)."""
+
+ # Physical input channel (e.g. "mic" / "system"), or "unknown" if no channel
+ # was clearly dominant during the word's window.
+ channel: Optional[str] = None
+ # True when ``channel`` was inferred by an unknown-resolution strategy rather
+ # than measured directly by VAD.
+ channel_resolved: Optional[bool] = None
+
+
+class DualChannelTurnEvent(TurnEvent):
+ """A ``TurnEvent`` enriched with per-word and turn-level channel attribution."""
+
+ words: List[DualChannelWord] = []
+ # Duration-weighted majority channel across ``words``, or "unknown".
+ channel: Optional[str] = None
+
+
+@dataclass
+class VadResult:
+ """One frame's voice-activity decision: whether speech is present and the
+ frame's RMS energy (used to weight per-channel attribution)."""
+
+ active: bool
+ energy: float
+
+
+@dataclass
+class VadFrame:
+ """A per-channel, per-frame VAD observation.
+
+ ``ts`` is stream-relative milliseconds from the channel's own sample counter
+ — the same reference frame as ``Word.start`` / ``Word.end``.
+ """
+
+ ts: float
+ channel: str
+ active: bool
+ rms: float
+
+
+class VadDetector:
+ """Pluggable per-channel voice-activity detector.
+
+ A separate instance is held per channel. ``process`` receives a fixed-size
+ frame of float samples in ``[-1.0, 1.0]`` at the session's sample rate.
+ Subclass / duck-type this to drop in a DNN-backed detector for noisy
+ environments.
+ """
+
+ def process(self, frame: Sequence[float]) -> VadResult: # pragma: no cover
+ raise NotImplementedError
+
+ def reset(self) -> None: # pragma: no cover
+ raise NotImplementedError
+
+
+class EnergyVad(VadDetector):
+ """Energy-based VAD with adaptive noise-floor tracking and hangover.
+
+ Pure Python, no dependencies. Suited to "which physical channel is speaking"
+ since the channels are already physically separated at capture; hand the
+ harder speech-vs-noise problem to a custom ``VadDetector`` via
+ ``ChannelAttributionOptions.create_vad``.
+
+ Tuning: ``threshold_ratio`` below 2 is over-sensitive, above 6 misses quiet
+ onsets/offsets; ``noise_floor_alpha`` above 0.1 adapts to non-stationary
+ background faster but risks creeping up onto a sustained quiet voice.
+ """
+
+ def __init__(
+ self,
+ threshold_ratio: float = 3.0,
+ noise_floor_alpha: float = 0.05,
+ hangover_frames: int = 10,
+ initial_noise_floor: float = 1e-4,
+ ):
+ self._threshold_ratio = threshold_ratio
+ self._noise_floor_alpha = noise_floor_alpha
+ self._hangover_frames = hangover_frames
+ self._initial_noise_floor = initial_noise_floor
+ self._noise_floor = initial_noise_floor
+ self._hangover_remaining = 0
+
+ def process(self, frame: Sequence[float]) -> VadResult:
+ n = len(frame)
+ sum_sq = 0.0
+ for s in frame:
+ sum_sq += s * s
+ rms = math.sqrt(sum_sq / n) if n > 0 else 0.0
+
+ threshold = self._noise_floor * self._threshold_ratio
+ active = rms > threshold
+
+ if active:
+ self._hangover_remaining = self._hangover_frames
+ elif self._hangover_remaining > 0:
+ self._hangover_remaining -= 1
+ active = True
+ # In hangover, don't update the floor — RMS may still be tail energy.
+ else:
+ self._noise_floor = (
+ self._noise_floor * (1 - self._noise_floor_alpha)
+ + rms * self._noise_floor_alpha
+ )
+
+ return VadResult(active=active, energy=rms)
+
+ def reset(self) -> None:
+ self._noise_floor = self._initial_noise_floor
+ self._hangover_remaining = 0
+
+
+class VadTimeline:
+ """Append-only ring buffer of ``VadFrame``s in stream-relative ms order.
+
+ ``push_frame`` is amortized O(1); ``frames_in_window`` is O(n) over kept
+ frames, fine for the per-word lookups done here.
+ """
+
+ def __init__(self, window_ms: int):
+ self._window_ms = window_ms
+ self._frames: List[VadFrame] = []
+ self._head = 0
+ # The threaded ``StreamingClient`` runs ``frames_in_window`` on the read
+ # thread while the user thread runs ``push_frame``; compaction swaps
+ # ``_frames`` / ``_head`` non-atomically. The lock keeps push/read/compact
+ # mutually exclusive. Uncontended on the async / single-threaded paths.
+ self._lock = threading.Lock()
+
+ def push_frame(self, frame: VadFrame) -> None:
+ with self._lock:
+ self._frames.append(frame)
+ cutoff = frame.ts - self._window_ms
+ while (
+ self._head < len(self._frames) and self._frames[self._head].ts < cutoff
+ ):
+ self._head += 1
+ # Compact occasionally so the list doesn't grow without bound.
+ if self._head > 1024 and self._head * 2 > len(self._frames):
+ self._frames = self._frames[self._head :]
+ self._head = 0
+
+ def frames_in_window(self, start_ms: float, end_ms: float) -> List[VadFrame]:
+ out: List[VadFrame] = []
+ with self._lock:
+ for i in range(self._head, len(self._frames)):
+ f = self._frames[i]
+ if f.ts < start_ms:
+ continue
+ if f.ts > end_ms:
+ break
+ out.append(f)
+ return out
+
+ def clear(self) -> None:
+ with self._lock:
+ self._frames = []
+ self._head = 0
+
+
+def _score_channels(frames: Iterable[VadFrame]) -> Dict[str, float]:
+ """Sum active-frame RMS per channel. Channels with no active energy are
+ omitted from the result."""
+ scores: Dict[str, float] = {}
+ for f in frames:
+ if not f.active:
+ continue
+ scores[f.channel] = scores.get(f.channel, 0.0) + f.rms
+ return scores
+
+
+def _top_by_ratio(scores: Dict[str, float], dominance_ratio: float) -> Optional[str]:
+ """Winner of a per-channel score map: the sole channel if only one had
+ energy, else the top channel iff it beats the runner-up by
+ ``dominance_ratio``. ``None`` when there's no clear winner (tie / too close)
+ or no scores.
+
+ The ratio is a real knob — raising it yields more ``None`` (i.e. "unknown",
+ which a resolution strategy may back-fill). This diverges from the Node
+ reference, whose absolute-winner fallback makes the equivalent ratio a no-op.
+ """
+ if not scores:
+ return None
+ ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
+ if len(ranked) == 1:
+ return ranked[0][0]
+ (top_name, top_score), (_, runner_score) = ranked[0], ranked[1]
+ if runner_score <= 0 or top_score >= dominance_ratio * runner_score:
+ return top_name
+ return None
+
+
+def attribute_word(
+ word: Word,
+ timeline: VadTimeline,
+ dominance_ratio: float,
+) -> str:
+ """Channel that dominated the word's ``[start, end]`` window, or "unknown"
+ if no channel had energy there or none was clearly dominant."""
+ scores = _score_channels(timeline.frames_in_window(word.start, word.end))
+ winner = _top_by_ratio(scores, dominance_ratio)
+ return winner if winner is not None else UNKNOWN_CHANNEL
+
+
+def roll_up_turn_channel(words: Sequence[DualChannelWord]) -> str:
+ """Duration-weighted majority of per-word channels. "unknown" if there are
+ no resolved words or two channels tie exactly."""
+ totals: Dict[str, float] = {}
+ for w in words:
+ if not w.channel or w.channel == UNKNOWN_CHANNEL:
+ continue
+ dur = max(0, w.end - w.start)
+ totals[w.channel] = totals.get(w.channel, 0.0) + dur
+ if not totals:
+ return UNKNOWN_CHANNEL
+ ranked = sorted(totals.items(), key=lambda kv: kv[1], reverse=True)
+ if len(ranked) == 1:
+ return ranked[0][0]
+ (top_name, top_ms), (_, runner_ms) = ranked[0], ranked[1]
+ if top_ms == runner_ms:
+ return UNKNOWN_CHANNEL
+ return top_name
+
+
+def attribute_turn(
+ turn: DualChannelTurnEvent,
+ timeline: VadTimeline,
+ dominance_ratio: float,
+) -> None:
+ """Write ``turn.words[i].channel`` for every word and set ``turn.channel``
+ to the duration-weighted rollup. Mutates the enriched ``turn`` in place."""
+ for w in turn.words:
+ w.channel = attribute_word(w, timeline, dominance_ratio)
+ turn.channel = roll_up_turn_channel(turn.words)
+
+
+def resolve_unknown_channels_by_window(
+ turn: DualChannelTurnEvent,
+ resolution_window_words: int,
+) -> None:
+ """Back-fill "unknown" words from the dominant non-"unknown" channel among
+ +/-N neighbor words in the same turn. Words with no resolved neighbors stay
+ "unknown"; confident decisions are never modified; resolved words are tagged
+ ``channel_resolved = True``."""
+ words = turn.words
+ mutated = False
+ for i, target in enumerate(words):
+ if target.channel != UNKNOWN_CHANNEL:
+ continue
+ tally: Dict[str, int] = {}
+ lo = max(0, i - resolution_window_words)
+ hi = min(len(words) - 1, i + resolution_window_words)
+ for j in range(lo, hi + 1):
+ if j == i:
+ continue
+ ch = words[j].channel
+ if not ch or ch == UNKNOWN_CHANNEL:
+ continue
+ tally[ch] = tally.get(ch, 0) + 1
+ if not tally:
+ continue
+ top: Optional[str] = None
+ top_count = 0
+ tied = False
+ for name, count in tally.items():
+ if count > top_count:
+ top, top_count, tied = name, count, False
+ elif count == top_count:
+ tied = True
+ if top is not None and not tied:
+ target.channel = top
+ target.channel_resolved = True
+ mutated = True
+ if mutated:
+ turn.channel = roll_up_turn_channel(words)
+
+
+def resolve_unknown_channels_by_speaker_history(
+ turn: DualChannelTurnEvent,
+ timeline: VadTimeline,
+ speaker_history: Dict[str, Dict[str, float]],
+ min_rms_evidence: float,
+ dominance_ratio: float,
+) -> None:
+ """Back-fill "unknown" words using each speaker's session-wide channel
+ evidence: per speaker, sum active VAD-frame RMS per channel across all their
+ words so far (accumulated into ``speaker_history``). A speaker is resolvable
+ once their total evidence clears ``min_rms_evidence`` and their top channel
+ beats the runner-up by ``dominance_ratio``. Only touches "unknown" words;
+ ``speaker`` is never modified."""
+ # 1. Accumulate evidence from this turn's words.
+ for w in turn.words:
+ if not w.speaker:
+ continue
+ entry = speaker_history.setdefault(w.speaker, {})
+ for f in timeline.frames_in_window(w.start, w.end):
+ if not f.active:
+ continue
+ entry[f.channel] = entry.get(f.channel, 0.0) + f.rms
+
+ # 2. Fill unknown words whose speakers have dominant evidence.
+ mutated = False
+ for w in turn.words:
+ if w.channel != UNKNOWN_CHANNEL or not w.speaker:
+ continue
+ entry_or_none = speaker_history.get(w.speaker)
+ if not entry_or_none or sum(entry_or_none.values()) < min_rms_evidence:
+ continue
+ winner = _top_by_ratio(entry_or_none, dominance_ratio)
+ if winner is not None:
+ w.channel = winner
+ w.channel_resolved = True
+ mutated = True
+ if mutated:
+ turn.channel = roll_up_turn_channel(turn.words)
+
+
+RESOLVE_UNKNOWN_METHODS = ("none", "window", "speaker-history")
+
+
+@dataclass
+class ChannelAttributionOptions:
+ """Tuning for client-side channel attribution. All fields have sane defaults
+ matching the Node SDK; override only when needed."""
+
+ # Per-word energy ratio above which a channel is declared dominant.
+ dominance_ratio: float = 4.0
+ # How far back the VAD timeline retains frames for per-word lookups.
+ timeline_window_ms: int = 30_000
+ # Factory for the per-channel detector; called once per channel with its
+ # name. Defaults to a fresh ``EnergyVad`` per channel.
+ create_vad: Optional[Callable[[str], VadDetector]] = None
+ # How to fill words VAD couldn't attribute: "none" | "window" |
+ # "speaker-history".
+ resolve_unknown_channels_method: str = "window"
+ # "window": +/-N neighbor words consulted to fill an unknown word.
+ resolution_window_words: int = 2
+ # "speaker-history": minimum cumulative active-RMS before a speaker's
+ # channel is considered established.
+ speaker_history_min_rms_evidence: float = 0.5
+ # "speaker-history": top channel must beat the runner-up by this ratio.
+ speaker_history_dominance_ratio: float = 3.0
+
+ def __post_init__(self) -> None:
+ # Fail fast at construction; an invalid method would otherwise be
+ # swallowed inside the Turn handler, silently disabling resolution.
+ if self.resolve_unknown_channels_method not in RESOLVE_UNKNOWN_METHODS:
+ raise ValueError(
+ "resolve_unknown_channels_method must be one of "
+ f"{RESOLVE_UNKNOWN_METHODS}; got "
+ f"{self.resolve_unknown_channels_method!r}."
+ )
+
+ def _make_vad(self, channel: str) -> VadDetector:
+ if self.create_vad is not None:
+ return self.create_vad(channel)
+ return EnergyVad()
+
+
+_BIG_ENDIAN = sys.byteorder == "big"
+
+
+def _pcm16_to_array(data: bytes) -> array:
+ """Parse little-endian 16-bit PCM bytes into a signed-short ``array``."""
+ if len(data) % 2 != 0:
+ raise ValueError(
+ f"PCM data length must be even (16-bit samples); got {len(data)} bytes."
+ )
+ samples = array("h")
+ samples.frombytes(bytes(data))
+ if _BIG_ENDIAN:
+ samples.byteswap() # interpret the bytes as little-endian
+ return samples
+
+
+def _array_to_pcm16(samples: array) -> bytes:
+ """Serialize a signed-short ``array`` back to little-endian 16-bit PCM."""
+ if _BIG_ENDIAN:
+ samples = samples[:]
+ samples.byteswap()
+ return samples.tobytes()
+
+
+class _ChannelMixer:
+ """Owns per-channel PCM buffers, per-channel VAD, the shared VAD timeline,
+ and the mono mixing math. Runtime-agnostic — both coordinators drive it."""
+
+ def __init__(
+ self,
+ channels: Sequence[str],
+ sample_rate: int,
+ options: ChannelAttributionOptions,
+ ):
+ self.channels = list(channels)
+ self.sample_rate = sample_rate
+ self.timeline = VadTimeline(options.timeline_window_ms)
+
+ self._vad_frame_samples = max(1, round(sample_rate * _VAD_FRAME_MS / 1000))
+ self._min_chunk_samples = max(1, round(sample_rate * _MIN_CHUNK_MS / 1000))
+ self._max_chunk_samples = max(
+ self._min_chunk_samples, round(sample_rate * _MAX_CHUNK_MS / 1000)
+ )
+
+ self._buffers: Dict[str, array] = {n: array("h") for n in self.channels}
+ self._vad_frame: Dict[str, List[float]] = {n: [] for n in self.channels}
+ self._received: Dict[str, int] = {n: 0 for n in self.channels}
+ self._vads: Dict[str, VadDetector] = {
+ n: options._make_vad(n) for n in self.channels
+ }
+ # Channels closed via ``close_channel`` — treated as silence so their
+ # absence no longer gates mixing for the survivors.
+ self._ended: Set[str] = set()
+
+ def close_channel(self, channel: str) -> None:
+ """Mark a channel finished (source ended). Subsequent ``drain`` calls
+ stop waiting on it and pad it with silence."""
+ self._ended.add(channel)
+
+ def ingest(
+ self,
+ channel: str,
+ data: bytes,
+ on_vad: Optional[Callable[[VadFrame], None]] = None,
+ ) -> None:
+ samples = _pcm16_to_array(data)
+ self._buffers[channel].extend(samples)
+
+ vad = self._vads[channel]
+ frame_buf = self._vad_frame[channel]
+ received = self._received[channel]
+ for s in samples:
+ frame_buf.append(s / 0x8000)
+ received += 1
+ if len(frame_buf) == self._vad_frame_samples:
+ result = vad.process(frame_buf)
+ frame = VadFrame(
+ ts=received / self.sample_rate * 1000,
+ channel=channel,
+ active=result.active,
+ rms=result.energy,
+ )
+ self.timeline.push_frame(frame)
+ if on_vad is not None:
+ on_vad(frame)
+ frame_buf.clear()
+ self._received[channel] = received
+
+ def drain(self, force: bool = False) -> List[bytes]:
+ """Mix buffered audio into mono PCM chunks, each clamped to
+ ``[_MIN_CHUNK_MS, _MAX_CHUNK_MS]`` (the ``_MIN_CHUNK_MS`` floor applies
+ only while not ``force``).
+
+ While every channel is still feeding, mixing gates on the shortest live
+ buffer to keep channels time-aligned. Once a channel is closed
+ (``close_channel``) or on the final ``force`` flush, shorter/ended
+ buffers are zero-padded up to the longest instead of gating on them, so
+ a terminated source degrades to the survivors rather than stalling the
+ session and dropping everything accumulated since.
+ """
+ bufs = [self._buffers[n] for n in self.channels]
+ divisor = len(bufs)
+ # Pad (don't gate) once any channel has ended, or on the final flush.
+ pad = force or bool(self._ended)
+ out_chunks: List[bytes] = []
+ while True:
+ if pad:
+ mix_len = max((len(b) for b in bufs), default=0)
+ else:
+ live = [
+ len(self._buffers[n]) for n in self.channels if n not in self._ended
+ ]
+ mix_len = min(live) if live else 0
+ if mix_len == 0:
+ break
+ if not force and mix_len < self._min_chunk_samples:
+ break
+ if mix_len > self._max_chunk_samples:
+ mix_len = self._max_chunk_samples
+ out = array("h", bytes(2 * mix_len))
+ for i in range(mix_len):
+ total = 0
+ for b in bufs:
+ total += b[i] if i < len(b) else 0 # pad past buffer end
+ avg = round(total / divisor)
+ out[i] = -32768 if avg < -32768 else (32767 if avg > 32767 else avg)
+ for b in bufs:
+ del b[: min(mix_len, len(b))]
+ out_chunks.append(_array_to_pcm16(out))
+ return out_chunks
+
+
+def _validate_channels(channels: Sequence[str]) -> List[str]:
+ names = list(channels)
+ if len(names) < 2:
+ raise ValueError("channels must declare at least 2 channel names.")
+ if len(set(names)) != len(names):
+ raise ValueError("channels names must be unique.")
+ if any(not isinstance(n, str) or not n for n in names):
+ raise ValueError("channel names must be non-empty strings.")
+ return names
+
+
+class _BaseChannelStreamer:
+ """Shared dual/multi-channel coordination independent of the wrapped
+ client's sync/async I/O. Channel config lives here, never on
+ ``StreamingParameters`` (it must not reach the websocket URL); the wrapped
+ client streams ordinary mono audio and is otherwise untouched.
+ """
+
+ def __init__(
+ self,
+ channels: Sequence[str],
+ sample_rate: int,
+ options: Optional[ChannelAttributionOptions],
+ on_vad: Optional[Callable[[VadFrame], None]],
+ ):
+ self.channels = _validate_channels(channels)
+ self._options = options or ChannelAttributionOptions()
+ self._on_vad = on_vad
+ self._mixer = _ChannelMixer(self.channels, sample_rate, self._options)
+ self._speaker_history: Dict[str, Dict[str, float]] = {}
+ self._turn_handlers: List[Callable] = []
+ # Set by each subclass in __init__ (the concrete sync/async client).
+ self._client: Union["StreamingClient", "AsyncStreamingClient"]
+
+ def on(self, event: StreamingEvents, handler: Callable) -> None:
+ """Register an event handler. ``Turn`` events are delivered as an
+ enriched ``DualChannelTurnEvent``; all other events are forwarded to the
+ underlying client unchanged."""
+ if event == StreamingEvents.Turn:
+ self._turn_handlers.append(handler)
+ else:
+ self._client.on(event, handler)
+
+ def _check_channel(self, channel: str) -> None:
+ if channel not in self._mixer._buffers:
+ raise ValueError(
+ f'Unknown channel "{channel}"; declared channels: '
+ f"{', '.join(self.channels)}."
+ )
+
+ def _enrich(self, base_turn: TurnEvent) -> DualChannelTurnEvent:
+ """Build a ``DualChannelTurnEvent`` from the base turn (left untouched)
+ and run channel attribution + the configured unknown-resolution."""
+ # Dump the base event keeping None fields (pydantic v1/v2).
+ data = (
+ base_turn.model_dump()
+ if hasattr(base_turn, "model_dump")
+ else base_turn.dict()
+ )
+ enriched = DualChannelTurnEvent(**data)
+ attribute_turn(enriched, self._mixer.timeline, self._options.dominance_ratio)
+ method = self._options.resolve_unknown_channels_method
+ if method == "window":
+ resolve_unknown_channels_by_window(
+ enriched, self._options.resolution_window_words
+ )
+ elif method == "speaker-history":
+ resolve_unknown_channels_by_speaker_history(
+ enriched,
+ self._mixer.timeline,
+ self._speaker_history,
+ self._options.speaker_history_min_rms_evidence,
+ self._options.speaker_history_dominance_ratio,
+ )
+ # method == "none": leave "unknown" words as-is (validated at construction).
+ return enriched
+
+ @staticmethod
+ def _as_chunks(data: Union[bytes, Iterable[bytes]]) -> Iterable[bytes]:
+ if isinstance(data, (bytes, bytearray, memoryview)):
+ return [bytes(data)]
+ return data
+
+
+class ChannelStreamer(_BaseChannelStreamer):
+ """Dual/multi-channel coordinator for the threaded ``StreamingClient``.
+
+ Feed each named channel's 16-bit little-endian PCM via ``stream(channel,
+ data)``; the channels are summed into one mono stream over the client's
+ single session. Register handlers on the coordinator (``mixer.on(...)``):
+ ``Turn`` handlers receive an enriched ``DualChannelTurnEvent``; all other
+ events are forwarded to the wrapped client, keeping the base payloads clean
+ for single-stream use.
+
+ Requires ``pcm_s16le`` (linear mixing is invalid for ``pcm_mulaw``). Feed
+ every channel continuous PCM at the same ``sample_rate`` (silence as zeros);
+ call ``close_channel(name)`` when a source ends mid-session so the session
+ degrades to the survivors, and ``flush()`` before
+ ``client.disconnect(terminate=True)`` to push the trailing audio.
+ """
+
+ def __init__(
+ self,
+ client: "StreamingClient",
+ channels: Sequence[str],
+ sample_rate: int,
+ attribution: Optional[ChannelAttributionOptions] = None,
+ on_vad: Optional[Callable[[VadFrame], None]] = None,
+ ):
+ super().__init__(channels, sample_rate, attribution, on_vad)
+ self._client = client
+ client.on(StreamingEvents.Turn, self._handle_turn)
+
+ def _handle_turn(self, client: object, base_turn: TurnEvent) -> None:
+ enriched = self._enrich(base_turn)
+ for handler in self._turn_handlers:
+ try:
+ handler(client, enriched)
+ except Exception:
+ logger.exception("dual-channel on_turn handler raised")
+
+ def stream(self, channel: str, data: Union[bytes, Iterable[bytes]]) -> None:
+ """Ingest PCM for ``channel``, then send whatever can now be mixed."""
+ self._check_channel(channel)
+ for chunk in self._as_chunks(data):
+ self._mixer.ingest(channel, chunk, on_vad=self._on_vad)
+ for mixed in self._mixer.drain():
+ self._client.stream(mixed)
+
+ def close_channel(self, channel: str) -> None:
+ """Signal that ``channel``'s source has ended; the session keeps
+ streaming the survivors, sending any newly mixable audio immediately."""
+ self._check_channel(channel)
+ self._mixer.close_channel(channel)
+ for mixed in self._mixer.drain():
+ self._client.stream(mixed)
+
+ def flush(self) -> None:
+ """Mix and send any remaining buffered audio. Call before
+ ``client.disconnect``."""
+ for mixed in self._mixer.drain(force=True):
+ self._client.stream(mixed)
+
+
+class AsyncChannelStreamer(_BaseChannelStreamer):
+ """Asyncio-native counterpart to ``ChannelStreamer`` (wraps
+ ``AsyncStreamingClient``); ``stream`` / ``close_channel`` / ``flush`` are
+ coroutines. ``Turn`` handlers may be sync or ``async`` (awaited inline on the
+ read task). See ``ChannelStreamer`` for requirements.
+ """
+
+ def __init__(
+ self,
+ client: "AsyncStreamingClient",
+ channels: Sequence[str],
+ sample_rate: int,
+ attribution: Optional[ChannelAttributionOptions] = None,
+ on_vad: Optional[Callable[[VadFrame], None]] = None,
+ ):
+ super().__init__(channels, sample_rate, attribution, on_vad)
+ self._client: "AsyncStreamingClient" = client
+ client.on(StreamingEvents.Turn, self._handle_turn)
+
+ async def _handle_turn(self, client: object, base_turn: TurnEvent) -> None:
+ enriched = self._enrich(base_turn)
+ for handler in self._turn_handlers:
+ try:
+ result = handler(client, enriched)
+ if inspect.isawaitable(result):
+ await result
+ except Exception:
+ logger.exception("dual-channel on_turn handler raised")
+
+ async def stream(self, channel: str, data: Union[bytes, Iterable[bytes]]) -> None:
+ self._check_channel(channel)
+ for chunk in self._as_chunks(data):
+ self._mixer.ingest(channel, chunk, on_vad=self._on_vad)
+ for mixed in self._mixer.drain():
+ await self._client.stream(mixed)
+
+ async def close_channel(self, channel: str) -> None:
+ """Signal that ``channel``'s source has ended; the session keeps
+ streaming the survivors instead of stalling."""
+ self._check_channel(channel)
+ self._mixer.close_channel(channel)
+ for mixed in self._mixer.drain():
+ await self._client.stream(mixed)
+
+ async def flush(self) -> None:
+ for mixed in self._mixer.drain(force=True):
+ await self._client.stream(mixed)
diff --git a/assemblyai/streaming/v3/models.py b/assemblyai/streaming/v3/models.py
index 089a98f..f64dede 100644
--- a/assemblyai/streaming/v3/models.py
+++ b/assemblyai/streaming/v3/models.py
@@ -128,6 +128,10 @@ class ForceEndpoint(BaseModel):
type: Literal["ForceEndpoint"] = "ForceEndpoint"
+class KeepAlive(BaseModel):
+ type: Literal["KeepAlive"] = "KeepAlive"
+
+
class StreamingSessionParameters(BaseModel):
end_of_turn_confidence_threshold: Optional[float] = None
min_end_of_turn_silence_when_confident: Optional[int] = (
@@ -287,6 +291,7 @@ class UpdateConfiguration(StreamingSessionParameters):
bytes,
TerminateSession,
ForceEndpoint,
+ KeepAlive,
UpdateConfiguration,
]
diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py
index fc84951..a1ad588 100644
--- a/tests/unit/test_streaming.py
+++ b/tests/unit/test_streaming.py
@@ -27,7 +27,10 @@
Word,
)
from assemblyai.streaming.v3._base import _build_uri
-from assemblyai.streaming.v3.models import TerminateSession
+from assemblyai.streaming.v3.models import (
+ KeepAlive,
+ TerminateSession,
+)
def _disable_rw_threads(mocker: MockFixture):
@@ -558,6 +561,33 @@ def mocked_websocket_connect(
assert isinstance(client._write_queue.get(timeout=1), bytes)
+def test_client_keep_alive_enqueues_keep_alive_message(mocker: MockFixture):
+ # Given: a connected client with read/write threads disabled
+ mocker.patch(
+ "assemblyai.streaming.v3.client.websocket_connect",
+ return_value=None,
+ )
+ _disable_rw_threads(mocker)
+ client = StreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.connect(
+ StreamingParameters(
+ sample_rate=16000,
+ speech_model=SpeechModel.universal_streaming_english,
+ )
+ )
+
+ # When: keep_alive is called
+ client.keep_alive()
+
+ # Then: a KeepAlive message is enqueued for the write thread
+ assert client._write_queue.qsize() == 1
+ message = client._write_queue.get(timeout=1)
+ assert isinstance(message, KeepAlive)
+ assert message.type == "KeepAlive"
+
+
def test_client_connect_with_webhook(mocker: MockFixture):
actual_url = None
actual_additional_headers = None
diff --git a/tests/unit/test_streaming_async.py b/tests/unit/test_streaming_async.py
index c06986e..81449f0 100644
--- a/tests/unit/test_streaming_async.py
+++ b/tests/unit/test_streaming_async.py
@@ -901,6 +901,49 @@ async def test_force_endpoint_enqueues_force_endpoint_frame(mocker: MockFixture)
await client.disconnect()
+async def test_keep_alive_before_connect_raises_runtime_error():
+ # Given: a client that has never connected
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+
+ # When/Then: keep_alive raises rather than silently dropping the frame
+ with pytest.raises(RuntimeError, match="not connected"):
+ await client.keep_alive()
+
+
+async def test_keep_alive_sends_keep_alive_frame(mocker: MockFixture):
+ # Given: a connected async client over a fake websocket
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ await client.connect(_default_params())
+
+ # When: keep_alive is called
+ await client.keep_alive()
+
+ # Then: a single KeepAlive frame is written to the websocket
+ for _ in range(100):
+ keep_alive_frames = [
+ s for s in fake_ws.sent if isinstance(s, str) and "KeepAlive" in s
+ ]
+ if keep_alive_frames:
+ break
+ await asyncio.sleep(0.01)
+
+ keep_alive_frames = [
+ s for s in fake_ws.sent if isinstance(s, str) and "KeepAlive" in s
+ ]
+ assert len(keep_alive_frames) == 1
+ payload = json.loads(keep_alive_frames[0])
+ assert payload["type"] == "KeepAlive"
+
+ await client.disconnect()
+
+
async def test_warning_event_dispatched_to_handler(mocker: MockFixture):
fake_ws = _FakeAsyncWebSocket()
_patch_connect(mocker, fake_ws)
diff --git a/tests/unit/test_streaming_dual_channel.py b/tests/unit/test_streaming_dual_channel.py
new file mode 100644
index 0000000..06e526e
--- /dev/null
+++ b/tests/unit/test_streaming_dual_channel.py
@@ -0,0 +1,466 @@
+"""Unit tests for client-side dual/multi-channel streaming.
+
+Covers the pure logic (energy VAD, VAD timeline, word/turn attribution,
+unknown-channel resolution, the mono mixer) and the sync/async coordinators
+against a fake client. No network I/O.
+"""
+
+import sys
+from array import array
+
+import pytest
+
+from assemblyai.streaming.v3 import (
+ AsyncChannelStreamer,
+ ChannelAttributionOptions,
+ ChannelStreamer,
+ DualChannelTurnEvent,
+ DualChannelWord,
+ EnergyVad,
+ StreamingEvents,
+)
+from assemblyai.streaming.v3.extras import (
+ UNKNOWN_CHANNEL,
+ VadFrame,
+ VadTimeline,
+ _ChannelMixer,
+ attribute_turn,
+ attribute_word,
+ resolve_unknown_channels_by_speaker_history,
+ resolve_unknown_channels_by_window,
+ roll_up_turn_channel,
+)
+
+SAMPLE_RATE = 16000
+
+
+# --------------------------------------------------------------------------- #
+# helpers
+# --------------------------------------------------------------------------- #
+def _word(start: int, end: int, speaker=None, text="x") -> DualChannelWord:
+ return DualChannelWord(
+ start=start,
+ end=end,
+ confidence=1.0,
+ text=text,
+ word_is_final=True,
+ speaker=speaker,
+ )
+
+
+def _turn(words) -> DualChannelTurnEvent:
+ return DualChannelTurnEvent(
+ type="Turn",
+ turn_order=0,
+ turn_is_formatted=False,
+ end_of_turn=True,
+ transcript=" ".join(w.text for w in words),
+ end_of_turn_confidence=0.9,
+ words=list(words),
+ )
+
+
+def _pcm(value: int, n: int) -> bytes:
+ """``n`` samples of constant int16 ``value`` as little-endian PCM bytes."""
+ samples = array("h", [value] * n)
+ if sys.byteorder == "big":
+ samples.byteswap()
+ return samples.tobytes()
+
+
+def _active_frame(channel: str, ts: float, rms: float = 0.5) -> VadFrame:
+ return VadFrame(ts=ts, channel=channel, active=True, rms=rms)
+
+
+# --------------------------------------------------------------------------- #
+# EnergyVad
+# --------------------------------------------------------------------------- #
+def test_energy_vad_silence_is_inactive():
+ vad = EnergyVad()
+ result = vad.process([0.0] * 320)
+ assert result.active is False
+ assert result.energy == 0.0
+
+
+def test_energy_vad_detects_speech_above_noise():
+ vad = EnergyVad()
+ # First a quiet frame to seat the (low) noise floor, then a loud one.
+ vad.process([0.0] * 320)
+ result = vad.process([0.3] * 320)
+ assert result.active is True
+ assert result.energy == pytest.approx(0.3, rel=1e-6)
+
+
+def test_energy_vad_hangover_keeps_active_after_speech():
+ vad = EnergyVad(hangover_frames=3)
+ vad.process([0.3] * 320) # speech -> sets hangover
+ # Three quiet frames stay active by hangover, the fourth goes inactive.
+ assert vad.process([0.0] * 320).active is True
+ assert vad.process([0.0] * 320).active is True
+ assert vad.process([0.0] * 320).active is True
+ assert vad.process([0.0] * 320).active is False
+
+
+def test_energy_vad_reset():
+ vad = EnergyVad(hangover_frames=5)
+ vad.process([0.5] * 320)
+ vad.reset()
+ # After reset, a silent frame should be inactive (no lingering hangover).
+ assert vad.process([0.0] * 320).active is False
+
+
+# --------------------------------------------------------------------------- #
+# VadTimeline
+# --------------------------------------------------------------------------- #
+def test_timeline_returns_frames_in_window():
+ tl = VadTimeline(window_ms=30_000)
+ for ts in (0, 100, 200, 300, 400):
+ tl.push_frame(_active_frame("mic", ts))
+ frames = tl.frames_in_window(100, 300)
+ assert [f.ts for f in frames] == [100, 200, 300]
+
+
+def test_timeline_prunes_beyond_window():
+ tl = VadTimeline(window_ms=200)
+ for ts in (0, 100, 200, 1000):
+ tl.push_frame(_active_frame("mic", ts))
+ # Frames older than ts(1000) - 200 = 800 are pruned from the active head.
+ assert all(f.ts >= 800 for f in tl.frames_in_window(0, 2000))
+
+
+# --------------------------------------------------------------------------- #
+# word / turn attribution
+# --------------------------------------------------------------------------- #
+def test_attribute_word_unknown_when_no_energy():
+ tl = VadTimeline(30_000)
+ assert attribute_word(_word(0, 100), tl, dominance_ratio=4.0) == UNKNOWN_CHANNEL
+
+
+def test_attribute_word_single_channel_wins():
+ tl = VadTimeline(30_000)
+ for ts in (10, 30, 50):
+ tl.push_frame(_active_frame("mic", ts))
+ assert attribute_word(_word(0, 100), tl, dominance_ratio=4.0) == "mic"
+
+
+def test_attribute_word_dominant_channel_by_ratio():
+ tl = VadTimeline(30_000)
+ tl.push_frame(_active_frame("mic", 10, rms=1.0))
+ tl.push_frame(_active_frame("mic", 20, rms=1.0))
+ tl.push_frame(_active_frame("system", 30, rms=0.1))
+ # mic total 2.0 vs system 0.1 -> beats 4x ratio.
+ assert attribute_word(_word(0, 100), tl, dominance_ratio=4.0) == "mic"
+
+
+def test_attribute_word_close_scores_are_unknown():
+ tl = VadTimeline(30_000)
+ tl.push_frame(_active_frame("mic", 10, rms=1.0))
+ tl.push_frame(_active_frame("system", 20, rms=0.9))
+ # Too close to clear the 4x ratio -> unknown (resolution may back-fill).
+ assert attribute_word(_word(0, 100), tl, dominance_ratio=4.0) == UNKNOWN_CHANNEL
+
+
+def test_attribute_word_dominance_ratio_is_a_real_knob():
+ tl = VadTimeline(30_000)
+ tl.push_frame(_active_frame("mic", 10, rms=1.0))
+ tl.push_frame(_active_frame("system", 20, rms=0.9))
+ # Same frames, looser ratio -> mic now clears the bar and wins.
+ assert attribute_word(_word(0, 100), tl, dominance_ratio=1.1) == "mic"
+
+
+def test_attribute_word_exact_tie_is_unknown():
+ tl = VadTimeline(30_000)
+ tl.push_frame(_active_frame("mic", 10, rms=0.5))
+ tl.push_frame(_active_frame("system", 20, rms=0.5))
+ assert attribute_word(_word(0, 100), tl, dominance_ratio=4.0) == UNKNOWN_CHANNEL
+
+
+def test_roll_up_turn_channel_duration_weighted():
+ words = [_word(0, 1000), _word(1000, 1100)]
+ words[0].channel = "mic"
+ words[1].channel = "system"
+ # mic spans 1000ms vs system 100ms.
+ assert roll_up_turn_channel(words) == "mic"
+
+
+def test_roll_up_turn_channel_tie_is_unknown():
+ words = [_word(0, 100), _word(100, 200)]
+ words[0].channel = "mic"
+ words[1].channel = "system"
+ assert roll_up_turn_channel(words) == UNKNOWN_CHANNEL
+
+
+def test_attribute_turn_sets_word_and_turn_channels():
+ tl = VadTimeline(30_000)
+ for ts in (10, 30, 50, 110, 130):
+ tl.push_frame(_active_frame("system", ts))
+ turn = _turn([_word(0, 100), _word(100, 200)])
+ attribute_turn(turn, tl, dominance_ratio=4.0)
+ assert turn.words[0].channel == "system"
+ assert turn.words[1].channel == "system"
+ assert turn.channel == "system"
+
+
+# --------------------------------------------------------------------------- #
+# unknown-channel resolution
+# --------------------------------------------------------------------------- #
+def test_resolve_unknown_by_window_fills_from_neighbors():
+ turn = _turn([_word(0, 50), _word(60, 110), _word(120, 170)])
+ turn.words[0].channel = "mic"
+ turn.words[1].channel = UNKNOWN_CHANNEL
+ turn.words[2].channel = "mic"
+ resolve_unknown_channels_by_window(turn, resolution_window_words=2)
+ assert turn.words[1].channel == "mic"
+ assert turn.words[1].channel_resolved is True
+ assert turn.channel == "mic"
+
+
+def test_resolve_unknown_by_window_leaves_isolated_unknown():
+ turn = _turn([_word(0, 50)])
+ turn.words[0].channel = UNKNOWN_CHANNEL
+ resolve_unknown_channels_by_window(turn, resolution_window_words=2)
+ assert turn.words[0].channel == UNKNOWN_CHANNEL
+
+
+def test_resolve_unknown_by_speaker_history():
+ tl = VadTimeline(30_000)
+ # Speaker A clearly on mic during 0-100ms.
+ for ts in (10, 30, 50, 70):
+ tl.push_frame(_active_frame("mic", ts, rms=1.0))
+ history = {}
+ # First turn: word resolves to mic via VAD, building speaker A's history.
+ turn1 = _turn([_word(0, 100, speaker="A")])
+ attribute_turn(turn1, tl, dominance_ratio=4.0)
+ resolve_unknown_channels_by_speaker_history(
+ turn1, tl, history, min_rms_evidence=0.5, dominance_ratio=3.0
+ )
+ assert turn1.words[0].channel == "mic"
+
+ # Second turn: same speaker, but no VAD frames in the word window -> unknown,
+ # then back-filled from accumulated speaker history.
+ turn2 = _turn([_word(5000, 5100, speaker="A")])
+ attribute_turn(turn2, tl, dominance_ratio=4.0)
+ assert turn2.words[0].channel == UNKNOWN_CHANNEL
+ resolve_unknown_channels_by_speaker_history(
+ turn2, tl, history, min_rms_evidence=0.5, dominance_ratio=3.0
+ )
+ assert turn2.words[0].channel == "mic"
+ assert turn2.words[0].channel_resolved is True
+
+
+# --------------------------------------------------------------------------- #
+# mixer
+# --------------------------------------------------------------------------- #
+def _mixer(channels=("mic", "system")) -> _ChannelMixer:
+ return _ChannelMixer(list(channels), SAMPLE_RATE, ChannelAttributionOptions())
+
+
+def test_mixer_averages_channels_to_mono():
+ mixer = _mixer()
+ n = 800 # 50ms @ 16kHz -> meets the min-chunk floor
+ mixer.ingest("mic", _pcm(1000, n))
+ mixer.ingest("system", _pcm(2000, n))
+ chunks = mixer.drain()
+ assert len(chunks) == 1
+ out = array("h")
+ out.frombytes(chunks[0])
+ assert len(out) == n
+ assert all(s == 1500 for s in out) # round((1000 + 2000) / 2)
+
+
+def test_mixer_gates_on_shorter_buffer():
+ mixer = _mixer()
+ mixer.ingest("mic", _pcm(1000, 1600)) # 100ms
+ mixer.ingest("system", _pcm(2000, 800)) # 50ms
+ chunks = mixer.drain()
+ # Only the 800 overlapping samples can be mixed; the rest stays buffered.
+ assert sum(len(b) // 2 for b in chunks) == 800
+
+
+def test_mixer_min_chunk_floor_holds_small_buffers():
+ mixer = _mixer()
+ mixer.ingest("mic", _pcm(1000, 100)) # < 50ms
+ mixer.ingest("system", _pcm(2000, 100))
+ assert mixer.drain() == [] # below the floor
+ # force=True bypasses the floor (final flush).
+ assert len(mixer.drain(force=True)) == 1
+
+
+def test_mixer_caps_chunks_at_max_duration():
+ mixer = _mixer()
+ n = 16000 # 1 second -> must split into >=5 chunks of <=200ms (3200 samples)
+ mixer.ingest("mic", _pcm(1000, n))
+ mixer.ingest("system", _pcm(1000, n))
+ chunks = mixer.drain()
+ assert len(chunks) >= 5
+ assert all(len(b) // 2 <= 3200 for b in chunks)
+ assert sum(len(b) // 2 for b in chunks) == n
+
+
+def test_mixer_rejects_odd_length_pcm():
+ mixer = _mixer()
+ with pytest.raises(ValueError):
+ mixer.ingest("mic", b"\x00\x01\x02")
+
+
+def test_mixer_force_flush_pads_missing_channel():
+ # Regression: a channel that produced nothing must not gate the flush and
+ # swallow the other channel's audio.
+ mixer = _mixer()
+ mixer.ingest("mic", _pcm(1000, 800)) # system gets nothing
+ chunks = mixer.drain(force=True)
+ assert sum(len(b) // 2 for b in chunks) == 800 # mic's audio still sent
+
+
+def test_mixer_close_channel_lets_survivor_drain():
+ mixer = _mixer()
+ mixer.ingest("mic", _pcm(1000, 1600)) # 100ms
+ mixer.ingest("system", _pcm(2000, 800)) # 50ms
+ assert sum(len(b) // 2 for b in mixer.drain()) == 800 # aligned overlap only
+ # system's source ends -> the remaining mic audio drains without waiting.
+ mixer.close_channel("system")
+ assert sum(len(b) // 2 for b in mixer.drain()) == 800
+
+
+# --------------------------------------------------------------------------- #
+# coordinators
+# --------------------------------------------------------------------------- #
+class _FakeSyncClient:
+ def __init__(self):
+ self._handlers = {event: [] for event in StreamingEvents}
+ self.sent = []
+
+ def on(self, event, handler):
+ self._handlers[event].append(handler)
+
+ def stream(self, data):
+ self.sent.append(data)
+
+ def dispatch_turn(self, turn):
+ for handler in self._handlers[StreamingEvents.Turn]:
+ handler(self, turn)
+
+
+class _FakeAsyncClient(_FakeSyncClient):
+ async def stream(self, data): # type: ignore[override]
+ self.sent.append(data)
+
+
+def test_channel_streamer_validates_channels():
+ client = _FakeSyncClient()
+ with pytest.raises(ValueError):
+ ChannelStreamer(client, channels=["mic"], sample_rate=SAMPLE_RATE)
+ with pytest.raises(ValueError):
+ ChannelStreamer(client, channels=["mic", "mic"], sample_rate=SAMPLE_RATE)
+
+
+def test_channel_streamer_unknown_channel_raises():
+ client = _FakeSyncClient()
+ mixer = ChannelStreamer(client, ["mic", "system"], sample_rate=SAMPLE_RATE)
+ with pytest.raises(ValueError):
+ mixer.stream("speaker", _pcm(1, 800))
+
+
+def test_invalid_resolve_method_raises_at_construction():
+ # Must fail fast at construction, not silently inside the swallowed handler.
+ with pytest.raises(ValueError):
+ ChannelAttributionOptions(resolve_unknown_channels_method="bogus")
+
+
+def test_channel_streamer_close_channel_drains_survivor():
+ client = _FakeSyncClient()
+ mixer = ChannelStreamer(client, ["mic", "system"], sample_rate=SAMPLE_RATE)
+ mixer.stream("mic", _pcm(1000, 1600)) # 100ms
+ mixer.stream("system", _pcm(2000, 800)) # 50ms
+ before = sum(len(b) // 2 for b in client.sent)
+ mixer.close_channel("system") # system's source ended
+ after = sum(len(b) // 2 for b in client.sent)
+ assert after > before # surviving mic audio flushed instead of stalling
+
+
+def test_channel_streamer_sends_mixed_mono():
+ client = _FakeSyncClient()
+ mixer = ChannelStreamer(client, ["mic", "system"], sample_rate=SAMPLE_RATE)
+ mixer.stream("mic", _pcm(1000, 800))
+ mixer.stream("system", _pcm(2000, 800))
+ assert client.sent, "expected a mixed mono chunk to be sent"
+ out = array("h")
+ out.frombytes(b"".join(client.sent))
+ assert all(s == 1500 for s in out)
+
+
+def test_channel_streamer_enriches_turn_on_mixer_handler():
+ client = _FakeSyncClient()
+ mixer = ChannelStreamer(client, ["mic", "system"], sample_rate=SAMPLE_RATE)
+
+ seen = {}
+
+ def on_turn(_client, turn):
+ # Handler receives the enriched DualChannelTurnEvent.
+ seen["type"] = type(turn).__name__
+ seen["channel"] = turn.channel
+ seen["word_channel"] = turn.words[0].channel
+
+ mixer.on(StreamingEvents.Turn, on_turn)
+
+ # Drive enough mic audio to register active VAD frames in the word window.
+ mixer.stream("mic", _pcm(8000, 1600))
+ mixer.stream("system", _pcm(0, 1600))
+
+ # The client emits a base TurnEvent; the mixer enriches + dispatches it.
+ client.dispatch_turn(_turn([_word(0, 80)]))
+ assert seen["type"] == "DualChannelTurnEvent"
+ assert seen["channel"] == "mic"
+ assert seen["word_channel"] == "mic"
+
+
+def test_channel_streamer_none_method_leaves_unknown():
+ client = _FakeSyncClient()
+ opts = ChannelAttributionOptions(resolve_unknown_channels_method="none")
+ mixer = ChannelStreamer(
+ client, ["mic", "system"], sample_rate=SAMPLE_RATE, attribution=opts
+ )
+ captured = {}
+ mixer.on(StreamingEvents.Turn, lambda c, t: captured.update(ch=t.words[0].channel))
+ # No audio ingested -> no VAD -> word is unknown and stays unknown.
+ client.dispatch_turn(_turn([_word(0, 80)]))
+ assert captured["ch"] == UNKNOWN_CHANNEL
+
+
+def test_channel_streamer_forwards_non_turn_events_to_client():
+ client = _FakeSyncClient()
+ mixer = ChannelStreamer(client, ["mic", "system"], sample_rate=SAMPLE_RATE)
+
+ def on_error(_c, _e):
+ pass
+
+ mixer.on(StreamingEvents.Error, on_error)
+ # Non-Turn handlers are registered straight through on the underlying client.
+ assert on_error in client._handlers[StreamingEvents.Error]
+
+
+def test_channel_streamer_invokes_vad_callback():
+ client = _FakeSyncClient()
+ frames = []
+ mixer = ChannelStreamer(
+ client,
+ ["mic", "system"],
+ sample_rate=SAMPLE_RATE,
+ on_vad=frames.append,
+ )
+ mixer.stream("mic", _pcm(5000, 320)) # exactly one 20ms VAD frame
+ assert len(frames) == 1
+ assert frames[0].channel == "mic"
+
+
+@pytest.mark.asyncio
+async def test_async_channel_streamer_sends_mixed_mono():
+ client = _FakeAsyncClient()
+ mixer = AsyncChannelStreamer(client, ["mic", "system"], sample_rate=SAMPLE_RATE)
+ await mixer.stream("mic", _pcm(1000, 800))
+ await mixer.stream("system", _pcm(2000, 800))
+ await mixer.flush()
+ assert client.sent
+ out = array("h")
+ out.frombytes(b"".join(client.sent))
+ assert all(s == 1500 for s in out)