Skip to content

Instantly share code, notes, and snippets.

@mseri
Created March 18, 2026 08:54
Show Gist options
  • Select an option

  • Save mseri/df3208da68011a4eced69e5f3d2b97cd to your computer and use it in GitHub Desktop.

Select an option

Save mseri/df3208da68011a4eced69e5f3d2b97cd to your computer and use it in GitHub Desktop.
Pure PyTorch implementation of granite speech and simple cli (using VAD to reduce memory footprint). Implemented with substantial help from copilot
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "torch>=2.1.0",
# "torchaudio>=2.1.0",
# "transformers>=4.40.0",
# "soundfile>=0.12.1",
# "silero-vad>=5.1.2",
# ]
# ///
# Usage: uv run granite-speech.py --help
# usage: granite-speech.py [-h] -i INPUT [-o OUTPUT] [--max-new-tokens MAX_NEW_TOKENS] [--debug] [--vad-min-silence-ms VAD_MIN_SILENCE_MS]
# [--vad-speech-pad-ms VAD_SPEECH_PAD_MS] [--vad-max-segment-secs VAD_MAX_SEGMENT_SECS] [--output-segment-timestamps]
# prompt
#
# Transcribe audio with IBM Granite Speech using VAD-based segmentation.
#
# positional arguments:
# prompt Instruction prompt, e.g. 'Transcribe the audio verbatim.'
#
# options:
# -h, --help show this help message and exit
# -i INPUT, --input INPUT
# Path to audio file (WAV/MP3/FLAC).
# -o OUTPUT, --output OUTPUT
# Path to save transcript (.md appended if absent).
# --max-new-tokens MAX_NEW_TOKENS
# Maximum new tokens to generate per VAD segment (default: 512).
# --debug Print raw generation diagnostics for prompt wiring and returned tokens.
# --vad-min-silence-ms VAD_MIN_SILENCE_MS
# Minimum silence duration in ms for VAD segmentation (default: 300).
# --vad-speech-pad-ms VAD_SPEECH_PAD_MS
# Padding in ms added around VAD speech segments (default: 200).
# --vad-max-segment-secs VAD_MAX_SEGMENT_SECS
# Maximum duration in seconds for a single VAD speech segment after splitting (default: 30.0).
# --output-segment-timestamps
# Include per-segment timestamps in the saved transcript output.
import argparse
import os
import sys
import warnings
import soundfile as sf
import torch
import torchaudio.functional as F
from silero_vad import get_speech_timestamps, load_silero_vad
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
warnings.filterwarnings(
"ignore",
message="An output with one or more elements was resized",
category=UserWarning,
)
SAMPLE_RATE = 16_000
def load_audio(path: str) -> torch.Tensor:
"""Load audio, convert to mono, resample to 16 kHz, return shape [1, N]."""
if not os.path.exists(path):
print(f"Error: audio file not found: {path}", file=sys.stderr)
sys.exit(1)
data, sr = sf.read(path, dtype="float32", always_2d=True)
wav = torch.from_numpy(data.T)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != SAMPLE_RATE:
wav = F.resample(wav, orig_freq=sr, new_freq=SAMPLE_RATE)
return wav
def segment_audio_with_vad(
wav: torch.Tensor,
*,
min_silence_ms: int = 300,
speech_pad_ms: int = 200,
max_segment_secs: float = 30.0,
) -> list[tuple[int, int]]:
"""
Segment audio with Silero VAD and return [start_sample, end_sample) regions.
Long speech regions are split to keep transcription memory-bounded.
"""
audio = wav[0].cpu()
vad_model = load_silero_vad()
timestamps = get_speech_timestamps(
audio,
vad_model,
sampling_rate=SAMPLE_RATE,
min_silence_duration_ms=min_silence_ms,
speech_pad_ms=speech_pad_ms,
return_seconds=False,
)
if not timestamps:
return [(0, int(audio.shape[0]))]
max_segment_samples = int(max_segment_secs * SAMPLE_RATE)
segments: list[tuple[int, int]] = []
for item in timestamps:
start = int(item["start"])
end = int(item["end"])
while end - start > max_segment_samples:
split_end = start + max_segment_samples
segments.append((start, split_end))
start = split_end
if end > start:
segments.append((start, end))
return segments
def format_timestamp(seconds: float) -> str:
"""Format seconds as HH:MM:SS.m."""
total_tenths = int(round(seconds * 10))
hours = total_tenths // 36_000
minutes = (total_tenths % 36_000) // 600
secs = (total_tenths % 600) // 10
tenths = total_tenths % 10
return f"{hours:02d}:{minutes:02d}:{secs:02d}.{tenths}"
def merge_transcript_segments(segments: list[str], max_check: int = 400) -> str:
"""Merge adjacent text segments with simple suffix/prefix overlap removal."""
def longest_suffix_prefix(a: str, b: str) -> int:
if not a or not b:
return 0
limit = min(len(a), len(b), max_check)
for size in range(limit, 0, -1):
if a[-size:] == b[:size]:
return size
return 0
merged = ""
for segment in segments:
segment = segment.strip()
if not segment:
continue
if not merged:
merged = segment
continue
overlap = longest_suffix_prefix(merged, segment)
if overlap > 0:
merged += segment[overlap:]
else:
merged += " " + segment
return merged.strip()
def transcribe_audio_segment(
*,
model,
processor,
tokenizer,
audio_token_id: int,
prefix_ids: torch.Tensor,
prefix_mask: torch.Tensor,
suffix_ids: torch.Tensor,
suffix_mask: torch.Tensor,
wav_segment: torch.Tensor,
device: str,
dtype: torch.dtype,
max_new_tokens: int,
debug: bool,
segment_label: str,
) -> str:
"""
Transcribe a single VAD speech segment by encoding its audio features and
merging them into a placeholder-token sequence.
"""
seg_inputs = processor(text="<|audio|>", audio=wav_segment, return_tensors="pt")
seg_features = seg_inputs["input_features"].to(device=device, dtype=dtype)
with torch.no_grad():
audio_out = model.get_audio_features(seg_features)
audio_embed = audio_out.pooler_output # [1, A, H]
audio_len = audio_embed.shape[1]
audio_placeholder_ids = torch.full(
(1, audio_len),
audio_token_id,
dtype=prefix_ids.dtype,
device=device,
)
audio_placeholder_mask = torch.ones(
(1, audio_len),
dtype=prefix_mask.dtype,
device=device,
)
seg_input_ids = torch.cat([prefix_ids, audio_placeholder_ids, suffix_ids], dim=1)
seg_mask = torch.cat([prefix_mask, audio_placeholder_mask, suffix_mask], dim=1)
embed_fn = model.get_input_embeddings()
llm_lookup_ids = torch.where(
seg_input_ids == audio_token_id,
torch.zeros_like(seg_input_ids),
seg_input_ids,
)
seg_embeds = embed_fn(llm_lookup_ids)
special_audio_mask = (
(seg_input_ids == audio_token_id).unsqueeze(-1).expand_as(seg_embeds)
)
seg_embeds = seg_embeds.masked_scatter(
special_audio_mask,
audio_embed.to(device=seg_embeds.device, dtype=seg_embeds.dtype),
)
if debug:
print(
f"[debug] {segment_label}: audio_embed_len={audio_len} "
f"prompt_len={seg_input_ids.shape[1]} attention_mask_sum={int(seg_mask.sum().item())}"
)
with torch.no_grad():
output = model.generate(
inputs_embeds=seg_embeds,
attention_mask=seg_mask,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
num_beams=1,
no_repeat_ngram_size=5,
return_dict_in_generate=True,
output_scores=debug,
)
sequences = output.sequences
if sequences.shape[1] == 0:
if debug:
print(f"[debug] {segment_label}: empty returned sequence")
return ""
returned_ids = sequences[0].to(dtype=torch.int32).tolist()
returned_text_raw = tokenizer.decode(returned_ids, skip_special_tokens=False)
returned_text = tokenizer.decode(returned_ids, skip_special_tokens=True).strip()
if debug:
print(f"[debug] {segment_label}: returned_token_count={len(returned_ids)}")
print(f"[debug] {segment_label}: first_token_ids={returned_ids[:32]}")
print(f"[debug] {segment_label}: last_token_ids={returned_ids[-32:]}")
print(f"[debug] {segment_label}: raw_decoded={returned_text_raw!r}")
if hasattr(output, "scores") and output.scores:
print(
f"[debug] {segment_label}: generated_steps_from_scores={len(output.scores)}"
)
elif returned_text:
print(returned_text, flush=True, end=" ")
return returned_text
def main() -> None:
parser = argparse.ArgumentParser(
description="Transcribe audio with IBM Granite Speech using VAD-based segmentation."
)
parser.add_argument(
"prompt",
help="Instruction prompt, e.g. 'Transcribe the audio verbatim.'",
)
parser.add_argument(
"-i",
"--input",
required=True,
help="Path to audio file (WAV/MP3/FLAC).",
)
parser.add_argument(
"-o",
"--output",
default=None,
help="Path to save transcript (.md appended if absent).",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=512,
help="Maximum new tokens to generate per VAD segment (default: 512).",
)
parser.add_argument(
"--debug",
action="store_true",
help="Print raw generation diagnostics for prompt wiring and returned tokens.",
)
parser.add_argument(
"--vad-min-silence-ms",
type=int,
default=300,
help="Minimum silence duration in ms for VAD segmentation (default: 300).",
)
parser.add_argument(
"--vad-speech-pad-ms",
type=int,
default=200,
help="Padding in ms added around VAD speech segments (default: 200).",
)
parser.add_argument(
"--vad-max-segment-secs",
type=float,
default=30.0,
help="Maximum duration in seconds for a single VAD speech segment after splitting (default: 30.0).",
)
parser.add_argument(
"--output-segment-timestamps",
action="store_true",
help="Include per-segment timestamps in the saved transcript output.",
)
args = parser.parse_args()
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
print("Warning: neither MPS nor CUDA available, falling back to CPU.")
device = "cpu"
dtype = torch.bfloat16
model_name = "ibm-granite/granite-4.0-1b-speech"
audio_token = "<|audio|>"
print("Loading processor …")
processor = AutoProcessor.from_pretrained(model_name)
tokenizer = processor.tokenizer
print("Loading model weights …")
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name,
dtype=dtype,
)
model.to(device)
model.eval()
print(f"Loading audio: {args.input}")
wav = load_audio(args.input)
duration = wav.shape[-1] / SAMPLE_RATE
print(f"Loaded {wav.shape[-1]} samples ({duration:.1f}s)")
user_content = (
args.prompt if audio_token in args.prompt else f"{audio_token}{args.prompt}"
)
chat = [{"role": "user", "content": user_content}]
prompt = tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True
)
if args.debug:
print("\n[debug] prompt text:")
print(prompt)
audio_token_id: int = tokenizer.convert_tokens_to_ids(audio_token)
text_inputs = tokenizer(prompt, return_tensors="pt")
input_ids = text_inputs["input_ids"].to(device)
attention_mask = text_inputs["attention_mask"].to(device)
audio_positions = (input_ids[0] == audio_token_id).nonzero(as_tuple=True)[0]
if len(audio_positions) == 0:
print(
"Warning: <|audio|> token not found in prompt; audio will not be injected.",
file=sys.stderr,
)
pos = int(audio_positions[0]) if len(audio_positions) > 0 else input_ids.shape[1]
if args.debug:
print(
f"[debug] audio token id: {audio_token_id} | "
f"audio token positions in prompt ids: {audio_positions.tolist()}"
)
prefix_ids = input_ids[:, :pos]
prefix_mask = attention_mask[:, :pos]
suffix_ids = input_ids[:, pos + 1 :]
suffix_mask = attention_mask[:, pos + 1 :]
if args.debug:
prefix_text = tokenizer.decode(
prefix_ids[0].to(dtype=torch.int32).tolist(),
skip_special_tokens=False,
)
suffix_text = tokenizer.decode(
suffix_ids[0].to(dtype=torch.int32).tolist(),
skip_special_tokens=False,
)
print(f"[debug] prefix text: {prefix_text!r}")
print(f"[debug] suffix text: {suffix_text!r}")
print(
f"[debug] prefix ids len: {prefix_ids.shape[1]} | "
f"suffix ids len: {suffix_ids.shape[1]}"
)
vad_segments = segment_audio_with_vad(
wav,
min_silence_ms=args.vad_min_silence_ms,
speech_pad_ms=args.vad_speech_pad_ms,
max_segment_secs=args.vad_max_segment_secs,
)
total_vad_duration = sum(end - start for start, end in vad_segments) / SAMPLE_RATE
print(
f"\nVAD: {len(vad_segments)} speech segment(s) covering "
f"{total_vad_duration:.1f}s of audio"
)
all_segments: list[str] = []
timestamped_segments: list[tuple[float, float, str]] = []
for i, (seg_start, seg_end) in enumerate(vad_segments, start=1):
seg_wav = wav[:, seg_start:seg_end]
seg_start_s = seg_start / SAMPLE_RATE
seg_end_s = seg_end / SAMPLE_RATE
if args.debug:
print(
f"\n--- Segment {i}/{len(vad_segments)} "
f"({seg_start_s:.1f}–{seg_end_s:.1f}s) ---",
flush=True,
)
seg_text = transcribe_audio_segment(
model=model,
processor=processor,
tokenizer=tokenizer,
audio_token_id=audio_token_id,
prefix_ids=prefix_ids,
prefix_mask=prefix_mask,
suffix_ids=suffix_ids,
suffix_mask=suffix_mask,
wav_segment=seg_wav,
device=device,
dtype=dtype,
max_new_tokens=args.max_new_tokens,
debug=args.debug,
segment_label=f"segment {i}",
)
if seg_text:
all_segments.append(seg_text)
timestamped_segments.append((seg_start_s, seg_end_s, seg_text))
print()
print(f"\nGeneration complete ({len(all_segments)} transcribed segment(s)).")
if args.output:
generated_text = merge_transcript_segments(all_segments)
out_path = args.output if args.output.endswith(".md") else args.output + ".md"
transcript_body = generated_text
if args.output_segment_timestamps:
transcript_body = "\n\n".join(
f"[{format_timestamp(start)} - {format_timestamp(end)}] {text}"
for start, end, text in timestamped_segments
)
with open(out_path, "w", encoding="utf-8") as f:
f.write(f"# Prompt\n\n{args.prompt}\n\n# Transcript\n\n{transcript_body}\n")
print(f"Saved to: {out_path}")
if __name__ == "__main__":
main()
"""
PyTorch reference implementation of GraniteSpeech.
The state-dict key hierarchy is meant to be *identical* to the HuggingFace
implementation, allowing weights to be loaded with:
model.load_state_dict(hf_state_dict, strict=True)
GraniteSpeechForConditionalGeneration
├── encoder : GraniteSpeechCTCEncoder (Conformer)
├── projector: GraniteSpeechEncoderProjector (Blip2 Q-Former + linear)
└── language_model: GraniteForCausalLM (LLaMA-style + Granite)
"""
import math
from dataclasses import dataclass, field
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class GraniteSpeechEncoderConfig:
input_dim: int = 160
num_layers: int = 16
hidden_dim: int = 1024
feedforward_mult: int = 4
num_heads: int = 8
dim_head: int = 128
output_dim: int = 348
context_size: int = 200
max_pos_emb: int = 512
dropout: float = 0.1
conv_kernel_size: int = 15
conv_expansion_factor: int = 2
@dataclass
class Blip2QFormerConfig:
hidden_size: int = 1024
num_hidden_layers: int = 2
num_attention_heads: int = 16
intermediate_size: int = 4096
hidden_act: str = "gelu"
hidden_dropout_prob: float = 0.1
attention_probs_dropout_prob: float = 0.1
max_position_embeddings: int = 2048
layer_norm_eps: float = 1e-12
cross_attention_frequency: int = 1
encoder_hidden_size: int = 1024
chunk_size_feed_forward: int = 0
use_qformer_text_input: bool = False
@dataclass
class GraniteTextConfig:
vocab_size: int = 100353
hidden_size: int = 2048
intermediate_size: int = 4096
num_hidden_layers: int = 40
num_attention_heads: int = 16
num_key_value_heads: int = 4
hidden_act: str = "silu"
max_position_embeddings: int = 4096
rms_norm_eps: float = 1e-5
attention_bias: bool = False
attention_dropout: float = 0.0
mlp_bias: bool = False
rope_theta: float = 10000.0
embedding_multiplier: float = 12.0
logits_scaling: float = 8.0
residual_multiplier: float = 0.22
attention_multiplier: float = 0.0078125
pad_token_id: int = 100256
@dataclass
class GraniteSpeechConfig:
encoder_config: GraniteSpeechEncoderConfig = field(
default_factory=GraniteSpeechEncoderConfig
)
projector_config: Blip2QFormerConfig = field(default_factory=Blip2QFormerConfig)
text_config: GraniteTextConfig = field(default_factory=GraniteTextConfig)
audio_token_id: int = 100352
downsample_rate: int = 5
window_size: int = 15
initializer_range: float = 0.02
tie_word_embeddings: bool = False
has_lora_adapter: bool = False
def _config_get(config: Any, key: str, default: Any = None) -> Any:
if isinstance(config, dict):
return config.get(key, default)
return getattr(config, key, default)
def _strip_model_type(config_dict: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in config_dict.items() if k != "model_type"}
def speech_config_from_dict(config_dict: dict[str, Any]) -> GraniteSpeechConfig:
encoder_dict = _strip_model_type(config_dict.get("encoder_config", {}))
projector_dict = _strip_model_type(config_dict.get("projector_config", {}))
text_dict = _strip_model_type(config_dict.get("text_config", {}))
return GraniteSpeechConfig(
encoder_config=GraniteSpeechEncoderConfig(**encoder_dict),
projector_config=Blip2QFormerConfig(**projector_dict),
text_config=GraniteTextConfig(**text_dict),
audio_token_id=config_dict.get(
"audio_token_id", config_dict.get("audio_token_index", 100352)
),
downsample_rate=config_dict.get("downsample_rate", 5),
window_size=config_dict.get("window_size", 15),
initializer_range=config_dict.get("initializer_range", 0.02),
tie_word_embeddings=config_dict.get("tie_word_embeddings", False),
has_lora_adapter=config_dict.get("has_lora_adapter", False),
)
def speech_config_from_hf_config(config: Any) -> GraniteSpeechConfig:
encoder_config = _config_get(config, "encoder_config", {})
projector_config = _config_get(config, "projector_config", {})
text_config = _config_get(config, "text_config", {})
if hasattr(encoder_config, "to_dict"):
encoder_config = encoder_config.to_dict()
if hasattr(projector_config, "to_dict"):
projector_config = projector_config.to_dict()
if hasattr(text_config, "to_dict"):
text_config = text_config.to_dict()
cfg = {
"encoder_config": dict(encoder_config),
"projector_config": dict(projector_config),
"text_config": dict(text_config),
"audio_token_id": _config_get(
config, "audio_token_id", _config_get(config, "audio_token_index", 100352)
),
"downsample_rate": _config_get(config, "downsample_rate", 5),
"window_size": _config_get(config, "window_size", 15),
"initializer_range": _config_get(config, "initializer_range", 0.02),
"tie_word_embeddings": _config_get(config, "tie_word_embeddings", False),
"has_lora_adapter": _config_get(config, "has_lora_adapter", False),
}
return speech_config_from_dict(cfg)
class GraniteSpeechConformerFeedForward(nn.Module):
def __init__(self, config: GraniteSpeechEncoderConfig):
super().__init__()
self.pre_norm = nn.LayerNorm(config.hidden_dim)
self.up_proj = nn.Linear(
config.hidden_dim, config.hidden_dim * config.feedforward_mult
)
self.silu = nn.SiLU()
self.dropout = nn.Dropout(config.dropout)
self.down_proj = nn.Linear(
config.hidden_dim * config.feedforward_mult, config.hidden_dim
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(hidden_states)
hidden_states = self.up_proj(hidden_states)
hidden_states = self.dropout(self.silu(hidden_states))
hidden_states = self.down_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class GraniteSpeechConformerAttention(nn.Module):
def __init__(self, config: GraniteSpeechEncoderConfig):
super().__init__()
inner_dim = config.dim_head * config.num_heads
self.max_pos_emb = config.max_pos_emb
self.context_size = config.context_size
self.num_heads = config.num_heads
self.dim_head = config.dim_head
self.scale = self.dim_head**-0.5
self.pre_norm = nn.LayerNorm(config.hidden_dim)
self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, config.hidden_dim)
self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head)
self.dropout = nn.Dropout(config.dropout)
def forward(
self, hidden_states: torch.Tensor, attention_dists: torch.Tensor
) -> torch.Tensor:
hidden_states = self.pre_norm(hidden_states)
bsz, num_features, _ = hidden_states.shape
num_blocks = math.ceil(num_features / self.context_size)
remainder = num_features % self.context_size
if remainder > 0:
hidden_states = F.pad(
hidden_states, (0, 0, 0, self.context_size - remainder)
)
query_states = self.to_q(hidden_states)
key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)
# (bsz, num_blocks, num_heads, context_size, dim_head)
query_states = query_states.reshape(
bsz, num_blocks, self.context_size, self.num_heads, -1
).transpose(2, 3)
key_states = key_states.reshape(
bsz, num_blocks, self.context_size, self.num_heads, -1
).transpose(2, 3)
value_states = value_states.reshape(
bsz, num_blocks, self.context_size, self.num_heads, -1
).transpose(2, 3)
# Shaw's relative positional embeddings via einsum — O(C²·D) memory.
# query_states : (bsz, num_blocks, num_heads, context_size, dim_head)
# rel_pos_emb : (context_size, context_size, dim_head)
rel_pos_emb = self.rel_pos_emb(attention_dists)
pos_attn = (
torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb)
* self.scale
)
if remainder > 0:
# Mask out the padding region in the last block.
mask = torch.ones(
self.context_size,
self.context_size,
dtype=torch.bool,
device=hidden_states.device,
)
mask[:remainder, :remainder] = False
mask_value = -torch.finfo(pos_attn.dtype).max
pos_attn[:, -1, :].masked_fill_(mask, mask_value)
out = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=pos_attn,
scale=self.scale,
)
out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1)
out = self.to_out(out[:, :num_features, :])
return self.dropout(out)
class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
def __init__(self, chan_in: int, chan_out: int, kernel_size: int):
super().__init__()
pad = kernel_size // 2
pad_offset = (kernel_size + 1) % 2
self.padding = (pad, pad - pad_offset)
self.conv = nn.Conv1d(
chan_in, chan_out, kernel_size, groups=chan_in, bias=False
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(hidden_states, self.padding)
return self.conv(hidden_states)
class GraniteSpeechConformerConvModule(nn.Module):
def __init__(self, config: GraniteSpeechEncoderConfig):
super().__init__()
inner_dim = config.hidden_dim * config.conv_expansion_factor
self.norm = nn.LayerNorm(config.hidden_dim)
self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1)
self.glu = nn.GLU(dim=1)
self.depth_conv = GraniteSpeechConformerDepthWiseConv1d(
inner_dim, inner_dim, kernel_size=config.conv_kernel_size
)
self.silu = nn.SiLU()
self.batch_norm = nn.BatchNorm1d(inner_dim)
self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1)
self.dropout = nn.Dropout(config.dropout)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
hidden_states = self.up_conv(hidden_states.permute(0, 2, 1))
hidden_states = self.glu(hidden_states)
hidden_states = self.depth_conv(hidden_states)
hidden_states = self.silu(self.batch_norm(hidden_states))
hidden_states = self.down_conv(hidden_states).permute(0, 2, 1)
hidden_states = self.dropout(hidden_states)
return hidden_states
class GraniteSpeechConformerBlock(nn.Module):
def __init__(self, config: GraniteSpeechEncoderConfig):
super().__init__()
self.ff1 = GraniteSpeechConformerFeedForward(config)
self.attn = GraniteSpeechConformerAttention(config)
self.conv = GraniteSpeechConformerConvModule(config)
self.ff2 = GraniteSpeechConformerFeedForward(config)
self.post_norm = nn.LayerNorm(config.hidden_dim)
def forward(
self, hidden_states: torch.Tensor, attention_dists: torch.Tensor
) -> torch.Tensor:
hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states
hidden_states = (
self.attn(hidden_states, attention_dists=attention_dists) + hidden_states
)
hidden_states = self.conv(hidden_states) + hidden_states
hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states
hidden_states = self.post_norm(hidden_states)
return hidden_states
class GraniteSpeechCTCEncoder(nn.Module):
def __init__(self, config: GraniteSpeechEncoderConfig):
super().__init__()
self.config = config
seq = torch.arange(config.context_size)
relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
attention_dists = (
torch.clamp(relpos_dist, -config.context_size, config.context_size)
+ config.max_pos_emb
)
self.register_buffer("attention_dists", attention_dists, persistent=False)
self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True)
self.layers = nn.ModuleList(
[GraniteSpeechConformerBlock(config) for _ in range(config.num_layers)]
)
self.out = nn.Linear(config.hidden_dim, config.output_dim, bias=True)
self.out_mid = nn.Linear(config.output_dim, config.hidden_dim, bias=True)
self.num_layers = config.num_layers
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.input_linear(hidden_states)
for idx, layer in enumerate(self.layers, start=1):
hidden_states = layer(hidden_states, attention_dists=self.attention_dists)
if idx == self.num_layers // 2:
hidden_states_mid = self.out(hidden_states.clone())
# Inline softmax+projection — no persistent self.softmax submodule,
# matching HF exactly so state-dict keys are identical.
hidden_states = hidden_states + self.out_mid(
nn.Softmax(dim=-1)(hidden_states_mid)
)
return hidden_states
class _Blip2QFormerMultiHeadAttention(nn.Module):
def __init__(self, config: Blip2QFormerConfig, is_cross_attention: bool = False):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def _transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
return x.view(new_shape).permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
):
is_cross = encoder_hidden_states is not None
if is_cross:
key_layer = self._transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self._transpose_for_scores(self.value(encoder_hidden_states))
active_mask = encoder_attention_mask
else:
key_layer = self._transpose_for_scores(self.key(hidden_states))
value_layer = self._transpose_for_scores(self.value(hidden_states))
active_mask = attention_mask
query_layer = self._transpose_for_scores(self.query(hidden_states))
scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
scores = scores / math.sqrt(self.attention_head_size)
if active_mask is not None:
scores = scores + active_mask
probs = nn.Softmax(dim=-1)(scores)
probs_dropped = self.dropout(probs).to(value_layer.dtype)
ctx = torch.matmul(probs_dropped, value_layer)
ctx = ctx.permute(0, 2, 1, 3).contiguous()
ctx = ctx.view(ctx.size()[:-2] + (self.all_head_size,))
return ctx, probs
class _Blip2QFormerSelfOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class _Blip2QFormerAttention(nn.Module):
def __init__(self, config: Blip2QFormerConfig, is_cross_attention: bool = False):
super().__init__()
self.attention = _Blip2QFormerMultiHeadAttention(config, is_cross_attention)
self.output = _Blip2QFormerSelfOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
attn_out, _ = self.attention(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
return self.output(attn_out, hidden_states)
class _Blip2QFormerIntermediate(nn.Module):
def __init__(self, config: Blip2QFormerConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
# HF uses ACT2FN["gelu"] which should be the same as nn.GELU()
self.intermediate_act_fn = nn.GELU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.intermediate_act_fn(self.dense(hidden_states))
class _Blip2QFormerOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class _Blip2QFormerLayer(nn.Module):
def __init__(self, config: Blip2QFormerConfig, layer_idx: int):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = _Blip2QFormerAttention(config)
self.layer_idx = layer_idx
self.has_cross_attention = layer_idx % config.cross_attention_frequency == 0
if self.has_cross_attention:
self.crossattention = _Blip2QFormerAttention(
config, is_cross_attention=True
)
# `use_qformer_text_input` is False for GraniteSpeech — only query path.
self.intermediate_query = _Blip2QFormerIntermediate(config)
self.output_query = _Blip2QFormerOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
query_length: int = 0,
) -> torch.Tensor:
attention_output = self.attention(hidden_states, attention_mask=attention_mask)
if query_length > 0:
query_attn_out = attention_output[:, :query_length, :]
if self.has_cross_attention:
if encoder_hidden_states is None:
raise ValueError(
"encoder_hidden_states required for cross-attention"
)
query_attn_out = self.crossattention(
query_attn_out,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
layer_output = self._ff_chunk_query(query_attn_out)
if attention_output.shape[1] > query_length:
# text tokens present — not used in GraniteSpeech inference
layer_output = torch.cat(
[layer_output, attention_output[:, query_length:, :]], dim=1
)
else:
layer_output = self._ff_chunk_query(attention_output)
return layer_output
def _ff_chunk_query(self, attention_output: torch.Tensor) -> torch.Tensor:
intermediate = self.intermediate_query(attention_output)
return self.output_query(intermediate, attention_output)
class _Blip2QFormerEncoder(nn.Module):
def __init__(self, config: Blip2QFormerConfig):
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[_Blip2QFormerLayer(config, i) for i in range(config.num_hidden_layers)]
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
query_length: int = 0,
) -> torch.Tensor:
for layer_module in self.layer:
hidden_states = layer_module(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
query_length=query_length,
)
return hidden_states
class Blip2QFormerModel(nn.Module):
"""
Drop-in replacement for transformers.Blip2QFormerModel.
The module name and all child module names match the HF implementation so
that state-dict keys are identical.
"""
def __init__(self, config: Blip2QFormerConfig):
super().__init__()
self.config = config
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.encoder = _Blip2QFormerEncoder(config)
# Helpers that mirror Blip2QFormerModel.get_extended_attention_mask
# and invert_attention_mask from transformers.
@staticmethod
def _get_extended_attention_mask(
attention_mask: torch.Tensor, dtype: torch.dtype
) -> torch.Tensor:
"""(B, S) or (B, 1, S, S) → additive mask broadcastable to (B, H, S, S)."""
if attention_mask.dim() == 2:
extended = attention_mask[:, None, None, :]
elif attention_mask.dim() == 3:
extended = attention_mask[:, None, :, :]
else:
extended = attention_mask
extended = extended.to(dtype=dtype)
extended = (1.0 - extended) * -10000.0
return extended
@staticmethod
def _invert_attention_mask(
encoder_attention_mask: torch.Tensor, dtype: torch.dtype
) -> torch.Tensor:
"""Convert a 0/1 key-padding mask to an additive bias."""
if encoder_attention_mask.dim() == 3:
inverted = encoder_attention_mask[:, None, :, :]
elif encoder_attention_mask.dim() == 2:
inverted = encoder_attention_mask[:, None, None, :]
else:
inverted = encoder_attention_mask
inverted = inverted.to(dtype=dtype)
return (1.0 - inverted) * torch.finfo(dtype).min
def forward(
self,
query_embeds: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
**_kwargs,
) -> "Blip2QFormerOutput":
# Q-Former is kept in fp32 by HF convention.
dtype = self.layernorm.weight.dtype
query_embeds = query_embeds.to(dtype)
embedding_output = self.dropout(self.layernorm(query_embeds))
batch_size, query_length = embedding_output.shape[:2]
device = embedding_output.device
# Self-attention mask: all queries attend to all queries.
self_attn_mask = torch.ones((batch_size, query_length), device=device)
extended_attn_mask = self._get_extended_attention_mask(self_attn_mask, dtype)
# Cross-attention mask over encoder tokens.
if encoder_hidden_states is not None:
enc_hs = encoder_hidden_states.to(dtype)
if encoder_attention_mask is None:
enc_bsz, enc_len = enc_hs.shape[:2]
encoder_attention_mask = torch.ones((enc_bsz, enc_len), device=device)
encoder_extended_mask = self._invert_attention_mask(
encoder_attention_mask, dtype
)
else:
enc_hs = None
encoder_extended_mask = None
sequence_output = self.encoder(
embedding_output,
attention_mask=extended_attn_mask,
encoder_hidden_states=enc_hs,
encoder_attention_mask=encoder_extended_mask,
query_length=query_length,
)
return Blip2QFormerOutput(last_hidden_state=sequence_output)
class Blip2QFormerOutput:
"""Minimal output container (mirrors BaseModelOutputWithPoolingAndCrossAttentions)."""
__slots__ = ("last_hidden_state",)
def __init__(self, last_hidden_state: torch.Tensor):
self.last_hidden_state = last_hidden_state
class GraniteSpeechEncoderProjector(nn.Module):
def __init__(self, config: GraniteSpeechConfig):
super().__init__()
self.downsample_rate = config.downsample_rate
self.window_size = config.window_size
self.num_queries = config.window_size // config.downsample_rate
# Initialisation with normal_(mean=0, std=1) to match HF
self.query = nn.Parameter(
torch.empty(
1, self.num_queries, config.projector_config.hidden_size
).normal_(0.0, 1.0)
)
self.qformer = Blip2QFormerModel(config.projector_config)
self.linear = nn.Linear(
config.projector_config.hidden_size, config.text_config.hidden_size
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, dim = hidden_states.size()
nblocks = math.ceil(seq_len / self.window_size)
pad = nblocks * self.window_size - seq_len
hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim)
query_output = self.qformer(
query_embeds=self.query,
encoder_hidden_states=hidden_states,
encoder_attention_mask=None,
)
query_proj = self.linear(
query_output.last_hidden_state.view(
batch_size, nblocks * self.window_size // self.downsample_rate, -1
)
)
return query_proj
class GraniteRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class GraniteRotaryEmbedding(nn.Module):
def __init__(self, config: GraniteTextConfig):
super().__init__()
head_dim = config.hidden_size // config.num_attention_heads
inv_freq = 1.0 / (
config.rope_theta
** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(
self, x: torch.Tensor, position_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# position_ids: (batch, seq_len)
inv_freq = self.inv_freq
assert isinstance(inv_freq, torch.Tensor)
inv_freq_expanded = (
inv_freq.unsqueeze(0)
.unsqueeze(-1)
.float()
.expand(position_ids.shape[0], -1, 1)
.to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float()
with torch.autocast(device_type=x.device.type, enabled=False):
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(x.dtype), sin.to(x.dtype)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.unsqueeze(1) # (B, 1, S, D)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (_rotate_half(q) * sin)
k_embed = (k * cos) + (_rotate_half(k) * sin)
return q_embed, k_embed
def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Expand GQA key/value heads to match the number of query heads."""
if n_rep == 1:
return hidden_states
batch, num_kv_heads, slen, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_kv_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)
class GraniteAttention(nn.Module):
def __init__(self, config: GraniteTextConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.scaling = config.attention_multiplier
self.attention_dropout = config.attention_dropout
self.q_proj = nn.Linear(
config.hidden_size,
self.num_heads * self.head_dim,
bias=config.attention_bias,
)
self.k_proj = nn.Linear(
config.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
config.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim,
config.hidden_size,
bias=config.attention_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional["KVCache"] = None,
cache_position: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, None]:
bsz, q_len, _ = hidden_states.shape
hidden_shape = (bsz, q_len, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = _apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_values is not None:
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_position
)
key_states = _repeat_kv(key_states, self.num_key_value_groups)
value_states = _repeat_kv(value_states, self.num_key_value_groups)
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_weights = F.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None
class GraniteMLP(nn.Module):
def __init__(self, config: GraniteTextConfig):
super().__init__()
self.gate_proj = nn.Linear(
config.hidden_size, config.intermediate_size, bias=config.mlp_bias
)
self.up_proj = nn.Linear(
config.hidden_size, config.intermediate_size, bias=config.mlp_bias
)
self.down_proj = nn.Linear(
config.intermediate_size, config.hidden_size, bias=config.mlp_bias
)
self.act_fn = nn.SiLU() # config.hidden_act == "silu" for Granite
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class GraniteDecoderLayer(nn.Module):
def __init__(self, config: GraniteTextConfig, layer_idx: int):
super().__init__()
self.self_attn = GraniteAttention(config, layer_idx=layer_idx)
self.mlp = GraniteMLP(config)
self.input_layernorm = GraniteRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = GraniteRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.residual_multiplier = config.residual_multiplier
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
past_key_values: Optional["KVCache"] = None,
cache_position: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
past_key_values=past_key_values,
cache_position=cache_position,
)
hidden_states = residual + hidden_states * self.residual_multiplier
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states
class KVCache:
"""Simple dynamic key-value cache, interface-compatible with the HF DynamicCache."""
def __init__(self):
self._keys: dict[int, torch.Tensor] = {}
self._values: dict[int, torch.Tensor] = {}
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_position: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if layer_idx in self._keys:
self._keys[layer_idx] = torch.cat(
[self._keys[layer_idx], key_states], dim=2
)
self._values[layer_idx] = torch.cat(
[self._values[layer_idx], value_states], dim=2
)
else:
self._keys[layer_idx] = key_states
self._values[layer_idx] = value_states
return self._keys[layer_idx], self._values[layer_idx]
def get_seq_length(self) -> int:
if not self._keys:
return 0
return next(iter(self._keys.values())).shape[2]
def _make_causal_mask(
seq_len: int,
past_len: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""
Upper-triangular additive causal mask of shape (1, 1, seq_len, seq_len + past_len).
Positions that should be masked get -inf; attended positions get 0.
"""
total_len = seq_len + past_len
mask = torch.full((seq_len, total_len), torch.finfo(dtype).min, device=device)
# Each query position i can attend to keys 0 .. (past_len + i)
attend_cols = torch.arange(total_len, device=device)
query_rows = torch.arange(past_len, past_len + seq_len, device=device)
causal = attend_cols[None, :] <= query_rows[:, None]
mask = mask.masked_fill(causal, 0.0)
return mask.unsqueeze(0).unsqueeze(0) # (1, 1, S, S+past)
class GraniteModel(nn.Module):
def __init__(self, config: GraniteTextConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=self.padding_idx
)
self.layers = nn.ModuleList(
[GraniteDecoderLayer(config, i) for i in range(config.num_hidden_layers)]
)
self.norm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = GraniteRotaryEmbedding(config)
self.embedding_multiplier = config.embedding_multiplier
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[KVCache] = None,
cache_position: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is None:
if input_ids is None:
raise ValueError("Either input_ids or inputs_embeds must be provided")
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
inputs_embeds = inputs_embeds * self.embedding_multiplier
bsz, seq_len, _ = inputs_embeds.shape
past_len = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
if cache_position is None:
cache_position = torch.arange(
past_len, past_len + seq_len, device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# Build additive causal mask, then fold in any padding mask
causal_mask = _make_causal_mask(
seq_len, past_len, inputs_embeds.dtype, inputs_embeds.device
)
if attention_mask is not None and attention_mask.dim() == 2:
# attention_mask is 1 for real tokens, 0 for padding (shape: B, total_len)
# Expand to (B, 1, seq_len, total_len) additive form
pad_mask = (
1.0 - attention_mask[:, None, None, :].to(inputs_embeds.dtype)
) * torch.finfo(inputs_embeds.dtype).min
# pad_mask covers [0 .. total_len]; causal_mask covers the same range
causal_mask = causal_mask + pad_mask
position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
hidden_states = inputs_embeds
for layer in self.layers:
hidden_states = layer(
hidden_states,
attention_mask=causal_mask,
position_embeddings=position_embeddings,
past_key_values=past_key_values,
cache_position=cache_position,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class GraniteForCausalLM(nn.Module):
def __init__(self, config: GraniteTextConfig):
super().__init__()
self.config = config
self.model = GraniteModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def get_input_embeddings(self) -> nn.Embedding:
return self.model.get_input_embeddings()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[KVCache] = None,
cache_position: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> dict:
hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
cache_position=cache_position,
position_ids=position_ids,
)
logits = self.lm_head(hidden_states) / self.config.logits_scaling
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.vocab_size),
shift_labels.view(-1).to(shift_logits.device),
)
return {"loss": loss, "logits": logits, "past_key_values": past_key_values}
# Main class for inference
class GraniteSpeechForConditionalGeneration(nn.Module):
def __init__(self, config: GraniteSpeechConfig):
super().__init__()
self.config = config
self.audio_token_index = config.audio_token_id
self.encoder = GraniteSpeechCTCEncoder(config.encoder_config)
self.projector = GraniteSpeechEncoderProjector(config)
self.language_model = GraniteForCausalLM(config.text_config)
def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings()
def get_audio_features(self, input_features: torch.Tensor) -> torch.Tensor:
encoder_embeds = self.encoder(input_features)
return self.projector(encoder_embeds)
def get_merged_audio_embeddings(
self,
input_ids: torch.Tensor,
audio_features: torch.Tensor,
input_features_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
is_audio_idx = input_ids == self.audio_token_index
# Clamp audio positions to 0 so the LLM embedding lookup never goes OOV
llm_input_ids = torch.where(
is_audio_idx, torch.zeros_like(input_ids), input_ids
)
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
if input_features_mask is not None:
audio_features = audio_features[input_features_mask]
special_audio_mask = is_audio_idx.unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
input_features: Optional[torch.Tensor] = None,
input_features_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[KVCache] = None,
labels: Optional[torch.Tensor] = None,
) -> dict:
model_dtype = next(self.parameters()).dtype
if input_features is not None:
input_features = input_features.to(model_dtype)
audio_embeds = self.get_audio_features(input_features)
inputs_embeds = self.get_merged_audio_embeddings(
input_ids=input_ids,
audio_features=audio_embeds,
input_features_mask=input_features_mask,
)
else:
inputs_embeds = self.get_input_embeddings()(input_ids)
return self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
labels=labels,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment