Last active
January 13, 2024 06:52
-
-
Save skim0119/ee01a11b965342a9491a1e48a2e756b8 to your computer and use it in GitHub Desktop.
Run batch spike detection using MiV-OS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Main is called from other runfile. | |
| This script parse the configuration yaml file and push the task into some `func`. | |
| """ | |
| import pathlib | |
| import os | |
| from omegaconf import DictConfig, OmegaConf | |
| import hydra | |
| from hydra import compose, initialize | |
| from miv_os_contrib.mpi_logging import * | |
| def get_cfg(config_name, config_path="config"): | |
| with initialize(version_base=None, config_path=config_path): | |
| cfg = compose(config_name=config_name) | |
| return cfg | |
| def parse_cfg(cfg, create_directory=True): | |
| verbose = cfg["verbose"] | |
| skip_if_folder_exist = cfg["skip_if_folder_exist"] | |
| result_dir = cfg["result_folder"] | |
| cache_dir = cfg["cache_folder"] | |
| infos = cfg["recording_information"] | |
| if create_directory: | |
| os.makedirs(result_dir, exist_ok=True) | |
| os.makedirs(cache_dir, exist_ok=True) | |
| return verbose, skip_if_folder_exist, result_dir, cache_dir, infos | |
| def main(func, config_name, index=None): | |
| cfg = get_cfg(config_name) | |
| verbose, skip_if_folder_exist, result_dir, cache_dir, infos = parse_cfg(cfg) | |
| jobx = [] | |
| for idx, info in enumerate(infos): | |
| job = func(result_dir, cache_dir, skip_if_folder_exist, info, verbose, cfg) | |
| if job is not None: | |
| jobx.append(job) | |
| [job.result() for i, job in enumerate(jobx) if index is None or i in index] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import sys | |
| current = os.path.dirname(os.path.realpath(__file__)) | |
| sys.path.append(os.path.dirname(current)) | |
| from main import main | |
| from miv_os_contrib.mpi_logging import * | |
| import parsl | |
| from parsl import bash_app, python_app | |
| from parsl.config import Config | |
| from parsl.channels import LocalChannel | |
| from parsl.providers import SlurmProvider | |
| from parsl.executors import HighThroughputExecutor | |
| from parsl.launchers import SimpleLauncher | |
| from parsl.addresses import address_by_hostname | |
| worker_init = "" | |
| config = Config( | |
| executors=[ | |
| HighThroughputExecutor( | |
| label='frontera_htex', | |
| address=address_by_hostname(), | |
| # This option sets our 1 manager running on the lead node of the job | |
| # to spin up enough workers to concurrently invoke `ibrun <mpi_app>` calls | |
| max_workers=56, | |
| cores_per_worker=1, | |
| # Set the heartbeat params to avoid faults from periods of network unavailability | |
| # Addresses network drop concern from older Claire communication | |
| heartbeat_period=60, | |
| heartbeat_threshold=300, | |
| provider=SlurmProvider( | |
| partition='development', | |
| channel=LocalChannel(), | |
| cmd_timeout=60, | |
| nodes_per_block=1, | |
| walltime="02:00:00", | |
| # Set scaling limits | |
| init_blocks=1, | |
| min_blocks=0, | |
| max_blocks=1, | |
| # Specify number of ranks | |
| launcher=SimpleLauncher(), | |
| worker_init=worker_init, | |
| exclusive=False, | |
| ), | |
| ), | |
| ], | |
| ) | |
| parsl.load(config) | |
| @python_app | |
| def run_spike_detection(result_dir, cache_dir, skip_if_folder_exist, info, verbose, cfg): | |
| #from mpi4py import MPI | |
| import os | |
| import pathlib | |
| from miv.core.pipeline import Pipeline | |
| from miv.signal.filter import ButterBandpass | |
| from miv.signal.spike import ThresholdCutoff | |
| from miv.core.operator.policy import SupportMPIMerge | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| folder_path = info["path"] | |
| result_path = pathlib.Path(result_dir) | |
| cache_path = pathlib.Path(cache_dir) | |
| for tag in info["tag"]: | |
| result_path = result_path / tag | |
| cache_path = cache_path / tag | |
| device = info["recording_device"] | |
| os.makedirs(result_path, exist_ok=True) | |
| os.makedirs(cache_path, exist_ok=True) | |
| # 0. put log | |
| memo_filename = os.path.join(result_path, "memo.txt") | |
| with open(memo_filename, "w") as ffile: | |
| ffile.write(info["memo"]) | |
| # 1. Process OE recording spiketrain and events | |
| if device == "OpenEphys": | |
| from miv.io.openephys import DataManager | |
| data = DataManager(folder_path)[info["data_index"]] | |
| #data.configure_load(mpi_comm=MPI.COMM_WORLD) | |
| elif device == "Intan": | |
| #from miv_os_contrib.intan import DataIntanMPI | |
| #data = DataIntanMPI(folder_path) | |
| from miv.io.intan import DataIntan | |
| data = DataIntan(folder_path) | |
| elif device == "H5": | |
| from miv.io.file.import_signal import ImportSignal | |
| data = ImportSignal(folder_path) | |
| else: | |
| raise NotImplementedError | |
| bandpass_filter = ButterBandpass(lowcut=400, highcut=1500, order=4) | |
| spike_detection = ThresholdCutoff() | |
| data >> bandpass_filter >> spike_detection | |
| Pipeline(spike_detection).run(result_path, cache_path, verbose=verbose) | |
| if __name__ == "__main__": | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| config_name = "config.yaml" | |
| main(run_spike_detection, config_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment