Skip to content

Instantly share code, notes, and snippets.

@skim0119
Last active January 13, 2024 06:52
Show Gist options
  • Select an option

  • Save skim0119/ee01a11b965342a9491a1e48a2e756b8 to your computer and use it in GitHub Desktop.

Select an option

Save skim0119/ee01a11b965342a9491a1e48a2e756b8 to your computer and use it in GitHub Desktop.
Run batch spike detection using MiV-OS
"""
Main is called from other runfile.
This script parse the configuration yaml file and push the task into some `func`.
"""
import pathlib
import os
from omegaconf import DictConfig, OmegaConf
import hydra
from hydra import compose, initialize
from miv_os_contrib.mpi_logging import *
def get_cfg(config_name, config_path="config"):
with initialize(version_base=None, config_path=config_path):
cfg = compose(config_name=config_name)
return cfg
def parse_cfg(cfg, create_directory=True):
verbose = cfg["verbose"]
skip_if_folder_exist = cfg["skip_if_folder_exist"]
result_dir = cfg["result_folder"]
cache_dir = cfg["cache_folder"]
infos = cfg["recording_information"]
if create_directory:
os.makedirs(result_dir, exist_ok=True)
os.makedirs(cache_dir, exist_ok=True)
return verbose, skip_if_folder_exist, result_dir, cache_dir, infos
def main(func, config_name, index=None):
cfg = get_cfg(config_name)
verbose, skip_if_folder_exist, result_dir, cache_dir, infos = parse_cfg(cfg)
jobx = []
for idx, info in enumerate(infos):
job = func(result_dir, cache_dir, skip_if_folder_exist, info, verbose, cfg)
if job is not None:
jobx.append(job)
[job.result() for i, job in enumerate(jobx) if index is None or i in index]
import os
import sys
current = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.dirname(current))
from main import main
from miv_os_contrib.mpi_logging import *
import parsl
from parsl import bash_app, python_app
from parsl.config import Config
from parsl.channels import LocalChannel
from parsl.providers import SlurmProvider
from parsl.executors import HighThroughputExecutor
from parsl.launchers import SimpleLauncher
from parsl.addresses import address_by_hostname
worker_init = ""
config = Config(
executors=[
HighThroughputExecutor(
label='frontera_htex',
address=address_by_hostname(),
# This option sets our 1 manager running on the lead node of the job
# to spin up enough workers to concurrently invoke `ibrun <mpi_app>` calls
max_workers=56,
cores_per_worker=1,
# Set the heartbeat params to avoid faults from periods of network unavailability
# Addresses network drop concern from older Claire communication
heartbeat_period=60,
heartbeat_threshold=300,
provider=SlurmProvider(
partition='development',
channel=LocalChannel(),
cmd_timeout=60,
nodes_per_block=1,
walltime="02:00:00",
# Set scaling limits
init_blocks=1,
min_blocks=0,
max_blocks=1,
# Specify number of ranks
launcher=SimpleLauncher(),
worker_init=worker_init,
exclusive=False,
),
),
],
)
parsl.load(config)
@python_app
def run_spike_detection(result_dir, cache_dir, skip_if_folder_exist, info, verbose, cfg):
#from mpi4py import MPI
import os
import pathlib
from miv.core.pipeline import Pipeline
from miv.signal.filter import ButterBandpass
from miv.signal.spike import ThresholdCutoff
from miv.core.operator.policy import SupportMPIMerge
import logging
logging.basicConfig(level=logging.INFO)
folder_path = info["path"]
result_path = pathlib.Path(result_dir)
cache_path = pathlib.Path(cache_dir)
for tag in info["tag"]:
result_path = result_path / tag
cache_path = cache_path / tag
device = info["recording_device"]
os.makedirs(result_path, exist_ok=True)
os.makedirs(cache_path, exist_ok=True)
# 0. put log
memo_filename = os.path.join(result_path, "memo.txt")
with open(memo_filename, "w") as ffile:
ffile.write(info["memo"])
# 1. Process OE recording spiketrain and events
if device == "OpenEphys":
from miv.io.openephys import DataManager
data = DataManager(folder_path)[info["data_index"]]
#data.configure_load(mpi_comm=MPI.COMM_WORLD)
elif device == "Intan":
#from miv_os_contrib.intan import DataIntanMPI
#data = DataIntanMPI(folder_path)
from miv.io.intan import DataIntan
data = DataIntan(folder_path)
elif device == "H5":
from miv.io.file.import_signal import ImportSignal
data = ImportSignal(folder_path)
else:
raise NotImplementedError
bandpass_filter = ButterBandpass(lowcut=400, highcut=1500, order=4)
spike_detection = ThresholdCutoff()
data >> bandpass_filter >> spike_detection
Pipeline(spike_detection).run(result_path, cache_path, verbose=verbose)
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.INFO)
config_name = "config.yaml"
main(run_spike_detection, config_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment