Skip to content

Instantly share code, notes, and snippets.

@cryos
Created February 18, 2020 14:38
Show Gist options
  • Select an option

  • Save cryos/4e10febd17ef6f8f1b4ab1a93e10a6d7 to your computer and use it in GitHub Desktop.

Select an option

Save cryos/4e10febd17ef6f8f1b4ab1a93e10a6d7 to your computer and use it in GitHub Desktop.
try:
import numpy as np
import tomopy
import h5py
import matplotlib.pylab as plt
except ImportError:
pass
def find_nearest(data, value):
data = np.array(data)
return np.abs(data - value).argmin()
class IndexTracker(object):
def __init__(self, ax, X):
self.ax = ax
self._indx_txt = ax.set_title(' ', loc='center')
self.X = X
self.slices, rows, cols = X.shape
self.ind = self.slices//2
self.im = ax.imshow(self.X[self.ind, :, :], cmap='gray')
self.update()
def onscroll(self, event):
if event.button == 'up':
self.ind = (self.ind + 1) % self.slices
else:
self.ind = (self.ind - 1) % self.slices
self.update()
def update(self):
self.im.set_data(self.X[self.ind, :, :])
#self.ax.set_ylabel('slice %s' % self.ind)
self._indx_txt.set_text(f"frame {self.ind + 1} of {self.slices}")
self.im.axes.figure.canvas.draw()
def image_scrubber(data, *, ax=None):
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure
tracker = IndexTracker(ax, data)
# monkey patch the tracker onto the figure to keep it alive
fig._tracker = tracker
fig.canvas.mpl_connect('scroll_event', tracker.onscroll)
return tracker
def find_rot(fn, thresh=0.05):
from pystackreg import StackReg
sr = StackReg(StackReg.TRANSLATION)
f = h5py.File(fn, 'r')
img_bkg = np.squeeze(np.array(f['img_bkg_avg']))
ang = np.array(list(f['angle']))
tmp = np.abs(ang - ang[0] -180).argmin()
img0 = np.array(list(f['img_tomo'][0]))
img180_raw = np.array(list(f['img_tomo'][tmp]))
f.close()
img0 = img0 / img_bkg
img180_raw = img180_raw / img_bkg
img180 = img180_raw[:,::-1]
s = np.squeeze(img0.shape)
im1 = -np.log(img0)
im2 = -np.log(img180)
im1[np.isnan(im1)] = 0
im2[np.isnan(im2)] = 0
im1[im1 < thresh] = 0
im2[im2 < thresh] = 0
im1 = medfilt2d(im1,5)
im2 = medfilt2d(im2, 5)
im1_fft = np.fft.fft2(im1)
im2_fft = np.fft.fft2(im2)
results = dftregistration(im1_fft, im2_fft)
row_shift = results[2]
col_shift = results[3]
rot_cen = s[1]/2 + col_shift/2 - 1
tmat = sr.register(im1, im2)
rshft = -tmat[1, 2]
cshft = -tmat[0, 2]
rot_cen0 = s[1]/2 + cshft/2 - 1
print(f'rot_cen = {rot_cen} or {rot_cen0}')
return rot_cen
def rotcen_test(fn, start=None, stop=None, steps=None, sli=0, block_list=[], return_flag=0, print_flag=1, bkg_level=0, txm_normed_flag=0, denoise_flag=0, denoise_level=9):
import tomopy
f = h5py.File(fn, 'r')
tmp = np.array(f['img_tomo'][0])
s = [1, tmp.shape[0], tmp.shape[1]]
if denoise_flag:
import skimage.restoration as skr
addition_slice = 100
psf = 2
psf = np.ones([psf, psf])/(psf**2)
reg = None
balance = 0.3
is_real=True
clip = True
else:
addition_slice = 0
if sli == 0: sli = int(s[1]/2)
sli_exp = [np.max([0, sli-addition_slice//2]), np.min([sli+addition_slice//2+1, s[1]])]
theta = np.array(f['angle']) / 180.0 * np.pi
img_tomo = np.array(f['img_tomo'][:, sli_exp[0]:sli_exp[1], :])
if txm_normed_flag:
prj = img_tomo
else:
img_bkg = np.array(f['img_bkg_avg'][:, sli_exp[0]:sli_exp[1], :])
img_dark = np.array(f['img_dark_avg'][:, sli_exp[0]:sli_exp[1], :])
prj = (img_tomo - img_dark) / (img_bkg - img_dark)
f.close()
prj_norm = -np.log(prj)
prj_norm[np.isnan(prj_norm)] = 0
prj_norm[np.isinf(prj_norm)] = 0
prj_norm[prj_norm < 0] = 0
prj_norm -= bkg_level
prj_norm = tomopy.prep.stripe.remove_stripe_fw(prj_norm,level=denoise_level, wname='db5', sigma=1, pad=True)
if denoise_flag: # denoise using wiener filter
ss = prj_norm.shape
for i in range(ss[0]):
prj_norm[i] = skr.wiener(prj_norm[i], psf=psf, reg=reg, balance=balance, is_real=is_real, clip=clip)
s = prj_norm.shape
if len(s) == 2:
prj_norm = prj_norm.reshape(s[0], 1, s[1])
s = prj_norm.shape
pos = find_nearest(theta, theta[0]+np.pi)
block_list = list(block_list) + list(np.arange(pos+1, len(theta)))
if len(block_list):
allow_list = list(set(np.arange(len(prj_norm))) - set(block_list))
prj_norm = prj_norm[allow_list]
theta = theta[allow_list]
if start==None or stop==None or steps==None:
start = int(s[2]/2-50)
stop = int(s[2]/2+50)
steps = 26
cen = np.linspace(start, stop, steps)
img = np.zeros([len(cen), s[2], s[2]])
for i in range(len(cen)):
if print_flag:
print('{}: rotcen {}'.format(i+1, cen[i]))
img[i] = tomopy.recon(prj_norm[:, addition_slice:addition_slice+1], theta, center=cen[i], algorithm='gridrec')
fout = 'center_test.h5'
with h5py.File(fout, 'w') as hf:
hf.create_dataset('img', data=img)
hf.create_dataset('rot_cen', data=cen)
img = tomopy.circ_mask(img, axis=0, ratio=0.8)
tracker = image_scrubber(img)
if return_flag:
return img, cen
def img_variance(img):
import tomopy
s = img.shape
variance = np.zeros(s[0])
img = tomopy.circ_mask(img, axis=0, ratio=0.8)
for i in range(s[0]):
img[i] = medfilt2d(img[i], 5)
img_ = img[i].flatten()
t = img_>0
img_ = img_[t]
t = np.mean(img_)
variance[i] = np.sqrt(np.sum(np.power(np.abs(img_ - t), 2))/len(img_-1))
return variance
def recon(dataset, rot_cen, sli=[], binning=None, zero_flag=0, block_list=[], bkg_level=0, txm_normed_flag=0, read_full_memory=0, denoise_flag=0, denoise_level=9):
'''
reconstruct 3D tomography
Inputs:
--------
fn: string
filename of scan, e.g. 'fly_scan_0001.h5'
rot_cen: float
rotation center
sli: list
a range of slice to recontruct, e.g. [100:300]
bingning: int
binning the reconstruted 3D tomographic image
zero_flag: bool
if 1: set negative pixel value to 0
if 0: keep negative pixel value
block_list: list
a list of index for the projections that will not be considered in reconstruction
'''
# from PIL import Image
#f = h5py.File(fn, 'r')
print("active_scalars.shape: ", dataset.active_scalars.shape)
order = [2, 1, 0]
tmp = np.transpose(dataset.active_scalars, order)[0]
print("tmp.shape: ", tmp.shape)
s = [1, tmp.shape[0], tmp.shape[1]]
slice_info = ''
bin_info = ''
col_info = ''
if len(sli) == 0:
sli = [0, s[1]]
elif len(sli) == 1 and sli[0] >=0 and sli[0] <= s[1]:
sli = [sli[0], sli[0]+1]
slice_info = '_slice_{}'.format(sli[0])
elif len(sli) == 2 and sli[0] >=0 and sli[1] <= s[1]:
slice_info = '_slice_{}_{}'.format(sli[0], sli[1])
else:
print('non valid slice id, will take reconstruction for the whole object')
'''
if len(col) == 0:
col = [0, s[2]]
elif len(col) == 1 and col[0] >=0 and col[0] <= s[2]:
col = [col[0], col[0]+1]
col_info = '_col_{}'.format(col[0])
elif len(col) == 2 and col[0] >=0 and col[1] <= s[2]:
col_info = '_col_{}_{}'.format(col[0], col[1])
else:
col = [0, s[2]]
print('invalid col id, will take reconstruction for the whole object')
'''
#rot_cen = rot_cen - col[0]
#scan_id = np.array(f['scan_id'])
scan_id = 'tomviz-kitware-result'
eng = '42'
theta = np.array(dataset.tilt_angles) / 180.0 * np.pi
#eng = np.array(f['X_eng'])
pos = find_nearest(theta, theta[0]+np.pi)
block_list = list(block_list) + list(np.arange(pos+1, len(theta)))
allow_list = list(set(np.arange(len(theta))) - set(block_list))
theta = theta[allow_list]
# Doesn't seem necessary, skipping for now...
#tmp = np.squeeze(np.array(f['img_tomo'][0]))
s = tmp.shape
#f.close()
sli_step = 40
sli_total = np.arange(sli[0], sli[1])
binning = binning if binning else 1
bin_info = f'_bin_{binning}'
n_steps = int(len(sli_total) / sli_step)
rot_cen = rot_cen * 1.0 / binning
if read_full_memory:
sli_step = sli[1] - sli[0]
n_steps = 1
if denoise_flag:
add_slice = min(sli_step // 2, 20)
wiener_param = {}
psf = 2
wiener_param['psf'] = np.ones([psf, psf])/(psf**2)
wiener_param['reg'] = None
wiener_param['balance'] = 0.3
wiener_param['is_real']=True
wiener_param['clip'] = True
else:
add_slice = 0
wiener_param = []
try:
rec = np.zeros([sli_step*n_steps // binning, s[1] // binning, s[1] // binning], dtype=np.float32)
print("rec.shape: ", rec.shape)
except:
print('Cannot allocate memory')
'''
# first sli_step slices: will not do any denoising
prj_norm = proj_normalize(fn, [0, sli_step], txm_normed_flag, binning, allow_list, bkg_level)
prj_norm = wiener_denoise(prj_norm, wiener_param, denoise_flag)
rec_sub = tomopy.recon(prj_norm, theta, center=rot_cen, algorithm='gridrec')
rec[0 : rec_sub.shape[0]] = rec_sub
'''
# following slices
for i in range(n_steps):
if i == 0:
sli_sub = [0, sli_step]
current_sli = sli_sub
elif i == n_steps-1:
sli_sub = [i*sli_step+sli_total[0], len(sli_total)+sli[0]]
current_sli = sli_sub
else:
sli_sub = [i*sli_step+sli_total[0], (i+1)*sli_step+sli_total[0]]
current_sli = [sli_sub[0]-add_slice, sli_sub[1]+add_slice]
print(f'recon {i+1}/{n_steps}: sli = [{sli_sub[0]}, {sli_sub[1]}] ... ')
prj_norm = proj_normalize(dataset, current_sli, txm_normed_flag, binning, allow_list, bkg_level, denoise_level)
prj_norm = wiener_denoise(prj_norm, wiener_param, denoise_flag)
if i!=0 and i!=n_steps-1:
prj_norm = prj_norm[:, add_slice//binning:sli_step//binning+add_slice//binning]
rec_sub = tomopy.recon(prj_norm, theta, center=rot_cen, algorithm='gridrec')
print("rec_sub.shape: ", rec_sub.shape)
rec[i*sli_step // binning : i*sli_step // binning + rec_sub.shape[0]] = rec_sub
bin_info = f'_bin{int(binning)}'
fout = f'recon_scan_{str(scan_id)}{str(slice_info)}{str(bin_info)}'
if zero_flag:
rec[rec<0] = 0
fout_h5 = f'/tmp/{fout}.h5'
with h5py.File(fout_h5, 'w') as hf:
hf.create_dataset('img', data=np.array(rec, dtype=np.float32))
hf.create_dataset('scan_id', data=scan_id)
hf.create_dataset('X_eng', data=eng)
hf.create_dataset('rot_cen', data=rot_cen)
hf.create_dataset('binning', data=binning)
print(f'{fout} is saved.')
del rec_sub
#del img_tomo
del prj_norm
return rec
def wiener_denoise(prj_norm, wiener_param, denoise_flag):
import skimage.restoration as skr
if not denoise_flag or not len(wiener_param):
return prj_norm
ss = prj_norm.shape
psf = wiener_param['psf']
reg = wiener_param['reg']
balance = wiener_param['balance']
is_real = wiener_param['is_real']
clip = wiener_param['clip']
for j in range(ss[0]):
prj_norm[j] = skr.wiener(prj_norm[j], psf=psf, reg=reg, balance=balance, is_real=is_real, clip=clip)
return prj_norm
def proj_normalize(dataset, sli, txm_normed_flag, binning, allow_list=[], bkg_level=0, denoise_level=9):
#f = h5py.File(fn, 'r')
order = [2, 1, 0]
tmp = np.transpose(dataset.active_scalars, order)
print("active_scalars.shape: ", dataset.active_scalars.shape)
print("tmp.shape: ", tmp.shape)
img_tomo = np.array(tmp[:, sli[0]:sli[1], :])
print("img_tomo.shape: ", img_tomo.shape)
#img_dark = dataset.dark
#img_bkg = dataset.white
print("dark.shape: ", dataset.dark.shape)
try:
img_bkg = np.array(np.transpose(dataset.white, order)[:, sli[0]:sli[1]])
except:
img_bkg = []
try:
img_dark = np.array(np.transpose(dataset.dark, order)[:, sli[0]:sli[1]])
except:
img_dark = []
if len(img_dark) == 0 or len(img_bkg) == 0 or txm_normed_flag == 1:
prj = img_tomo
else:
prj = (img_tomo - img_dark) / (img_bkg - img_dark)
s = prj.shape
prj = bin_ndarray(prj, (s[0], int(s[1]/binning), int(s[2]/binning)), 'mean')
prj_norm = -np.log(prj)
prj_norm[np.isnan(prj_norm)] = 0
prj_norm[np.isinf(prj_norm)] = 0
prj_norm[prj_norm < 0] = 0
prj_norm = prj_norm[allow_list]
prj_norm = tomopy.prep.stripe.remove_stripe_fw(prj_norm,level=denoise_level, wname='db5', sigma=1, pad=True)
prj_norm -= bkg_level
#f.close()
#del img_tomo
#del img_bkg
#del img_dark
#del prj
return prj_norm
def bin_ndarray(ndarray, new_shape=None, operation='mean'):
"""
Bins an ndarray in all axes based on the target shape, by summing or
averaging.
Number of output dimensions must match number of input dimensions and
new axes must divide old ones.
Example
-------
>>> m = np.arange(0,100,1).reshape((10,10))
>>> n = bin_ndarray(m, new_shape=(5,5), operation='sum')
>>> print(n)
[[ 22 30 38 46 54]
[102 110 118 126 134]
[182 190 198 206 214]
[262 270 278 286 294]
[342 350 358 366 374]]
"""
if new_shape == None:
s = np.array(ndarray.shape)
s1 = np.int32(s/2)
new_shape = tuple(s1)
operation = operation.lower()
if not operation in ['sum', 'mean']:
raise ValueError("Operation not supported.")
if ndarray.ndim != len(new_shape):
raise ValueError("Shape mismatch: {} -> {}".format(ndarray.shape,
new_shape))
compression_pairs = [(d, c//d) for d,c in zip(new_shape,
ndarray.shape)]
flattened = [l for p in compression_pairs for l in p]
ndarray = ndarray.reshape(flattened)
for i in range(len(new_shape)):
op = getattr(ndarray, operation)
ndarray = op(-1*(i+1))
return ndarray
def show_image_slice(fn, sli=0):
f=h5py.File(fn,'r')
try:
img = np.squeeze(np.array(f['img_tomo'][sli]))
plt.figure()
plt.imshow(img)
except:
try:
img = np.squeeze(np.array(f['img_xanes'][sli]))
plt.imshow(img)
except:
print('cannot display image')
finally:
f.close()
def transform(dataset):
# We don't use file names, but pass around the data set. It has a dark and a white attribute, along with the tilt_angles
#print("We see a tomography data set: ", dataset.active_scalars.shape, " dark: ", dataset.dark, " white: ", dataset.white)
result = recon(dataset, rot_cen=650, sli=[], binning=None, zero_flag=0, block_list=[], txm_normed_flag=0, read_full_memory=0)
child = dataset.create_child_dataset()
child.active_scalars = result
return_values = {}
return_values['reconstruction_fxi'] = child
return return_values
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment