diff --git a/farsi_transcriber/models/whisper_transcriber.py b/farsi_transcriber/models/whisper_transcriber.py new file mode 100644 index 0000000..8310cca --- /dev/null +++ b/farsi_transcriber/models/whisper_transcriber.py @@ -0,0 +1,226 @@ +""" +Whisper Transcriber Module + +Handles Farsi audio/video transcription using OpenAI's Whisper model. +""" + +import os +import warnings +from pathlib import Path +from typing import Dict, List, Optional + +import torch +import whisper + + +class FarsiTranscriber: + """ + Wrapper around Whisper model for Farsi transcription. + + Supports both audio and video files, with word-level timestamp extraction. + """ + + # Supported audio formats + AUDIO_FORMATS = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac", ".wma"} + + # Supported video formats + VIDEO_FORMATS = {".mp4", ".mkv", ".mov", ".webm", ".avi", ".flv", ".wmv"} + + # Language code for Farsi/Persian + FARSI_LANGUAGE = "fa" + + def __init__(self, model_name: str = "medium", device: Optional[str] = None): + """ + Initialize Farsi Transcriber. + + Args: + model_name: Whisper model size ('tiny', 'base', 'small', 'medium', 'large') + device: Device to use ('cuda', 'cpu'). Auto-detect if None. + """ + self.model_name = model_name + + # Auto-detect device + if device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device + + print(f"Using device: {self.device}") + + # Load model + print(f"Loading Whisper model: {model_name}...") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.model = whisper.load_model(model_name, device=self.device) + + print(f"Model loaded successfully") + + def transcribe( + self, + file_path: str, + language: str = FARSI_LANGUAGE, + verbose: bool = False, + ) -> Dict: + """ + Transcribe an audio or video file in Farsi. + + Args: + file_path: Path to audio or video file + language: Language code (default: 'fa' for Farsi) + verbose: Whether to print progress + + Returns: + Dictionary with transcription results including word-level segments + """ + file_path = Path(file_path) + + # Validate file exists + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + + # Check format is supported + if not self._is_supported_format(file_path): + raise ValueError( + f"Unsupported format: {file_path.suffix}. " + f"Supported: {self.AUDIO_FORMATS | self.VIDEO_FORMATS}" + ) + + # Perform transcription + print(f"Transcribing: {file_path.name}") + + result = self.model.transcribe( + str(file_path), + language=language, + verbose=verbose, + ) + + # Enhance result with word-level segments + enhanced_result = self._enhance_with_word_segments(result) + + return enhanced_result + + def _is_supported_format(self, file_path: Path) -> bool: + """Check if file format is supported.""" + suffix = file_path.suffix.lower() + return suffix in (self.AUDIO_FORMATS | self.VIDEO_FORMATS) + + def _enhance_with_word_segments(self, result: Dict) -> Dict: + """ + Enhance transcription result with word-level timing information. + + Args: + result: Whisper transcription result + + Returns: + Enhanced result with word-level segments + """ + enhanced_segments = [] + + for segment in result.get("segments", []): + # Extract word-level timing if available + word_segments = self._extract_word_segments(segment) + + enhanced_segment = { + "id": segment.get("id"), + "start": segment.get("start"), + "end": segment.get("end"), + "text": segment.get("text", ""), + "words": word_segments, + } + enhanced_segments.append(enhanced_segment) + + result["segments"] = enhanced_segments + return result + + def _extract_word_segments(self, segment: Dict) -> List[Dict]: + """ + Extract word-level timing from a segment. + + Args: + segment: Whisper segment with text + + Returns: + List of word dictionaries with timing information + """ + text = segment.get("text", "").strip() + if not text: + return [] + + # For now, return simple word list + # Whisper v3 includes word-level details in some configurations + start_time = segment.get("start", 0) + end_time = segment.get("end", 0) + duration = end_time - start_time + + words = text.split() + if not words: + return [] + + # Distribute time evenly across words (simple approach) + # More sophisticated timing can be extracted from Whisper's internal data + word_duration = duration / len(words) if words else 0 + + word_segments = [] + for i, word in enumerate(words): + word_start = start_time + (i * word_duration) + word_end = word_start + word_duration + + word_segments.append( + { + "word": word, + "start": word_start, + "end": word_end, + } + ) + + return word_segments + + def format_result_for_display( + self, result: Dict, include_timestamps: bool = True + ) -> str: + """ + Format transcription result for display in UI. + + Args: + result: Transcription result + include_timestamps: Whether to include timestamps + + Returns: + Formatted text string + """ + lines = [] + + for segment in result.get("segments", []): + text = segment.get("text", "").strip() + if not text: + continue + + if include_timestamps: + start = segment.get("start", 0) + end = segment.get("end", 0) + timestamp = f"[{self._format_time(start)} - {self._format_time(end)}]" + lines.append(f"{timestamp}\n{text}\n") + else: + lines.append(text) + + return "\n".join(lines) + + @staticmethod + def _format_time(seconds: float) -> str: + """Format seconds to HH:MM:SS format.""" + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + milliseconds = int((seconds % 1) * 1000) + + return f"{hours:02d}:{minutes:02d}:{secs:02d}.{milliseconds:03d}" + + def get_device_info(self) -> str: + """Get information about current device and model.""" + return ( + f"Model: {self.model_name} | " + f"Device: {self.device.upper()} | " + f"VRAM: {torch.cuda.get_device_properties(self.device).total_memory / 1e9:.1f}GB " + if self.device == "cuda" + else f"Model: {self.model_name} | Device: {self.device.upper()}" + ) diff --git a/farsi_transcriber/ui/main_window.py b/farsi_transcriber/ui/main_window.py index 63d2941..e8ad3cd 100644 --- a/farsi_transcriber/ui/main_window.py +++ b/farsi_transcriber/ui/main_window.py @@ -22,6 +22,8 @@ from PyQt6.QtWidgets import ( ) from PyQt6.QtGui import QFont +from farsi_transcriber.models.whisper_transcriber import FarsiTranscriber + class TranscriptionWorker(QThread): """Worker thread for transcription to prevent UI freezing""" @@ -35,22 +37,33 @@ class TranscriptionWorker(QThread): super().__init__() self.file_path = file_path self.model_name = model_name + self.transcriber = None def run(self): """Run transcription in background thread""" try: - # TODO: Import and use Whisper model - # This will be implemented in Phase 3 + # Initialize Whisper transcriber self.progress_update.emit("Loading Whisper model...") - self.progress_update.emit(f"Transcribing: {Path(self.file_path).name}") - self.progress_update.emit("Transcription complete!") + self.transcriber = FarsiTranscriber(model_name=self.model_name) - # Placeholder result structure (will be replaced with real data in Phase 3) - result = { - "text": "نتایج تجزیه و تحلیل صوتی اینجا نمایش داده خواهند شد", - "segments": [], - } - self.transcription_complete.emit(result) + # Perform transcription + self.progress_update.emit(f"Transcribing: {Path(self.file_path).name}") + result = self.transcriber.transcribe(self.file_path) + + # Format result for display with timestamps + display_text = self.transcriber.format_result_for_display(result) + + # Add full text for export + result["full_text"] = result.get("text", "") + + self.progress_update.emit("Transcription complete!") + self.transcription_complete.emit( + { + "text": display_text, + "segments": result.get("segments", []), + "full_text": result.get("text", ""), + } + ) except Exception as e: self.error_occurred.emit(f"Error: {str(e)}") @@ -70,6 +83,7 @@ class MainWindow(QMainWindow): super().__init__() self.selected_file = None self.transcription_worker = None + self.last_result = None self.init_ui() def init_ui(self):