# /// script # requires-python = ">=3.9" # dependencies = [ # "torch>=2.1.0", # "torchaudio>=2.1.0", # "transformers>=4.40.0", # "soundfile>=0.12.1", # "silero-vad>=5.1.2", # ] # /// # Usage: uv run granite-speech.py --help # usage: granite-speech.py [-h] -i INPUT [-o OUTPUT] [--max-new-tokens MAX_NEW_TOKENS] [--debug] [--vad-min-silence-ms VAD_MIN_SILENCE_MS] # [--vad-speech-pad-ms VAD_SPEECH_PAD_MS] [--vad-max-segment-secs VAD_MAX_SEGMENT_SECS] [--output-segment-timestamps] # prompt # # Transcribe audio with IBM Granite Speech using VAD-based segmentation. # # positional arguments: # prompt Instruction prompt, e.g. 'Transcribe the audio verbatim.' # # options: # -h, --help show this help message and exit # -i INPUT, --input INPUT # Path to audio file (WAV/MP3/FLAC). # -o OUTPUT, --output OUTPUT # Path to save transcript (.md appended if absent). # --max-new-tokens MAX_NEW_TOKENS # Maximum new tokens to generate per VAD segment (default: 512). # --debug Print raw generation diagnostics for prompt wiring and returned tokens. # --vad-min-silence-ms VAD_MIN_SILENCE_MS # Minimum silence duration in ms for VAD segmentation (default: 300). # --vad-speech-pad-ms VAD_SPEECH_PAD_MS # Padding in ms added around VAD speech segments (default: 200). # --vad-max-segment-secs VAD_MAX_SEGMENT_SECS # Maximum duration in seconds for a single VAD speech segment after splitting (default: 30.0). # --output-segment-timestamps # Include per-segment timestamps in the saved transcript output. import argparse import os import sys import warnings import soundfile as sf import torch import torchaudio.functional as F from silero_vad import get_speech_timestamps, load_silero_vad from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor warnings.filterwarnings( "ignore", message="An output with one or more elements was resized", category=UserWarning, ) SAMPLE_RATE = 16_000 def load_audio(path: str) -> torch.Tensor: """Load audio, convert to mono, resample to 16 kHz, return shape [1, N].""" if not os.path.exists(path): print(f"Error: audio file not found: {path}", file=sys.stderr) sys.exit(1) data, sr = sf.read(path, dtype="float32", always_2d=True) wav = torch.from_numpy(data.T) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) if sr != SAMPLE_RATE: wav = F.resample(wav, orig_freq=sr, new_freq=SAMPLE_RATE) return wav def segment_audio_with_vad( wav: torch.Tensor, *, min_silence_ms: int = 300, speech_pad_ms: int = 200, max_segment_secs: float = 30.0, ) -> list[tuple[int, int]]: """ Segment audio with Silero VAD and return [start_sample, end_sample) regions. Long speech regions are split to keep transcription memory-bounded. """ audio = wav[0].cpu() vad_model = load_silero_vad() timestamps = get_speech_timestamps( audio, vad_model, sampling_rate=SAMPLE_RATE, min_silence_duration_ms=min_silence_ms, speech_pad_ms=speech_pad_ms, return_seconds=False, ) if not timestamps: return [(0, int(audio.shape[0]))] max_segment_samples = int(max_segment_secs * SAMPLE_RATE) segments: list[tuple[int, int]] = [] for item in timestamps: start = int(item["start"]) end = int(item["end"]) while end - start > max_segment_samples: split_end = start + max_segment_samples segments.append((start, split_end)) start = split_end if end > start: segments.append((start, end)) return segments def format_timestamp(seconds: float) -> str: """Format seconds as HH:MM:SS.m.""" total_tenths = int(round(seconds * 10)) hours = total_tenths // 36_000 minutes = (total_tenths % 36_000) // 600 secs = (total_tenths % 600) // 10 tenths = total_tenths % 10 return f"{hours:02d}:{minutes:02d}:{secs:02d}.{tenths}" def merge_transcript_segments(segments: list[str], max_check: int = 400) -> str: """Merge adjacent text segments with simple suffix/prefix overlap removal.""" def longest_suffix_prefix(a: str, b: str) -> int: if not a or not b: return 0 limit = min(len(a), len(b), max_check) for size in range(limit, 0, -1): if a[-size:] == b[:size]: return size return 0 merged = "" for segment in segments: segment = segment.strip() if not segment: continue if not merged: merged = segment continue overlap = longest_suffix_prefix(merged, segment) if overlap > 0: merged += segment[overlap:] else: merged += " " + segment return merged.strip() def transcribe_audio_segment( *, model, processor, tokenizer, audio_token_id: int, prefix_ids: torch.Tensor, prefix_mask: torch.Tensor, suffix_ids: torch.Tensor, suffix_mask: torch.Tensor, wav_segment: torch.Tensor, device: str, dtype: torch.dtype, max_new_tokens: int, debug: bool, segment_label: str, ) -> str: """ Transcribe a single VAD speech segment by encoding its audio features and merging them into a placeholder-token sequence. """ seg_inputs = processor(text="<|audio|>", audio=wav_segment, return_tensors="pt") seg_features = seg_inputs["input_features"].to(device=device, dtype=dtype) with torch.no_grad(): audio_out = model.get_audio_features(seg_features) audio_embed = audio_out.pooler_output # [1, A, H] audio_len = audio_embed.shape[1] audio_placeholder_ids = torch.full( (1, audio_len), audio_token_id, dtype=prefix_ids.dtype, device=device, ) audio_placeholder_mask = torch.ones( (1, audio_len), dtype=prefix_mask.dtype, device=device, ) seg_input_ids = torch.cat([prefix_ids, audio_placeholder_ids, suffix_ids], dim=1) seg_mask = torch.cat([prefix_mask, audio_placeholder_mask, suffix_mask], dim=1) embed_fn = model.get_input_embeddings() llm_lookup_ids = torch.where( seg_input_ids == audio_token_id, torch.zeros_like(seg_input_ids), seg_input_ids, ) seg_embeds = embed_fn(llm_lookup_ids) special_audio_mask = ( (seg_input_ids == audio_token_id).unsqueeze(-1).expand_as(seg_embeds) ) seg_embeds = seg_embeds.masked_scatter( special_audio_mask, audio_embed.to(device=seg_embeds.device, dtype=seg_embeds.dtype), ) if debug: print( f"[debug] {segment_label}: audio_embed_len={audio_len} " f"prompt_len={seg_input_ids.shape[1]} attention_mask_sum={int(seg_mask.sum().item())}" ) with torch.no_grad(): output = model.generate( inputs_embeds=seg_embeds, attention_mask=seg_mask, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, do_sample=False, num_beams=1, no_repeat_ngram_size=5, return_dict_in_generate=True, output_scores=debug, ) sequences = output.sequences if sequences.shape[1] == 0: if debug: print(f"[debug] {segment_label}: empty returned sequence") return "" returned_ids = sequences[0].to(dtype=torch.int32).tolist() returned_text_raw = tokenizer.decode(returned_ids, skip_special_tokens=False) returned_text = tokenizer.decode(returned_ids, skip_special_tokens=True).strip() if debug: print(f"[debug] {segment_label}: returned_token_count={len(returned_ids)}") print(f"[debug] {segment_label}: first_token_ids={returned_ids[:32]}") print(f"[debug] {segment_label}: last_token_ids={returned_ids[-32:]}") print(f"[debug] {segment_label}: raw_decoded={returned_text_raw!r}") if hasattr(output, "scores") and output.scores: print( f"[debug] {segment_label}: generated_steps_from_scores={len(output.scores)}" ) elif returned_text: print(returned_text, flush=True, end=" ") return returned_text def main() -> None: parser = argparse.ArgumentParser( description="Transcribe audio with IBM Granite Speech using VAD-based segmentation." ) parser.add_argument( "prompt", help="Instruction prompt, e.g. 'Transcribe the audio verbatim.'", ) parser.add_argument( "-i", "--input", required=True, help="Path to audio file (WAV/MP3/FLAC).", ) parser.add_argument( "-o", "--output", default=None, help="Path to save transcript (.md appended if absent).", ) parser.add_argument( "--max-new-tokens", type=int, default=512, help="Maximum new tokens to generate per VAD segment (default: 512).", ) parser.add_argument( "--debug", action="store_true", help="Print raw generation diagnostics for prompt wiring and returned tokens.", ) parser.add_argument( "--vad-min-silence-ms", type=int, default=300, help="Minimum silence duration in ms for VAD segmentation (default: 300).", ) parser.add_argument( "--vad-speech-pad-ms", type=int, default=200, help="Padding in ms added around VAD speech segments (default: 200).", ) parser.add_argument( "--vad-max-segment-secs", type=float, default=30.0, help="Maximum duration in seconds for a single VAD speech segment after splitting (default: 30.0).", ) parser.add_argument( "--output-segment-timestamps", action="store_true", help="Include per-segment timestamps in the saved transcript output.", ) args = parser.parse_args() if torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: print("Warning: neither MPS nor CUDA available, falling back to CPU.") device = "cpu" dtype = torch.bfloat16 model_name = "ibm-granite/granite-4.0-1b-speech" audio_token = "<|audio|>" print("Loading processor …") processor = AutoProcessor.from_pretrained(model_name) tokenizer = processor.tokenizer print("Loading model weights …") model = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, dtype=dtype, ) model.to(device) model.eval() print(f"Loading audio: {args.input}") wav = load_audio(args.input) duration = wav.shape[-1] / SAMPLE_RATE print(f"Loaded {wav.shape[-1]} samples ({duration:.1f}s)") user_content = ( args.prompt if audio_token in args.prompt else f"{audio_token}{args.prompt}" ) chat = [{"role": "user", "content": user_content}] prompt = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True ) if args.debug: print("\n[debug] prompt text:") print(prompt) audio_token_id: int = tokenizer.convert_tokens_to_ids(audio_token) text_inputs = tokenizer(prompt, return_tensors="pt") input_ids = text_inputs["input_ids"].to(device) attention_mask = text_inputs["attention_mask"].to(device) audio_positions = (input_ids[0] == audio_token_id).nonzero(as_tuple=True)[0] if len(audio_positions) == 0: print( "Warning: <|audio|> token not found in prompt; audio will not be injected.", file=sys.stderr, ) pos = int(audio_positions[0]) if len(audio_positions) > 0 else input_ids.shape[1] if args.debug: print( f"[debug] audio token id: {audio_token_id} | " f"audio token positions in prompt ids: {audio_positions.tolist()}" ) prefix_ids = input_ids[:, :pos] prefix_mask = attention_mask[:, :pos] suffix_ids = input_ids[:, pos + 1 :] suffix_mask = attention_mask[:, pos + 1 :] if args.debug: prefix_text = tokenizer.decode( prefix_ids[0].to(dtype=torch.int32).tolist(), skip_special_tokens=False, ) suffix_text = tokenizer.decode( suffix_ids[0].to(dtype=torch.int32).tolist(), skip_special_tokens=False, ) print(f"[debug] prefix text: {prefix_text!r}") print(f"[debug] suffix text: {suffix_text!r}") print( f"[debug] prefix ids len: {prefix_ids.shape[1]} | " f"suffix ids len: {suffix_ids.shape[1]}" ) vad_segments = segment_audio_with_vad( wav, min_silence_ms=args.vad_min_silence_ms, speech_pad_ms=args.vad_speech_pad_ms, max_segment_secs=args.vad_max_segment_secs, ) total_vad_duration = sum(end - start for start, end in vad_segments) / SAMPLE_RATE print( f"\nVAD: {len(vad_segments)} speech segment(s) covering " f"{total_vad_duration:.1f}s of audio" ) all_segments: list[str] = [] timestamped_segments: list[tuple[float, float, str]] = [] for i, (seg_start, seg_end) in enumerate(vad_segments, start=1): seg_wav = wav[:, seg_start:seg_end] seg_start_s = seg_start / SAMPLE_RATE seg_end_s = seg_end / SAMPLE_RATE if args.debug: print( f"\n--- Segment {i}/{len(vad_segments)} " f"({seg_start_s:.1f}–{seg_end_s:.1f}s) ---", flush=True, ) seg_text = transcribe_audio_segment( model=model, processor=processor, tokenizer=tokenizer, audio_token_id=audio_token_id, prefix_ids=prefix_ids, prefix_mask=prefix_mask, suffix_ids=suffix_ids, suffix_mask=suffix_mask, wav_segment=seg_wav, device=device, dtype=dtype, max_new_tokens=args.max_new_tokens, debug=args.debug, segment_label=f"segment {i}", ) if seg_text: all_segments.append(seg_text) timestamped_segments.append((seg_start_s, seg_end_s, seg_text)) print() print(f"\nGeneration complete ({len(all_segments)} transcribed segment(s)).") if args.output: generated_text = merge_transcript_segments(all_segments) out_path = args.output if args.output.endswith(".md") else args.output + ".md" transcript_body = generated_text if args.output_segment_timestamps: transcript_body = "\n\n".join( f"[{format_timestamp(start)} - {format_timestamp(end)}] {text}" for start, end, text in timestamped_segments ) with open(out_path, "w", encoding="utf-8") as f: f.write(f"# Prompt\n\n{args.prompt}\n\n# Transcript\n\n{transcript_body}\n") print(f"Saved to: {out_path}") if __name__ == "__main__": main()