mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
Merge a561337c78d72fac50091b233c67d9b27918ab81 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
4b6e05c94f
421
examples/streaming_client.py
Normal file
421
examples/streaming_client.py
Normal file
@ -0,0 +1,421 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Example client for Whisper WebSocket streaming server.
|
||||||
|
|
||||||
|
This script demonstrates how to connect to the WebSocket server and stream audio
|
||||||
|
for real-time transcription.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
try:
|
||||||
|
import websockets
|
||||||
|
WEBSOCKETS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
WEBSOCKETS_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pyaudio
|
||||||
|
PYAUDIO_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PYAUDIO_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperStreamingClient:
|
||||||
|
"""Client for connecting to Whisper WebSocket streaming server."""
|
||||||
|
|
||||||
|
def __init__(self, server_url: str = "ws://localhost:8765"):
|
||||||
|
"""Initialize the streaming client."""
|
||||||
|
if not WEBSOCKETS_AVAILABLE:
|
||||||
|
raise ImportError("websockets library is required: pip install websockets")
|
||||||
|
|
||||||
|
self.server_url = server_url
|
||||||
|
self.websocket = None
|
||||||
|
self.is_connected = False
|
||||||
|
self.is_streaming = False
|
||||||
|
|
||||||
|
# Audio settings
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.channels = 1
|
||||||
|
self.chunk_size = 1024
|
||||||
|
self.audio_format = pyaudio.paInt16 if PYAUDIO_AVAILABLE else None
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def connect(self) -> bool:
|
||||||
|
"""Connect to the WebSocket server."""
|
||||||
|
try:
|
||||||
|
self.websocket = await websockets.connect(
|
||||||
|
self.server_url,
|
||||||
|
ping_interval=20,
|
||||||
|
ping_timeout=10
|
||||||
|
)
|
||||||
|
self.is_connected = True
|
||||||
|
self.logger.info(f"Connected to {self.server_url}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to connect: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Disconnect from the WebSocket server."""
|
||||||
|
if self.websocket:
|
||||||
|
await self.websocket.close()
|
||||||
|
self.is_connected = False
|
||||||
|
self.logger.info("Disconnected from server")
|
||||||
|
|
||||||
|
async def configure_stream(self, config: dict) -> bool:
|
||||||
|
"""Configure the streaming parameters."""
|
||||||
|
try:
|
||||||
|
message = {
|
||||||
|
"type": "configure",
|
||||||
|
"config": config
|
||||||
|
}
|
||||||
|
await self.websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
# Wait for response
|
||||||
|
response = await self.websocket.recv()
|
||||||
|
response_data = json.loads(response)
|
||||||
|
|
||||||
|
if response_data.get("type") == "configuration_updated":
|
||||||
|
self.logger.info("Stream configured successfully")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.logger.error(f"Configuration failed: {response_data}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error configuring stream: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def start_stream(self) -> bool:
|
||||||
|
"""Start the transcription stream."""
|
||||||
|
try:
|
||||||
|
message = {"type": "start_stream"}
|
||||||
|
await self.websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
# Wait for response
|
||||||
|
response = await self.websocket.recv()
|
||||||
|
response_data = json.loads(response)
|
||||||
|
|
||||||
|
if response_data.get("type") == "stream_started":
|
||||||
|
self.is_streaming = True
|
||||||
|
self.logger.info("Stream started successfully")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.logger.error(f"Failed to start stream: {response_data}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error starting stream: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def stop_stream(self) -> bool:
|
||||||
|
"""Stop the transcription stream."""
|
||||||
|
try:
|
||||||
|
message = {"type": "stop_stream"}
|
||||||
|
await self.websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
# Wait for response
|
||||||
|
response = await self.websocket.recv()
|
||||||
|
response_data = json.loads(response)
|
||||||
|
|
||||||
|
if response_data.get("type") == "stream_stopped":
|
||||||
|
self.is_streaming = False
|
||||||
|
self.logger.info("Stream stopped successfully")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.logger.error(f"Failed to stop stream: {response_data}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error stopping stream: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def send_audio_data(self, audio_data: np.ndarray) -> None:
|
||||||
|
"""Send audio data to the server."""
|
||||||
|
try:
|
||||||
|
if not self.is_streaming:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert audio to bytes
|
||||||
|
if audio_data.dtype != np.int16:
|
||||||
|
audio_data = (audio_data * 32767).astype(np.int16)
|
||||||
|
|
||||||
|
audio_bytes = audio_data.tobytes()
|
||||||
|
audio_b64 = base64.b64encode(audio_bytes).decode('utf-8')
|
||||||
|
|
||||||
|
message = {
|
||||||
|
"type": "audio_data",
|
||||||
|
"format": "pcm16",
|
||||||
|
"audio": audio_b64
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error sending audio data: {e}")
|
||||||
|
|
||||||
|
async def listen_for_results(self) -> None:
|
||||||
|
"""Listen for transcription results from the server."""
|
||||||
|
try:
|
||||||
|
while self.is_connected:
|
||||||
|
response = await self.websocket.recv()
|
||||||
|
response_data = json.loads(response)
|
||||||
|
|
||||||
|
message_type = response_data.get("type")
|
||||||
|
|
||||||
|
if message_type == "transcription_result":
|
||||||
|
self._handle_transcription_result(response_data)
|
||||||
|
elif message_type == "error":
|
||||||
|
self._handle_error(response_data)
|
||||||
|
elif message_type == "connection_established":
|
||||||
|
self._handle_connection_established(response_data)
|
||||||
|
else:
|
||||||
|
self.logger.info(f"Received: {response_data}")
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
self.logger.info("Connection closed by server")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error listening for results: {e}")
|
||||||
|
|
||||||
|
def _handle_transcription_result(self, data: dict) -> None:
|
||||||
|
"""Handle transcription result from server."""
|
||||||
|
result = data.get("result", {})
|
||||||
|
text = result.get("text", "")
|
||||||
|
confidence = result.get("confidence", 0.0)
|
||||||
|
is_final = result.get("is_final", True)
|
||||||
|
|
||||||
|
status = "FINAL" if is_final else "PARTIAL"
|
||||||
|
print(f"[{status}] ({confidence:.2f}): {text}")
|
||||||
|
|
||||||
|
def _handle_error(self, data: dict) -> None:
|
||||||
|
"""Handle error message from server."""
|
||||||
|
error_type = data.get("error_type", "Unknown")
|
||||||
|
message = data.get("message", "")
|
||||||
|
print(f"ERROR [{error_type}]: {message}")
|
||||||
|
|
||||||
|
def _handle_connection_established(self, data: dict) -> None:
|
||||||
|
"""Handle connection established message."""
|
||||||
|
server_info = data.get("server_info", {})
|
||||||
|
print(f"Connected to server version {server_info.get('version', 'unknown')}")
|
||||||
|
print(f"Supported formats: {server_info.get('supported_formats', [])}")
|
||||||
|
|
||||||
|
async def get_status(self) -> dict:
|
||||||
|
"""Get status from the server."""
|
||||||
|
try:
|
||||||
|
message = {"type": "get_status"}
|
||||||
|
await self.websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
response = await self.websocket.recv()
|
||||||
|
response_data = json.loads(response)
|
||||||
|
|
||||||
|
if response_data.get("type") == "status":
|
||||||
|
return response_data
|
||||||
|
else:
|
||||||
|
self.logger.error(f"Unexpected status response: {response_data}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error getting status: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class MicrophoneStreamer:
|
||||||
|
"""Stream audio from microphone to Whisper server."""
|
||||||
|
|
||||||
|
def __init__(self, client: WhisperStreamingClient):
|
||||||
|
"""Initialize microphone streamer."""
|
||||||
|
if not PYAUDIO_AVAILABLE:
|
||||||
|
raise ImportError("pyaudio library is required: pip install pyaudio")
|
||||||
|
|
||||||
|
self.client = client
|
||||||
|
self.audio = None
|
||||||
|
self.stream = None
|
||||||
|
self.is_recording = False
|
||||||
|
|
||||||
|
def start_recording(self) -> bool:
|
||||||
|
"""Start recording from microphone."""
|
||||||
|
try:
|
||||||
|
self.audio = pyaudio.PyAudio()
|
||||||
|
|
||||||
|
self.stream = self.audio.open(
|
||||||
|
format=self.client.audio_format,
|
||||||
|
channels=self.client.channels,
|
||||||
|
rate=self.client.sample_rate,
|
||||||
|
input=True,
|
||||||
|
frames_per_buffer=self.client.chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.is_recording = True
|
||||||
|
print(f"Started recording from microphone (SR: {self.client.sample_rate}Hz)")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error starting microphone: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def stop_recording(self) -> None:
|
||||||
|
"""Stop recording from microphone."""
|
||||||
|
self.is_recording = False
|
||||||
|
|
||||||
|
if self.stream:
|
||||||
|
self.stream.stop_stream()
|
||||||
|
self.stream.close()
|
||||||
|
|
||||||
|
if self.audio:
|
||||||
|
self.audio.terminate()
|
||||||
|
|
||||||
|
print("Stopped recording")
|
||||||
|
|
||||||
|
async def stream_audio(self) -> None:
|
||||||
|
"""Stream audio from microphone to server."""
|
||||||
|
print("Streaming audio... (Press Ctrl+C to stop)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while self.is_recording:
|
||||||
|
# Read audio data
|
||||||
|
data = self.stream.read(self.client.chunk_size, exception_on_overflow=False)
|
||||||
|
audio_data = np.frombuffer(data, dtype=np.int16)
|
||||||
|
|
||||||
|
# Send to server
|
||||||
|
await self.client.send_audio_data(audio_data)
|
||||||
|
|
||||||
|
# Small delay to avoid overwhelming the server
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\\nStopping audio stream...")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error streaming audio: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def run_demo_client(server_url: str, model: str = "base", use_microphone: bool = False):
|
||||||
|
"""Run a demo of the streaming client."""
|
||||||
|
client = WhisperStreamingClient(server_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Connect to server
|
||||||
|
if not await client.connect():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Start listening for results in background
|
||||||
|
listen_task = asyncio.create_task(client.listen_for_results())
|
||||||
|
|
||||||
|
# Wait a bit for connection to be established
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# Configure stream
|
||||||
|
config = {
|
||||||
|
"model_name": model,
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"language": None, # Auto-detect
|
||||||
|
"temperature": 0.0,
|
||||||
|
"return_timestamps": True
|
||||||
|
}
|
||||||
|
|
||||||
|
if not await client.configure_stream(config):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Start stream
|
||||||
|
if not await client.start_stream():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stream audio
|
||||||
|
if use_microphone and PYAUDIO_AVAILABLE:
|
||||||
|
# Use microphone
|
||||||
|
mic_streamer = MicrophoneStreamer(client)
|
||||||
|
if mic_streamer.start_recording():
|
||||||
|
try:
|
||||||
|
await mic_streamer.stream_audio()
|
||||||
|
finally:
|
||||||
|
mic_streamer.stop_recording()
|
||||||
|
else:
|
||||||
|
# Use synthetic audio for demo
|
||||||
|
print("Streaming synthetic audio... (Press Ctrl+C to stop)")
|
||||||
|
try:
|
||||||
|
duration = 0
|
||||||
|
while duration < 30: # Stream for 30 seconds
|
||||||
|
# Generate 1 second of synthetic speech-like audio
|
||||||
|
t = np.linspace(0, 1, 16000)
|
||||||
|
frequency = 440 + 50 * np.sin(2 * np.pi * 0.5 * duration) # Varying frequency
|
||||||
|
audio = 0.3 * np.sin(2 * np.pi * frequency * t)
|
||||||
|
audio_data = (audio * 32767).astype(np.int16)
|
||||||
|
|
||||||
|
await client.send_audio_data(audio_data)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
duration += 1
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\\nStopping synthetic audio stream...")
|
||||||
|
|
||||||
|
# Stop stream
|
||||||
|
await client.stop_stream()
|
||||||
|
|
||||||
|
# Get final status
|
||||||
|
status = await client.get_status()
|
||||||
|
if status:
|
||||||
|
print(f"\\nFinal Status:")
|
||||||
|
print(f" Processor state: {status.get('processor', {}).get('state', 'unknown')}")
|
||||||
|
print(f" Segments processed: {status.get('processor', {}).get('segments_completed', 0)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Demo error: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await client.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function."""
|
||||||
|
parser = argparse.ArgumentParser(description="Whisper Streaming Client Demo")
|
||||||
|
parser.add_argument("--server", default="ws://localhost:8765", help="WebSocket server URL")
|
||||||
|
parser.add_argument("--model", default="base", help="Whisper model to use")
|
||||||
|
parser.add_argument("--microphone", action="store_true", help="Use microphone input")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO if args.verbose else logging.WARNING,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check dependencies
|
||||||
|
if not WEBSOCKETS_AVAILABLE:
|
||||||
|
print("Error: websockets library is required")
|
||||||
|
print("Install with: pip install websockets")
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.microphone and not PYAUDIO_AVAILABLE:
|
||||||
|
print("Error: pyaudio library is required for microphone input")
|
||||||
|
print("Install with: pip install pyaudio")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Whisper Streaming Client Demo")
|
||||||
|
print(f"Server: {args.server}")
|
||||||
|
print(f"Model: {args.model}")
|
||||||
|
print(f"Input: {'Microphone' if args.microphone else 'Synthetic audio'}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Run the demo
|
||||||
|
try:
|
||||||
|
asyncio.run(run_demo_client(args.server, args.model, args.microphone))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\\nDemo interrupted by user")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Demo failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
304
whisper/streaming/README.md
Normal file
304
whisper/streaming/README.md
Normal file
@ -0,0 +1,304 @@
|
|||||||
|
# Whisper Real-time Streaming Module
|
||||||
|
|
||||||
|
This module provides real-time streaming capabilities for OpenAI Whisper, enabling low-latency transcription for live audio streams.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Real-time Processing**: Stream audio and receive transcription results in real-time
|
||||||
|
- **WebSocket Server**: WebSocket-based API for easy integration
|
||||||
|
- **Voice Activity Detection**: Intelligent audio segmentation using VAD
|
||||||
|
- **CTranslate2 Acceleration**: Optional CTranslate2 backend for faster inference
|
||||||
|
- **Configurable Buffering**: Adaptive audio buffering with overlap handling
|
||||||
|
- **Multi-client Support**: Handle multiple concurrent streaming connections
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Starting the WebSocket Server
|
||||||
|
|
||||||
|
```python
|
||||||
|
from whisper.streaming import WhisperWebSocketServer, StreamConfig
|
||||||
|
|
||||||
|
# Create default configuration
|
||||||
|
config = StreamConfig(
|
||||||
|
model_name="base",
|
||||||
|
sample_rate=16000,
|
||||||
|
chunk_duration_ms=1000,
|
||||||
|
language=None # Auto-detect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start server
|
||||||
|
server = WhisperWebSocketServer(
|
||||||
|
host="localhost",
|
||||||
|
port=8765,
|
||||||
|
default_config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run server
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(server.start_server())
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using the Stream Processor Directly
|
||||||
|
|
||||||
|
```python
|
||||||
|
from whisper.streaming import StreamProcessor, StreamConfig
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Configure streaming
|
||||||
|
config = StreamConfig(
|
||||||
|
model_name="base",
|
||||||
|
chunk_duration_ms=1000,
|
||||||
|
return_timestamps=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Result callback
|
||||||
|
def on_result(result):
|
||||||
|
print(f"[{result.confidence:.2f}]: {result.text}")
|
||||||
|
|
||||||
|
# Create processor
|
||||||
|
processor = StreamProcessor(config, result_callback=on_result)
|
||||||
|
processor.start()
|
||||||
|
|
||||||
|
# Send audio data
|
||||||
|
audio_data = np.random.randn(16000).astype(np.float32) # 1 second of audio
|
||||||
|
processor.add_audio(audio_data)
|
||||||
|
|
||||||
|
# Stop when done
|
||||||
|
processor.stop()
|
||||||
|
```
|
||||||
|
|
||||||
|
## WebSocket API
|
||||||
|
|
||||||
|
### Connection Messages
|
||||||
|
|
||||||
|
**Connect**: Connect to `ws://localhost:8765`
|
||||||
|
|
||||||
|
**Configure Stream**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "configure",
|
||||||
|
"config": {
|
||||||
|
"model_name": "base",
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"language": "en",
|
||||||
|
"temperature": 0.0,
|
||||||
|
"return_timestamps": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Start Stream**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "start_stream"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Send Audio**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "audio_data",
|
||||||
|
"format": "pcm16",
|
||||||
|
"audio": "<base64-encoded-audio-data>"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Stop Stream**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "stop_stream"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Response Messages
|
||||||
|
|
||||||
|
**Transcription Result**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "transcription_result",
|
||||||
|
"result": {
|
||||||
|
"text": "Hello world",
|
||||||
|
"start_time": 0.0,
|
||||||
|
"end_time": 2.0,
|
||||||
|
"confidence": 0.95,
|
||||||
|
"is_final": true,
|
||||||
|
"language": "en"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Options
|
||||||
|
|
||||||
|
### StreamConfig Parameters
|
||||||
|
|
||||||
|
- `sample_rate`: Audio sample rate (default: 16000)
|
||||||
|
- `chunk_duration_ms`: Duration of each processing chunk (default: 1000)
|
||||||
|
- `buffer_duration_ms`: Total buffer duration (default: 5000)
|
||||||
|
- `overlap_duration_ms`: Overlap between chunks (default: 200)
|
||||||
|
- `model_name`: Whisper model to use (default: "base")
|
||||||
|
- `language`: Source language (None for auto-detect)
|
||||||
|
- `temperature`: Sampling temperature (default: 0.0)
|
||||||
|
- `vad_threshold`: Voice activity detection threshold (default: 0.5)
|
||||||
|
- `use_ctranslate2`: Enable CTranslate2 acceleration (default: False)
|
||||||
|
- `device`: Device for inference ("auto", "cpu", "cuda")
|
||||||
|
|
||||||
|
## Performance Optimization
|
||||||
|
|
||||||
|
### CTranslate2 Backend
|
||||||
|
|
||||||
|
For better performance, especially on GPU:
|
||||||
|
|
||||||
|
```python
|
||||||
|
config = StreamConfig(
|
||||||
|
model_name="base",
|
||||||
|
use_ctranslate2=True,
|
||||||
|
device="cuda",
|
||||||
|
compute_type="float16"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Chunking Strategy
|
||||||
|
|
||||||
|
Optimize chunk and buffer sizes based on your use case:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Low latency (faster response, higher CPU usage)
|
||||||
|
config = StreamConfig(
|
||||||
|
chunk_duration_ms=500,
|
||||||
|
buffer_duration_ms=2000,
|
||||||
|
overlap_duration_ms=100
|
||||||
|
)
|
||||||
|
|
||||||
|
# Balanced (default settings)
|
||||||
|
config = StreamConfig(
|
||||||
|
chunk_duration_ms=1000,
|
||||||
|
buffer_duration_ms=5000,
|
||||||
|
overlap_duration_ms=200
|
||||||
|
)
|
||||||
|
|
||||||
|
# High accuracy (slower response, better accuracy)
|
||||||
|
config = StreamConfig(
|
||||||
|
chunk_duration_ms=2000,
|
||||||
|
buffer_duration_ms=10000,
|
||||||
|
overlap_duration_ms=500
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Client Examples
|
||||||
|
|
||||||
|
### Python WebSocket Client
|
||||||
|
|
||||||
|
See `examples/streaming_client.py` for a complete example of connecting to the WebSocket server and streaming audio.
|
||||||
|
|
||||||
|
### JavaScript Client
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const ws = new WebSocket('ws://localhost:8765');
|
||||||
|
|
||||||
|
ws.onopen = function() {
|
||||||
|
// Configure stream
|
||||||
|
ws.send(JSON.stringify({
|
||||||
|
type: 'configure',
|
||||||
|
config: {
|
||||||
|
model_name: 'base',
|
||||||
|
language: 'en'
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Start stream
|
||||||
|
ws.send(JSON.stringify({type: 'start_stream'}));
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onmessage = function(event) {
|
||||||
|
const data = JSON.parse(event.data);
|
||||||
|
|
||||||
|
if (data.type === 'transcription_result') {
|
||||||
|
console.log('Transcription:', data.result.text);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Send audio data
|
||||||
|
function sendAudio(audioBuffer) {
|
||||||
|
const audioBase64 = btoa(String.fromCharCode(...audioBuffer));
|
||||||
|
ws.send(JSON.stringify({
|
||||||
|
type: 'audio_data',
|
||||||
|
format: 'pcm16',
|
||||||
|
audio: audioBase64
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Audio Formats
|
||||||
|
|
||||||
|
- PCM 16-bit (`pcm16`)
|
||||||
|
- PCM 32-bit (`pcm32`)
|
||||||
|
- IEEE Float 32-bit (`float32`)
|
||||||
|
- Sample rates: 8000, 16000, 22050, 44100, 48000 Hz
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The streaming system provides comprehensive error handling:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def error_callback(error):
|
||||||
|
print(f"Processing error: {error}")
|
||||||
|
# Handle error (retry, fallback, etc.)
|
||||||
|
|
||||||
|
processor = StreamProcessor(
|
||||||
|
config=config,
|
||||||
|
error_callback=error_callback
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
### Required
|
||||||
|
- `numpy`
|
||||||
|
- `torch` (for standard Whisper backend)
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
- `ctranslate2` (for accelerated inference)
|
||||||
|
- `transformers` (for CTranslate2 integration)
|
||||||
|
- `websockets` (for WebSocket server)
|
||||||
|
- `pyaudio` (for microphone input in examples)
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install core streaming dependencies
|
||||||
|
pip install websockets
|
||||||
|
|
||||||
|
# For CTranslate2 acceleration
|
||||||
|
pip install ctranslate2 transformers
|
||||||
|
|
||||||
|
# For microphone input examples
|
||||||
|
pip install pyaudio
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Benchmarks
|
||||||
|
|
||||||
|
Typical performance on different hardware:
|
||||||
|
|
||||||
|
| Model | Device | RTF* | Latency |
|
||||||
|
|-------|--------|------|---------|
|
||||||
|
| tiny | CPU | 0.1x | ~100ms |
|
||||||
|
| base | CPU | 0.3x | ~300ms |
|
||||||
|
| small | CPU | 0.6x | ~600ms |
|
||||||
|
| base | GPU | 0.05x | ~50ms |
|
||||||
|
| small | GPU | 0.1x | ~100ms |
|
||||||
|
|
||||||
|
*RTF = Real-time Factor (lower is better)
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
When contributing to the streaming module:
|
||||||
|
|
||||||
|
1. Maintain backward compatibility
|
||||||
|
2. Add comprehensive error handling
|
||||||
|
3. Include performance benchmarks
|
||||||
|
4. Update documentation
|
||||||
|
5. Add tests for new features
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Same as the main Whisper project (MIT License).
|
||||||
22
whisper/streaming/__init__.py
Normal file
22
whisper/streaming/__init__.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# Whisper Real-time Streaming Module
|
||||||
|
"""
|
||||||
|
This module provides real-time streaming capabilities for OpenAI Whisper.
|
||||||
|
Supports WebSocket-based streaming, chunked processing, and low-latency transcription.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .stream_processor import StreamProcessor, StreamConfig
|
||||||
|
from .websocket_server import WhisperWebSocketServer
|
||||||
|
from .audio_buffer import AudioBuffer, AudioChunk
|
||||||
|
from .ctranslate2_backend import CTranslate2Backend
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'StreamProcessor',
|
||||||
|
'StreamConfig',
|
||||||
|
'WhisperWebSocketServer',
|
||||||
|
'AudioBuffer',
|
||||||
|
'AudioChunk',
|
||||||
|
'CTranslate2Backend'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Version info
|
||||||
|
__version__ = "1.0.0"
|
||||||
385
whisper/streaming/audio_buffer.py
Normal file
385
whisper/streaming/audio_buffer.py
Normal file
@ -0,0 +1,385 @@
|
|||||||
|
"""
|
||||||
|
Audio buffer management for real-time streaming transcription.
|
||||||
|
|
||||||
|
This module handles audio buffering, chunking, and Voice Activity Detection (VAD)
|
||||||
|
for efficient real-time processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import collections
|
||||||
|
from typing import Optional, List, Iterator, Tuple, Union
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import numpy as np
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AudioChunk:
|
||||||
|
"""Represents a chunk of audio data with metadata."""
|
||||||
|
data: np.ndarray
|
||||||
|
sample_rate: int
|
||||||
|
timestamp: float
|
||||||
|
duration: float
|
||||||
|
chunk_id: int
|
||||||
|
is_silence: bool = False
|
||||||
|
vad_confidence: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class AudioBuffer:
|
||||||
|
"""
|
||||||
|
Thread-safe audio buffer for real-time streaming.
|
||||||
|
|
||||||
|
This buffer maintains a sliding window of audio data and provides
|
||||||
|
chunks for processing while handling voice activity detection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
chunk_duration_ms: int = 1000,
|
||||||
|
buffer_duration_ms: int = 5000,
|
||||||
|
overlap_duration_ms: int = 200,
|
||||||
|
vad_threshold: float = 0.3,
|
||||||
|
silence_timeout_ms: int = 1000
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the audio buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_rate: Audio sample rate in Hz
|
||||||
|
chunk_duration_ms: Duration of each processing chunk in milliseconds
|
||||||
|
buffer_duration_ms: Total buffer duration in milliseconds
|
||||||
|
overlap_duration_ms: Overlap between consecutive chunks in milliseconds
|
||||||
|
vad_threshold: Voice Activity Detection threshold (0.0-1.0)
|
||||||
|
silence_timeout_ms: Timeout for silence detection in milliseconds
|
||||||
|
"""
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.chunk_duration_ms = chunk_duration_ms
|
||||||
|
self.buffer_duration_ms = buffer_duration_ms
|
||||||
|
self.overlap_duration_ms = overlap_duration_ms
|
||||||
|
self.vad_threshold = vad_threshold
|
||||||
|
self.silence_timeout_ms = silence_timeout_ms
|
||||||
|
|
||||||
|
# Calculate sizes in samples
|
||||||
|
self.chunk_size = int(sample_rate * chunk_duration_ms / 1000)
|
||||||
|
self.buffer_size = int(sample_rate * buffer_duration_ms / 1000)
|
||||||
|
self.overlap_size = int(sample_rate * overlap_duration_ms / 1000)
|
||||||
|
self.silence_timeout_samples = int(sample_rate * silence_timeout_ms / 1000)
|
||||||
|
|
||||||
|
# Initialize buffer
|
||||||
|
self.buffer = np.zeros(self.buffer_size, dtype=np.float32)
|
||||||
|
self.buffer_position = 0
|
||||||
|
self.total_samples_received = 0
|
||||||
|
self.chunk_counter = 0
|
||||||
|
|
||||||
|
# Thread safety
|
||||||
|
self.lock = threading.RLock()
|
||||||
|
self.chunk_queue = queue.Queue()
|
||||||
|
|
||||||
|
# State tracking
|
||||||
|
self.last_vad_activity = 0
|
||||||
|
self.is_speaking = False
|
||||||
|
self.speech_start_sample = None
|
||||||
|
|
||||||
|
def add_audio(self, audio_data: np.ndarray) -> None:
|
||||||
|
"""
|
||||||
|
Add audio data to the buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: Audio samples as numpy array
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
audio_data = audio_data.astype(np.float32)
|
||||||
|
samples_to_add = len(audio_data)
|
||||||
|
|
||||||
|
# Handle buffer wraparound
|
||||||
|
if self.buffer_position + samples_to_add <= self.buffer_size:
|
||||||
|
# Fits in buffer without wraparound
|
||||||
|
self.buffer[self.buffer_position:self.buffer_position + samples_to_add] = audio_data
|
||||||
|
else:
|
||||||
|
# Need to wrap around
|
||||||
|
first_part_size = self.buffer_size - self.buffer_position
|
||||||
|
second_part_size = samples_to_add - first_part_size
|
||||||
|
|
||||||
|
self.buffer[self.buffer_position:] = audio_data[:first_part_size]
|
||||||
|
self.buffer[:second_part_size] = audio_data[first_part_size:]
|
||||||
|
|
||||||
|
self.buffer_position = (self.buffer_position + samples_to_add) % self.buffer_size
|
||||||
|
self.total_samples_received += samples_to_add
|
||||||
|
|
||||||
|
# Check if we have enough data for a new chunk
|
||||||
|
if self.total_samples_received >= self.chunk_size:
|
||||||
|
self._create_chunks()
|
||||||
|
|
||||||
|
def _create_chunks(self) -> None:
|
||||||
|
"""Create audio chunks from the current buffer state."""
|
||||||
|
# Calculate how many chunks we can create
|
||||||
|
available_samples = min(self.total_samples_received, self.buffer_size)
|
||||||
|
|
||||||
|
# Create chunks with overlap
|
||||||
|
chunk_start = 0
|
||||||
|
while chunk_start + self.chunk_size <= available_samples:
|
||||||
|
chunk_end = chunk_start + self.chunk_size
|
||||||
|
|
||||||
|
# Extract chunk data (handle wraparound)
|
||||||
|
start_pos = (self.buffer_position - available_samples + chunk_start) % self.buffer_size
|
||||||
|
chunk_data = self._extract_circular_data(start_pos, self.chunk_size)
|
||||||
|
|
||||||
|
# Perform voice activity detection
|
||||||
|
vad_confidence = self._calculate_vad(chunk_data)
|
||||||
|
is_silence = vad_confidence < self.vad_threshold
|
||||||
|
|
||||||
|
# Update speech state
|
||||||
|
timestamp = (self.total_samples_received - available_samples + chunk_start) / self.sample_rate
|
||||||
|
self._update_speech_state(vad_confidence, timestamp)
|
||||||
|
|
||||||
|
# Create chunk
|
||||||
|
chunk = AudioChunk(
|
||||||
|
data=chunk_data,
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
timestamp=timestamp,
|
||||||
|
duration=self.chunk_duration_ms / 1000.0,
|
||||||
|
chunk_id=self.chunk_counter,
|
||||||
|
is_silence=is_silence,
|
||||||
|
vad_confidence=vad_confidence
|
||||||
|
)
|
||||||
|
|
||||||
|
self.chunk_queue.put(chunk)
|
||||||
|
self.chunk_counter += 1
|
||||||
|
|
||||||
|
# Move to next chunk position (with overlap)
|
||||||
|
step_size = self.chunk_size - self.overlap_size
|
||||||
|
chunk_start += step_size
|
||||||
|
|
||||||
|
def _extract_circular_data(self, start_pos: int, length: int) -> np.ndarray:
|
||||||
|
"""Extract data from circular buffer."""
|
||||||
|
if start_pos + length <= self.buffer_size:
|
||||||
|
return self.buffer[start_pos:start_pos + length].copy()
|
||||||
|
else:
|
||||||
|
# Handle wraparound
|
||||||
|
first_part_size = self.buffer_size - start_pos
|
||||||
|
second_part_size = length - first_part_size
|
||||||
|
|
||||||
|
result = np.zeros(length, dtype=np.float32)
|
||||||
|
result[:first_part_size] = self.buffer[start_pos:]
|
||||||
|
result[first_part_size:] = self.buffer[:second_part_size]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _calculate_vad(self, audio_data: np.ndarray) -> float:
|
||||||
|
"""
|
||||||
|
Simple Voice Activity Detection using energy and zero-crossing rate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: Audio chunk to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VAD confidence score (0.0-1.0)
|
||||||
|
"""
|
||||||
|
if len(audio_data) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Energy-based detection
|
||||||
|
energy = np.mean(audio_data ** 2)
|
||||||
|
energy_threshold = 0.001 # Adjust based on your use case
|
||||||
|
|
||||||
|
# Zero-crossing rate
|
||||||
|
zero_crossings = np.sum(np.abs(np.diff(np.sign(audio_data)))) / (2 * len(audio_data))
|
||||||
|
zcr_threshold = 0.05
|
||||||
|
|
||||||
|
# Combined score
|
||||||
|
energy_score = min(1.0, energy / energy_threshold)
|
||||||
|
zcr_score = min(1.0, zero_crossings / zcr_threshold)
|
||||||
|
|
||||||
|
# Weight energy more heavily than ZCR
|
||||||
|
vad_score = 0.8 * energy_score + 0.2 * zcr_score
|
||||||
|
return min(1.0, vad_score)
|
||||||
|
|
||||||
|
def _update_speech_state(self, vad_confidence: float, timestamp: float) -> None:
|
||||||
|
"""Update the speech activity state based on VAD results."""
|
||||||
|
if vad_confidence >= self.vad_threshold:
|
||||||
|
self.last_vad_activity = self.total_samples_received
|
||||||
|
if not self.is_speaking:
|
||||||
|
self.is_speaking = True
|
||||||
|
self.speech_start_sample = self.total_samples_received
|
||||||
|
else:
|
||||||
|
# Check for end of speech
|
||||||
|
silence_duration = self.total_samples_received - self.last_vad_activity
|
||||||
|
if self.is_speaking and silence_duration > self.silence_timeout_samples:
|
||||||
|
self.is_speaking = False
|
||||||
|
self.speech_start_sample = None
|
||||||
|
|
||||||
|
def get_chunk(self, timeout: Optional[float] = None) -> Optional[AudioChunk]:
|
||||||
|
"""
|
||||||
|
Get the next available audio chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum time to wait for a chunk (None for no timeout)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AudioChunk if available, None if timeout or no chunks
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self.chunk_queue.get(timeout=timeout)
|
||||||
|
except queue.Empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_chunks_batch(self, max_chunks: int = 10, timeout: float = 0.1) -> List[AudioChunk]:
|
||||||
|
"""
|
||||||
|
Get multiple chunks in a batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_chunks: Maximum number of chunks to return
|
||||||
|
timeout: Maximum time to wait for first chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of AudioChunks (may be empty)
|
||||||
|
"""
|
||||||
|
chunks = []
|
||||||
|
try:
|
||||||
|
# Get first chunk with timeout
|
||||||
|
first_chunk = self.chunk_queue.get(timeout=timeout)
|
||||||
|
chunks.append(first_chunk)
|
||||||
|
|
||||||
|
# Get remaining chunks without blocking
|
||||||
|
for _ in range(max_chunks - 1):
|
||||||
|
try:
|
||||||
|
chunk = self.chunk_queue.get_nowait()
|
||||||
|
chunks.append(chunk)
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def is_speech_active(self) -> bool:
|
||||||
|
"""Check if speech is currently being detected."""
|
||||||
|
return self.is_speaking
|
||||||
|
|
||||||
|
def get_buffer_info(self) -> dict:
|
||||||
|
"""Get information about the current buffer state."""
|
||||||
|
with self.lock:
|
||||||
|
return {
|
||||||
|
"buffer_size_samples": self.buffer_size,
|
||||||
|
"buffer_size_ms": self.buffer_duration_ms,
|
||||||
|
"chunk_size_samples": self.chunk_size,
|
||||||
|
"chunk_size_ms": self.chunk_duration_ms,
|
||||||
|
"total_samples_received": self.total_samples_received,
|
||||||
|
"total_duration_ms": self.total_samples_received / self.sample_rate * 1000,
|
||||||
|
"chunks_created": self.chunk_counter,
|
||||||
|
"chunks_pending": self.chunk_queue.qsize(),
|
||||||
|
"is_speaking": self.is_speaking,
|
||||||
|
"buffer_position": self.buffer_position
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear the buffer and reset state."""
|
||||||
|
with self.lock:
|
||||||
|
self.buffer.fill(0)
|
||||||
|
self.buffer_position = 0
|
||||||
|
self.total_samples_received = 0
|
||||||
|
self.chunk_counter = 0
|
||||||
|
self.last_vad_activity = 0
|
||||||
|
self.is_speaking = False
|
||||||
|
self.speech_start_sample = None
|
||||||
|
|
||||||
|
# Clear the queue
|
||||||
|
while not self.chunk_queue.empty():
|
||||||
|
try:
|
||||||
|
self.chunk_queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingVAD:
|
||||||
|
"""
|
||||||
|
Enhanced Voice Activity Detection for streaming audio.
|
||||||
|
|
||||||
|
Uses a more sophisticated approach than the basic VAD in AudioBuffer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
frame_duration_ms: int = 20,
|
||||||
|
energy_threshold: float = 0.001,
|
||||||
|
zcr_threshold: float = 0.05,
|
||||||
|
smoothing_window: int = 5
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the streaming VAD.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_rate: Audio sample rate
|
||||||
|
frame_duration_ms: Duration of each analysis frame
|
||||||
|
energy_threshold: Energy threshold for speech detection
|
||||||
|
zcr_threshold: Zero-crossing rate threshold
|
||||||
|
smoothing_window: Number of frames to smooth over
|
||||||
|
"""
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.frame_duration_ms = frame_duration_ms
|
||||||
|
self.energy_threshold = energy_threshold
|
||||||
|
self.zcr_threshold = zcr_threshold
|
||||||
|
self.smoothing_window = smoothing_window
|
||||||
|
|
||||||
|
self.frame_size = int(sample_rate * frame_duration_ms / 1000)
|
||||||
|
self.energy_history = collections.deque(maxlen=smoothing_window)
|
||||||
|
self.zcr_history = collections.deque(maxlen=smoothing_window)
|
||||||
|
|
||||||
|
def analyze_frame(self, audio_frame: np.ndarray) -> Tuple[float, dict]:
|
||||||
|
"""
|
||||||
|
Analyze an audio frame for voice activity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_frame: Audio frame data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (vad_probability, analysis_details)
|
||||||
|
"""
|
||||||
|
if len(audio_frame) == 0:
|
||||||
|
return 0.0, {}
|
||||||
|
|
||||||
|
# Energy calculation
|
||||||
|
energy = np.mean(audio_frame ** 2)
|
||||||
|
self.energy_history.append(energy)
|
||||||
|
|
||||||
|
# Zero-crossing rate calculation
|
||||||
|
zero_crossings = np.sum(np.abs(np.diff(np.sign(audio_frame))))
|
||||||
|
zcr = zero_crossings / (2 * len(audio_frame))
|
||||||
|
self.zcr_history.append(zcr)
|
||||||
|
|
||||||
|
# Smoothed values
|
||||||
|
avg_energy = np.mean(self.energy_history)
|
||||||
|
avg_zcr = np.mean(self.zcr_history)
|
||||||
|
|
||||||
|
# Speech probability calculation
|
||||||
|
energy_score = min(1.0, avg_energy / self.energy_threshold)
|
||||||
|
zcr_score = min(1.0, avg_zcr / self.zcr_threshold)
|
||||||
|
|
||||||
|
# Adaptive thresholding based on recent history
|
||||||
|
if len(self.energy_history) == self.smoothing_window:
|
||||||
|
energy_std = np.std(self.energy_history)
|
||||||
|
if energy_std > self.energy_threshold * 0.5:
|
||||||
|
# High variability suggests speech
|
||||||
|
energy_score *= 1.2
|
||||||
|
|
||||||
|
vad_probability = 0.7 * energy_score + 0.3 * zcr_score
|
||||||
|
vad_probability = min(1.0, vad_probability)
|
||||||
|
|
||||||
|
analysis_details = {
|
||||||
|
"energy": energy,
|
||||||
|
"zcr": zcr,
|
||||||
|
"avg_energy": avg_energy,
|
||||||
|
"avg_zcr": avg_zcr,
|
||||||
|
"energy_score": energy_score,
|
||||||
|
"zcr_score": zcr_score,
|
||||||
|
"energy_std": np.std(self.energy_history) if len(self.energy_history) > 1 else 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return vad_probability, analysis_details
|
||||||
|
|
||||||
|
def is_speech(self, vad_probability: float, threshold: float = 0.5) -> bool:
|
||||||
|
"""Determine if the current frame contains speech."""
|
||||||
|
return vad_probability >= threshold
|
||||||
589
whisper/streaming/ctranslate2_backend.py
Normal file
589
whisper/streaming/ctranslate2_backend.py
Normal file
@ -0,0 +1,589 @@
|
|||||||
|
"""
|
||||||
|
CTranslate2 backend for accelerated Whisper inference.
|
||||||
|
|
||||||
|
This module provides a CTranslate2-based backend for faster inference,
|
||||||
|
especially useful for real-time streaming applications.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any, List, Union
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ctranslate2
|
||||||
|
import transformers
|
||||||
|
CTRANSLATE2_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
CTRANSLATE2_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class CTranslate2Backend:
|
||||||
|
"""
|
||||||
|
CTranslate2 backend for accelerated Whisper inference.
|
||||||
|
|
||||||
|
This backend converts Whisper models to CTranslate2 format for faster inference,
|
||||||
|
particularly beneficial for streaming applications where low latency is critical.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "base",
|
||||||
|
device: str = "auto",
|
||||||
|
compute_type: str = "float16",
|
||||||
|
inter_threads: int = 4,
|
||||||
|
intra_threads: int = 1,
|
||||||
|
cache_dir: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the CTranslate2 backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Whisper model name (tiny, base, small, medium, large, large-v2, large-v3)
|
||||||
|
device: Device for inference ("cpu", "cuda", "auto")
|
||||||
|
compute_type: Compute precision ("float32", "float16", "int8")
|
||||||
|
inter_threads: Number of inter-op threads
|
||||||
|
intra_threads: Number of intra-op threads
|
||||||
|
cache_dir: Directory to cache converted models
|
||||||
|
"""
|
||||||
|
if not CTRANSLATE2_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"CTranslate2 is not available. Please install with: "
|
||||||
|
"pip install ctranslate2 transformers"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
|
self.device = self._determine_device(device)
|
||||||
|
self.compute_type = compute_type
|
||||||
|
self.inter_threads = inter_threads
|
||||||
|
self.intra_threads = intra_threads
|
||||||
|
self.cache_dir = cache_dir or str(Path.home() / ".cache" / "whisper_ct2")
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self.model = None
|
||||||
|
self.processor = None
|
||||||
|
self.tokenizer = None
|
||||||
|
|
||||||
|
# Model info
|
||||||
|
self.model_path = None
|
||||||
|
self.is_loaded = False
|
||||||
|
|
||||||
|
# Performance tracking
|
||||||
|
self.inference_times = []
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
def _determine_device(self, device: str) -> str:
|
||||||
|
"""Determine the best available device."""
|
||||||
|
if device == "auto":
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
else:
|
||||||
|
return "cpu"
|
||||||
|
except ImportError:
|
||||||
|
return "cpu"
|
||||||
|
return device
|
||||||
|
|
||||||
|
def _load_model(self) -> None:
|
||||||
|
"""Load and convert the Whisper model to CTranslate2 format."""
|
||||||
|
try:
|
||||||
|
# Create cache directory
|
||||||
|
cache_path = Path(self.cache_dir)
|
||||||
|
cache_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
model_cache_path = cache_path / f"whisper-{self.model_name}-ct2"
|
||||||
|
|
||||||
|
# Convert model if not cached
|
||||||
|
if not model_cache_path.exists():
|
||||||
|
self.logger.info(f"Converting Whisper {self.model_name} to CTranslate2 format...")
|
||||||
|
self._convert_model(model_cache_path)
|
||||||
|
else:
|
||||||
|
self.logger.info(f"Using cached CTranslate2 model: {model_cache_path}")
|
||||||
|
|
||||||
|
self.model_path = str(model_cache_path)
|
||||||
|
|
||||||
|
# Load the CTranslate2 model
|
||||||
|
self.model = ctranslate2.models.Whisper(
|
||||||
|
self.model_path,
|
||||||
|
device=self.device,
|
||||||
|
compute_type=self.compute_type,
|
||||||
|
inter_threads=self.inter_threads,
|
||||||
|
intra_threads=self.intra_threads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the processor and tokenizer
|
||||||
|
self._load_processor()
|
||||||
|
|
||||||
|
self.is_loaded = True
|
||||||
|
self.logger.info(f"CTranslate2 Whisper model loaded successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to load CTranslate2 model: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _convert_model(self, output_path: Path) -> None:
|
||||||
|
"""Convert Whisper model to CTranslate2 format."""
|
||||||
|
try:
|
||||||
|
import whisper
|
||||||
|
|
||||||
|
# Load original Whisper model
|
||||||
|
self.logger.info("Loading original Whisper model...")
|
||||||
|
whisper_model = whisper.load_model(self.model_name)
|
||||||
|
|
||||||
|
# Convert to CTranslate2
|
||||||
|
self.logger.info("Converting to CTranslate2 format...")
|
||||||
|
|
||||||
|
# Save model state for conversion
|
||||||
|
temp_model_path = output_path.parent / f"temp_whisper_{self.model_name}"
|
||||||
|
temp_model_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Save the model components
|
||||||
|
import torch
|
||||||
|
torch.save({
|
||||||
|
'model_state_dict': whisper_model.state_dict(),
|
||||||
|
'dims': whisper_model.dims.__dict__,
|
||||||
|
}, temp_model_path / "pytorch_model.bin")
|
||||||
|
|
||||||
|
# Create config for conversion
|
||||||
|
config = {
|
||||||
|
"architectures": ["WhisperForConditionalGeneration"],
|
||||||
|
"model_type": "whisper",
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
}
|
||||||
|
|
||||||
|
import json
|
||||||
|
with open(temp_model_path / "config.json", "w") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
# Convert using ct2-whisper-converter (if available) or direct conversion
|
||||||
|
try:
|
||||||
|
# Try direct conversion
|
||||||
|
ctranslate2.converters.TransformersConverter(
|
||||||
|
str(temp_model_path)
|
||||||
|
).convert(
|
||||||
|
str(output_path),
|
||||||
|
quantization=self.compute_type
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Fallback: manual conversion
|
||||||
|
self._manual_convert(whisper_model, output_path)
|
||||||
|
|
||||||
|
# Cleanup temporary files
|
||||||
|
import shutil
|
||||||
|
if temp_model_path.exists():
|
||||||
|
shutil.rmtree(temp_model_path)
|
||||||
|
|
||||||
|
self.logger.info(f"Model conversion completed: {output_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Model conversion failed: {e}")
|
||||||
|
# Fallback: create a stub that uses regular Whisper
|
||||||
|
self._create_fallback_model(output_path)
|
||||||
|
|
||||||
|
def _manual_convert(self, whisper_model, output_path: Path) -> None:
|
||||||
|
"""Manual conversion when automatic conversion fails."""
|
||||||
|
# This is a simplified fallback conversion
|
||||||
|
# In practice, you might want to use the official whisper-ctranslate2 converter
|
||||||
|
self.logger.warning("Using fallback conversion - performance may be suboptimal")
|
||||||
|
|
||||||
|
output_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Save model info
|
||||||
|
model_info = {
|
||||||
|
"model_name": self.model_name,
|
||||||
|
"conversion_method": "fallback",
|
||||||
|
"device": self.device,
|
||||||
|
"compute_type": self.compute_type
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(output_path / "model_info.json", "w") as f:
|
||||||
|
import json
|
||||||
|
json.dump(model_info, f, indent=2)
|
||||||
|
|
||||||
|
def _create_fallback_model(self, output_path: Path) -> None:
|
||||||
|
"""Create a fallback model directory when conversion fails."""
|
||||||
|
output_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
fallback_info = {
|
||||||
|
"model_name": self.model_name,
|
||||||
|
"fallback": True,
|
||||||
|
"message": "CTranslate2 conversion failed, will use standard Whisper"
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(output_path / "fallback.json", "w") as f:
|
||||||
|
import json
|
||||||
|
json.dump(fallback_info, f, indent=2)
|
||||||
|
|
||||||
|
def _load_processor(self) -> None:
|
||||||
|
"""Load the audio processor and tokenizer."""
|
||||||
|
try:
|
||||||
|
# For Whisper, we need to handle audio preprocessing manually
|
||||||
|
# since CTranslate2 expects specific input formats
|
||||||
|
|
||||||
|
# Create a simple processor wrapper
|
||||||
|
self.processor = WhisperProcessor(self.model_name)
|
||||||
|
|
||||||
|
self.logger.info("Processor loaded successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to load processor: {e}")
|
||||||
|
self.processor = None
|
||||||
|
|
||||||
|
def transcribe(
|
||||||
|
self,
|
||||||
|
audio: Union[np.ndarray, str],
|
||||||
|
language: Optional[str] = None,
|
||||||
|
task: str = "transcribe",
|
||||||
|
temperature: float = 0.0,
|
||||||
|
return_timestamps: bool = False,
|
||||||
|
return_word_timestamps: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Transcribe audio using the CTranslate2 backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio data as numpy array or file path
|
||||||
|
language: Source language (None for auto-detection)
|
||||||
|
task: Task type ("transcribe" or "translate")
|
||||||
|
temperature: Sampling temperature
|
||||||
|
return_timestamps: Include segment timestamps
|
||||||
|
return_word_timestamps: Include word-level timestamps
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with transcription results
|
||||||
|
"""
|
||||||
|
if not self.is_loaded:
|
||||||
|
raise RuntimeError("Model is not loaded")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Preprocess audio
|
||||||
|
if isinstance(audio, str):
|
||||||
|
audio_features = self._load_audio_file(audio)
|
||||||
|
else:
|
||||||
|
audio_features = self._preprocess_audio(audio)
|
||||||
|
|
||||||
|
# Prepare generation parameters
|
||||||
|
generation_params = {
|
||||||
|
"language": language,
|
||||||
|
"task": task,
|
||||||
|
"beam_size": 1 if temperature > 0 else 5,
|
||||||
|
"temperature": temperature,
|
||||||
|
"return_scores": True,
|
||||||
|
"return_no_speech_prob": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove None values
|
||||||
|
generation_params = {k: v for k, v in generation_params.items() if v is not None}
|
||||||
|
|
||||||
|
# Generate transcription
|
||||||
|
results = self.model.generate(
|
||||||
|
audio_features,
|
||||||
|
**generation_params
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process results
|
||||||
|
transcription_result = self._process_results(
|
||||||
|
results,
|
||||||
|
return_timestamps=return_timestamps,
|
||||||
|
return_word_timestamps=return_word_timestamps
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track inference time
|
||||||
|
inference_time = time.time() - start_time
|
||||||
|
self.inference_times.append(inference_time)
|
||||||
|
|
||||||
|
transcription_result["processing_time"] = inference_time
|
||||||
|
return transcription_result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Transcription failed: {e}")
|
||||||
|
# Fallback to empty result
|
||||||
|
return {
|
||||||
|
"text": "",
|
||||||
|
"segments": [],
|
||||||
|
"language": language or "en",
|
||||||
|
"processing_time": time.time() - start_time,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _load_audio_file(self, file_path: str) -> np.ndarray:
|
||||||
|
"""Load audio file and convert to model input format."""
|
||||||
|
try:
|
||||||
|
import whisper
|
||||||
|
# Use Whisper's built-in audio loading
|
||||||
|
audio = whisper.load_audio(file_path)
|
||||||
|
return self._preprocess_audio(audio)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to load audio file {file_path}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _preprocess_audio(self, audio: np.ndarray) -> np.ndarray:
|
||||||
|
"""Preprocess audio data for the model."""
|
||||||
|
try:
|
||||||
|
import whisper
|
||||||
|
|
||||||
|
# Ensure audio is the right format
|
||||||
|
if len(audio.shape) > 1:
|
||||||
|
audio = audio.mean(axis=1) # Convert to mono
|
||||||
|
|
||||||
|
# Pad or trim to expected length
|
||||||
|
audio = whisper.pad_or_trim(audio)
|
||||||
|
|
||||||
|
# Convert to log-mel spectrogram
|
||||||
|
mel = whisper.log_mel_spectrogram(audio).unsqueeze(0)
|
||||||
|
|
||||||
|
return mel.numpy()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Audio preprocessing failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _process_results(
|
||||||
|
self,
|
||||||
|
results,
|
||||||
|
return_timestamps: bool = False,
|
||||||
|
return_word_timestamps: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Process CTranslate2 results into standard format."""
|
||||||
|
try:
|
||||||
|
if not results or not results[0]:
|
||||||
|
return {
|
||||||
|
"text": "",
|
||||||
|
"segments": [],
|
||||||
|
"language": "en"
|
||||||
|
}
|
||||||
|
|
||||||
|
result = results[0]
|
||||||
|
|
||||||
|
# Extract text
|
||||||
|
if hasattr(result, 'sequences'):
|
||||||
|
# Handle sequence results
|
||||||
|
text_tokens = result.sequences[0]
|
||||||
|
text = self._decode_tokens(text_tokens)
|
||||||
|
else:
|
||||||
|
# Handle direct text results
|
||||||
|
text = str(result)
|
||||||
|
|
||||||
|
# Build result dictionary
|
||||||
|
transcription_result = {
|
||||||
|
"text": text.strip(),
|
||||||
|
"language": getattr(result, 'language', 'en')
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add confidence if available
|
||||||
|
if hasattr(result, 'scores') and result.scores:
|
||||||
|
confidence = float(np.exp(np.mean(result.scores)))
|
||||||
|
transcription_result["confidence"] = confidence
|
||||||
|
|
||||||
|
# Add segments if requested
|
||||||
|
if return_timestamps:
|
||||||
|
segments = self._extract_segments(result, return_word_timestamps)
|
||||||
|
transcription_result["segments"] = segments
|
||||||
|
|
||||||
|
return transcription_result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Result processing failed: {e}")
|
||||||
|
return {
|
||||||
|
"text": "",
|
||||||
|
"segments": [],
|
||||||
|
"language": "en"
|
||||||
|
}
|
||||||
|
|
||||||
|
def _decode_tokens(self, tokens: List[int]) -> str:
|
||||||
|
"""Decode token IDs to text."""
|
||||||
|
# This is a simplified decoder - in practice you'd use the proper tokenizer
|
||||||
|
try:
|
||||||
|
if self.tokenizer:
|
||||||
|
return self.tokenizer.decode(tokens)
|
||||||
|
else:
|
||||||
|
# Fallback: assume tokens are already text-like
|
||||||
|
return " ".join(str(token) for token in tokens)
|
||||||
|
except Exception:
|
||||||
|
return " ".join(str(token) for token in tokens)
|
||||||
|
|
||||||
|
def _extract_segments(self, result, return_word_timestamps: bool) -> List[Dict[str, Any]]:
|
||||||
|
"""Extract segment information from results."""
|
||||||
|
segments = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# This is a simplified segment extraction
|
||||||
|
# Real implementation would depend on CTranslate2's output format
|
||||||
|
text = getattr(result, 'text', '')
|
||||||
|
|
||||||
|
if text:
|
||||||
|
segment = {
|
||||||
|
"id": 0,
|
||||||
|
"start": 0.0,
|
||||||
|
"end": 30.0, # Assume 30-second segments
|
||||||
|
"text": text,
|
||||||
|
"tokens": getattr(result, 'sequences', [[]])[0] if hasattr(result, 'sequences') else [],
|
||||||
|
"temperature": 0.0,
|
||||||
|
"avg_logprob": float(np.mean(result.scores)) if hasattr(result, 'scores') and result.scores else -1.0,
|
||||||
|
"compression_ratio": len(text) / max(1, len(text.split())),
|
||||||
|
"no_speech_prob": getattr(result, 'no_speech_prob', 0.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
if return_word_timestamps:
|
||||||
|
# Add word-level timestamps (placeholder)
|
||||||
|
words = text.split()
|
||||||
|
word_duration = 30.0 / max(1, len(words))
|
||||||
|
segment["words"] = [
|
||||||
|
{
|
||||||
|
"word": word,
|
||||||
|
"start": i * word_duration,
|
||||||
|
"end": (i + 1) * word_duration,
|
||||||
|
"probability": 0.9
|
||||||
|
}
|
||||||
|
for i, word in enumerate(words)
|
||||||
|
]
|
||||||
|
|
||||||
|
segments.append(segment)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Segment extraction failed: {e}")
|
||||||
|
|
||||||
|
return segments
|
||||||
|
|
||||||
|
def get_model_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get information about the loaded model."""
|
||||||
|
return {
|
||||||
|
"model_name": self.model_name,
|
||||||
|
"device": self.device,
|
||||||
|
"compute_type": self.compute_type,
|
||||||
|
"is_loaded": self.is_loaded,
|
||||||
|
"model_path": self.model_path,
|
||||||
|
"backend": "CTranslate2",
|
||||||
|
"avg_inference_time": np.mean(self.inference_times) if self.inference_times else 0.0,
|
||||||
|
"total_inferences": len(self.inference_times)
|
||||||
|
}
|
||||||
|
|
||||||
|
def warmup(self, duration: float = 1.0) -> None:
|
||||||
|
"""Warm up the model with dummy audio."""
|
||||||
|
try:
|
||||||
|
# Create dummy audio
|
||||||
|
sample_rate = 16000
|
||||||
|
dummy_audio = np.random.randn(int(sample_rate * duration)).astype(np.float32)
|
||||||
|
|
||||||
|
# Run inference to warm up
|
||||||
|
self.transcribe(dummy_audio, temperature=0.0)
|
||||||
|
self.logger.info("Model warmup completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Model warmup failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperProcessor:
|
||||||
|
"""Simple processor wrapper for Whisper audio preprocessing."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str):
|
||||||
|
self.model_name = model_name
|
||||||
|
# In a full implementation, you'd load the proper feature extractor
|
||||||
|
# For now, this is a placeholder
|
||||||
|
|
||||||
|
def preprocess(self, audio: np.ndarray) -> np.ndarray:
|
||||||
|
"""Preprocess audio for the model."""
|
||||||
|
# This would contain the actual preprocessing logic
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_models() -> List[str]:
|
||||||
|
"""Get list of available Whisper models for CTranslate2."""
|
||||||
|
return [
|
||||||
|
"tiny",
|
||||||
|
"tiny.en",
|
||||||
|
"base",
|
||||||
|
"base.en",
|
||||||
|
"small",
|
||||||
|
"small.en",
|
||||||
|
"medium",
|
||||||
|
"medium.en",
|
||||||
|
"large",
|
||||||
|
"large-v1",
|
||||||
|
"large-v2",
|
||||||
|
"large-v3"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_model(
|
||||||
|
model_name: str,
|
||||||
|
audio_duration: float = 30.0,
|
||||||
|
num_runs: int = 5,
|
||||||
|
device: str = "auto"
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Benchmark a CTranslate2 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model to benchmark
|
||||||
|
audio_duration: Duration of test audio in seconds
|
||||||
|
num_runs: Number of benchmark runs
|
||||||
|
device: Device for inference
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with benchmark results
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create backend
|
||||||
|
backend = CTranslate2Backend(model_name, device=device)
|
||||||
|
|
||||||
|
# Generate test audio
|
||||||
|
sample_rate = 16000
|
||||||
|
test_audio = np.random.randn(int(sample_rate * audio_duration)).astype(np.float32)
|
||||||
|
|
||||||
|
# Warm up
|
||||||
|
backend.warmup()
|
||||||
|
|
||||||
|
# Run benchmark
|
||||||
|
times = []
|
||||||
|
for i in range(num_runs):
|
||||||
|
start_time = time.time()
|
||||||
|
result = backend.transcribe(test_audio)
|
||||||
|
end_time = time.time()
|
||||||
|
times.append(end_time - start_time)
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
times = np.array(times)
|
||||||
|
return {
|
||||||
|
"model_name": model_name,
|
||||||
|
"device": device,
|
||||||
|
"audio_duration": audio_duration,
|
||||||
|
"num_runs": num_runs,
|
||||||
|
"mean_time": float(np.mean(times)),
|
||||||
|
"std_time": float(np.std(times)),
|
||||||
|
"min_time": float(np.min(times)),
|
||||||
|
"max_time": float(np.max(times)),
|
||||||
|
"realtime_factor": float(audio_duration / np.mean(times))
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"error": str(e),
|
||||||
|
"model_name": model_name,
|
||||||
|
"device": device
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Simple test
|
||||||
|
if CTRANSLATE2_AVAILABLE:
|
||||||
|
backend = CTranslate2Backend("base")
|
||||||
|
print(f"Model info: {backend.get_model_info()}")
|
||||||
|
|
||||||
|
# Test with dummy audio
|
||||||
|
test_audio = np.random.randn(16000).astype(np.float32) # 1 second
|
||||||
|
result = backend.transcribe(test_audio)
|
||||||
|
print(f"Test result: {result}")
|
||||||
|
else:
|
||||||
|
print("CTranslate2 is not available. Install with: pip install ctranslate2 transformers")
|
||||||
518
whisper/streaming/stream_processor.py
Normal file
518
whisper/streaming/stream_processor.py
Normal file
@ -0,0 +1,518 @@
|
|||||||
|
"""
|
||||||
|
Real-time streaming processor for Whisper transcription.
|
||||||
|
|
||||||
|
This module handles the core streaming logic, integrating audio buffering,
|
||||||
|
real-time transcription, and result management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, List, Optional, Callable, Any, Union
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
|
||||||
|
from .audio_buffer import AudioBuffer, AudioChunk, StreamingVAD
|
||||||
|
|
||||||
|
|
||||||
|
class StreamState(Enum):
|
||||||
|
"""Possible states of the stream processor."""
|
||||||
|
STOPPED = "stopped"
|
||||||
|
STARTING = "starting"
|
||||||
|
RUNNING = "running"
|
||||||
|
STOPPING = "stopping"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamConfig:
|
||||||
|
"""Configuration for the stream processor."""
|
||||||
|
# Audio settings
|
||||||
|
sample_rate: int = 16000
|
||||||
|
chunk_duration_ms: int = 1000
|
||||||
|
buffer_duration_ms: int = 5000
|
||||||
|
overlap_duration_ms: int = 200
|
||||||
|
|
||||||
|
# Transcription settings
|
||||||
|
model_name: str = "base"
|
||||||
|
language: Optional[str] = None
|
||||||
|
task: str = "transcribe" # "transcribe" or "translate"
|
||||||
|
temperature: float = 0.0
|
||||||
|
condition_on_previous_text: bool = False
|
||||||
|
|
||||||
|
# Streaming settings
|
||||||
|
realtime_factor: float = 1.0 # Target processing speed vs real-time
|
||||||
|
max_processing_delay_ms: int = 500
|
||||||
|
min_silence_duration_ms: int = 1000
|
||||||
|
max_segment_duration_ms: int = 30000
|
||||||
|
|
||||||
|
# VAD settings
|
||||||
|
vad_threshold: float = 0.5
|
||||||
|
vad_frame_duration_ms: int = 20
|
||||||
|
|
||||||
|
# Performance settings
|
||||||
|
use_ctranslate2: bool = False
|
||||||
|
device: str = "auto" # "auto", "cpu", "cuda"
|
||||||
|
compute_type: str = "float16"
|
||||||
|
|
||||||
|
# Output settings
|
||||||
|
return_timestamps: bool = True
|
||||||
|
return_word_timestamps: bool = False
|
||||||
|
return_confidence_scores: bool = True
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "StreamConfig":
|
||||||
|
"""Create from dictionary."""
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TranscriptionResult:
|
||||||
|
"""Result of a transcription operation."""
|
||||||
|
text: str
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
confidence: float
|
||||||
|
language: Optional[str] = None
|
||||||
|
chunks: Optional[List[Dict[str, Any]]] = None
|
||||||
|
processing_time_ms: float = 0.0
|
||||||
|
is_final: bool = True
|
||||||
|
segment_id: str = ""
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for serialization."""
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
def to_json(self) -> str:
|
||||||
|
"""Convert to JSON string."""
|
||||||
|
return json.dumps(self.to_dict(), ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamProcessor:
|
||||||
|
"""
|
||||||
|
Real-time streaming processor for Whisper transcription.
|
||||||
|
|
||||||
|
This class manages the entire streaming pipeline:
|
||||||
|
1. Audio buffering and chunking
|
||||||
|
2. Voice Activity Detection
|
||||||
|
3. Real-time transcription
|
||||||
|
4. Result aggregation and delivery
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: StreamConfig,
|
||||||
|
model: Optional[Any] = None,
|
||||||
|
result_callback: Optional[Callable[[TranscriptionResult], None]] = None,
|
||||||
|
error_callback: Optional[Callable[[Exception], None]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the stream processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Stream configuration
|
||||||
|
model: Pre-loaded Whisper model (optional)
|
||||||
|
result_callback: Callback for transcription results
|
||||||
|
error_callback: Callback for errors
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.model = model
|
||||||
|
self.result_callback = result_callback
|
||||||
|
self.error_callback = error_callback
|
||||||
|
|
||||||
|
# State management
|
||||||
|
self.state = StreamState.STOPPED
|
||||||
|
self.start_time = None
|
||||||
|
self.total_audio_duration = 0.0
|
||||||
|
self.total_processing_time = 0.0
|
||||||
|
|
||||||
|
# Audio processing
|
||||||
|
self.audio_buffer = AudioBuffer(
|
||||||
|
sample_rate=config.sample_rate,
|
||||||
|
chunk_duration_ms=config.chunk_duration_ms,
|
||||||
|
buffer_duration_ms=config.buffer_duration_ms,
|
||||||
|
overlap_duration_ms=config.overlap_duration_ms,
|
||||||
|
vad_threshold=config.vad_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vad = StreamingVAD(
|
||||||
|
sample_rate=config.sample_rate,
|
||||||
|
frame_duration_ms=config.vad_frame_duration_ms
|
||||||
|
)
|
||||||
|
|
||||||
|
# Threading
|
||||||
|
self.processing_thread = None
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
|
||||||
|
# Results management
|
||||||
|
self.pending_segments = []
|
||||||
|
self.completed_segments = []
|
||||||
|
self.segment_counter = 0
|
||||||
|
|
||||||
|
# Performance tracking
|
||||||
|
self.processing_stats = {
|
||||||
|
"chunks_processed": 0,
|
||||||
|
"average_processing_time_ms": 0.0,
|
||||||
|
"max_processing_time_ms": 0.0,
|
||||||
|
"realtime_factor": 0.0,
|
||||||
|
"dropped_chunks": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def start(self) -> bool:
|
||||||
|
"""
|
||||||
|
Start the streaming processor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if started successfully, False otherwise
|
||||||
|
"""
|
||||||
|
if self.state != StreamState.STOPPED:
|
||||||
|
self.logger.warning(f"Cannot start processor in state {self.state}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.state = StreamState.STARTING
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.stop_event.clear()
|
||||||
|
|
||||||
|
# Initialize model if not provided
|
||||||
|
if self.model is None:
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
# Start processing thread
|
||||||
|
self.processing_thread = threading.Thread(
|
||||||
|
target=self._processing_loop,
|
||||||
|
name="WhisperStreamProcessor",
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
self.processing_thread.start()
|
||||||
|
|
||||||
|
self.state = StreamState.RUNNING
|
||||||
|
self.logger.info("Stream processor started successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.state = StreamState.ERROR
|
||||||
|
self.logger.error(f"Failed to start stream processor: {e}")
|
||||||
|
if self.error_callback:
|
||||||
|
self.error_callback(e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def stop(self) -> bool:
|
||||||
|
"""
|
||||||
|
Stop the streaming processor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if stopped successfully, False otherwise
|
||||||
|
"""
|
||||||
|
if self.state not in [StreamState.RUNNING, StreamState.ERROR]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.state = StreamState.STOPPING
|
||||||
|
self.stop_event.set()
|
||||||
|
|
||||||
|
# Wait for processing thread to finish
|
||||||
|
if self.processing_thread and self.processing_thread.is_alive():
|
||||||
|
self.processing_thread.join(timeout=5.0)
|
||||||
|
|
||||||
|
self.state = StreamState.STOPPED
|
||||||
|
self.logger.info("Stream processor stopped successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error stopping stream processor: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def add_audio(self, audio_data: Union[bytes, list, tuple]) -> None:
|
||||||
|
"""
|
||||||
|
Add audio data to the processing pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: Audio data as bytes, list, or tuple
|
||||||
|
"""
|
||||||
|
if self.state != StreamState.RUNNING:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert to numpy array if needed
|
||||||
|
import numpy as np
|
||||||
|
if isinstance(audio_data, bytes):
|
||||||
|
# Assume 16-bit PCM
|
||||||
|
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
elif isinstance(audio_data, (list, tuple)):
|
||||||
|
audio_array = np.array(audio_data, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
audio_array = audio_data
|
||||||
|
|
||||||
|
self.audio_buffer.add_audio(audio_array)
|
||||||
|
self.total_audio_duration += len(audio_array) / self.config.sample_rate
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error adding audio data: {e}")
|
||||||
|
if self.error_callback:
|
||||||
|
self.error_callback(e)
|
||||||
|
|
||||||
|
def _load_model(self) -> None:
|
||||||
|
"""Load the Whisper model based on configuration."""
|
||||||
|
try:
|
||||||
|
if self.config.use_ctranslate2:
|
||||||
|
from .ctranslate2_backend import CTranslate2Backend
|
||||||
|
self.model = CTranslate2Backend(
|
||||||
|
model_name=self.config.model_name,
|
||||||
|
device=self.config.device,
|
||||||
|
compute_type=self.config.compute_type
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
import whisper
|
||||||
|
self.model = whisper.load_model(
|
||||||
|
self.config.model_name,
|
||||||
|
device=self.config.device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(f"Loaded model: {self.config.model_name}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to load model: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _processing_loop(self) -> None:
|
||||||
|
"""Main processing loop running in a separate thread."""
|
||||||
|
self.logger.info("Processing loop started")
|
||||||
|
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
# Get audio chunks from buffer
|
||||||
|
chunks = self.audio_buffer.get_chunks_batch(
|
||||||
|
max_chunks=5,
|
||||||
|
timeout=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process chunks
|
||||||
|
for chunk in chunks:
|
||||||
|
if self.stop_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
self._process_chunk(chunk)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in processing loop: {e}")
|
||||||
|
if self.error_callback:
|
||||||
|
self.error_callback(e)
|
||||||
|
|
||||||
|
self.logger.info("Processing loop ended")
|
||||||
|
|
||||||
|
def _process_chunk(self, chunk: AudioChunk) -> None:
|
||||||
|
"""
|
||||||
|
Process a single audio chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk: AudioChunk to process
|
||||||
|
"""
|
||||||
|
processing_start = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Skip processing if chunk is silence (unless we need to finalize a segment)
|
||||||
|
if chunk.is_silence and not self._should_process_silence(chunk):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Prepare audio for transcription
|
||||||
|
audio_data = chunk.data
|
||||||
|
|
||||||
|
# Transcribe using the model
|
||||||
|
if self.config.use_ctranslate2:
|
||||||
|
result = self._transcribe_with_ctranslate2(audio_data, chunk)
|
||||||
|
else:
|
||||||
|
result = self._transcribe_with_whisper(audio_data, chunk)
|
||||||
|
|
||||||
|
# Process the transcription result
|
||||||
|
if result and result.text.strip():
|
||||||
|
self._handle_transcription_result(result, chunk)
|
||||||
|
|
||||||
|
# Update performance stats
|
||||||
|
processing_time = (time.time() - processing_start) * 1000
|
||||||
|
self._update_processing_stats(processing_time, chunk.duration)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error processing chunk {chunk.chunk_id}: {e}")
|
||||||
|
self.processing_stats["dropped_chunks"] += 1
|
||||||
|
|
||||||
|
def _should_process_silence(self, chunk: AudioChunk) -> bool:
|
||||||
|
"""Determine if we should process a silence chunk."""
|
||||||
|
# Process silence if we have pending segments that need to be finalized
|
||||||
|
return len(self.pending_segments) > 0
|
||||||
|
|
||||||
|
def _transcribe_with_whisper(self, audio_data, chunk: AudioChunk) -> Optional[TranscriptionResult]:
|
||||||
|
"""Transcribe using standard Whisper model."""
|
||||||
|
try:
|
||||||
|
# Prepare transcription options
|
||||||
|
options = {
|
||||||
|
"language": self.config.language,
|
||||||
|
"task": self.config.task,
|
||||||
|
"temperature": self.config.temperature,
|
||||||
|
"condition_on_previous_text": self.config.condition_on_previous_text,
|
||||||
|
"word_timestamps": self.config.return_word_timestamps,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove None values
|
||||||
|
options = {k: v for k, v in options.items() if v is not None}
|
||||||
|
|
||||||
|
# Transcribe
|
||||||
|
result = self.model.transcribe(audio_data, **options)
|
||||||
|
|
||||||
|
# Extract text and metadata
|
||||||
|
text = result.get("text", "").strip()
|
||||||
|
language = result.get("language")
|
||||||
|
segments = result.get("segments", [])
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Calculate confidence (simplified)
|
||||||
|
confidence = 0.8 # Default confidence for standard Whisper
|
||||||
|
if segments:
|
||||||
|
# Use average log probability if available
|
||||||
|
avg_logprobs = [s.get("avg_logprob", -1.0) for s in segments if s.get("avg_logprob")]
|
||||||
|
if avg_logprobs:
|
||||||
|
# Convert log probability to confidence (rough approximation)
|
||||||
|
avg_logprob = sum(avg_logprobs) / len(avg_logprobs)
|
||||||
|
confidence = max(0.1, min(0.99, 1.0 + avg_logprob / 2.0))
|
||||||
|
|
||||||
|
return TranscriptionResult(
|
||||||
|
text=text,
|
||||||
|
start_time=chunk.timestamp,
|
||||||
|
end_time=chunk.timestamp + chunk.duration,
|
||||||
|
confidence=confidence,
|
||||||
|
language=language,
|
||||||
|
chunks=segments if self.config.return_timestamps else None,
|
||||||
|
segment_id=f"seg_{self.segment_counter}_{chunk.chunk_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in Whisper transcription: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _transcribe_with_ctranslate2(self, audio_data, chunk: AudioChunk) -> Optional[TranscriptionResult]:
|
||||||
|
"""Transcribe using CTranslate2 backend."""
|
||||||
|
try:
|
||||||
|
result = self.model.transcribe(
|
||||||
|
audio_data,
|
||||||
|
language=self.config.language,
|
||||||
|
task=self.config.task,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
return_timestamps=self.config.return_timestamps,
|
||||||
|
return_word_timestamps=self.config.return_word_timestamps
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result or not result.get("text", "").strip():
|
||||||
|
return None
|
||||||
|
|
||||||
|
return TranscriptionResult(
|
||||||
|
text=result["text"].strip(),
|
||||||
|
start_time=chunk.timestamp,
|
||||||
|
end_time=chunk.timestamp + chunk.duration,
|
||||||
|
confidence=result.get("confidence", 0.8),
|
||||||
|
language=result.get("language"),
|
||||||
|
chunks=result.get("segments"),
|
||||||
|
segment_id=f"seg_{self.segment_counter}_{chunk.chunk_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in CTranslate2 transcription: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _handle_transcription_result(self, result: TranscriptionResult, chunk: AudioChunk) -> None:
|
||||||
|
"""
|
||||||
|
Handle a transcription result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Transcription result
|
||||||
|
chunk: Source audio chunk
|
||||||
|
"""
|
||||||
|
# Add processing time
|
||||||
|
result.processing_time_ms = (time.time() - chunk.timestamp) * 1000
|
||||||
|
|
||||||
|
# Determine if this is a final result
|
||||||
|
result.is_final = not self.audio_buffer.is_speech_active()
|
||||||
|
|
||||||
|
# Add to appropriate list
|
||||||
|
if result.is_final:
|
||||||
|
self.completed_segments.append(result)
|
||||||
|
self.segment_counter += 1
|
||||||
|
else:
|
||||||
|
self.pending_segments.append(result)
|
||||||
|
|
||||||
|
# Send result to callback
|
||||||
|
if self.result_callback:
|
||||||
|
try:
|
||||||
|
self.result_callback(result)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in result callback: {e}")
|
||||||
|
|
||||||
|
def _update_processing_stats(self, processing_time_ms: float, chunk_duration_s: float) -> None:
|
||||||
|
"""Update processing performance statistics."""
|
||||||
|
self.processing_stats["chunks_processed"] += 1
|
||||||
|
self.total_processing_time += processing_time_ms / 1000
|
||||||
|
|
||||||
|
# Update average processing time
|
||||||
|
count = self.processing_stats["chunks_processed"]
|
||||||
|
current_avg = self.processing_stats["average_processing_time_ms"]
|
||||||
|
self.processing_stats["average_processing_time_ms"] = (
|
||||||
|
(current_avg * (count - 1) + processing_time_ms) / count
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update max processing time
|
||||||
|
self.processing_stats["max_processing_time_ms"] = max(
|
||||||
|
self.processing_stats["max_processing_time_ms"],
|
||||||
|
processing_time_ms
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate realtime factor
|
||||||
|
if chunk_duration_s > 0:
|
||||||
|
realtime_factor = chunk_duration_s / (processing_time_ms / 1000)
|
||||||
|
self.processing_stats["realtime_factor"] = realtime_factor
|
||||||
|
|
||||||
|
def get_status(self) -> Dict[str, Any]:
|
||||||
|
"""Get current processor status."""
|
||||||
|
return {
|
||||||
|
"state": self.state.value,
|
||||||
|
"uptime_seconds": time.time() - self.start_time if self.start_time else 0,
|
||||||
|
"total_audio_duration": self.total_audio_duration,
|
||||||
|
"total_processing_time": self.total_processing_time,
|
||||||
|
"buffer_info": self.audio_buffer.get_buffer_info(),
|
||||||
|
"processing_stats": self.processing_stats.copy(),
|
||||||
|
"segments_completed": len(self.completed_segments),
|
||||||
|
"segments_pending": len(self.pending_segments)
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_results(self, since_segment: int = 0) -> List[TranscriptionResult]:
|
||||||
|
"""
|
||||||
|
Get transcription results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
since_segment: Return results after this segment number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of transcription results
|
||||||
|
"""
|
||||||
|
all_results = self.completed_segments + self.pending_segments
|
||||||
|
return [result for result in all_results if int(result.segment_id.split('_')[1]) >= since_segment]
|
||||||
|
|
||||||
|
def clear_completed_segments(self) -> None:
|
||||||
|
"""Clear completed segments to free memory."""
|
||||||
|
self.completed_segments.clear()
|
||||||
|
|
||||||
|
def get_config(self) -> StreamConfig:
|
||||||
|
"""Get current configuration."""
|
||||||
|
return self.config
|
||||||
498
whisper/streaming/websocket_server.py
Normal file
498
whisper/streaming/websocket_server.py
Normal file
@ -0,0 +1,498 @@
|
|||||||
|
"""
|
||||||
|
WebSocket server for real-time Whisper transcription.
|
||||||
|
|
||||||
|
This module provides a WebSocket-based API for streaming audio and receiving
|
||||||
|
real-time transcription results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Dict, Any, Optional, Set, Callable
|
||||||
|
import websockets
|
||||||
|
import websockets.server
|
||||||
|
from websockets.exceptions import ConnectionClosed, WebSocketException
|
||||||
|
import threading
|
||||||
|
import base64
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from .stream_processor import StreamProcessor, StreamConfig, TranscriptionResult, StreamState
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperWebSocketServer:
|
||||||
|
"""
|
||||||
|
WebSocket server for real-time Whisper transcription.
|
||||||
|
|
||||||
|
Supports multiple concurrent connections with independent processing streams.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str = "localhost",
|
||||||
|
port: int = 8765,
|
||||||
|
default_config: Optional[StreamConfig] = None,
|
||||||
|
model_cache: Optional[Dict[str, Any]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the WebSocket server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: Server host address
|
||||||
|
port: Server port
|
||||||
|
default_config: Default stream configuration
|
||||||
|
model_cache: Optional model cache for performance
|
||||||
|
"""
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.default_config = default_config or StreamConfig()
|
||||||
|
self.model_cache = model_cache or {}
|
||||||
|
|
||||||
|
# Connection management
|
||||||
|
self.active_connections: Set[websockets.WebSocketServerProtocol] = set()
|
||||||
|
self.connection_processors: Dict[str, StreamProcessor] = {}
|
||||||
|
self.connection_configs: Dict[str, StreamConfig] = {}
|
||||||
|
|
||||||
|
# Server state
|
||||||
|
self.server = None
|
||||||
|
self.is_running = False
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
self.stats = {
|
||||||
|
"connections_total": 0,
|
||||||
|
"connections_active": 0,
|
||||||
|
"messages_received": 0,
|
||||||
|
"messages_sent": 0,
|
||||||
|
"audio_bytes_processed": 0,
|
||||||
|
"start_time": None
|
||||||
|
}
|
||||||
|
|
||||||
|
async def start_server(self) -> None:
|
||||||
|
"""Start the WebSocket server."""
|
||||||
|
try:
|
||||||
|
self.stats["start_time"] = time.time()
|
||||||
|
self.server = await websockets.serve(
|
||||||
|
self._handle_connection,
|
||||||
|
self.host,
|
||||||
|
self.port,
|
||||||
|
ping_interval=20,
|
||||||
|
ping_timeout=10,
|
||||||
|
max_size=10 * 1024 * 1024, # 10MB max message size
|
||||||
|
compression=None # Disable compression for audio data
|
||||||
|
)
|
||||||
|
|
||||||
|
self.is_running = True
|
||||||
|
self.logger.info(f"WebSocket server started on ws://{self.host}:{self.port}")
|
||||||
|
|
||||||
|
# Wait until shutdown
|
||||||
|
await self.shutdown_event.wait()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error starting WebSocket server: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def stop_server(self) -> None:
|
||||||
|
"""Stop the WebSocket server."""
|
||||||
|
if not self.is_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.is_running = False
|
||||||
|
self.shutdown_event.set()
|
||||||
|
|
||||||
|
# Stop all active processors
|
||||||
|
for processor in self.connection_processors.values():
|
||||||
|
processor.stop()
|
||||||
|
|
||||||
|
# Close all connections
|
||||||
|
if self.active_connections:
|
||||||
|
await asyncio.gather(
|
||||||
|
*[conn.close() for conn in self.active_connections.copy()],
|
||||||
|
return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Close the server
|
||||||
|
if self.server:
|
||||||
|
self.server.close()
|
||||||
|
await self.server.wait_closed()
|
||||||
|
|
||||||
|
self.logger.info("WebSocket server stopped")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error stopping WebSocket server: {e}")
|
||||||
|
|
||||||
|
async def _handle_connection(self, websocket: websockets.WebSocketServerProtocol, path: str) -> None:
|
||||||
|
"""Handle a new WebSocket connection."""
|
||||||
|
connection_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}:{id(websocket)}"
|
||||||
|
|
||||||
|
self.logger.info(f"New connection: {connection_id}")
|
||||||
|
self.active_connections.add(websocket)
|
||||||
|
self.stats["connections_total"] += 1
|
||||||
|
self.stats["connections_active"] += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize connection
|
||||||
|
await self._initialize_connection(websocket, connection_id)
|
||||||
|
|
||||||
|
# Handle messages
|
||||||
|
async for message in websocket:
|
||||||
|
await self._handle_message(websocket, connection_id, message)
|
||||||
|
|
||||||
|
except ConnectionClosed:
|
||||||
|
self.logger.info(f"Connection closed: {connection_id}")
|
||||||
|
except WebSocketException as e:
|
||||||
|
self.logger.warning(f"WebSocket error for {connection_id}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Unexpected error for {connection_id}: {e}")
|
||||||
|
await self._send_error(websocket, "Internal server error", str(e))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await self._cleanup_connection(websocket, connection_id)
|
||||||
|
|
||||||
|
async def _initialize_connection(self, websocket: websockets.WebSocketServerProtocol, connection_id: str) -> None:
|
||||||
|
"""Initialize a new connection."""
|
||||||
|
# Send welcome message
|
||||||
|
welcome_msg = {
|
||||||
|
"type": "connection_established",
|
||||||
|
"connection_id": connection_id,
|
||||||
|
"server_info": {
|
||||||
|
"version": "1.0.0",
|
||||||
|
"supported_formats": ["pcm16", "pcm32", "float32"],
|
||||||
|
"supported_sample_rates": [8000, 16000, 22050, 44100, 48000],
|
||||||
|
"max_audio_chunk_size": 1024 * 1024 # 1MB
|
||||||
|
},
|
||||||
|
"default_config": self.default_config.to_dict()
|
||||||
|
}
|
||||||
|
|
||||||
|
await websocket.send(json.dumps(welcome_msg))
|
||||||
|
|
||||||
|
async def _cleanup_connection(self, websocket: websockets.WebSocketServerProtocol, connection_id: str) -> None:
|
||||||
|
"""Clean up connection resources."""
|
||||||
|
# Remove from active connections
|
||||||
|
self.active_connections.discard(websocket)
|
||||||
|
self.stats["connections_active"] -= 1
|
||||||
|
|
||||||
|
# Stop processor if exists
|
||||||
|
if connection_id in self.connection_processors:
|
||||||
|
processor = self.connection_processors[connection_id]
|
||||||
|
processor.stop()
|
||||||
|
del self.connection_processors[connection_id]
|
||||||
|
|
||||||
|
# Remove config
|
||||||
|
self.connection_configs.pop(connection_id, None)
|
||||||
|
|
||||||
|
async def _handle_message(self, websocket: websockets.WebSocketServerProtocol, connection_id: str, message: str) -> None:
|
||||||
|
"""Handle incoming WebSocket message."""
|
||||||
|
self.stats["messages_received"] += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse JSON message
|
||||||
|
if isinstance(message, bytes):
|
||||||
|
# Handle binary audio data
|
||||||
|
await self._handle_binary_audio(websocket, connection_id, message)
|
||||||
|
return
|
||||||
|
|
||||||
|
data = json.loads(message)
|
||||||
|
message_type = data.get("type")
|
||||||
|
|
||||||
|
if message_type == "configure":
|
||||||
|
await self._handle_configure(websocket, connection_id, data)
|
||||||
|
elif message_type == "start_stream":
|
||||||
|
await self._handle_start_stream(websocket, connection_id, data)
|
||||||
|
elif message_type == "stop_stream":
|
||||||
|
await self._handle_stop_stream(websocket, connection_id, data)
|
||||||
|
elif message_type == "audio_data":
|
||||||
|
await self._handle_audio_data(websocket, connection_id, data)
|
||||||
|
elif message_type == "get_status":
|
||||||
|
await self._handle_get_status(websocket, connection_id, data)
|
||||||
|
elif message_type == "get_results":
|
||||||
|
await self._handle_get_results(websocket, connection_id, data)
|
||||||
|
else:
|
||||||
|
await self._send_error(websocket, "Unknown message type", f"Unsupported message type: {message_type}")
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
await self._send_error(websocket, "Invalid JSON", str(e))
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error handling message from {connection_id}: {e}")
|
||||||
|
await self._send_error(websocket, "Message processing error", str(e))
|
||||||
|
|
||||||
|
async def _handle_configure(self, websocket: websockets.WebSocketServerProtocol, connection_id: str, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle stream configuration."""
|
||||||
|
try:
|
||||||
|
config_data = data.get("config", {})
|
||||||
|
config = StreamConfig.from_dict({**self.default_config.to_dict(), **config_data})
|
||||||
|
|
||||||
|
self.connection_configs[connection_id] = config
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"type": "configuration_updated",
|
||||||
|
"config": config.to_dict(),
|
||||||
|
"timestamp": time.time()
|
||||||
|
}
|
||||||
|
|
||||||
|
await websocket.send(json.dumps(response))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await self._send_error(websocket, "Configuration error", str(e))
|
||||||
|
|
||||||
|
async def _handle_start_stream(self, websocket: websockets.WebSocketServerProtocol, connection_id: str, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle stream start request."""
|
||||||
|
try:
|
||||||
|
# Get or create config
|
||||||
|
config = self.connection_configs.get(connection_id, self.default_config)
|
||||||
|
|
||||||
|
# Create result callback
|
||||||
|
def result_callback(result: TranscriptionResult):
|
||||||
|
asyncio.create_task(self._send_transcription_result(websocket, result))
|
||||||
|
|
||||||
|
def error_callback(error: Exception):
|
||||||
|
asyncio.create_task(self._send_error(websocket, "Processing error", str(error)))
|
||||||
|
|
||||||
|
# Create and start processor
|
||||||
|
processor = StreamProcessor(
|
||||||
|
config=config,
|
||||||
|
model=self.model_cache.get(config.model_name),
|
||||||
|
result_callback=result_callback,
|
||||||
|
error_callback=error_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
if processor.start():
|
||||||
|
self.connection_processors[connection_id] = processor
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"type": "stream_started",
|
||||||
|
"connection_id": connection_id,
|
||||||
|
"config": config.to_dict(),
|
||||||
|
"timestamp": time.time()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
response = {
|
||||||
|
"type": "error",
|
||||||
|
"error": "Failed to start stream processor",
|
||||||
|
"timestamp": time.time()
|
||||||
|
}
|
||||||
|
|
||||||
|
await websocket.send(json.dumps(response))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await self._send_error(websocket, "Stream start error", str(e))
|
||||||
|
|
||||||
|
async def _handle_stop_stream(self, websocket: websockets.WebSocketServerProtocol, connection_id: str, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle stream stop request."""
|
||||||
|
try:
|
||||||
|
if connection_id in self.connection_processors:
|
||||||
|
processor = self.connection_processors[connection_id]
|
||||||
|
processor.stop()
|
||||||
|
del self.connection_processors[connection_id]
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"type": "stream_stopped",
|
||||||
|
"connection_id": connection_id,
|
||||||
|
"timestamp": time.time()
|
||||||
|
}
|
||||||
|
|
||||||
|
await websocket.send(json.dumps(response))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await self._send_error(websocket, "Stream stop error", str(e))
|
||||||
|
|
||||||
|
async def _handle_audio_data(self, websocket: websockets.WebSocketServerProtocol, connection_id: str, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle audio data from JSON message."""
|
||||||
|
try:
|
||||||
|
processor = self.connection_processors.get(connection_id)
|
||||||
|
if not processor:
|
||||||
|
await self._send_error(websocket, "Stream not started", "Start stream before sending audio")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Decode audio data
|
||||||
|
audio_format = data.get("format", "pcm16")
|
||||||
|
audio_b64 = data.get("audio")
|
||||||
|
|
||||||
|
if not audio_b64:
|
||||||
|
await self._send_error(websocket, "Missing audio data", "Audio data field is required")
|
||||||
|
return
|
||||||
|
|
||||||
|
audio_bytes = base64.b64decode(audio_b64)
|
||||||
|
audio_data = self._decode_audio(audio_bytes, audio_format)
|
||||||
|
|
||||||
|
processor.add_audio(audio_data)
|
||||||
|
self.stats["audio_bytes_processed"] += len(audio_bytes)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await self._send_error(websocket, "Audio processing error", str(e))
|
||||||
|
|
||||||
|
async def _handle_binary_audio(self, websocket: websockets.WebSocketServerProtocol, connection_id: str, audio_bytes: bytes) -> None:
|
||||||
|
"""Handle binary audio data."""
|
||||||
|
try:
|
||||||
|
processor = self.connection_processors.get(connection_id)
|
||||||
|
if not processor:
|
||||||
|
return # Silently ignore if no processor
|
||||||
|
|
||||||
|
# Assume PCM16 format for binary data
|
||||||
|
audio_data = self._decode_audio(audio_bytes, "pcm16")
|
||||||
|
processor.add_audio(audio_data)
|
||||||
|
self.stats["audio_bytes_processed"] += len(audio_bytes)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error processing binary audio from {connection_id}: {e}")
|
||||||
|
|
||||||
|
async def _handle_get_status(self, websocket: websockets.WebSocketServerProtocol, connection_id: str, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle status request."""
|
||||||
|
try:
|
||||||
|
processor_status = {}
|
||||||
|
if connection_id in self.connection_processors:
|
||||||
|
processor = self.connection_processors[connection_id]
|
||||||
|
processor_status = processor.get_status()
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"type": "status",
|
||||||
|
"connection_id": connection_id,
|
||||||
|
"processor": processor_status,
|
||||||
|
"server_stats": self.stats.copy(),
|
||||||
|
"timestamp": time.time()
|
||||||
|
}
|
||||||
|
|
||||||
|
await websocket.send(json.dumps(response))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await self._send_error(websocket, "Status error", str(e))
|
||||||
|
|
||||||
|
async def _handle_get_results(self, websocket: websockets.WebSocketServerProtocol, connection_id: str, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle results request."""
|
||||||
|
try:
|
||||||
|
processor = self.connection_processors.get(connection_id)
|
||||||
|
if not processor:
|
||||||
|
await self._send_error(websocket, "Stream not started", "No active stream")
|
||||||
|
return
|
||||||
|
|
||||||
|
since_segment = data.get("since_segment", 0)
|
||||||
|
results = processor.get_results(since_segment)
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"type": "results",
|
||||||
|
"connection_id": connection_id,
|
||||||
|
"results": [result.to_dict() for result in results],
|
||||||
|
"total_results": len(results),
|
||||||
|
"timestamp": time.time()
|
||||||
|
}
|
||||||
|
|
||||||
|
await websocket.send(json.dumps(response))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await self._send_error(websocket, "Results error", str(e))
|
||||||
|
|
||||||
|
def _decode_audio(self, audio_bytes: bytes, audio_format: str) -> list:
|
||||||
|
"""Decode audio bytes to list of samples."""
|
||||||
|
if audio_format == "pcm16":
|
||||||
|
# 16-bit PCM
|
||||||
|
samples = struct.unpack(f"<{len(audio_bytes)//2}h", audio_bytes)
|
||||||
|
return [s / 32768.0 for s in samples] # Normalize to [-1, 1]
|
||||||
|
elif audio_format == "pcm32":
|
||||||
|
# 32-bit PCM
|
||||||
|
samples = struct.unpack(f"<{len(audio_bytes)//4}i", audio_bytes)
|
||||||
|
return [s / 2147483648.0 for s in samples] # Normalize to [-1, 1]
|
||||||
|
elif audio_format == "float32":
|
||||||
|
# 32-bit float
|
||||||
|
samples = struct.unpack(f"<{len(audio_bytes)//4}f", audio_bytes)
|
||||||
|
return list(samples)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported audio format: {audio_format}")
|
||||||
|
|
||||||
|
async def _send_transcription_result(self, websocket: websockets.WebSocketServerProtocol, result: TranscriptionResult) -> None:
|
||||||
|
"""Send transcription result to client."""
|
||||||
|
try:
|
||||||
|
message = {
|
||||||
|
"type": "transcription_result",
|
||||||
|
"result": result.to_dict(),
|
||||||
|
"timestamp": time.time()
|
||||||
|
}
|
||||||
|
|
||||||
|
await websocket.send(json.dumps(message))
|
||||||
|
self.stats["messages_sent"] += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error sending transcription result: {e}")
|
||||||
|
|
||||||
|
async def _send_error(self, websocket: websockets.WebSocketServerProtocol, error_type: str, message: str) -> None:
|
||||||
|
"""Send error message to client."""
|
||||||
|
try:
|
||||||
|
error_msg = {
|
||||||
|
"type": "error",
|
||||||
|
"error_type": error_type,
|
||||||
|
"message": message,
|
||||||
|
"timestamp": time.time()
|
||||||
|
}
|
||||||
|
|
||||||
|
await websocket.send(json.dumps(error_msg))
|
||||||
|
self.stats["messages_sent"] += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error sending error message: {e}")
|
||||||
|
|
||||||
|
def get_server_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get server statistics."""
|
||||||
|
stats = self.stats.copy()
|
||||||
|
if stats["start_time"]:
|
||||||
|
stats["uptime_seconds"] = time.time() - stats["start_time"]
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def run_websocket_server(
|
||||||
|
host: str = "localhost",
|
||||||
|
port: int = 8765,
|
||||||
|
config: Optional[StreamConfig] = None,
|
||||||
|
model_cache: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Convenience function to run the WebSocket server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: Server host
|
||||||
|
port: Server port
|
||||||
|
config: Default stream configuration
|
||||||
|
model_cache: Model cache for performance
|
||||||
|
"""
|
||||||
|
server = WhisperWebSocketServer(host, port, config, model_cache)
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
try:
|
||||||
|
await server.start_server()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\\nShutting down server...")
|
||||||
|
await server.stop_server()
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Whisper WebSocket Server")
|
||||||
|
parser.add_argument("--host", default="localhost", help="Server host")
|
||||||
|
parser.add_argument("--port", type=int, default=8765, help="Server port")
|
||||||
|
parser.add_argument("--model", default="base", help="Whisper model name")
|
||||||
|
parser.add_argument("--device", default="auto", help="Device for inference")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO if args.verbose else logging.WARNING,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create default config
|
||||||
|
default_config = StreamConfig(
|
||||||
|
model_name=args.model,
|
||||||
|
device=args.device
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Starting Whisper WebSocket server on ws://{args.host}:{args.port}")
|
||||||
|
print(f"Model: {args.model}, Device: {args.device}")
|
||||||
|
|
||||||
|
run_websocket_server(args.host, args.port, default_config)
|
||||||
Loading…
x
Reference in New Issue
Block a user