Skip to content

Instantly share code, notes, and snippets.

@alexrockhill
Created February 2, 2023 23:07
Show Gist options
  • Select an option

  • Save alexrockhill/152022ff81ae852e72b76a344753d85a to your computer and use it in GitHub Desktop.

Select an option

Save alexrockhill/152022ff81ae852e72b76a344753d85a to your computer and use it in GitHub Desktop.

Revisions

  1. alexrockhill created this gist Feb 2, 2023.
    118 changes: 118 additions & 0 deletions flat_map_brain.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,118 @@
    import os
    import os.path as op
    import numpy as np
    import mne
    import imageio

    misc_path = mne.datasets.misc.data_path()
    sample_path = mne.datasets.sample.data_path()
    subjects_dir = sample_path / 'subjects'
    subject = 'fsaverage'

    raw = mne.io.read_raw(sample_path / 'MEG' / 'sample' / \
    'sample_audvis_filt-0-40_raw.fif')
    trans = mne.coreg.estimate_head_mri_t(subject, subjects_dir)

    view_kwargs = dict(azimuth=120, elevation=100, distance=600,
    focalpoint=(0, 0, -15))

    surf_data = dict(lh=dict(), rh=dict())
    x_dir = np.array([1., 0., 0.])
    for hemi in ('lh', 'rh'):
    for surf in ('pial', 'inflated', 'curv', 'cortex.patch.flat'):
    for img in ('', '.T1', '.T2', ''):
    surf_fname = op.join(subjects_dir, subject, 'surf',
    f'{hemi}.{surf}')
    if op.isfile(surf_fname):
    break
    if surf == 'curv':
    surf_data[hemi]['curv'] = np.array(mne.surface.read_curvature(
    surf_fname, binary=False) > 0, np.int64)
    else:
    if surf.split('.')[-1] == 'flat':
    surf = 'flat'
    coords, faces, orig_faces = mne.surface._read_patch(surf_fname)
    # rotate 90 degrees to get to a more standard orientation
    # where X determines the distance between the hemis
    coords = coords[:, [1, 0, 2]]
    coords[:, 1] *= -1
    else:
    coords, faces = mne.surface.read_surface(surf_fname)
    if surf in ('inflated', 'flat'):
    x_ = coords @ x_dir
    coords -= (np.max(x_) if hemi == 'lh' else np.min(x_)) * x_dir
    surface = dict(rr=coords, tris=faces)
    mne.surface.complete_surface_info(
    surface, copy=False, verbose=False, do_neighbor_tri=False)
    surf_data[hemi][surf] = surface['rr'], surface['tris'], surface['nn']


    for hemi in ('lh', 'rh'):
    surf_data[hemi]['vectors'] = \
    surf_data[hemi]['inflated'][0] - surf_data[hemi]['pial'][0]
    surf_data[hemi]['normal_vectors'] = \
    surf_data[hemi]['inflated'][2] - surf_data[hemi]['pial'][2]
    surf_data[hemi]['vectors2'] = \
    surf_data[hemi]['flat'][0] - surf_data[hemi]['inflated'][0]
    surf_data[hemi]['normal_vectors2'] = \
    surf_data[hemi]['flat'][2] - surf_data[hemi]['inflated'][2]


    images = list()
    view_kwargs = dict(azimuth=120, elevation=90)
    brain = mne.viz.Brain(subject, subjects_dir=subjects_dir, surf='flat',
    cortex='low_contrast', alpha=1, background='white')
    brain._renderer.plotter.camera.focal_point = (0, 0, 0)

    # brain.add_annotation('aparc.a2009s', borders=False, alpha=0.5)

    images += [brain.screenshot()] * 10

    elevation_delta = 20
    azimuth_delta = 20
    n_steps = 201
    for t in np.linspace(0, 1, n_steps):

    for hemi in ('lh', 'rh'):
    coords, faces, nn = surf_data[hemi]['flat']
    coords = coords.copy()
    coords -= surf_data[hemi]['vectors2'] * t
    nn = nn.copy()
    nn -= surf_data[hemi]['normal_vectors2'] * t
    brain._renderer.plotter.update_coordinates(
    coords, brain._layered_meshes[hemi]._polydata, render=False)
    brain._layered_meshes[hemi]._polydata.point_data.active_normals = nn

    brain._renderer.plotter.camera.zoom(1 + 1 / n_steps)
    brain._renderer.plotter.camera.elevation = elevation_delta * t
    brain._renderer.plotter.camera.azimuth = azimuth_delta * t

    brain._renderer.plotter.update()
    images.append(brain.screenshot())

    for i in range(5):
    images.append(images[-1])

    n_steps = 51
    for t in np.linspace(0, 1, n_steps):

    for hemi in ('lh', 'rh'):
    coords, faces, nn = surf_data[hemi]['inflated']
    coords = coords.copy()
    coords -= surf_data[hemi]['vectors'] * t
    nn = nn.copy()
    nn -= surf_data[hemi]['normal_vectors'] * t
    brain._renderer.plotter.update_coordinates(
    coords, brain._layered_meshes[hemi]._polydata, render=False)
    brain._layered_meshes[hemi]._polydata.point_data.active_normals = nn
    brain._layered_meshes[hemi].update_overlay('curv', opacity=1 - t * 0.6)

    brain._renderer._update()
    images.append(brain.screenshot())

    for i in range(5):
    images.append(images[-1])

    brain.close()

    imageio.mimwrite('flat.mp4', images, fps=24)