diff --git a/README.md b/README.md index b148540..d8c6d91 100644 --- a/README.md +++ b/README.md @@ -785,6 +785,70 @@ finally: +
+ Dual-channel: mic + system audio in one session + +For note-taker apps that capture two live sources (microphone **and** system/speaker output) but want them handled as **one** streaming session — while still knowing which source each word came from — wrap the client in a `ChannelStreamer`. + +You declare named channels and feed each channel's PCM separately. The SDK runs per-channel energy VAD, mixes the channels into a single mono stream over one websocket, and — for handlers registered on the coordinator — delivers an enriched `DualChannelTurnEvent` whose words/turn carry their originating channel (`turn.channel` and per-word `word.channel`). The base `Word` / `TurnEvent` stay unchanged, so single-stream payloads aren't affected. Attribution is fully client-side and model-agnostic, so it composes with `speaker_labels`, multilingual, and `u3-rt-pro`. It is a **separate dimension from diarization** — `word.channel` (physical source) is independent of `word.speaker` (voice): two people on the same `system` channel get distinct speaker labels, while one person heard on two channels keeps a single speaker label. + +Unlike a browser sample, the SDK does not capture audio — you supply 16-bit PCM for each channel (from `sounddevice`, `pyaudio`, a loopback device, files, …). + +```python +from assemblyai.streaming.v3 import ( + ChannelStreamer, StreamingClient, StreamingClientOptions, + StreamingEvents, StreamingParameters, +) + +def on_turn(client, event): # event is a DualChannelTurnEvent + print(f"[{event.channel}] {event.transcript}") + for w in event.words: + print(f" {w.text!r} -> channel={w.channel} speaker={w.speaker}") + +client = StreamingClient(StreamingClientOptions(api_key="")) + +# Declare the channels and the session sample rate (must be pcm_s16le). +mixer = ChannelStreamer(client, channels=["mic", "system"], sample_rate=16000) +# Register handlers on the mixer: Turn handlers receive the enriched event, +# other events (Begin/Error/…) are forwarded to the client. +mixer.on(StreamingEvents.Turn, on_turn) +client.connect(StreamingParameters( + sample_rate=16000, speech_model="u3-rt-pro", speaker_labels=True, +)) + +# Feed each source separately — e.g. from two capture callbacks. Send +# continuous PCM for every channel (silence as zeros), at the same rate. +mixer.stream("mic", mic_pcm) +mixer.stream("system", system_pcm) + +mixer.flush() # push trailing buffered audio +client.disconnect(terminate=True) +``` + +`AsyncChannelStreamer` is the asyncio-native equivalent (`await mixer.stream(...)` / `await mixer.close_channel(...)` / `await mixer.flush()`); register handlers the same way with `mixer.on(...)`. + +**Sources that end mid-session.** Mixing keeps channels aligned by consuming the shortest buffer, so it assumes every channel keeps delivering PCM (send silence as zeros, don't omit it). When a source genuinely ends (file EOF, screen share stopped, device removed), call `mixer.close_channel(name)` so the session degrades to the surviving channel(s) instead of stalling — the ended channel is then padded with silence. + +**Swappable VAD.** The default detector is the built-in energy-based `EnergyVad`. Supply your own (e.g. a DNN VAD such as Silero) via `ChannelAttributionOptions.create_vad`, which is called once per channel with the channel name; subclass `VadDetector` (`process(frame) -> VadResult`, `reset()`). Pass `on_vad=callback` to observe raw per-frame activity (e.g. a live "who's talking" meter). Tune the default with `EnergyVad(threshold_ratio=3.0, noise_floor_alpha=0.05, hangover_frames=10)` — `threshold_ratio` below ~2 is too sensitive, above ~6 misses quiet onsets/offsets. + +**Resolving unknown channels.** A word is `"unknown"` when no channel was clearly dominant in its window — silence, or two channels too close to call (the top must beat the runner-up by `dominance_ratio`, default 4). `ChannelAttributionOptions.resolve_unknown_channels_method` back-fills these: + +- `"window"` (default) — from the dominant non-`"unknown"` channel among ±`resolution_window_words` neighbor words. +- `"speaker-history"` — from the speaker's session-wide channel evidence (requires `speaker_labels`). +- `"none"` — leave `"unknown"` as-is. + +Back-filled words are flagged `word.channel_resolved = True`; confident per-word decisions are never overwritten. The method is validated at construction, so a typo raises immediately rather than silently disabling resolution. + +**Caveats.** + +- Requires 16-bit PCM (`pcm_s16le`, the default) — linear mixing is invalid for `pcm_mulaw`. +- Capturing the system/speaker output is platform-specific: macOS needs a loopback driver (e.g. BlackHole); Windows uses WASAPI loopback; Linux a PulseAudio/PipeWire monitor source. +- If the mic physically picks up the speakers, that bleed can pull attribution toward `mic`. Apply acoustic echo cancellation at capture (`getUserMedia({ audio: { echoCancellation: true } })` in browser front-ends, or an AEC-capable native path) — the SDK only receives already-captured PCM, so it can't apply AEC itself. Transcription quality is unaffected; only the `channel` field. + +See [`examples/streaming_dual_channel.py`](./examples/streaming_dual_channel.py) for a complete runnable demo. + +
+
Stream a local file (async) diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py index 0baa0e8..bffaa07 100644 --- a/assemblyai/__version__.py +++ b/assemblyai/__version__.py @@ -1 +1 @@ -__version__ = "0.64.20" +__version__ = "0.64.21" diff --git a/assemblyai/streaming/v3/__init__.py b/assemblyai/streaming/v3/__init__.py index bd9ca87..f72aa0f 100644 --- a/assemblyai/streaming/v3/__init__.py +++ b/assemblyai/streaming/v3/__init__.py @@ -1,5 +1,17 @@ from .async_client import AsyncStreamingClient from .client import StreamingClient +from .extras import ( + AsyncChannelStreamer, + ChannelAttributionOptions, + ChannelStreamer, + DualChannelTurnEvent, + DualChannelWord, + EnergyVad, + VadDetector, + VadFrame, + VadResult, + attribute_turn, +) from .models import ( BeginEvent, Encoding, @@ -27,8 +39,14 @@ ) __all__ = [ + "AsyncChannelStreamer", "AsyncStreamingClient", "BeginEvent", + "ChannelAttributionOptions", + "ChannelStreamer", + "DualChannelTurnEvent", + "DualChannelWord", + "EnergyVad", "Encoding", "EventMessage", "LLMGatewayResponseEvent", @@ -50,6 +68,10 @@ "StreamingSessionParameters", "TerminationEvent", "TurnEvent", + "VadDetector", + "VadFrame", + "VadResult", "WarningEvent", "Word", + "attribute_turn", ] diff --git a/assemblyai/streaming/v3/async_client.py b/assemblyai/streaming/v3/async_client.py index f04c9b6..03f3ef0 100644 --- a/assemblyai/streaming/v3/async_client.py +++ b/assemblyai/streaming/v3/async_client.py @@ -37,6 +37,7 @@ ErrorEvent, EventMessage, ForceEndpoint, + KeepAlive, OperationMessage, StreamingClientOptions, StreamingError, @@ -80,10 +81,10 @@ class AsyncStreamingClient(_BaseStreamingClient): Behavioral notes vs. the sync ``StreamingClient``: - - ``stream`` / ``set_params`` / ``force_endpoint`` raise ``RuntimeError`` - when called before ``connect()`` — silent drop would diverge from the - sync client (which buffers pre-connect data) in a way that's easy to - miss. After the connection has closed, the same calls are silent + - ``stream`` / ``set_params`` / ``force_endpoint`` / ``keep_alive`` raise + ``RuntimeError`` when called before ``connect()`` — silent drop would + diverge from the sync client (which buffers pre-connect data) in a way + that's easy to miss. After the connection has closed, the same calls are silent no-ops so cleanup paths don't need defensive try/except. - ``disconnect(terminate=True)`` waits at most 2.0s for the write task to drain the ``TerminateSession`` frame before forcing teardown. The sync @@ -288,6 +289,12 @@ async def force_endpoint(self) -> None: return await write_queue.put(ForceEndpoint()) + async def keep_alive(self) -> None: + write_queue, stop_event = self._ensure_connected("keep_alive") + if stop_event.is_set(): + return + await write_queue.put(KeepAlive()) + def _ensure_connected( self, method: str ) -> "tuple[asyncio.Queue[OperationMessage], asyncio.Event]": diff --git a/assemblyai/streaming/v3/client.py b/assemblyai/streaming/v3/client.py index 8dc5084..64740c3 100644 --- a/assemblyai/streaming/v3/client.py +++ b/assemblyai/streaming/v3/client.py @@ -24,6 +24,7 @@ ErrorEvent, EventMessage, ForceEndpoint, + KeepAlive, OperationMessage, StreamingClientOptions, StreamingError, @@ -194,6 +195,10 @@ def force_endpoint(self): message = ForceEndpoint() self._write_queue.put(message) + def keep_alive(self): + message = KeepAlive() + self._write_queue.put(message) + def _write_message(self) -> None: while True: if not self._websocket: diff --git a/assemblyai/streaming/v3/extras.py b/assemblyai/streaming/v3/extras.py new file mode 100644 index 0000000..f29006c --- /dev/null +++ b/assemblyai/streaming/v3/extras.py @@ -0,0 +1,742 @@ +"""Client-side dual / multi-channel support for streaming v3. + +Note-taker use case: capture two live sources (microphone + system audio) as +one streaming session while still knowing which physical source each word came +from. + +- Each named channel's PCM is fed in separately via ``ChannelStreamer.stream``. +- Per-channel energy VAD records which channel was acoustically active when. +- The channels are summed into one mono stream sent over the existing single + websocket session. +- On every ``Turn``, words are attributed back to a channel by matching their + server timestamps against the per-channel VAD timeline; the enriched + ``TurnEvent`` (``turn.channel`` + per-word ``word.channel``) is delivered to + the handler registered on the coordinator. + +Attribution is purely client-side, so any ``speech_model`` works unchanged and +channel (physical source) stays independent of diarization ``speaker``. +""" + +import inspect +import logging +import math +import sys +import threading +from array import array +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Union, +) + +from .models import StreamingEvents, TurnEvent, Word + +if TYPE_CHECKING: # avoid an import cycle; only used for type hints + from .async_client import AsyncStreamingClient + from .client import StreamingClient + +logger = logging.getLogger(__name__) + +UNKNOWN_CHANNEL = "unknown" + +# 20 ms VAD frames, and the server's per-message audio duration limits. +_VAD_FRAME_MS = 20 +_MIN_CHUNK_MS = 50 +_MAX_CHUNK_MS = 200 + + +# Enriched event types subclass the base streaming models and add the +# client-side channel fields, so the base ``Word`` / ``TurnEvent`` payloads stay +# clean for single-stream users. A coordinator builds these from the base events +# and delivers them to handlers registered on the coordinator. +class DualChannelWord(Word): + """A ``Word`` enriched with channel attribution (independent of ``speaker``).""" + + # Physical input channel (e.g. "mic" / "system"), or "unknown" if no channel + # was clearly dominant during the word's window. + channel: Optional[str] = None + # True when ``channel`` was inferred by an unknown-resolution strategy rather + # than measured directly by VAD. + channel_resolved: Optional[bool] = None + + +class DualChannelTurnEvent(TurnEvent): + """A ``TurnEvent`` enriched with per-word and turn-level channel attribution.""" + + words: List[DualChannelWord] = [] + # Duration-weighted majority channel across ``words``, or "unknown". + channel: Optional[str] = None + + +@dataclass +class VadResult: + """One frame's voice-activity decision: whether speech is present and the + frame's RMS energy (used to weight per-channel attribution).""" + + active: bool + energy: float + + +@dataclass +class VadFrame: + """A per-channel, per-frame VAD observation. + + ``ts`` is stream-relative milliseconds from the channel's own sample counter + — the same reference frame as ``Word.start`` / ``Word.end``. + """ + + ts: float + channel: str + active: bool + rms: float + + +class VadDetector: + """Pluggable per-channel voice-activity detector. + + A separate instance is held per channel. ``process`` receives a fixed-size + frame of float samples in ``[-1.0, 1.0]`` at the session's sample rate. + Subclass / duck-type this to drop in a DNN-backed detector for noisy + environments. + """ + + def process(self, frame: Sequence[float]) -> VadResult: # pragma: no cover + raise NotImplementedError + + def reset(self) -> None: # pragma: no cover + raise NotImplementedError + + +class EnergyVad(VadDetector): + """Energy-based VAD with adaptive noise-floor tracking and hangover. + + Pure Python, no dependencies. Suited to "which physical channel is speaking" + since the channels are already physically separated at capture; hand the + harder speech-vs-noise problem to a custom ``VadDetector`` via + ``ChannelAttributionOptions.create_vad``. + + Tuning: ``threshold_ratio`` below 2 is over-sensitive, above 6 misses quiet + onsets/offsets; ``noise_floor_alpha`` above 0.1 adapts to non-stationary + background faster but risks creeping up onto a sustained quiet voice. + """ + + def __init__( + self, + threshold_ratio: float = 3.0, + noise_floor_alpha: float = 0.05, + hangover_frames: int = 10, + initial_noise_floor: float = 1e-4, + ): + self._threshold_ratio = threshold_ratio + self._noise_floor_alpha = noise_floor_alpha + self._hangover_frames = hangover_frames + self._initial_noise_floor = initial_noise_floor + self._noise_floor = initial_noise_floor + self._hangover_remaining = 0 + + def process(self, frame: Sequence[float]) -> VadResult: + n = len(frame) + sum_sq = 0.0 + for s in frame: + sum_sq += s * s + rms = math.sqrt(sum_sq / n) if n > 0 else 0.0 + + threshold = self._noise_floor * self._threshold_ratio + active = rms > threshold + + if active: + self._hangover_remaining = self._hangover_frames + elif self._hangover_remaining > 0: + self._hangover_remaining -= 1 + active = True + # In hangover, don't update the floor — RMS may still be tail energy. + else: + self._noise_floor = ( + self._noise_floor * (1 - self._noise_floor_alpha) + + rms * self._noise_floor_alpha + ) + + return VadResult(active=active, energy=rms) + + def reset(self) -> None: + self._noise_floor = self._initial_noise_floor + self._hangover_remaining = 0 + + +class VadTimeline: + """Append-only ring buffer of ``VadFrame``s in stream-relative ms order. + + ``push_frame`` is amortized O(1); ``frames_in_window`` is O(n) over kept + frames, fine for the per-word lookups done here. + """ + + def __init__(self, window_ms: int): + self._window_ms = window_ms + self._frames: List[VadFrame] = [] + self._head = 0 + # The threaded ``StreamingClient`` runs ``frames_in_window`` on the read + # thread while the user thread runs ``push_frame``; compaction swaps + # ``_frames`` / ``_head`` non-atomically. The lock keeps push/read/compact + # mutually exclusive. Uncontended on the async / single-threaded paths. + self._lock = threading.Lock() + + def push_frame(self, frame: VadFrame) -> None: + with self._lock: + self._frames.append(frame) + cutoff = frame.ts - self._window_ms + while ( + self._head < len(self._frames) and self._frames[self._head].ts < cutoff + ): + self._head += 1 + # Compact occasionally so the list doesn't grow without bound. + if self._head > 1024 and self._head * 2 > len(self._frames): + self._frames = self._frames[self._head :] + self._head = 0 + + def frames_in_window(self, start_ms: float, end_ms: float) -> List[VadFrame]: + out: List[VadFrame] = [] + with self._lock: + for i in range(self._head, len(self._frames)): + f = self._frames[i] + if f.ts < start_ms: + continue + if f.ts > end_ms: + break + out.append(f) + return out + + def clear(self) -> None: + with self._lock: + self._frames = [] + self._head = 0 + + +def _score_channels(frames: Iterable[VadFrame]) -> Dict[str, float]: + """Sum active-frame RMS per channel. Channels with no active energy are + omitted from the result.""" + scores: Dict[str, float] = {} + for f in frames: + if not f.active: + continue + scores[f.channel] = scores.get(f.channel, 0.0) + f.rms + return scores + + +def _top_by_ratio(scores: Dict[str, float], dominance_ratio: float) -> Optional[str]: + """Winner of a per-channel score map: the sole channel if only one had + energy, else the top channel iff it beats the runner-up by + ``dominance_ratio``. ``None`` when there's no clear winner (tie / too close) + or no scores. + + The ratio is a real knob — raising it yields more ``None`` (i.e. "unknown", + which a resolution strategy may back-fill). This diverges from the Node + reference, whose absolute-winner fallback makes the equivalent ratio a no-op. + """ + if not scores: + return None + ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True) + if len(ranked) == 1: + return ranked[0][0] + (top_name, top_score), (_, runner_score) = ranked[0], ranked[1] + if runner_score <= 0 or top_score >= dominance_ratio * runner_score: + return top_name + return None + + +def attribute_word( + word: Word, + timeline: VadTimeline, + dominance_ratio: float, +) -> str: + """Channel that dominated the word's ``[start, end]`` window, or "unknown" + if no channel had energy there or none was clearly dominant.""" + scores = _score_channels(timeline.frames_in_window(word.start, word.end)) + winner = _top_by_ratio(scores, dominance_ratio) + return winner if winner is not None else UNKNOWN_CHANNEL + + +def roll_up_turn_channel(words: Sequence[DualChannelWord]) -> str: + """Duration-weighted majority of per-word channels. "unknown" if there are + no resolved words or two channels tie exactly.""" + totals: Dict[str, float] = {} + for w in words: + if not w.channel or w.channel == UNKNOWN_CHANNEL: + continue + dur = max(0, w.end - w.start) + totals[w.channel] = totals.get(w.channel, 0.0) + dur + if not totals: + return UNKNOWN_CHANNEL + ranked = sorted(totals.items(), key=lambda kv: kv[1], reverse=True) + if len(ranked) == 1: + return ranked[0][0] + (top_name, top_ms), (_, runner_ms) = ranked[0], ranked[1] + if top_ms == runner_ms: + return UNKNOWN_CHANNEL + return top_name + + +def attribute_turn( + turn: DualChannelTurnEvent, + timeline: VadTimeline, + dominance_ratio: float, +) -> None: + """Write ``turn.words[i].channel`` for every word and set ``turn.channel`` + to the duration-weighted rollup. Mutates the enriched ``turn`` in place.""" + for w in turn.words: + w.channel = attribute_word(w, timeline, dominance_ratio) + turn.channel = roll_up_turn_channel(turn.words) + + +def resolve_unknown_channels_by_window( + turn: DualChannelTurnEvent, + resolution_window_words: int, +) -> None: + """Back-fill "unknown" words from the dominant non-"unknown" channel among + +/-N neighbor words in the same turn. Words with no resolved neighbors stay + "unknown"; confident decisions are never modified; resolved words are tagged + ``channel_resolved = True``.""" + words = turn.words + mutated = False + for i, target in enumerate(words): + if target.channel != UNKNOWN_CHANNEL: + continue + tally: Dict[str, int] = {} + lo = max(0, i - resolution_window_words) + hi = min(len(words) - 1, i + resolution_window_words) + for j in range(lo, hi + 1): + if j == i: + continue + ch = words[j].channel + if not ch or ch == UNKNOWN_CHANNEL: + continue + tally[ch] = tally.get(ch, 0) + 1 + if not tally: + continue + top: Optional[str] = None + top_count = 0 + tied = False + for name, count in tally.items(): + if count > top_count: + top, top_count, tied = name, count, False + elif count == top_count: + tied = True + if top is not None and not tied: + target.channel = top + target.channel_resolved = True + mutated = True + if mutated: + turn.channel = roll_up_turn_channel(words) + + +def resolve_unknown_channels_by_speaker_history( + turn: DualChannelTurnEvent, + timeline: VadTimeline, + speaker_history: Dict[str, Dict[str, float]], + min_rms_evidence: float, + dominance_ratio: float, +) -> None: + """Back-fill "unknown" words using each speaker's session-wide channel + evidence: per speaker, sum active VAD-frame RMS per channel across all their + words so far (accumulated into ``speaker_history``). A speaker is resolvable + once their total evidence clears ``min_rms_evidence`` and their top channel + beats the runner-up by ``dominance_ratio``. Only touches "unknown" words; + ``speaker`` is never modified.""" + # 1. Accumulate evidence from this turn's words. + for w in turn.words: + if not w.speaker: + continue + entry = speaker_history.setdefault(w.speaker, {}) + for f in timeline.frames_in_window(w.start, w.end): + if not f.active: + continue + entry[f.channel] = entry.get(f.channel, 0.0) + f.rms + + # 2. Fill unknown words whose speakers have dominant evidence. + mutated = False + for w in turn.words: + if w.channel != UNKNOWN_CHANNEL or not w.speaker: + continue + entry_or_none = speaker_history.get(w.speaker) + if not entry_or_none or sum(entry_or_none.values()) < min_rms_evidence: + continue + winner = _top_by_ratio(entry_or_none, dominance_ratio) + if winner is not None: + w.channel = winner + w.channel_resolved = True + mutated = True + if mutated: + turn.channel = roll_up_turn_channel(turn.words) + + +RESOLVE_UNKNOWN_METHODS = ("none", "window", "speaker-history") + + +@dataclass +class ChannelAttributionOptions: + """Tuning for client-side channel attribution. All fields have sane defaults + matching the Node SDK; override only when needed.""" + + # Per-word energy ratio above which a channel is declared dominant. + dominance_ratio: float = 4.0 + # How far back the VAD timeline retains frames for per-word lookups. + timeline_window_ms: int = 30_000 + # Factory for the per-channel detector; called once per channel with its + # name. Defaults to a fresh ``EnergyVad`` per channel. + create_vad: Optional[Callable[[str], VadDetector]] = None + # How to fill words VAD couldn't attribute: "none" | "window" | + # "speaker-history". + resolve_unknown_channels_method: str = "window" + # "window": +/-N neighbor words consulted to fill an unknown word. + resolution_window_words: int = 2 + # "speaker-history": minimum cumulative active-RMS before a speaker's + # channel is considered established. + speaker_history_min_rms_evidence: float = 0.5 + # "speaker-history": top channel must beat the runner-up by this ratio. + speaker_history_dominance_ratio: float = 3.0 + + def __post_init__(self) -> None: + # Fail fast at construction; an invalid method would otherwise be + # swallowed inside the Turn handler, silently disabling resolution. + if self.resolve_unknown_channels_method not in RESOLVE_UNKNOWN_METHODS: + raise ValueError( + "resolve_unknown_channels_method must be one of " + f"{RESOLVE_UNKNOWN_METHODS}; got " + f"{self.resolve_unknown_channels_method!r}." + ) + + def _make_vad(self, channel: str) -> VadDetector: + if self.create_vad is not None: + return self.create_vad(channel) + return EnergyVad() + + +_BIG_ENDIAN = sys.byteorder == "big" + + +def _pcm16_to_array(data: bytes) -> array: + """Parse little-endian 16-bit PCM bytes into a signed-short ``array``.""" + if len(data) % 2 != 0: + raise ValueError( + f"PCM data length must be even (16-bit samples); got {len(data)} bytes." + ) + samples = array("h") + samples.frombytes(bytes(data)) + if _BIG_ENDIAN: + samples.byteswap() # interpret the bytes as little-endian + return samples + + +def _array_to_pcm16(samples: array) -> bytes: + """Serialize a signed-short ``array`` back to little-endian 16-bit PCM.""" + if _BIG_ENDIAN: + samples = samples[:] + samples.byteswap() + return samples.tobytes() + + +class _ChannelMixer: + """Owns per-channel PCM buffers, per-channel VAD, the shared VAD timeline, + and the mono mixing math. Runtime-agnostic — both coordinators drive it.""" + + def __init__( + self, + channels: Sequence[str], + sample_rate: int, + options: ChannelAttributionOptions, + ): + self.channels = list(channels) + self.sample_rate = sample_rate + self.timeline = VadTimeline(options.timeline_window_ms) + + self._vad_frame_samples = max(1, round(sample_rate * _VAD_FRAME_MS / 1000)) + self._min_chunk_samples = max(1, round(sample_rate * _MIN_CHUNK_MS / 1000)) + self._max_chunk_samples = max( + self._min_chunk_samples, round(sample_rate * _MAX_CHUNK_MS / 1000) + ) + + self._buffers: Dict[str, array] = {n: array("h") for n in self.channels} + self._vad_frame: Dict[str, List[float]] = {n: [] for n in self.channels} + self._received: Dict[str, int] = {n: 0 for n in self.channels} + self._vads: Dict[str, VadDetector] = { + n: options._make_vad(n) for n in self.channels + } + # Channels closed via ``close_channel`` — treated as silence so their + # absence no longer gates mixing for the survivors. + self._ended: Set[str] = set() + + def close_channel(self, channel: str) -> None: + """Mark a channel finished (source ended). Subsequent ``drain`` calls + stop waiting on it and pad it with silence.""" + self._ended.add(channel) + + def ingest( + self, + channel: str, + data: bytes, + on_vad: Optional[Callable[[VadFrame], None]] = None, + ) -> None: + samples = _pcm16_to_array(data) + self._buffers[channel].extend(samples) + + vad = self._vads[channel] + frame_buf = self._vad_frame[channel] + received = self._received[channel] + for s in samples: + frame_buf.append(s / 0x8000) + received += 1 + if len(frame_buf) == self._vad_frame_samples: + result = vad.process(frame_buf) + frame = VadFrame( + ts=received / self.sample_rate * 1000, + channel=channel, + active=result.active, + rms=result.energy, + ) + self.timeline.push_frame(frame) + if on_vad is not None: + on_vad(frame) + frame_buf.clear() + self._received[channel] = received + + def drain(self, force: bool = False) -> List[bytes]: + """Mix buffered audio into mono PCM chunks, each clamped to + ``[_MIN_CHUNK_MS, _MAX_CHUNK_MS]`` (the ``_MIN_CHUNK_MS`` floor applies + only while not ``force``). + + While every channel is still feeding, mixing gates on the shortest live + buffer to keep channels time-aligned. Once a channel is closed + (``close_channel``) or on the final ``force`` flush, shorter/ended + buffers are zero-padded up to the longest instead of gating on them, so + a terminated source degrades to the survivors rather than stalling the + session and dropping everything accumulated since. + """ + bufs = [self._buffers[n] for n in self.channels] + divisor = len(bufs) + # Pad (don't gate) once any channel has ended, or on the final flush. + pad = force or bool(self._ended) + out_chunks: List[bytes] = [] + while True: + if pad: + mix_len = max((len(b) for b in bufs), default=0) + else: + live = [ + len(self._buffers[n]) for n in self.channels if n not in self._ended + ] + mix_len = min(live) if live else 0 + if mix_len == 0: + break + if not force and mix_len < self._min_chunk_samples: + break + if mix_len > self._max_chunk_samples: + mix_len = self._max_chunk_samples + out = array("h", bytes(2 * mix_len)) + for i in range(mix_len): + total = 0 + for b in bufs: + total += b[i] if i < len(b) else 0 # pad past buffer end + avg = round(total / divisor) + out[i] = -32768 if avg < -32768 else (32767 if avg > 32767 else avg) + for b in bufs: + del b[: min(mix_len, len(b))] + out_chunks.append(_array_to_pcm16(out)) + return out_chunks + + +def _validate_channels(channels: Sequence[str]) -> List[str]: + names = list(channels) + if len(names) < 2: + raise ValueError("channels must declare at least 2 channel names.") + if len(set(names)) != len(names): + raise ValueError("channels names must be unique.") + if any(not isinstance(n, str) or not n for n in names): + raise ValueError("channel names must be non-empty strings.") + return names + + +class _BaseChannelStreamer: + """Shared dual/multi-channel coordination independent of the wrapped + client's sync/async I/O. Channel config lives here, never on + ``StreamingParameters`` (it must not reach the websocket URL); the wrapped + client streams ordinary mono audio and is otherwise untouched. + """ + + def __init__( + self, + channels: Sequence[str], + sample_rate: int, + options: Optional[ChannelAttributionOptions], + on_vad: Optional[Callable[[VadFrame], None]], + ): + self.channels = _validate_channels(channels) + self._options = options or ChannelAttributionOptions() + self._on_vad = on_vad + self._mixer = _ChannelMixer(self.channels, sample_rate, self._options) + self._speaker_history: Dict[str, Dict[str, float]] = {} + self._turn_handlers: List[Callable] = [] + # Set by each subclass in __init__ (the concrete sync/async client). + self._client: Union["StreamingClient", "AsyncStreamingClient"] + + def on(self, event: StreamingEvents, handler: Callable) -> None: + """Register an event handler. ``Turn`` events are delivered as an + enriched ``DualChannelTurnEvent``; all other events are forwarded to the + underlying client unchanged.""" + if event == StreamingEvents.Turn: + self._turn_handlers.append(handler) + else: + self._client.on(event, handler) + + def _check_channel(self, channel: str) -> None: + if channel not in self._mixer._buffers: + raise ValueError( + f'Unknown channel "{channel}"; declared channels: ' + f"{', '.join(self.channels)}." + ) + + def _enrich(self, base_turn: TurnEvent) -> DualChannelTurnEvent: + """Build a ``DualChannelTurnEvent`` from the base turn (left untouched) + and run channel attribution + the configured unknown-resolution.""" + # Dump the base event keeping None fields (pydantic v1/v2). + data = ( + base_turn.model_dump() + if hasattr(base_turn, "model_dump") + else base_turn.dict() + ) + enriched = DualChannelTurnEvent(**data) + attribute_turn(enriched, self._mixer.timeline, self._options.dominance_ratio) + method = self._options.resolve_unknown_channels_method + if method == "window": + resolve_unknown_channels_by_window( + enriched, self._options.resolution_window_words + ) + elif method == "speaker-history": + resolve_unknown_channels_by_speaker_history( + enriched, + self._mixer.timeline, + self._speaker_history, + self._options.speaker_history_min_rms_evidence, + self._options.speaker_history_dominance_ratio, + ) + # method == "none": leave "unknown" words as-is (validated at construction). + return enriched + + @staticmethod + def _as_chunks(data: Union[bytes, Iterable[bytes]]) -> Iterable[bytes]: + if isinstance(data, (bytes, bytearray, memoryview)): + return [bytes(data)] + return data + + +class ChannelStreamer(_BaseChannelStreamer): + """Dual/multi-channel coordinator for the threaded ``StreamingClient``. + + Feed each named channel's 16-bit little-endian PCM via ``stream(channel, + data)``; the channels are summed into one mono stream over the client's + single session. Register handlers on the coordinator (``mixer.on(...)``): + ``Turn`` handlers receive an enriched ``DualChannelTurnEvent``; all other + events are forwarded to the wrapped client, keeping the base payloads clean + for single-stream use. + + Requires ``pcm_s16le`` (linear mixing is invalid for ``pcm_mulaw``). Feed + every channel continuous PCM at the same ``sample_rate`` (silence as zeros); + call ``close_channel(name)`` when a source ends mid-session so the session + degrades to the survivors, and ``flush()`` before + ``client.disconnect(terminate=True)`` to push the trailing audio. + """ + + def __init__( + self, + client: "StreamingClient", + channels: Sequence[str], + sample_rate: int, + attribution: Optional[ChannelAttributionOptions] = None, + on_vad: Optional[Callable[[VadFrame], None]] = None, + ): + super().__init__(channels, sample_rate, attribution, on_vad) + self._client = client + client.on(StreamingEvents.Turn, self._handle_turn) + + def _handle_turn(self, client: object, base_turn: TurnEvent) -> None: + enriched = self._enrich(base_turn) + for handler in self._turn_handlers: + try: + handler(client, enriched) + except Exception: + logger.exception("dual-channel on_turn handler raised") + + def stream(self, channel: str, data: Union[bytes, Iterable[bytes]]) -> None: + """Ingest PCM for ``channel``, then send whatever can now be mixed.""" + self._check_channel(channel) + for chunk in self._as_chunks(data): + self._mixer.ingest(channel, chunk, on_vad=self._on_vad) + for mixed in self._mixer.drain(): + self._client.stream(mixed) + + def close_channel(self, channel: str) -> None: + """Signal that ``channel``'s source has ended; the session keeps + streaming the survivors, sending any newly mixable audio immediately.""" + self._check_channel(channel) + self._mixer.close_channel(channel) + for mixed in self._mixer.drain(): + self._client.stream(mixed) + + def flush(self) -> None: + """Mix and send any remaining buffered audio. Call before + ``client.disconnect``.""" + for mixed in self._mixer.drain(force=True): + self._client.stream(mixed) + + +class AsyncChannelStreamer(_BaseChannelStreamer): + """Asyncio-native counterpart to ``ChannelStreamer`` (wraps + ``AsyncStreamingClient``); ``stream`` / ``close_channel`` / ``flush`` are + coroutines. ``Turn`` handlers may be sync or ``async`` (awaited inline on the + read task). See ``ChannelStreamer`` for requirements. + """ + + def __init__( + self, + client: "AsyncStreamingClient", + channels: Sequence[str], + sample_rate: int, + attribution: Optional[ChannelAttributionOptions] = None, + on_vad: Optional[Callable[[VadFrame], None]] = None, + ): + super().__init__(channels, sample_rate, attribution, on_vad) + self._client: "AsyncStreamingClient" = client + client.on(StreamingEvents.Turn, self._handle_turn) + + async def _handle_turn(self, client: object, base_turn: TurnEvent) -> None: + enriched = self._enrich(base_turn) + for handler in self._turn_handlers: + try: + result = handler(client, enriched) + if inspect.isawaitable(result): + await result + except Exception: + logger.exception("dual-channel on_turn handler raised") + + async def stream(self, channel: str, data: Union[bytes, Iterable[bytes]]) -> None: + self._check_channel(channel) + for chunk in self._as_chunks(data): + self._mixer.ingest(channel, chunk, on_vad=self._on_vad) + for mixed in self._mixer.drain(): + await self._client.stream(mixed) + + async def close_channel(self, channel: str) -> None: + """Signal that ``channel``'s source has ended; the session keeps + streaming the survivors instead of stalling.""" + self._check_channel(channel) + self._mixer.close_channel(channel) + for mixed in self._mixer.drain(): + await self._client.stream(mixed) + + async def flush(self) -> None: + for mixed in self._mixer.drain(force=True): + await self._client.stream(mixed) diff --git a/assemblyai/streaming/v3/models.py b/assemblyai/streaming/v3/models.py index 089a98f..f64dede 100644 --- a/assemblyai/streaming/v3/models.py +++ b/assemblyai/streaming/v3/models.py @@ -128,6 +128,10 @@ class ForceEndpoint(BaseModel): type: Literal["ForceEndpoint"] = "ForceEndpoint" +class KeepAlive(BaseModel): + type: Literal["KeepAlive"] = "KeepAlive" + + class StreamingSessionParameters(BaseModel): end_of_turn_confidence_threshold: Optional[float] = None min_end_of_turn_silence_when_confident: Optional[int] = ( @@ -287,6 +291,7 @@ class UpdateConfiguration(StreamingSessionParameters): bytes, TerminateSession, ForceEndpoint, + KeepAlive, UpdateConfiguration, ] diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py index fc84951..a1ad588 100644 --- a/tests/unit/test_streaming.py +++ b/tests/unit/test_streaming.py @@ -27,7 +27,10 @@ Word, ) from assemblyai.streaming.v3._base import _build_uri -from assemblyai.streaming.v3.models import TerminateSession +from assemblyai.streaming.v3.models import ( + KeepAlive, + TerminateSession, +) def _disable_rw_threads(mocker: MockFixture): @@ -558,6 +561,33 @@ def mocked_websocket_connect( assert isinstance(client._write_queue.get(timeout=1), bytes) +def test_client_keep_alive_enqueues_keep_alive_message(mocker: MockFixture): + # Given: a connected client with read/write threads disabled + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + return_value=None, + ) + _disable_rw_threads(mocker) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.connect( + StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.universal_streaming_english, + ) + ) + + # When: keep_alive is called + client.keep_alive() + + # Then: a KeepAlive message is enqueued for the write thread + assert client._write_queue.qsize() == 1 + message = client._write_queue.get(timeout=1) + assert isinstance(message, KeepAlive) + assert message.type == "KeepAlive" + + def test_client_connect_with_webhook(mocker: MockFixture): actual_url = None actual_additional_headers = None diff --git a/tests/unit/test_streaming_async.py b/tests/unit/test_streaming_async.py index c06986e..81449f0 100644 --- a/tests/unit/test_streaming_async.py +++ b/tests/unit/test_streaming_async.py @@ -901,6 +901,49 @@ async def test_force_endpoint_enqueues_force_endpoint_frame(mocker: MockFixture) await client.disconnect() +async def test_keep_alive_before_connect_raises_runtime_error(): + # Given: a client that has never connected + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + + # When/Then: keep_alive raises rather than silently dropping the frame + with pytest.raises(RuntimeError, match="not connected"): + await client.keep_alive() + + +async def test_keep_alive_sends_keep_alive_frame(mocker: MockFixture): + # Given: a connected async client over a fake websocket + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + # When: keep_alive is called + await client.keep_alive() + + # Then: a single KeepAlive frame is written to the websocket + for _ in range(100): + keep_alive_frames = [ + s for s in fake_ws.sent if isinstance(s, str) and "KeepAlive" in s + ] + if keep_alive_frames: + break + await asyncio.sleep(0.01) + + keep_alive_frames = [ + s for s in fake_ws.sent if isinstance(s, str) and "KeepAlive" in s + ] + assert len(keep_alive_frames) == 1 + payload = json.loads(keep_alive_frames[0]) + assert payload["type"] == "KeepAlive" + + await client.disconnect() + + async def test_warning_event_dispatched_to_handler(mocker: MockFixture): fake_ws = _FakeAsyncWebSocket() _patch_connect(mocker, fake_ws) diff --git a/tests/unit/test_streaming_dual_channel.py b/tests/unit/test_streaming_dual_channel.py new file mode 100644 index 0000000..06e526e --- /dev/null +++ b/tests/unit/test_streaming_dual_channel.py @@ -0,0 +1,466 @@ +"""Unit tests for client-side dual/multi-channel streaming. + +Covers the pure logic (energy VAD, VAD timeline, word/turn attribution, +unknown-channel resolution, the mono mixer) and the sync/async coordinators +against a fake client. No network I/O. +""" + +import sys +from array import array + +import pytest + +from assemblyai.streaming.v3 import ( + AsyncChannelStreamer, + ChannelAttributionOptions, + ChannelStreamer, + DualChannelTurnEvent, + DualChannelWord, + EnergyVad, + StreamingEvents, +) +from assemblyai.streaming.v3.extras import ( + UNKNOWN_CHANNEL, + VadFrame, + VadTimeline, + _ChannelMixer, + attribute_turn, + attribute_word, + resolve_unknown_channels_by_speaker_history, + resolve_unknown_channels_by_window, + roll_up_turn_channel, +) + +SAMPLE_RATE = 16000 + + +# --------------------------------------------------------------------------- # +# helpers +# --------------------------------------------------------------------------- # +def _word(start: int, end: int, speaker=None, text="x") -> DualChannelWord: + return DualChannelWord( + start=start, + end=end, + confidence=1.0, + text=text, + word_is_final=True, + speaker=speaker, + ) + + +def _turn(words) -> DualChannelTurnEvent: + return DualChannelTurnEvent( + type="Turn", + turn_order=0, + turn_is_formatted=False, + end_of_turn=True, + transcript=" ".join(w.text for w in words), + end_of_turn_confidence=0.9, + words=list(words), + ) + + +def _pcm(value: int, n: int) -> bytes: + """``n`` samples of constant int16 ``value`` as little-endian PCM bytes.""" + samples = array("h", [value] * n) + if sys.byteorder == "big": + samples.byteswap() + return samples.tobytes() + + +def _active_frame(channel: str, ts: float, rms: float = 0.5) -> VadFrame: + return VadFrame(ts=ts, channel=channel, active=True, rms=rms) + + +# --------------------------------------------------------------------------- # +# EnergyVad +# --------------------------------------------------------------------------- # +def test_energy_vad_silence_is_inactive(): + vad = EnergyVad() + result = vad.process([0.0] * 320) + assert result.active is False + assert result.energy == 0.0 + + +def test_energy_vad_detects_speech_above_noise(): + vad = EnergyVad() + # First a quiet frame to seat the (low) noise floor, then a loud one. + vad.process([0.0] * 320) + result = vad.process([0.3] * 320) + assert result.active is True + assert result.energy == pytest.approx(0.3, rel=1e-6) + + +def test_energy_vad_hangover_keeps_active_after_speech(): + vad = EnergyVad(hangover_frames=3) + vad.process([0.3] * 320) # speech -> sets hangover + # Three quiet frames stay active by hangover, the fourth goes inactive. + assert vad.process([0.0] * 320).active is True + assert vad.process([0.0] * 320).active is True + assert vad.process([0.0] * 320).active is True + assert vad.process([0.0] * 320).active is False + + +def test_energy_vad_reset(): + vad = EnergyVad(hangover_frames=5) + vad.process([0.5] * 320) + vad.reset() + # After reset, a silent frame should be inactive (no lingering hangover). + assert vad.process([0.0] * 320).active is False + + +# --------------------------------------------------------------------------- # +# VadTimeline +# --------------------------------------------------------------------------- # +def test_timeline_returns_frames_in_window(): + tl = VadTimeline(window_ms=30_000) + for ts in (0, 100, 200, 300, 400): + tl.push_frame(_active_frame("mic", ts)) + frames = tl.frames_in_window(100, 300) + assert [f.ts for f in frames] == [100, 200, 300] + + +def test_timeline_prunes_beyond_window(): + tl = VadTimeline(window_ms=200) + for ts in (0, 100, 200, 1000): + tl.push_frame(_active_frame("mic", ts)) + # Frames older than ts(1000) - 200 = 800 are pruned from the active head. + assert all(f.ts >= 800 for f in tl.frames_in_window(0, 2000)) + + +# --------------------------------------------------------------------------- # +# word / turn attribution +# --------------------------------------------------------------------------- # +def test_attribute_word_unknown_when_no_energy(): + tl = VadTimeline(30_000) + assert attribute_word(_word(0, 100), tl, dominance_ratio=4.0) == UNKNOWN_CHANNEL + + +def test_attribute_word_single_channel_wins(): + tl = VadTimeline(30_000) + for ts in (10, 30, 50): + tl.push_frame(_active_frame("mic", ts)) + assert attribute_word(_word(0, 100), tl, dominance_ratio=4.0) == "mic" + + +def test_attribute_word_dominant_channel_by_ratio(): + tl = VadTimeline(30_000) + tl.push_frame(_active_frame("mic", 10, rms=1.0)) + tl.push_frame(_active_frame("mic", 20, rms=1.0)) + tl.push_frame(_active_frame("system", 30, rms=0.1)) + # mic total 2.0 vs system 0.1 -> beats 4x ratio. + assert attribute_word(_word(0, 100), tl, dominance_ratio=4.0) == "mic" + + +def test_attribute_word_close_scores_are_unknown(): + tl = VadTimeline(30_000) + tl.push_frame(_active_frame("mic", 10, rms=1.0)) + tl.push_frame(_active_frame("system", 20, rms=0.9)) + # Too close to clear the 4x ratio -> unknown (resolution may back-fill). + assert attribute_word(_word(0, 100), tl, dominance_ratio=4.0) == UNKNOWN_CHANNEL + + +def test_attribute_word_dominance_ratio_is_a_real_knob(): + tl = VadTimeline(30_000) + tl.push_frame(_active_frame("mic", 10, rms=1.0)) + tl.push_frame(_active_frame("system", 20, rms=0.9)) + # Same frames, looser ratio -> mic now clears the bar and wins. + assert attribute_word(_word(0, 100), tl, dominance_ratio=1.1) == "mic" + + +def test_attribute_word_exact_tie_is_unknown(): + tl = VadTimeline(30_000) + tl.push_frame(_active_frame("mic", 10, rms=0.5)) + tl.push_frame(_active_frame("system", 20, rms=0.5)) + assert attribute_word(_word(0, 100), tl, dominance_ratio=4.0) == UNKNOWN_CHANNEL + + +def test_roll_up_turn_channel_duration_weighted(): + words = [_word(0, 1000), _word(1000, 1100)] + words[0].channel = "mic" + words[1].channel = "system" + # mic spans 1000ms vs system 100ms. + assert roll_up_turn_channel(words) == "mic" + + +def test_roll_up_turn_channel_tie_is_unknown(): + words = [_word(0, 100), _word(100, 200)] + words[0].channel = "mic" + words[1].channel = "system" + assert roll_up_turn_channel(words) == UNKNOWN_CHANNEL + + +def test_attribute_turn_sets_word_and_turn_channels(): + tl = VadTimeline(30_000) + for ts in (10, 30, 50, 110, 130): + tl.push_frame(_active_frame("system", ts)) + turn = _turn([_word(0, 100), _word(100, 200)]) + attribute_turn(turn, tl, dominance_ratio=4.0) + assert turn.words[0].channel == "system" + assert turn.words[1].channel == "system" + assert turn.channel == "system" + + +# --------------------------------------------------------------------------- # +# unknown-channel resolution +# --------------------------------------------------------------------------- # +def test_resolve_unknown_by_window_fills_from_neighbors(): + turn = _turn([_word(0, 50), _word(60, 110), _word(120, 170)]) + turn.words[0].channel = "mic" + turn.words[1].channel = UNKNOWN_CHANNEL + turn.words[2].channel = "mic" + resolve_unknown_channels_by_window(turn, resolution_window_words=2) + assert turn.words[1].channel == "mic" + assert turn.words[1].channel_resolved is True + assert turn.channel == "mic" + + +def test_resolve_unknown_by_window_leaves_isolated_unknown(): + turn = _turn([_word(0, 50)]) + turn.words[0].channel = UNKNOWN_CHANNEL + resolve_unknown_channels_by_window(turn, resolution_window_words=2) + assert turn.words[0].channel == UNKNOWN_CHANNEL + + +def test_resolve_unknown_by_speaker_history(): + tl = VadTimeline(30_000) + # Speaker A clearly on mic during 0-100ms. + for ts in (10, 30, 50, 70): + tl.push_frame(_active_frame("mic", ts, rms=1.0)) + history = {} + # First turn: word resolves to mic via VAD, building speaker A's history. + turn1 = _turn([_word(0, 100, speaker="A")]) + attribute_turn(turn1, tl, dominance_ratio=4.0) + resolve_unknown_channels_by_speaker_history( + turn1, tl, history, min_rms_evidence=0.5, dominance_ratio=3.0 + ) + assert turn1.words[0].channel == "mic" + + # Second turn: same speaker, but no VAD frames in the word window -> unknown, + # then back-filled from accumulated speaker history. + turn2 = _turn([_word(5000, 5100, speaker="A")]) + attribute_turn(turn2, tl, dominance_ratio=4.0) + assert turn2.words[0].channel == UNKNOWN_CHANNEL + resolve_unknown_channels_by_speaker_history( + turn2, tl, history, min_rms_evidence=0.5, dominance_ratio=3.0 + ) + assert turn2.words[0].channel == "mic" + assert turn2.words[0].channel_resolved is True + + +# --------------------------------------------------------------------------- # +# mixer +# --------------------------------------------------------------------------- # +def _mixer(channels=("mic", "system")) -> _ChannelMixer: + return _ChannelMixer(list(channels), SAMPLE_RATE, ChannelAttributionOptions()) + + +def test_mixer_averages_channels_to_mono(): + mixer = _mixer() + n = 800 # 50ms @ 16kHz -> meets the min-chunk floor + mixer.ingest("mic", _pcm(1000, n)) + mixer.ingest("system", _pcm(2000, n)) + chunks = mixer.drain() + assert len(chunks) == 1 + out = array("h") + out.frombytes(chunks[0]) + assert len(out) == n + assert all(s == 1500 for s in out) # round((1000 + 2000) / 2) + + +def test_mixer_gates_on_shorter_buffer(): + mixer = _mixer() + mixer.ingest("mic", _pcm(1000, 1600)) # 100ms + mixer.ingest("system", _pcm(2000, 800)) # 50ms + chunks = mixer.drain() + # Only the 800 overlapping samples can be mixed; the rest stays buffered. + assert sum(len(b) // 2 for b in chunks) == 800 + + +def test_mixer_min_chunk_floor_holds_small_buffers(): + mixer = _mixer() + mixer.ingest("mic", _pcm(1000, 100)) # < 50ms + mixer.ingest("system", _pcm(2000, 100)) + assert mixer.drain() == [] # below the floor + # force=True bypasses the floor (final flush). + assert len(mixer.drain(force=True)) == 1 + + +def test_mixer_caps_chunks_at_max_duration(): + mixer = _mixer() + n = 16000 # 1 second -> must split into >=5 chunks of <=200ms (3200 samples) + mixer.ingest("mic", _pcm(1000, n)) + mixer.ingest("system", _pcm(1000, n)) + chunks = mixer.drain() + assert len(chunks) >= 5 + assert all(len(b) // 2 <= 3200 for b in chunks) + assert sum(len(b) // 2 for b in chunks) == n + + +def test_mixer_rejects_odd_length_pcm(): + mixer = _mixer() + with pytest.raises(ValueError): + mixer.ingest("mic", b"\x00\x01\x02") + + +def test_mixer_force_flush_pads_missing_channel(): + # Regression: a channel that produced nothing must not gate the flush and + # swallow the other channel's audio. + mixer = _mixer() + mixer.ingest("mic", _pcm(1000, 800)) # system gets nothing + chunks = mixer.drain(force=True) + assert sum(len(b) // 2 for b in chunks) == 800 # mic's audio still sent + + +def test_mixer_close_channel_lets_survivor_drain(): + mixer = _mixer() + mixer.ingest("mic", _pcm(1000, 1600)) # 100ms + mixer.ingest("system", _pcm(2000, 800)) # 50ms + assert sum(len(b) // 2 for b in mixer.drain()) == 800 # aligned overlap only + # system's source ends -> the remaining mic audio drains without waiting. + mixer.close_channel("system") + assert sum(len(b) // 2 for b in mixer.drain()) == 800 + + +# --------------------------------------------------------------------------- # +# coordinators +# --------------------------------------------------------------------------- # +class _FakeSyncClient: + def __init__(self): + self._handlers = {event: [] for event in StreamingEvents} + self.sent = [] + + def on(self, event, handler): + self._handlers[event].append(handler) + + def stream(self, data): + self.sent.append(data) + + def dispatch_turn(self, turn): + for handler in self._handlers[StreamingEvents.Turn]: + handler(self, turn) + + +class _FakeAsyncClient(_FakeSyncClient): + async def stream(self, data): # type: ignore[override] + self.sent.append(data) + + +def test_channel_streamer_validates_channels(): + client = _FakeSyncClient() + with pytest.raises(ValueError): + ChannelStreamer(client, channels=["mic"], sample_rate=SAMPLE_RATE) + with pytest.raises(ValueError): + ChannelStreamer(client, channels=["mic", "mic"], sample_rate=SAMPLE_RATE) + + +def test_channel_streamer_unknown_channel_raises(): + client = _FakeSyncClient() + mixer = ChannelStreamer(client, ["mic", "system"], sample_rate=SAMPLE_RATE) + with pytest.raises(ValueError): + mixer.stream("speaker", _pcm(1, 800)) + + +def test_invalid_resolve_method_raises_at_construction(): + # Must fail fast at construction, not silently inside the swallowed handler. + with pytest.raises(ValueError): + ChannelAttributionOptions(resolve_unknown_channels_method="bogus") + + +def test_channel_streamer_close_channel_drains_survivor(): + client = _FakeSyncClient() + mixer = ChannelStreamer(client, ["mic", "system"], sample_rate=SAMPLE_RATE) + mixer.stream("mic", _pcm(1000, 1600)) # 100ms + mixer.stream("system", _pcm(2000, 800)) # 50ms + before = sum(len(b) // 2 for b in client.sent) + mixer.close_channel("system") # system's source ended + after = sum(len(b) // 2 for b in client.sent) + assert after > before # surviving mic audio flushed instead of stalling + + +def test_channel_streamer_sends_mixed_mono(): + client = _FakeSyncClient() + mixer = ChannelStreamer(client, ["mic", "system"], sample_rate=SAMPLE_RATE) + mixer.stream("mic", _pcm(1000, 800)) + mixer.stream("system", _pcm(2000, 800)) + assert client.sent, "expected a mixed mono chunk to be sent" + out = array("h") + out.frombytes(b"".join(client.sent)) + assert all(s == 1500 for s in out) + + +def test_channel_streamer_enriches_turn_on_mixer_handler(): + client = _FakeSyncClient() + mixer = ChannelStreamer(client, ["mic", "system"], sample_rate=SAMPLE_RATE) + + seen = {} + + def on_turn(_client, turn): + # Handler receives the enriched DualChannelTurnEvent. + seen["type"] = type(turn).__name__ + seen["channel"] = turn.channel + seen["word_channel"] = turn.words[0].channel + + mixer.on(StreamingEvents.Turn, on_turn) + + # Drive enough mic audio to register active VAD frames in the word window. + mixer.stream("mic", _pcm(8000, 1600)) + mixer.stream("system", _pcm(0, 1600)) + + # The client emits a base TurnEvent; the mixer enriches + dispatches it. + client.dispatch_turn(_turn([_word(0, 80)])) + assert seen["type"] == "DualChannelTurnEvent" + assert seen["channel"] == "mic" + assert seen["word_channel"] == "mic" + + +def test_channel_streamer_none_method_leaves_unknown(): + client = _FakeSyncClient() + opts = ChannelAttributionOptions(resolve_unknown_channels_method="none") + mixer = ChannelStreamer( + client, ["mic", "system"], sample_rate=SAMPLE_RATE, attribution=opts + ) + captured = {} + mixer.on(StreamingEvents.Turn, lambda c, t: captured.update(ch=t.words[0].channel)) + # No audio ingested -> no VAD -> word is unknown and stays unknown. + client.dispatch_turn(_turn([_word(0, 80)])) + assert captured["ch"] == UNKNOWN_CHANNEL + + +def test_channel_streamer_forwards_non_turn_events_to_client(): + client = _FakeSyncClient() + mixer = ChannelStreamer(client, ["mic", "system"], sample_rate=SAMPLE_RATE) + + def on_error(_c, _e): + pass + + mixer.on(StreamingEvents.Error, on_error) + # Non-Turn handlers are registered straight through on the underlying client. + assert on_error in client._handlers[StreamingEvents.Error] + + +def test_channel_streamer_invokes_vad_callback(): + client = _FakeSyncClient() + frames = [] + mixer = ChannelStreamer( + client, + ["mic", "system"], + sample_rate=SAMPLE_RATE, + on_vad=frames.append, + ) + mixer.stream("mic", _pcm(5000, 320)) # exactly one 20ms VAD frame + assert len(frames) == 1 + assert frames[0].channel == "mic" + + +@pytest.mark.asyncio +async def test_async_channel_streamer_sends_mixed_mono(): + client = _FakeAsyncClient() + mixer = AsyncChannelStreamer(client, ["mic", "system"], sample_rate=SAMPLE_RATE) + await mixer.stream("mic", _pcm(1000, 800)) + await mixer.stream("system", _pcm(2000, 800)) + await mixer.flush() + assert client.sent + out = array("h") + out.frombytes(b"".join(client.sent)) + assert all(s == 1500 for s in out)