Created
March 18, 2026 08:54
-
-
Save mseri/df3208da68011a4eced69e5f3d2b97cd to your computer and use it in GitHub Desktop.
Pure PyTorch implementation of granite speech and simple cli (using VAD to reduce memory footprint). Implemented with substantial help from copilot
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.9" | |
| # dependencies = [ | |
| # "torch>=2.1.0", | |
| # "torchaudio>=2.1.0", | |
| # "transformers>=4.40.0", | |
| # "soundfile>=0.12.1", | |
| # "silero-vad>=5.1.2", | |
| # ] | |
| # /// | |
| # Usage: uv run granite-speech.py --help | |
| # usage: granite-speech.py [-h] -i INPUT [-o OUTPUT] [--max-new-tokens MAX_NEW_TOKENS] [--debug] [--vad-min-silence-ms VAD_MIN_SILENCE_MS] | |
| # [--vad-speech-pad-ms VAD_SPEECH_PAD_MS] [--vad-max-segment-secs VAD_MAX_SEGMENT_SECS] [--output-segment-timestamps] | |
| # prompt | |
| # | |
| # Transcribe audio with IBM Granite Speech using VAD-based segmentation. | |
| # | |
| # positional arguments: | |
| # prompt Instruction prompt, e.g. 'Transcribe the audio verbatim.' | |
| # | |
| # options: | |
| # -h, --help show this help message and exit | |
| # -i INPUT, --input INPUT | |
| # Path to audio file (WAV/MP3/FLAC). | |
| # -o OUTPUT, --output OUTPUT | |
| # Path to save transcript (.md appended if absent). | |
| # --max-new-tokens MAX_NEW_TOKENS | |
| # Maximum new tokens to generate per VAD segment (default: 512). | |
| # --debug Print raw generation diagnostics for prompt wiring and returned tokens. | |
| # --vad-min-silence-ms VAD_MIN_SILENCE_MS | |
| # Minimum silence duration in ms for VAD segmentation (default: 300). | |
| # --vad-speech-pad-ms VAD_SPEECH_PAD_MS | |
| # Padding in ms added around VAD speech segments (default: 200). | |
| # --vad-max-segment-secs VAD_MAX_SEGMENT_SECS | |
| # Maximum duration in seconds for a single VAD speech segment after splitting (default: 30.0). | |
| # --output-segment-timestamps | |
| # Include per-segment timestamps in the saved transcript output. | |
| import argparse | |
| import os | |
| import sys | |
| import warnings | |
| import soundfile as sf | |
| import torch | |
| import torchaudio.functional as F | |
| from silero_vad import get_speech_timestamps, load_silero_vad | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor | |
| warnings.filterwarnings( | |
| "ignore", | |
| message="An output with one or more elements was resized", | |
| category=UserWarning, | |
| ) | |
| SAMPLE_RATE = 16_000 | |
| def load_audio(path: str) -> torch.Tensor: | |
| """Load audio, convert to mono, resample to 16 kHz, return shape [1, N].""" | |
| if not os.path.exists(path): | |
| print(f"Error: audio file not found: {path}", file=sys.stderr) | |
| sys.exit(1) | |
| data, sr = sf.read(path, dtype="float32", always_2d=True) | |
| wav = torch.from_numpy(data.T) | |
| if wav.shape[0] > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| if sr != SAMPLE_RATE: | |
| wav = F.resample(wav, orig_freq=sr, new_freq=SAMPLE_RATE) | |
| return wav | |
| def segment_audio_with_vad( | |
| wav: torch.Tensor, | |
| *, | |
| min_silence_ms: int = 300, | |
| speech_pad_ms: int = 200, | |
| max_segment_secs: float = 30.0, | |
| ) -> list[tuple[int, int]]: | |
| """ | |
| Segment audio with Silero VAD and return [start_sample, end_sample) regions. | |
| Long speech regions are split to keep transcription memory-bounded. | |
| """ | |
| audio = wav[0].cpu() | |
| vad_model = load_silero_vad() | |
| timestamps = get_speech_timestamps( | |
| audio, | |
| vad_model, | |
| sampling_rate=SAMPLE_RATE, | |
| min_silence_duration_ms=min_silence_ms, | |
| speech_pad_ms=speech_pad_ms, | |
| return_seconds=False, | |
| ) | |
| if not timestamps: | |
| return [(0, int(audio.shape[0]))] | |
| max_segment_samples = int(max_segment_secs * SAMPLE_RATE) | |
| segments: list[tuple[int, int]] = [] | |
| for item in timestamps: | |
| start = int(item["start"]) | |
| end = int(item["end"]) | |
| while end - start > max_segment_samples: | |
| split_end = start + max_segment_samples | |
| segments.append((start, split_end)) | |
| start = split_end | |
| if end > start: | |
| segments.append((start, end)) | |
| return segments | |
| def format_timestamp(seconds: float) -> str: | |
| """Format seconds as HH:MM:SS.m.""" | |
| total_tenths = int(round(seconds * 10)) | |
| hours = total_tenths // 36_000 | |
| minutes = (total_tenths % 36_000) // 600 | |
| secs = (total_tenths % 600) // 10 | |
| tenths = total_tenths % 10 | |
| return f"{hours:02d}:{minutes:02d}:{secs:02d}.{tenths}" | |
| def merge_transcript_segments(segments: list[str], max_check: int = 400) -> str: | |
| """Merge adjacent text segments with simple suffix/prefix overlap removal.""" | |
| def longest_suffix_prefix(a: str, b: str) -> int: | |
| if not a or not b: | |
| return 0 | |
| limit = min(len(a), len(b), max_check) | |
| for size in range(limit, 0, -1): | |
| if a[-size:] == b[:size]: | |
| return size | |
| return 0 | |
| merged = "" | |
| for segment in segments: | |
| segment = segment.strip() | |
| if not segment: | |
| continue | |
| if not merged: | |
| merged = segment | |
| continue | |
| overlap = longest_suffix_prefix(merged, segment) | |
| if overlap > 0: | |
| merged += segment[overlap:] | |
| else: | |
| merged += " " + segment | |
| return merged.strip() | |
| def transcribe_audio_segment( | |
| *, | |
| model, | |
| processor, | |
| tokenizer, | |
| audio_token_id: int, | |
| prefix_ids: torch.Tensor, | |
| prefix_mask: torch.Tensor, | |
| suffix_ids: torch.Tensor, | |
| suffix_mask: torch.Tensor, | |
| wav_segment: torch.Tensor, | |
| device: str, | |
| dtype: torch.dtype, | |
| max_new_tokens: int, | |
| debug: bool, | |
| segment_label: str, | |
| ) -> str: | |
| """ | |
| Transcribe a single VAD speech segment by encoding its audio features and | |
| merging them into a placeholder-token sequence. | |
| """ | |
| seg_inputs = processor(text="<|audio|>", audio=wav_segment, return_tensors="pt") | |
| seg_features = seg_inputs["input_features"].to(device=device, dtype=dtype) | |
| with torch.no_grad(): | |
| audio_out = model.get_audio_features(seg_features) | |
| audio_embed = audio_out.pooler_output # [1, A, H] | |
| audio_len = audio_embed.shape[1] | |
| audio_placeholder_ids = torch.full( | |
| (1, audio_len), | |
| audio_token_id, | |
| dtype=prefix_ids.dtype, | |
| device=device, | |
| ) | |
| audio_placeholder_mask = torch.ones( | |
| (1, audio_len), | |
| dtype=prefix_mask.dtype, | |
| device=device, | |
| ) | |
| seg_input_ids = torch.cat([prefix_ids, audio_placeholder_ids, suffix_ids], dim=1) | |
| seg_mask = torch.cat([prefix_mask, audio_placeholder_mask, suffix_mask], dim=1) | |
| embed_fn = model.get_input_embeddings() | |
| llm_lookup_ids = torch.where( | |
| seg_input_ids == audio_token_id, | |
| torch.zeros_like(seg_input_ids), | |
| seg_input_ids, | |
| ) | |
| seg_embeds = embed_fn(llm_lookup_ids) | |
| special_audio_mask = ( | |
| (seg_input_ids == audio_token_id).unsqueeze(-1).expand_as(seg_embeds) | |
| ) | |
| seg_embeds = seg_embeds.masked_scatter( | |
| special_audio_mask, | |
| audio_embed.to(device=seg_embeds.device, dtype=seg_embeds.dtype), | |
| ) | |
| if debug: | |
| print( | |
| f"[debug] {segment_label}: audio_embed_len={audio_len} " | |
| f"prompt_len={seg_input_ids.shape[1]} attention_mask_sum={int(seg_mask.sum().item())}" | |
| ) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| inputs_embeds=seg_embeds, | |
| attention_mask=seg_mask, | |
| max_new_tokens=max_new_tokens, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| do_sample=False, | |
| num_beams=1, | |
| no_repeat_ngram_size=5, | |
| return_dict_in_generate=True, | |
| output_scores=debug, | |
| ) | |
| sequences = output.sequences | |
| if sequences.shape[1] == 0: | |
| if debug: | |
| print(f"[debug] {segment_label}: empty returned sequence") | |
| return "" | |
| returned_ids = sequences[0].to(dtype=torch.int32).tolist() | |
| returned_text_raw = tokenizer.decode(returned_ids, skip_special_tokens=False) | |
| returned_text = tokenizer.decode(returned_ids, skip_special_tokens=True).strip() | |
| if debug: | |
| print(f"[debug] {segment_label}: returned_token_count={len(returned_ids)}") | |
| print(f"[debug] {segment_label}: first_token_ids={returned_ids[:32]}") | |
| print(f"[debug] {segment_label}: last_token_ids={returned_ids[-32:]}") | |
| print(f"[debug] {segment_label}: raw_decoded={returned_text_raw!r}") | |
| if hasattr(output, "scores") and output.scores: | |
| print( | |
| f"[debug] {segment_label}: generated_steps_from_scores={len(output.scores)}" | |
| ) | |
| elif returned_text: | |
| print(returned_text, flush=True, end=" ") | |
| return returned_text | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="Transcribe audio with IBM Granite Speech using VAD-based segmentation." | |
| ) | |
| parser.add_argument( | |
| "prompt", | |
| help="Instruction prompt, e.g. 'Transcribe the audio verbatim.'", | |
| ) | |
| parser.add_argument( | |
| "-i", | |
| "--input", | |
| required=True, | |
| help="Path to audio file (WAV/MP3/FLAC).", | |
| ) | |
| parser.add_argument( | |
| "-o", | |
| "--output", | |
| default=None, | |
| help="Path to save transcript (.md appended if absent).", | |
| ) | |
| parser.add_argument( | |
| "--max-new-tokens", | |
| type=int, | |
| default=512, | |
| help="Maximum new tokens to generate per VAD segment (default: 512).", | |
| ) | |
| parser.add_argument( | |
| "--debug", | |
| action="store_true", | |
| help="Print raw generation diagnostics for prompt wiring and returned tokens.", | |
| ) | |
| parser.add_argument( | |
| "--vad-min-silence-ms", | |
| type=int, | |
| default=300, | |
| help="Minimum silence duration in ms for VAD segmentation (default: 300).", | |
| ) | |
| parser.add_argument( | |
| "--vad-speech-pad-ms", | |
| type=int, | |
| default=200, | |
| help="Padding in ms added around VAD speech segments (default: 200).", | |
| ) | |
| parser.add_argument( | |
| "--vad-max-segment-secs", | |
| type=float, | |
| default=30.0, | |
| help="Maximum duration in seconds for a single VAD speech segment after splitting (default: 30.0).", | |
| ) | |
| parser.add_argument( | |
| "--output-segment-timestamps", | |
| action="store_true", | |
| help="Include per-segment timestamps in the saved transcript output.", | |
| ) | |
| args = parser.parse_args() | |
| if torch.backends.mps.is_available(): | |
| device = "mps" | |
| elif torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| print("Warning: neither MPS nor CUDA available, falling back to CPU.") | |
| device = "cpu" | |
| dtype = torch.bfloat16 | |
| model_name = "ibm-granite/granite-4.0-1b-speech" | |
| audio_token = "<|audio|>" | |
| print("Loading processor …") | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| tokenizer = processor.tokenizer | |
| print("Loading model weights …") | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_name, | |
| dtype=dtype, | |
| ) | |
| model.to(device) | |
| model.eval() | |
| print(f"Loading audio: {args.input}") | |
| wav = load_audio(args.input) | |
| duration = wav.shape[-1] / SAMPLE_RATE | |
| print(f"Loaded {wav.shape[-1]} samples ({duration:.1f}s)") | |
| user_content = ( | |
| args.prompt if audio_token in args.prompt else f"{audio_token}{args.prompt}" | |
| ) | |
| chat = [{"role": "user", "content": user_content}] | |
| prompt = tokenizer.apply_chat_template( | |
| chat, tokenize=False, add_generation_prompt=True | |
| ) | |
| if args.debug: | |
| print("\n[debug] prompt text:") | |
| print(prompt) | |
| audio_token_id: int = tokenizer.convert_tokens_to_ids(audio_token) | |
| text_inputs = tokenizer(prompt, return_tensors="pt") | |
| input_ids = text_inputs["input_ids"].to(device) | |
| attention_mask = text_inputs["attention_mask"].to(device) | |
| audio_positions = (input_ids[0] == audio_token_id).nonzero(as_tuple=True)[0] | |
| if len(audio_positions) == 0: | |
| print( | |
| "Warning: <|audio|> token not found in prompt; audio will not be injected.", | |
| file=sys.stderr, | |
| ) | |
| pos = int(audio_positions[0]) if len(audio_positions) > 0 else input_ids.shape[1] | |
| if args.debug: | |
| print( | |
| f"[debug] audio token id: {audio_token_id} | " | |
| f"audio token positions in prompt ids: {audio_positions.tolist()}" | |
| ) | |
| prefix_ids = input_ids[:, :pos] | |
| prefix_mask = attention_mask[:, :pos] | |
| suffix_ids = input_ids[:, pos + 1 :] | |
| suffix_mask = attention_mask[:, pos + 1 :] | |
| if args.debug: | |
| prefix_text = tokenizer.decode( | |
| prefix_ids[0].to(dtype=torch.int32).tolist(), | |
| skip_special_tokens=False, | |
| ) | |
| suffix_text = tokenizer.decode( | |
| suffix_ids[0].to(dtype=torch.int32).tolist(), | |
| skip_special_tokens=False, | |
| ) | |
| print(f"[debug] prefix text: {prefix_text!r}") | |
| print(f"[debug] suffix text: {suffix_text!r}") | |
| print( | |
| f"[debug] prefix ids len: {prefix_ids.shape[1]} | " | |
| f"suffix ids len: {suffix_ids.shape[1]}" | |
| ) | |
| vad_segments = segment_audio_with_vad( | |
| wav, | |
| min_silence_ms=args.vad_min_silence_ms, | |
| speech_pad_ms=args.vad_speech_pad_ms, | |
| max_segment_secs=args.vad_max_segment_secs, | |
| ) | |
| total_vad_duration = sum(end - start for start, end in vad_segments) / SAMPLE_RATE | |
| print( | |
| f"\nVAD: {len(vad_segments)} speech segment(s) covering " | |
| f"{total_vad_duration:.1f}s of audio" | |
| ) | |
| all_segments: list[str] = [] | |
| timestamped_segments: list[tuple[float, float, str]] = [] | |
| for i, (seg_start, seg_end) in enumerate(vad_segments, start=1): | |
| seg_wav = wav[:, seg_start:seg_end] | |
| seg_start_s = seg_start / SAMPLE_RATE | |
| seg_end_s = seg_end / SAMPLE_RATE | |
| if args.debug: | |
| print( | |
| f"\n--- Segment {i}/{len(vad_segments)} " | |
| f"({seg_start_s:.1f}–{seg_end_s:.1f}s) ---", | |
| flush=True, | |
| ) | |
| seg_text = transcribe_audio_segment( | |
| model=model, | |
| processor=processor, | |
| tokenizer=tokenizer, | |
| audio_token_id=audio_token_id, | |
| prefix_ids=prefix_ids, | |
| prefix_mask=prefix_mask, | |
| suffix_ids=suffix_ids, | |
| suffix_mask=suffix_mask, | |
| wav_segment=seg_wav, | |
| device=device, | |
| dtype=dtype, | |
| max_new_tokens=args.max_new_tokens, | |
| debug=args.debug, | |
| segment_label=f"segment {i}", | |
| ) | |
| if seg_text: | |
| all_segments.append(seg_text) | |
| timestamped_segments.append((seg_start_s, seg_end_s, seg_text)) | |
| print() | |
| print(f"\nGeneration complete ({len(all_segments)} transcribed segment(s)).") | |
| if args.output: | |
| generated_text = merge_transcript_segments(all_segments) | |
| out_path = args.output if args.output.endswith(".md") else args.output + ".md" | |
| transcript_body = generated_text | |
| if args.output_segment_timestamps: | |
| transcript_body = "\n\n".join( | |
| f"[{format_timestamp(start)} - {format_timestamp(end)}] {text}" | |
| for start, end, text in timestamped_segments | |
| ) | |
| with open(out_path, "w", encoding="utf-8") as f: | |
| f.write(f"# Prompt\n\n{args.prompt}\n\n# Transcript\n\n{transcript_body}\n") | |
| print(f"Saved to: {out_path}") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| PyTorch reference implementation of GraniteSpeech. | |
| The state-dict key hierarchy is meant to be *identical* to the HuggingFace | |
| implementation, allowing weights to be loaded with: | |
| model.load_state_dict(hf_state_dict, strict=True) | |
| GraniteSpeechForConditionalGeneration | |
| ├── encoder : GraniteSpeechCTCEncoder (Conformer) | |
| ├── projector: GraniteSpeechEncoderProjector (Blip2 Q-Former + linear) | |
| └── language_model: GraniteForCausalLM (LLaMA-style + Granite) | |
| """ | |
| import math | |
| from dataclasses import dataclass, field | |
| from typing import Any, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| @dataclass | |
| class GraniteSpeechEncoderConfig: | |
| input_dim: int = 160 | |
| num_layers: int = 16 | |
| hidden_dim: int = 1024 | |
| feedforward_mult: int = 4 | |
| num_heads: int = 8 | |
| dim_head: int = 128 | |
| output_dim: int = 348 | |
| context_size: int = 200 | |
| max_pos_emb: int = 512 | |
| dropout: float = 0.1 | |
| conv_kernel_size: int = 15 | |
| conv_expansion_factor: int = 2 | |
| @dataclass | |
| class Blip2QFormerConfig: | |
| hidden_size: int = 1024 | |
| num_hidden_layers: int = 2 | |
| num_attention_heads: int = 16 | |
| intermediate_size: int = 4096 | |
| hidden_act: str = "gelu" | |
| hidden_dropout_prob: float = 0.1 | |
| attention_probs_dropout_prob: float = 0.1 | |
| max_position_embeddings: int = 2048 | |
| layer_norm_eps: float = 1e-12 | |
| cross_attention_frequency: int = 1 | |
| encoder_hidden_size: int = 1024 | |
| chunk_size_feed_forward: int = 0 | |
| use_qformer_text_input: bool = False | |
| @dataclass | |
| class GraniteTextConfig: | |
| vocab_size: int = 100353 | |
| hidden_size: int = 2048 | |
| intermediate_size: int = 4096 | |
| num_hidden_layers: int = 40 | |
| num_attention_heads: int = 16 | |
| num_key_value_heads: int = 4 | |
| hidden_act: str = "silu" | |
| max_position_embeddings: int = 4096 | |
| rms_norm_eps: float = 1e-5 | |
| attention_bias: bool = False | |
| attention_dropout: float = 0.0 | |
| mlp_bias: bool = False | |
| rope_theta: float = 10000.0 | |
| embedding_multiplier: float = 12.0 | |
| logits_scaling: float = 8.0 | |
| residual_multiplier: float = 0.22 | |
| attention_multiplier: float = 0.0078125 | |
| pad_token_id: int = 100256 | |
| @dataclass | |
| class GraniteSpeechConfig: | |
| encoder_config: GraniteSpeechEncoderConfig = field( | |
| default_factory=GraniteSpeechEncoderConfig | |
| ) | |
| projector_config: Blip2QFormerConfig = field(default_factory=Blip2QFormerConfig) | |
| text_config: GraniteTextConfig = field(default_factory=GraniteTextConfig) | |
| audio_token_id: int = 100352 | |
| downsample_rate: int = 5 | |
| window_size: int = 15 | |
| initializer_range: float = 0.02 | |
| tie_word_embeddings: bool = False | |
| has_lora_adapter: bool = False | |
| def _config_get(config: Any, key: str, default: Any = None) -> Any: | |
| if isinstance(config, dict): | |
| return config.get(key, default) | |
| return getattr(config, key, default) | |
| def _strip_model_type(config_dict: dict[str, Any]) -> dict[str, Any]: | |
| return {k: v for k, v in config_dict.items() if k != "model_type"} | |
| def speech_config_from_dict(config_dict: dict[str, Any]) -> GraniteSpeechConfig: | |
| encoder_dict = _strip_model_type(config_dict.get("encoder_config", {})) | |
| projector_dict = _strip_model_type(config_dict.get("projector_config", {})) | |
| text_dict = _strip_model_type(config_dict.get("text_config", {})) | |
| return GraniteSpeechConfig( | |
| encoder_config=GraniteSpeechEncoderConfig(**encoder_dict), | |
| projector_config=Blip2QFormerConfig(**projector_dict), | |
| text_config=GraniteTextConfig(**text_dict), | |
| audio_token_id=config_dict.get( | |
| "audio_token_id", config_dict.get("audio_token_index", 100352) | |
| ), | |
| downsample_rate=config_dict.get("downsample_rate", 5), | |
| window_size=config_dict.get("window_size", 15), | |
| initializer_range=config_dict.get("initializer_range", 0.02), | |
| tie_word_embeddings=config_dict.get("tie_word_embeddings", False), | |
| has_lora_adapter=config_dict.get("has_lora_adapter", False), | |
| ) | |
| def speech_config_from_hf_config(config: Any) -> GraniteSpeechConfig: | |
| encoder_config = _config_get(config, "encoder_config", {}) | |
| projector_config = _config_get(config, "projector_config", {}) | |
| text_config = _config_get(config, "text_config", {}) | |
| if hasattr(encoder_config, "to_dict"): | |
| encoder_config = encoder_config.to_dict() | |
| if hasattr(projector_config, "to_dict"): | |
| projector_config = projector_config.to_dict() | |
| if hasattr(text_config, "to_dict"): | |
| text_config = text_config.to_dict() | |
| cfg = { | |
| "encoder_config": dict(encoder_config), | |
| "projector_config": dict(projector_config), | |
| "text_config": dict(text_config), | |
| "audio_token_id": _config_get( | |
| config, "audio_token_id", _config_get(config, "audio_token_index", 100352) | |
| ), | |
| "downsample_rate": _config_get(config, "downsample_rate", 5), | |
| "window_size": _config_get(config, "window_size", 15), | |
| "initializer_range": _config_get(config, "initializer_range", 0.02), | |
| "tie_word_embeddings": _config_get(config, "tie_word_embeddings", False), | |
| "has_lora_adapter": _config_get(config, "has_lora_adapter", False), | |
| } | |
| return speech_config_from_dict(cfg) | |
| class GraniteSpeechConformerFeedForward(nn.Module): | |
| def __init__(self, config: GraniteSpeechEncoderConfig): | |
| super().__init__() | |
| self.pre_norm = nn.LayerNorm(config.hidden_dim) | |
| self.up_proj = nn.Linear( | |
| config.hidden_dim, config.hidden_dim * config.feedforward_mult | |
| ) | |
| self.silu = nn.SiLU() | |
| self.dropout = nn.Dropout(config.dropout) | |
| self.down_proj = nn.Linear( | |
| config.hidden_dim * config.feedforward_mult, config.hidden_dim | |
| ) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.pre_norm(hidden_states) | |
| hidden_states = self.up_proj(hidden_states) | |
| hidden_states = self.dropout(self.silu(hidden_states)) | |
| hidden_states = self.down_proj(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| return hidden_states | |
| class GraniteSpeechConformerAttention(nn.Module): | |
| def __init__(self, config: GraniteSpeechEncoderConfig): | |
| super().__init__() | |
| inner_dim = config.dim_head * config.num_heads | |
| self.max_pos_emb = config.max_pos_emb | |
| self.context_size = config.context_size | |
| self.num_heads = config.num_heads | |
| self.dim_head = config.dim_head | |
| self.scale = self.dim_head**-0.5 | |
| self.pre_norm = nn.LayerNorm(config.hidden_dim) | |
| self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False) | |
| self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False) | |
| self.to_out = nn.Linear(inner_dim, config.hidden_dim) | |
| self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head) | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward( | |
| self, hidden_states: torch.Tensor, attention_dists: torch.Tensor | |
| ) -> torch.Tensor: | |
| hidden_states = self.pre_norm(hidden_states) | |
| bsz, num_features, _ = hidden_states.shape | |
| num_blocks = math.ceil(num_features / self.context_size) | |
| remainder = num_features % self.context_size | |
| if remainder > 0: | |
| hidden_states = F.pad( | |
| hidden_states, (0, 0, 0, self.context_size - remainder) | |
| ) | |
| query_states = self.to_q(hidden_states) | |
| key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1) | |
| # (bsz, num_blocks, num_heads, context_size, dim_head) | |
| query_states = query_states.reshape( | |
| bsz, num_blocks, self.context_size, self.num_heads, -1 | |
| ).transpose(2, 3) | |
| key_states = key_states.reshape( | |
| bsz, num_blocks, self.context_size, self.num_heads, -1 | |
| ).transpose(2, 3) | |
| value_states = value_states.reshape( | |
| bsz, num_blocks, self.context_size, self.num_heads, -1 | |
| ).transpose(2, 3) | |
| # Shaw's relative positional embeddings via einsum — O(C²·D) memory. | |
| # query_states : (bsz, num_blocks, num_heads, context_size, dim_head) | |
| # rel_pos_emb : (context_size, context_size, dim_head) | |
| rel_pos_emb = self.rel_pos_emb(attention_dists) | |
| pos_attn = ( | |
| torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) | |
| * self.scale | |
| ) | |
| if remainder > 0: | |
| # Mask out the padding region in the last block. | |
| mask = torch.ones( | |
| self.context_size, | |
| self.context_size, | |
| dtype=torch.bool, | |
| device=hidden_states.device, | |
| ) | |
| mask[:remainder, :remainder] = False | |
| mask_value = -torch.finfo(pos_attn.dtype).max | |
| pos_attn[:, -1, :].masked_fill_(mask, mask_value) | |
| out = F.scaled_dot_product_attention( | |
| query_states, | |
| key_states, | |
| value_states, | |
| attn_mask=pos_attn, | |
| scale=self.scale, | |
| ) | |
| out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1) | |
| out = self.to_out(out[:, :num_features, :]) | |
| return self.dropout(out) | |
| class GraniteSpeechConformerDepthWiseConv1d(nn.Module): | |
| def __init__(self, chan_in: int, chan_out: int, kernel_size: int): | |
| super().__init__() | |
| pad = kernel_size // 2 | |
| pad_offset = (kernel_size + 1) % 2 | |
| self.padding = (pad, pad - pad_offset) | |
| self.conv = nn.Conv1d( | |
| chan_in, chan_out, kernel_size, groups=chan_in, bias=False | |
| ) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = F.pad(hidden_states, self.padding) | |
| return self.conv(hidden_states) | |
| class GraniteSpeechConformerConvModule(nn.Module): | |
| def __init__(self, config: GraniteSpeechEncoderConfig): | |
| super().__init__() | |
| inner_dim = config.hidden_dim * config.conv_expansion_factor | |
| self.norm = nn.LayerNorm(config.hidden_dim) | |
| self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1) | |
| self.glu = nn.GLU(dim=1) | |
| self.depth_conv = GraniteSpeechConformerDepthWiseConv1d( | |
| inner_dim, inner_dim, kernel_size=config.conv_kernel_size | |
| ) | |
| self.silu = nn.SiLU() | |
| self.batch_norm = nn.BatchNorm1d(inner_dim) | |
| self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1) | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.norm(hidden_states) | |
| hidden_states = self.up_conv(hidden_states.permute(0, 2, 1)) | |
| hidden_states = self.glu(hidden_states) | |
| hidden_states = self.depth_conv(hidden_states) | |
| hidden_states = self.silu(self.batch_norm(hidden_states)) | |
| hidden_states = self.down_conv(hidden_states).permute(0, 2, 1) | |
| hidden_states = self.dropout(hidden_states) | |
| return hidden_states | |
| class GraniteSpeechConformerBlock(nn.Module): | |
| def __init__(self, config: GraniteSpeechEncoderConfig): | |
| super().__init__() | |
| self.ff1 = GraniteSpeechConformerFeedForward(config) | |
| self.attn = GraniteSpeechConformerAttention(config) | |
| self.conv = GraniteSpeechConformerConvModule(config) | |
| self.ff2 = GraniteSpeechConformerFeedForward(config) | |
| self.post_norm = nn.LayerNorm(config.hidden_dim) | |
| def forward( | |
| self, hidden_states: torch.Tensor, attention_dists: torch.Tensor | |
| ) -> torch.Tensor: | |
| hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states | |
| hidden_states = ( | |
| self.attn(hidden_states, attention_dists=attention_dists) + hidden_states | |
| ) | |
| hidden_states = self.conv(hidden_states) + hidden_states | |
| hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states | |
| hidden_states = self.post_norm(hidden_states) | |
| return hidden_states | |
| class GraniteSpeechCTCEncoder(nn.Module): | |
| def __init__(self, config: GraniteSpeechEncoderConfig): | |
| super().__init__() | |
| self.config = config | |
| seq = torch.arange(config.context_size) | |
| relpos_dist = seq.view(-1, 1) - seq.view(1, -1) | |
| attention_dists = ( | |
| torch.clamp(relpos_dist, -config.context_size, config.context_size) | |
| + config.max_pos_emb | |
| ) | |
| self.register_buffer("attention_dists", attention_dists, persistent=False) | |
| self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True) | |
| self.layers = nn.ModuleList( | |
| [GraniteSpeechConformerBlock(config) for _ in range(config.num_layers)] | |
| ) | |
| self.out = nn.Linear(config.hidden_dim, config.output_dim, bias=True) | |
| self.out_mid = nn.Linear(config.output_dim, config.hidden_dim, bias=True) | |
| self.num_layers = config.num_layers | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.input_linear(hidden_states) | |
| for idx, layer in enumerate(self.layers, start=1): | |
| hidden_states = layer(hidden_states, attention_dists=self.attention_dists) | |
| if idx == self.num_layers // 2: | |
| hidden_states_mid = self.out(hidden_states.clone()) | |
| # Inline softmax+projection — no persistent self.softmax submodule, | |
| # matching HF exactly so state-dict keys are identical. | |
| hidden_states = hidden_states + self.out_mid( | |
| nn.Softmax(dim=-1)(hidden_states_mid) | |
| ) | |
| return hidden_states | |
| class _Blip2QFormerMultiHeadAttention(nn.Module): | |
| def __init__(self, config: Blip2QFormerConfig, is_cross_attention: bool = False): | |
| super().__init__() | |
| self.num_attention_heads = config.num_attention_heads | |
| self.attention_head_size = config.hidden_size // config.num_attention_heads | |
| self.all_head_size = self.num_attention_heads * self.attention_head_size | |
| self.query = nn.Linear(config.hidden_size, self.all_head_size) | |
| if is_cross_attention: | |
| self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size) | |
| self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size) | |
| else: | |
| self.key = nn.Linear(config.hidden_size, self.all_head_size) | |
| self.value = nn.Linear(config.hidden_size, self.all_head_size) | |
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) | |
| def _transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: | |
| new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |
| return x.view(new_shape).permute(0, 2, 1, 3) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| ): | |
| is_cross = encoder_hidden_states is not None | |
| if is_cross: | |
| key_layer = self._transpose_for_scores(self.key(encoder_hidden_states)) | |
| value_layer = self._transpose_for_scores(self.value(encoder_hidden_states)) | |
| active_mask = encoder_attention_mask | |
| else: | |
| key_layer = self._transpose_for_scores(self.key(hidden_states)) | |
| value_layer = self._transpose_for_scores(self.value(hidden_states)) | |
| active_mask = attention_mask | |
| query_layer = self._transpose_for_scores(self.query(hidden_states)) | |
| scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |
| scores = scores / math.sqrt(self.attention_head_size) | |
| if active_mask is not None: | |
| scores = scores + active_mask | |
| probs = nn.Softmax(dim=-1)(scores) | |
| probs_dropped = self.dropout(probs).to(value_layer.dtype) | |
| ctx = torch.matmul(probs_dropped, value_layer) | |
| ctx = ctx.permute(0, 2, 1, 3).contiguous() | |
| ctx = ctx.view(ctx.size()[:-2] + (self.all_head_size,)) | |
| return ctx, probs | |
| class _Blip2QFormerSelfOutput(nn.Module): | |
| def __init__(self, config: Blip2QFormerConfig): | |
| super().__init__() | |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
| def forward( | |
| self, hidden_states: torch.Tensor, input_tensor: torch.Tensor | |
| ) -> torch.Tensor: | |
| hidden_states = self.dense(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
| return hidden_states | |
| class _Blip2QFormerAttention(nn.Module): | |
| def __init__(self, config: Blip2QFormerConfig, is_cross_attention: bool = False): | |
| super().__init__() | |
| self.attention = _Blip2QFormerMultiHeadAttention(config, is_cross_attention) | |
| self.output = _Blip2QFormerSelfOutput(config) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| attn_out, _ = self.attention( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| ) | |
| return self.output(attn_out, hidden_states) | |
| class _Blip2QFormerIntermediate(nn.Module): | |
| def __init__(self, config: Blip2QFormerConfig): | |
| super().__init__() | |
| self.dense = nn.Linear(config.hidden_size, config.intermediate_size) | |
| # HF uses ACT2FN["gelu"] which should be the same as nn.GELU() | |
| self.intermediate_act_fn = nn.GELU() | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| return self.intermediate_act_fn(self.dense(hidden_states)) | |
| class _Blip2QFormerOutput(nn.Module): | |
| def __init__(self, config: Blip2QFormerConfig): | |
| super().__init__() | |
| self.dense = nn.Linear(config.intermediate_size, config.hidden_size) | |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
| def forward( | |
| self, hidden_states: torch.Tensor, input_tensor: torch.Tensor | |
| ) -> torch.Tensor: | |
| hidden_states = self.dense(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
| return hidden_states | |
| class _Blip2QFormerLayer(nn.Module): | |
| def __init__(self, config: Blip2QFormerConfig, layer_idx: int): | |
| super().__init__() | |
| self.chunk_size_feed_forward = config.chunk_size_feed_forward | |
| self.seq_len_dim = 1 | |
| self.attention = _Blip2QFormerAttention(config) | |
| self.layer_idx = layer_idx | |
| self.has_cross_attention = layer_idx % config.cross_attention_frequency == 0 | |
| if self.has_cross_attention: | |
| self.crossattention = _Blip2QFormerAttention( | |
| config, is_cross_attention=True | |
| ) | |
| # `use_qformer_text_input` is False for GraniteSpeech — only query path. | |
| self.intermediate_query = _Blip2QFormerIntermediate(config) | |
| self.output_query = _Blip2QFormerOutput(config) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| query_length: int = 0, | |
| ) -> torch.Tensor: | |
| attention_output = self.attention(hidden_states, attention_mask=attention_mask) | |
| if query_length > 0: | |
| query_attn_out = attention_output[:, :query_length, :] | |
| if self.has_cross_attention: | |
| if encoder_hidden_states is None: | |
| raise ValueError( | |
| "encoder_hidden_states required for cross-attention" | |
| ) | |
| query_attn_out = self.crossattention( | |
| query_attn_out, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| ) | |
| layer_output = self._ff_chunk_query(query_attn_out) | |
| if attention_output.shape[1] > query_length: | |
| # text tokens present — not used in GraniteSpeech inference | |
| layer_output = torch.cat( | |
| [layer_output, attention_output[:, query_length:, :]], dim=1 | |
| ) | |
| else: | |
| layer_output = self._ff_chunk_query(attention_output) | |
| return layer_output | |
| def _ff_chunk_query(self, attention_output: torch.Tensor) -> torch.Tensor: | |
| intermediate = self.intermediate_query(attention_output) | |
| return self.output_query(intermediate, attention_output) | |
| class _Blip2QFormerEncoder(nn.Module): | |
| def __init__(self, config: Blip2QFormerConfig): | |
| super().__init__() | |
| self.config = config | |
| self.layer = nn.ModuleList( | |
| [_Blip2QFormerLayer(config, i) for i in range(config.num_hidden_layers)] | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| query_length: int = 0, | |
| ) -> torch.Tensor: | |
| for layer_module in self.layer: | |
| hidden_states = layer_module( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| query_length=query_length, | |
| ) | |
| return hidden_states | |
| class Blip2QFormerModel(nn.Module): | |
| """ | |
| Drop-in replacement for transformers.Blip2QFormerModel. | |
| The module name and all child module names match the HF implementation so | |
| that state-dict keys are identical. | |
| """ | |
| def __init__(self, config: Blip2QFormerConfig): | |
| super().__init__() | |
| self.config = config | |
| self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
| self.encoder = _Blip2QFormerEncoder(config) | |
| # Helpers that mirror Blip2QFormerModel.get_extended_attention_mask | |
| # and invert_attention_mask from transformers. | |
| @staticmethod | |
| def _get_extended_attention_mask( | |
| attention_mask: torch.Tensor, dtype: torch.dtype | |
| ) -> torch.Tensor: | |
| """(B, S) or (B, 1, S, S) → additive mask broadcastable to (B, H, S, S).""" | |
| if attention_mask.dim() == 2: | |
| extended = attention_mask[:, None, None, :] | |
| elif attention_mask.dim() == 3: | |
| extended = attention_mask[:, None, :, :] | |
| else: | |
| extended = attention_mask | |
| extended = extended.to(dtype=dtype) | |
| extended = (1.0 - extended) * -10000.0 | |
| return extended | |
| @staticmethod | |
| def _invert_attention_mask( | |
| encoder_attention_mask: torch.Tensor, dtype: torch.dtype | |
| ) -> torch.Tensor: | |
| """Convert a 0/1 key-padding mask to an additive bias.""" | |
| if encoder_attention_mask.dim() == 3: | |
| inverted = encoder_attention_mask[:, None, :, :] | |
| elif encoder_attention_mask.dim() == 2: | |
| inverted = encoder_attention_mask[:, None, None, :] | |
| else: | |
| inverted = encoder_attention_mask | |
| inverted = inverted.to(dtype=dtype) | |
| return (1.0 - inverted) * torch.finfo(dtype).min | |
| def forward( | |
| self, | |
| query_embeds: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| **_kwargs, | |
| ) -> "Blip2QFormerOutput": | |
| # Q-Former is kept in fp32 by HF convention. | |
| dtype = self.layernorm.weight.dtype | |
| query_embeds = query_embeds.to(dtype) | |
| embedding_output = self.dropout(self.layernorm(query_embeds)) | |
| batch_size, query_length = embedding_output.shape[:2] | |
| device = embedding_output.device | |
| # Self-attention mask: all queries attend to all queries. | |
| self_attn_mask = torch.ones((batch_size, query_length), device=device) | |
| extended_attn_mask = self._get_extended_attention_mask(self_attn_mask, dtype) | |
| # Cross-attention mask over encoder tokens. | |
| if encoder_hidden_states is not None: | |
| enc_hs = encoder_hidden_states.to(dtype) | |
| if encoder_attention_mask is None: | |
| enc_bsz, enc_len = enc_hs.shape[:2] | |
| encoder_attention_mask = torch.ones((enc_bsz, enc_len), device=device) | |
| encoder_extended_mask = self._invert_attention_mask( | |
| encoder_attention_mask, dtype | |
| ) | |
| else: | |
| enc_hs = None | |
| encoder_extended_mask = None | |
| sequence_output = self.encoder( | |
| embedding_output, | |
| attention_mask=extended_attn_mask, | |
| encoder_hidden_states=enc_hs, | |
| encoder_attention_mask=encoder_extended_mask, | |
| query_length=query_length, | |
| ) | |
| return Blip2QFormerOutput(last_hidden_state=sequence_output) | |
| class Blip2QFormerOutput: | |
| """Minimal output container (mirrors BaseModelOutputWithPoolingAndCrossAttentions).""" | |
| __slots__ = ("last_hidden_state",) | |
| def __init__(self, last_hidden_state: torch.Tensor): | |
| self.last_hidden_state = last_hidden_state | |
| class GraniteSpeechEncoderProjector(nn.Module): | |
| def __init__(self, config: GraniteSpeechConfig): | |
| super().__init__() | |
| self.downsample_rate = config.downsample_rate | |
| self.window_size = config.window_size | |
| self.num_queries = config.window_size // config.downsample_rate | |
| # Initialisation with normal_(mean=0, std=1) to match HF | |
| self.query = nn.Parameter( | |
| torch.empty( | |
| 1, self.num_queries, config.projector_config.hidden_size | |
| ).normal_(0.0, 1.0) | |
| ) | |
| self.qformer = Blip2QFormerModel(config.projector_config) | |
| self.linear = nn.Linear( | |
| config.projector_config.hidden_size, config.text_config.hidden_size | |
| ) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| batch_size, seq_len, dim = hidden_states.size() | |
| nblocks = math.ceil(seq_len / self.window_size) | |
| pad = nblocks * self.window_size - seq_len | |
| hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0) | |
| hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim) | |
| query_output = self.qformer( | |
| query_embeds=self.query, | |
| encoder_hidden_states=hidden_states, | |
| encoder_attention_mask=None, | |
| ) | |
| query_proj = self.linear( | |
| query_output.last_hidden_state.view( | |
| batch_size, nblocks * self.window_size // self.downsample_rate, -1 | |
| ) | |
| ) | |
| return query_proj | |
| class GraniteRMSNorm(nn.Module): | |
| def __init__(self, hidden_size: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| input_dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(torch.float32) | |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
| return self.weight * hidden_states.to(input_dtype) | |
| class GraniteRotaryEmbedding(nn.Module): | |
| def __init__(self, config: GraniteTextConfig): | |
| super().__init__() | |
| head_dim = config.hidden_size // config.num_attention_heads | |
| inv_freq = 1.0 / ( | |
| config.rope_theta | |
| ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) | |
| ) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| @torch.no_grad() | |
| def forward( | |
| self, x: torch.Tensor, position_ids: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| # position_ids: (batch, seq_len) | |
| inv_freq = self.inv_freq | |
| assert isinstance(inv_freq, torch.Tensor) | |
| inv_freq_expanded = ( | |
| inv_freq.unsqueeze(0) | |
| .unsqueeze(-1) | |
| .float() | |
| .expand(position_ids.shape[0], -1, 1) | |
| .to(x.device) | |
| ) | |
| position_ids_expanded = position_ids[:, None, :].float() | |
| with torch.autocast(device_type=x.device.type, enabled=False): | |
| freqs = ( | |
| inv_freq_expanded.float() @ position_ids_expanded.float() | |
| ).transpose(1, 2) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| cos = emb.cos() | |
| sin = emb.sin() | |
| return cos.to(x.dtype), sin.to(x.dtype) | |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def _apply_rotary_pos_emb( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| cos = cos.unsqueeze(1) # (B, 1, S, D) | |
| sin = sin.unsqueeze(1) | |
| q_embed = (q * cos) + (_rotate_half(q) * sin) | |
| k_embed = (k * cos) + (_rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| """Expand GQA key/value heads to match the number of query heads.""" | |
| if n_rep == 1: | |
| return hidden_states | |
| batch, num_kv_heads, slen, head_dim = hidden_states.shape | |
| hidden_states = hidden_states[:, :, None, :, :].expand( | |
| batch, num_kv_heads, n_rep, slen, head_dim | |
| ) | |
| return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim) | |
| class GraniteAttention(nn.Module): | |
| def __init__(self, config: GraniteTextConfig, layer_idx: int): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| self.head_dim = config.hidden_size // config.num_attention_heads | |
| self.num_heads = config.num_attention_heads | |
| self.num_key_value_heads = config.num_key_value_heads | |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads | |
| self.scaling = config.attention_multiplier | |
| self.attention_dropout = config.attention_dropout | |
| self.q_proj = nn.Linear( | |
| config.hidden_size, | |
| self.num_heads * self.head_dim, | |
| bias=config.attention_bias, | |
| ) | |
| self.k_proj = nn.Linear( | |
| config.hidden_size, | |
| self.num_key_value_heads * self.head_dim, | |
| bias=config.attention_bias, | |
| ) | |
| self.v_proj = nn.Linear( | |
| config.hidden_size, | |
| self.num_key_value_heads * self.head_dim, | |
| bias=config.attention_bias, | |
| ) | |
| self.o_proj = nn.Linear( | |
| self.num_heads * self.head_dim, | |
| config.hidden_size, | |
| bias=config.attention_bias, | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], | |
| attention_mask: Optional[torch.Tensor] = None, | |
| past_key_values: Optional["KVCache"] = None, | |
| cache_position: Optional[torch.Tensor] = None, | |
| ) -> tuple[torch.Tensor, None]: | |
| bsz, q_len, _ = hidden_states.shape | |
| hidden_shape = (bsz, q_len, -1, self.head_dim) | |
| query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) | |
| key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) | |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) | |
| cos, sin = position_embeddings | |
| query_states, key_states = _apply_rotary_pos_emb( | |
| query_states, key_states, cos, sin | |
| ) | |
| if past_key_values is not None: | |
| key_states, value_states = past_key_values.update( | |
| key_states, value_states, self.layer_idx, cache_position | |
| ) | |
| key_states = _repeat_kv(key_states, self.num_key_value_groups) | |
| value_states = _repeat_kv(value_states, self.num_key_value_groups) | |
| attn_weights = ( | |
| torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling | |
| ) | |
| if attention_mask is not None: | |
| attn_weights = attn_weights + attention_mask | |
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( | |
| query_states.dtype | |
| ) | |
| attn_weights = F.dropout( | |
| attn_weights, p=self.attention_dropout, training=self.training | |
| ) | |
| attn_output = torch.matmul(attn_weights, value_states) | |
| attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, -1) | |
| attn_output = self.o_proj(attn_output) | |
| return attn_output, None | |
| class GraniteMLP(nn.Module): | |
| def __init__(self, config: GraniteTextConfig): | |
| super().__init__() | |
| self.gate_proj = nn.Linear( | |
| config.hidden_size, config.intermediate_size, bias=config.mlp_bias | |
| ) | |
| self.up_proj = nn.Linear( | |
| config.hidden_size, config.intermediate_size, bias=config.mlp_bias | |
| ) | |
| self.down_proj = nn.Linear( | |
| config.intermediate_size, config.hidden_size, bias=config.mlp_bias | |
| ) | |
| self.act_fn = nn.SiLU() # config.hidden_act == "silu" for Granite | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | |
| class GraniteDecoderLayer(nn.Module): | |
| def __init__(self, config: GraniteTextConfig, layer_idx: int): | |
| super().__init__() | |
| self.self_attn = GraniteAttention(config, layer_idx=layer_idx) | |
| self.mlp = GraniteMLP(config) | |
| self.input_layernorm = GraniteRMSNorm( | |
| config.hidden_size, eps=config.rms_norm_eps | |
| ) | |
| self.post_attention_layernorm = GraniteRMSNorm( | |
| config.hidden_size, eps=config.rms_norm_eps | |
| ) | |
| self.residual_multiplier = config.residual_multiplier | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, | |
| past_key_values: Optional["KVCache"] = None, | |
| cache_position: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| residual = hidden_states | |
| hidden_states = self.input_layernorm(hidden_states) | |
| hidden_states, _ = self.self_attn( | |
| hidden_states, | |
| position_embeddings=position_embeddings, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| cache_position=cache_position, | |
| ) | |
| hidden_states = residual + hidden_states * self.residual_multiplier | |
| residual = hidden_states | |
| hidden_states = self.post_attention_layernorm(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states * self.residual_multiplier | |
| return hidden_states | |
| class KVCache: | |
| """Simple dynamic key-value cache, interface-compatible with the HF DynamicCache.""" | |
| def __init__(self): | |
| self._keys: dict[int, torch.Tensor] = {} | |
| self._values: dict[int, torch.Tensor] = {} | |
| def update( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| layer_idx: int, | |
| cache_position: Optional[torch.Tensor] = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if layer_idx in self._keys: | |
| self._keys[layer_idx] = torch.cat( | |
| [self._keys[layer_idx], key_states], dim=2 | |
| ) | |
| self._values[layer_idx] = torch.cat( | |
| [self._values[layer_idx], value_states], dim=2 | |
| ) | |
| else: | |
| self._keys[layer_idx] = key_states | |
| self._values[layer_idx] = value_states | |
| return self._keys[layer_idx], self._values[layer_idx] | |
| def get_seq_length(self) -> int: | |
| if not self._keys: | |
| return 0 | |
| return next(iter(self._keys.values())).shape[2] | |
| def _make_causal_mask( | |
| seq_len: int, | |
| past_len: int, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| ) -> torch.Tensor: | |
| """ | |
| Upper-triangular additive causal mask of shape (1, 1, seq_len, seq_len + past_len). | |
| Positions that should be masked get -inf; attended positions get 0. | |
| """ | |
| total_len = seq_len + past_len | |
| mask = torch.full((seq_len, total_len), torch.finfo(dtype).min, device=device) | |
| # Each query position i can attend to keys 0 .. (past_len + i) | |
| attend_cols = torch.arange(total_len, device=device) | |
| query_rows = torch.arange(past_len, past_len + seq_len, device=device) | |
| causal = attend_cols[None, :] <= query_rows[:, None] | |
| mask = mask.masked_fill(causal, 0.0) | |
| return mask.unsqueeze(0).unsqueeze(0) # (1, 1, S, S+past) | |
| class GraniteModel(nn.Module): | |
| def __init__(self, config: GraniteTextConfig): | |
| super().__init__() | |
| self.config = config | |
| self.padding_idx = config.pad_token_id | |
| self.embed_tokens = nn.Embedding( | |
| config.vocab_size, config.hidden_size, padding_idx=self.padding_idx | |
| ) | |
| self.layers = nn.ModuleList( | |
| [GraniteDecoderLayer(config, i) for i in range(config.num_hidden_layers)] | |
| ) | |
| self.norm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.rotary_emb = GraniteRotaryEmbedding(config) | |
| self.embedding_multiplier = config.embedding_multiplier | |
| def get_input_embeddings(self) -> nn.Embedding: | |
| return self.embed_tokens | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[KVCache] = None, | |
| cache_position: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| if inputs_embeds is None: | |
| if input_ids is None: | |
| raise ValueError("Either input_ids or inputs_embeds must be provided") | |
| inputs_embeds = self.embed_tokens(input_ids) | |
| assert inputs_embeds is not None | |
| inputs_embeds = inputs_embeds * self.embedding_multiplier | |
| bsz, seq_len, _ = inputs_embeds.shape | |
| past_len = ( | |
| past_key_values.get_seq_length() if past_key_values is not None else 0 | |
| ) | |
| if cache_position is None: | |
| cache_position = torch.arange( | |
| past_len, past_len + seq_len, device=inputs_embeds.device | |
| ) | |
| if position_ids is None: | |
| position_ids = cache_position.unsqueeze(0) | |
| # Build additive causal mask, then fold in any padding mask | |
| causal_mask = _make_causal_mask( | |
| seq_len, past_len, inputs_embeds.dtype, inputs_embeds.device | |
| ) | |
| if attention_mask is not None and attention_mask.dim() == 2: | |
| # attention_mask is 1 for real tokens, 0 for padding (shape: B, total_len) | |
| # Expand to (B, 1, seq_len, total_len) additive form | |
| pad_mask = ( | |
| 1.0 - attention_mask[:, None, None, :].to(inputs_embeds.dtype) | |
| ) * torch.finfo(inputs_embeds.dtype).min | |
| # pad_mask covers [0 .. total_len]; causal_mask covers the same range | |
| causal_mask = causal_mask + pad_mask | |
| position_embeddings = self.rotary_emb(inputs_embeds, position_ids) | |
| hidden_states = inputs_embeds | |
| for layer in self.layers: | |
| hidden_states = layer( | |
| hidden_states, | |
| attention_mask=causal_mask, | |
| position_embeddings=position_embeddings, | |
| past_key_values=past_key_values, | |
| cache_position=cache_position, | |
| ) | |
| hidden_states = self.norm(hidden_states) | |
| return hidden_states | |
| class GraniteForCausalLM(nn.Module): | |
| def __init__(self, config: GraniteTextConfig): | |
| super().__init__() | |
| self.config = config | |
| self.model = GraniteModel(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| def get_input_embeddings(self) -> nn.Embedding: | |
| return self.model.get_input_embeddings() | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[KVCache] = None, | |
| cache_position: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| ) -> dict: | |
| hidden_states = self.model( | |
| input_ids=input_ids, | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| cache_position=cache_position, | |
| position_ids=position_ids, | |
| ) | |
| logits = self.lm_head(hidden_states) / self.config.logits_scaling | |
| loss = None | |
| if labels is not None: | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| loss = F.cross_entropy( | |
| shift_logits.view(-1, self.vocab_size), | |
| shift_labels.view(-1).to(shift_logits.device), | |
| ) | |
| return {"loss": loss, "logits": logits, "past_key_values": past_key_values} | |
| # Main class for inference | |
| class GraniteSpeechForConditionalGeneration(nn.Module): | |
| def __init__(self, config: GraniteSpeechConfig): | |
| super().__init__() | |
| self.config = config | |
| self.audio_token_index = config.audio_token_id | |
| self.encoder = GraniteSpeechCTCEncoder(config.encoder_config) | |
| self.projector = GraniteSpeechEncoderProjector(config) | |
| self.language_model = GraniteForCausalLM(config.text_config) | |
| def get_input_embeddings(self) -> nn.Embedding: | |
| return self.language_model.get_input_embeddings() | |
| def get_audio_features(self, input_features: torch.Tensor) -> torch.Tensor: | |
| encoder_embeds = self.encoder(input_features) | |
| return self.projector(encoder_embeds) | |
| def get_merged_audio_embeddings( | |
| self, | |
| input_ids: torch.Tensor, | |
| audio_features: torch.Tensor, | |
| input_features_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| is_audio_idx = input_ids == self.audio_token_index | |
| # Clamp audio positions to 0 so the LLM embedding lookup never goes OOV | |
| llm_input_ids = torch.where( | |
| is_audio_idx, torch.zeros_like(input_ids), input_ids | |
| ) | |
| inputs_embeds = self.get_input_embeddings()(llm_input_ids) | |
| audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) | |
| if input_features_mask is not None: | |
| audio_features = audio_features[input_features_mask] | |
| special_audio_mask = is_audio_idx.unsqueeze(-1).expand_as(inputs_embeds) | |
| inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) | |
| return inputs_embeds | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| input_features: Optional[torch.Tensor] = None, | |
| input_features_mask: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[KVCache] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| ) -> dict: | |
| model_dtype = next(self.parameters()).dtype | |
| if input_features is not None: | |
| input_features = input_features.to(model_dtype) | |
| audio_embeds = self.get_audio_features(input_features) | |
| inputs_embeds = self.get_merged_audio_embeddings( | |
| input_ids=input_ids, | |
| audio_features=audio_embeds, | |
| input_features_mask=input_features_mask, | |
| ) | |
| else: | |
| inputs_embeds = self.get_input_embeddings()(input_ids) | |
| return self.language_model( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| labels=labels, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment