Last active
February 22, 2024 16:43
-
-
Save francescopapaleo/cc952bceb113a8a49d35b655d9616adc to your computer and use it in GitHub Desktop.
A class to generate various types of audio signals directly as PyTorch tensors.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Torch Signal Generator Class | |
| Copyright (C) 2024 Francesco Papaleo | |
| This program is free software: you can redistribute it and/or modify | |
| it under the terms of the GNU General Public License as published by | |
| the Free Software Foundation, either version 3 of the License, or | |
| (at your option) any later version. | |
| This program is distributed in the hope that it will be useful, | |
| but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| GNU General Public License for more details. | |
| You should have received a copy of the GNU General Public License | |
| along with this program. If not, see <https://www.gnu.org/licenses/>. | |
| """ | |
| import torch | |
| import torchaudio | |
| import logging | |
| import sounddevice as sd | |
| from typing import List, Optional, Union | |
| def dbfs_to_amp(dbfs: float = -9) -> float: | |
| """ | |
| Convert decibels full scale dBFS to a linear amplitude scale. | |
| Parameters | |
| ---------- | |
| dbfs : float | |
| The amplitude value in dBFS to convert. | |
| Returns | |
| ------- | |
| float | |
| The converted linear amplitude value. | |
| """ | |
| return 10 ** (dbfs / 20) | |
| class TorchSignalGenerator: | |
| """ | |
| A class to generate various types of audio signals directly as PyTorch tensors. | |
| Attributes | |
| ---------- | |
| sample_rate : int | |
| The sample rate of the generated signals. | |
| channel_mode : str | |
| The channel mode of the generated signals, either "mono" or "stereo". | |
| """ | |
| def __init__(self, | |
| sample_rate=48000, | |
| channel_mode="mono", | |
| ): | |
| self.sample_rate = sample_rate | |
| self.channel_mode = channel_mode | |
| def place_signal( | |
| self, | |
| signal: torch.Tensor, | |
| start_time: float, | |
| signal_duration: float, | |
| total_duration: float, | |
| ) -> torch.Tensor: | |
| """ | |
| Places a generated signal within a specified total duration with silence before and after. | |
| Parameters | |
| ---------- | |
| signal : | |
| The generated signal to place. | |
| start_time : float | |
| The start time in seconds to place the signal within the total duration. | |
| signal_duration : float | |
| The duration of the signal in seconds. | |
| total_duration : float | |
| The total duration of the resulting signal including silence in seconds. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| The signal placed within the total duration, adjusted for the specified channel mode. | |
| """ | |
| total_samples = int(self.sample_rate * total_duration) | |
| start_samples = int(self.sample_rate * start_time) | |
| signal_samples = int(self.sample_rate * signal_duration) | |
| difference = total_samples - (start_samples + signal_samples) | |
| # Ensure the signal is placed within the total duration | |
| if difference < 0: | |
| logging.warning( | |
| f"Signal duration exceeds total duration by {-difference / self.sample_rate} seconds" | |
| ) | |
| # Place the signal within the total duration | |
| initial_silence = torch.zeros(start_samples) | |
| final_silence = torch.zeros(difference) | |
| placed_signal = torch.cat( | |
| [initial_silence, signal, final_silence], dim=0 | |
| ) | |
| # Adjust the signal based on the channel mode | |
| if self.channel_mode == "mono": # [1, num_samples] | |
| placed_signal = placed_signal.unsqueeze(0) | |
| elif self.channel_mode == "stereo": # [2, num_samples] | |
| # Duplicate the signal for both channels | |
| placed_signal = torch.stack([placed_signal, placed_signal], dim=0) | |
| return placed_signal | |
| def auto_place_signal(func): | |
| def wrapper(self, *args, **kwargs): | |
| # Extract or default the placement parameters | |
| start_time = kwargs.pop("start_time", 0) # Default values as examples | |
| signal_duration = kwargs.pop("signal_duration", 1) | |
| total_duration = kwargs.pop("total_duration", 5) | |
| signal = func(self, *args, **kwargs) # Generate the signal | |
| placed_signal = self.place_signal( | |
| signal.squeeze(), start_time, signal_duration, total_duration | |
| ) | |
| return placed_signal | |
| return wrapper | |
| @auto_place_signal | |
| def gen_impulse(self, signal_duration=0.001, amplitude_dbfs=-1) -> torch.Tensor: | |
| """Generates an impulse signal.""" | |
| amplitude = dbfs_to_amp(amplitude_dbfs) | |
| impulse = torch.zeros(int(self.sample_rate * signal_duration)) | |
| impulse[0] = amplitude # Set the first sample to the specified amplitude | |
| return impulse | |
| @auto_place_signal | |
| def gen_sine( | |
| self, frequency=440, signal_duration=1, amplitude_dbfs=-1 | |
| ) -> torch.Tensor: | |
| """Generates a sine wave signal.""" | |
| amplitude = dbfs_to_amp(amplitude_dbfs) | |
| t = torch.linspace(0, signal_duration, int(self.sample_rate * signal_duration)) | |
| sine_wave = amplitude * torch.sin(2 * torch.pi * frequency * t) | |
| return sine_wave | |
| @auto_place_signal | |
| def gen_sawtooth( | |
| self, frequency=440, signal_duration=1, amplitude_dbfs=-1 | |
| ) -> torch.Tensor: | |
| """Generates a sawtooth wave signal.""" | |
| amplitude = dbfs_to_amp(amplitude_dbfs) | |
| t = torch.linspace(0, signal_duration, int(self.sample_rate * signal_duration)) | |
| phase = (t * frequency) % 1.0 | |
| sawtooth_wave = 2 * amplitude * (phase - 0.5) | |
| return sawtooth_wave | |
| @auto_place_signal | |
| def gen_square( | |
| self, frequency=440, signal_duration=1, amplitude_dbfs=-1, duty_cycle=0.5 | |
| ) -> torch.Tensor: | |
| """Generates a square wave signal with correct duty cycle logic.""" | |
| amplitude = dbfs_to_amp(amplitude_dbfs) | |
| t = torch.linspace(0, signal_duration, int(self.sample_rate * signal_duration)) | |
| # Compute the phase of the wave, which is used to create the square wave based on the duty cycle | |
| phase = (t * frequency) % 1.0 | |
| square_wave = torch.where(phase < duty_cycle, amplitude, -amplitude) | |
| return square_wave |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment