Skip to content

Instantly share code, notes, and snippets.

@francescopapaleo
Last active February 22, 2024 16:43
Show Gist options
  • Select an option

  • Save francescopapaleo/cc952bceb113a8a49d35b655d9616adc to your computer and use it in GitHub Desktop.

Select an option

Save francescopapaleo/cc952bceb113a8a49d35b655d9616adc to your computer and use it in GitHub Desktop.
A class to generate various types of audio signals directly as PyTorch tensors.
"""
Torch Signal Generator Class
Copyright (C) 2024 Francesco Papaleo
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import torchaudio
import logging
import sounddevice as sd
from typing import List, Optional, Union
def dbfs_to_amp(dbfs: float = -9) -> float:
"""
Convert decibels full scale dBFS to a linear amplitude scale.
Parameters
----------
dbfs : float
The amplitude value in dBFS to convert.
Returns
-------
float
The converted linear amplitude value.
"""
return 10 ** (dbfs / 20)
class TorchSignalGenerator:
"""
A class to generate various types of audio signals directly as PyTorch tensors.
Attributes
----------
sample_rate : int
The sample rate of the generated signals.
channel_mode : str
The channel mode of the generated signals, either "mono" or "stereo".
"""
def __init__(self,
sample_rate=48000,
channel_mode="mono",
):
self.sample_rate = sample_rate
self.channel_mode = channel_mode
def place_signal(
self,
signal: torch.Tensor,
start_time: float,
signal_duration: float,
total_duration: float,
) -> torch.Tensor:
"""
Places a generated signal within a specified total duration with silence before and after.
Parameters
----------
signal :
The generated signal to place.
start_time : float
The start time in seconds to place the signal within the total duration.
signal_duration : float
The duration of the signal in seconds.
total_duration : float
The total duration of the resulting signal including silence in seconds.
Returns
-------
torch.Tensor
The signal placed within the total duration, adjusted for the specified channel mode.
"""
total_samples = int(self.sample_rate * total_duration)
start_samples = int(self.sample_rate * start_time)
signal_samples = int(self.sample_rate * signal_duration)
difference = total_samples - (start_samples + signal_samples)
# Ensure the signal is placed within the total duration
if difference < 0:
logging.warning(
f"Signal duration exceeds total duration by {-difference / self.sample_rate} seconds"
)
# Place the signal within the total duration
initial_silence = torch.zeros(start_samples)
final_silence = torch.zeros(difference)
placed_signal = torch.cat(
[initial_silence, signal, final_silence], dim=0
)
# Adjust the signal based on the channel mode
if self.channel_mode == "mono": # [1, num_samples]
placed_signal = placed_signal.unsqueeze(0)
elif self.channel_mode == "stereo": # [2, num_samples]
# Duplicate the signal for both channels
placed_signal = torch.stack([placed_signal, placed_signal], dim=0)
return placed_signal
def auto_place_signal(func):
def wrapper(self, *args, **kwargs):
# Extract or default the placement parameters
start_time = kwargs.pop("start_time", 0) # Default values as examples
signal_duration = kwargs.pop("signal_duration", 1)
total_duration = kwargs.pop("total_duration", 5)
signal = func(self, *args, **kwargs) # Generate the signal
placed_signal = self.place_signal(
signal.squeeze(), start_time, signal_duration, total_duration
)
return placed_signal
return wrapper
@auto_place_signal
def gen_impulse(self, signal_duration=0.001, amplitude_dbfs=-1) -> torch.Tensor:
"""Generates an impulse signal."""
amplitude = dbfs_to_amp(amplitude_dbfs)
impulse = torch.zeros(int(self.sample_rate * signal_duration))
impulse[0] = amplitude # Set the first sample to the specified amplitude
return impulse
@auto_place_signal
def gen_sine(
self, frequency=440, signal_duration=1, amplitude_dbfs=-1
) -> torch.Tensor:
"""Generates a sine wave signal."""
amplitude = dbfs_to_amp(amplitude_dbfs)
t = torch.linspace(0, signal_duration, int(self.sample_rate * signal_duration))
sine_wave = amplitude * torch.sin(2 * torch.pi * frequency * t)
return sine_wave
@auto_place_signal
def gen_sawtooth(
self, frequency=440, signal_duration=1, amplitude_dbfs=-1
) -> torch.Tensor:
"""Generates a sawtooth wave signal."""
amplitude = dbfs_to_amp(amplitude_dbfs)
t = torch.linspace(0, signal_duration, int(self.sample_rate * signal_duration))
phase = (t * frequency) % 1.0
sawtooth_wave = 2 * amplitude * (phase - 0.5)
return sawtooth_wave
@auto_place_signal
def gen_square(
self, frequency=440, signal_duration=1, amplitude_dbfs=-1, duty_cycle=0.5
) -> torch.Tensor:
"""Generates a square wave signal with correct duty cycle logic."""
amplitude = dbfs_to_amp(amplitude_dbfs)
t = torch.linspace(0, signal_duration, int(self.sample_rate * signal_duration))
# Compute the phase of the wave, which is used to create the square wave based on the duty cycle
phase = (t * frequency) % 1.0
square_wave = torch.where(phase < duty_cycle, amplitude, -amplitude)
return square_wave
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment