Created
February 18, 2020 14:38
-
-
Save cryos/4e10febd17ef6f8f1b4ab1a93e10a6d7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| try: | |
| import numpy as np | |
| import tomopy | |
| import h5py | |
| import matplotlib.pylab as plt | |
| except ImportError: | |
| pass | |
| def find_nearest(data, value): | |
| data = np.array(data) | |
| return np.abs(data - value).argmin() | |
| class IndexTracker(object): | |
| def __init__(self, ax, X): | |
| self.ax = ax | |
| self._indx_txt = ax.set_title(' ', loc='center') | |
| self.X = X | |
| self.slices, rows, cols = X.shape | |
| self.ind = self.slices//2 | |
| self.im = ax.imshow(self.X[self.ind, :, :], cmap='gray') | |
| self.update() | |
| def onscroll(self, event): | |
| if event.button == 'up': | |
| self.ind = (self.ind + 1) % self.slices | |
| else: | |
| self.ind = (self.ind - 1) % self.slices | |
| self.update() | |
| def update(self): | |
| self.im.set_data(self.X[self.ind, :, :]) | |
| #self.ax.set_ylabel('slice %s' % self.ind) | |
| self._indx_txt.set_text(f"frame {self.ind + 1} of {self.slices}") | |
| self.im.axes.figure.canvas.draw() | |
| def image_scrubber(data, *, ax=None): | |
| if ax is None: | |
| fig, ax = plt.subplots() | |
| else: | |
| fig = ax.figure | |
| tracker = IndexTracker(ax, data) | |
| # monkey patch the tracker onto the figure to keep it alive | |
| fig._tracker = tracker | |
| fig.canvas.mpl_connect('scroll_event', tracker.onscroll) | |
| return tracker | |
| def find_rot(fn, thresh=0.05): | |
| from pystackreg import StackReg | |
| sr = StackReg(StackReg.TRANSLATION) | |
| f = h5py.File(fn, 'r') | |
| img_bkg = np.squeeze(np.array(f['img_bkg_avg'])) | |
| ang = np.array(list(f['angle'])) | |
| tmp = np.abs(ang - ang[0] -180).argmin() | |
| img0 = np.array(list(f['img_tomo'][0])) | |
| img180_raw = np.array(list(f['img_tomo'][tmp])) | |
| f.close() | |
| img0 = img0 / img_bkg | |
| img180_raw = img180_raw / img_bkg | |
| img180 = img180_raw[:,::-1] | |
| s = np.squeeze(img0.shape) | |
| im1 = -np.log(img0) | |
| im2 = -np.log(img180) | |
| im1[np.isnan(im1)] = 0 | |
| im2[np.isnan(im2)] = 0 | |
| im1[im1 < thresh] = 0 | |
| im2[im2 < thresh] = 0 | |
| im1 = medfilt2d(im1,5) | |
| im2 = medfilt2d(im2, 5) | |
| im1_fft = np.fft.fft2(im1) | |
| im2_fft = np.fft.fft2(im2) | |
| results = dftregistration(im1_fft, im2_fft) | |
| row_shift = results[2] | |
| col_shift = results[3] | |
| rot_cen = s[1]/2 + col_shift/2 - 1 | |
| tmat = sr.register(im1, im2) | |
| rshft = -tmat[1, 2] | |
| cshft = -tmat[0, 2] | |
| rot_cen0 = s[1]/2 + cshft/2 - 1 | |
| print(f'rot_cen = {rot_cen} or {rot_cen0}') | |
| return rot_cen | |
| def rotcen_test(fn, start=None, stop=None, steps=None, sli=0, block_list=[], return_flag=0, print_flag=1, bkg_level=0, txm_normed_flag=0, denoise_flag=0, denoise_level=9): | |
| import tomopy | |
| f = h5py.File(fn, 'r') | |
| tmp = np.array(f['img_tomo'][0]) | |
| s = [1, tmp.shape[0], tmp.shape[1]] | |
| if denoise_flag: | |
| import skimage.restoration as skr | |
| addition_slice = 100 | |
| psf = 2 | |
| psf = np.ones([psf, psf])/(psf**2) | |
| reg = None | |
| balance = 0.3 | |
| is_real=True | |
| clip = True | |
| else: | |
| addition_slice = 0 | |
| if sli == 0: sli = int(s[1]/2) | |
| sli_exp = [np.max([0, sli-addition_slice//2]), np.min([sli+addition_slice//2+1, s[1]])] | |
| theta = np.array(f['angle']) / 180.0 * np.pi | |
| img_tomo = np.array(f['img_tomo'][:, sli_exp[0]:sli_exp[1], :]) | |
| if txm_normed_flag: | |
| prj = img_tomo | |
| else: | |
| img_bkg = np.array(f['img_bkg_avg'][:, sli_exp[0]:sli_exp[1], :]) | |
| img_dark = np.array(f['img_dark_avg'][:, sli_exp[0]:sli_exp[1], :]) | |
| prj = (img_tomo - img_dark) / (img_bkg - img_dark) | |
| f.close() | |
| prj_norm = -np.log(prj) | |
| prj_norm[np.isnan(prj_norm)] = 0 | |
| prj_norm[np.isinf(prj_norm)] = 0 | |
| prj_norm[prj_norm < 0] = 0 | |
| prj_norm -= bkg_level | |
| prj_norm = tomopy.prep.stripe.remove_stripe_fw(prj_norm,level=denoise_level, wname='db5', sigma=1, pad=True) | |
| if denoise_flag: # denoise using wiener filter | |
| ss = prj_norm.shape | |
| for i in range(ss[0]): | |
| prj_norm[i] = skr.wiener(prj_norm[i], psf=psf, reg=reg, balance=balance, is_real=is_real, clip=clip) | |
| s = prj_norm.shape | |
| if len(s) == 2: | |
| prj_norm = prj_norm.reshape(s[0], 1, s[1]) | |
| s = prj_norm.shape | |
| pos = find_nearest(theta, theta[0]+np.pi) | |
| block_list = list(block_list) + list(np.arange(pos+1, len(theta))) | |
| if len(block_list): | |
| allow_list = list(set(np.arange(len(prj_norm))) - set(block_list)) | |
| prj_norm = prj_norm[allow_list] | |
| theta = theta[allow_list] | |
| if start==None or stop==None or steps==None: | |
| start = int(s[2]/2-50) | |
| stop = int(s[2]/2+50) | |
| steps = 26 | |
| cen = np.linspace(start, stop, steps) | |
| img = np.zeros([len(cen), s[2], s[2]]) | |
| for i in range(len(cen)): | |
| if print_flag: | |
| print('{}: rotcen {}'.format(i+1, cen[i])) | |
| img[i] = tomopy.recon(prj_norm[:, addition_slice:addition_slice+1], theta, center=cen[i], algorithm='gridrec') | |
| fout = 'center_test.h5' | |
| with h5py.File(fout, 'w') as hf: | |
| hf.create_dataset('img', data=img) | |
| hf.create_dataset('rot_cen', data=cen) | |
| img = tomopy.circ_mask(img, axis=0, ratio=0.8) | |
| tracker = image_scrubber(img) | |
| if return_flag: | |
| return img, cen | |
| def img_variance(img): | |
| import tomopy | |
| s = img.shape | |
| variance = np.zeros(s[0]) | |
| img = tomopy.circ_mask(img, axis=0, ratio=0.8) | |
| for i in range(s[0]): | |
| img[i] = medfilt2d(img[i], 5) | |
| img_ = img[i].flatten() | |
| t = img_>0 | |
| img_ = img_[t] | |
| t = np.mean(img_) | |
| variance[i] = np.sqrt(np.sum(np.power(np.abs(img_ - t), 2))/len(img_-1)) | |
| return variance | |
| def recon(dataset, rot_cen, sli=[], binning=None, zero_flag=0, block_list=[], bkg_level=0, txm_normed_flag=0, read_full_memory=0, denoise_flag=0, denoise_level=9): | |
| ''' | |
| reconstruct 3D tomography | |
| Inputs: | |
| -------- | |
| fn: string | |
| filename of scan, e.g. 'fly_scan_0001.h5' | |
| rot_cen: float | |
| rotation center | |
| sli: list | |
| a range of slice to recontruct, e.g. [100:300] | |
| bingning: int | |
| binning the reconstruted 3D tomographic image | |
| zero_flag: bool | |
| if 1: set negative pixel value to 0 | |
| if 0: keep negative pixel value | |
| block_list: list | |
| a list of index for the projections that will not be considered in reconstruction | |
| ''' | |
| # from PIL import Image | |
| #f = h5py.File(fn, 'r') | |
| print("active_scalars.shape: ", dataset.active_scalars.shape) | |
| order = [2, 1, 0] | |
| tmp = np.transpose(dataset.active_scalars, order)[0] | |
| print("tmp.shape: ", tmp.shape) | |
| s = [1, tmp.shape[0], tmp.shape[1]] | |
| slice_info = '' | |
| bin_info = '' | |
| col_info = '' | |
| if len(sli) == 0: | |
| sli = [0, s[1]] | |
| elif len(sli) == 1 and sli[0] >=0 and sli[0] <= s[1]: | |
| sli = [sli[0], sli[0]+1] | |
| slice_info = '_slice_{}'.format(sli[0]) | |
| elif len(sli) == 2 and sli[0] >=0 and sli[1] <= s[1]: | |
| slice_info = '_slice_{}_{}'.format(sli[0], sli[1]) | |
| else: | |
| print('non valid slice id, will take reconstruction for the whole object') | |
| ''' | |
| if len(col) == 0: | |
| col = [0, s[2]] | |
| elif len(col) == 1 and col[0] >=0 and col[0] <= s[2]: | |
| col = [col[0], col[0]+1] | |
| col_info = '_col_{}'.format(col[0]) | |
| elif len(col) == 2 and col[0] >=0 and col[1] <= s[2]: | |
| col_info = '_col_{}_{}'.format(col[0], col[1]) | |
| else: | |
| col = [0, s[2]] | |
| print('invalid col id, will take reconstruction for the whole object') | |
| ''' | |
| #rot_cen = rot_cen - col[0] | |
| #scan_id = np.array(f['scan_id']) | |
| scan_id = 'tomviz-kitware-result' | |
| eng = '42' | |
| theta = np.array(dataset.tilt_angles) / 180.0 * np.pi | |
| #eng = np.array(f['X_eng']) | |
| pos = find_nearest(theta, theta[0]+np.pi) | |
| block_list = list(block_list) + list(np.arange(pos+1, len(theta))) | |
| allow_list = list(set(np.arange(len(theta))) - set(block_list)) | |
| theta = theta[allow_list] | |
| # Doesn't seem necessary, skipping for now... | |
| #tmp = np.squeeze(np.array(f['img_tomo'][0])) | |
| s = tmp.shape | |
| #f.close() | |
| sli_step = 40 | |
| sli_total = np.arange(sli[0], sli[1]) | |
| binning = binning if binning else 1 | |
| bin_info = f'_bin_{binning}' | |
| n_steps = int(len(sli_total) / sli_step) | |
| rot_cen = rot_cen * 1.0 / binning | |
| if read_full_memory: | |
| sli_step = sli[1] - sli[0] | |
| n_steps = 1 | |
| if denoise_flag: | |
| add_slice = min(sli_step // 2, 20) | |
| wiener_param = {} | |
| psf = 2 | |
| wiener_param['psf'] = np.ones([psf, psf])/(psf**2) | |
| wiener_param['reg'] = None | |
| wiener_param['balance'] = 0.3 | |
| wiener_param['is_real']=True | |
| wiener_param['clip'] = True | |
| else: | |
| add_slice = 0 | |
| wiener_param = [] | |
| try: | |
| rec = np.zeros([sli_step*n_steps // binning, s[1] // binning, s[1] // binning], dtype=np.float32) | |
| print("rec.shape: ", rec.shape) | |
| except: | |
| print('Cannot allocate memory') | |
| ''' | |
| # first sli_step slices: will not do any denoising | |
| prj_norm = proj_normalize(fn, [0, sli_step], txm_normed_flag, binning, allow_list, bkg_level) | |
| prj_norm = wiener_denoise(prj_norm, wiener_param, denoise_flag) | |
| rec_sub = tomopy.recon(prj_norm, theta, center=rot_cen, algorithm='gridrec') | |
| rec[0 : rec_sub.shape[0]] = rec_sub | |
| ''' | |
| # following slices | |
| for i in range(n_steps): | |
| if i == 0: | |
| sli_sub = [0, sli_step] | |
| current_sli = sli_sub | |
| elif i == n_steps-1: | |
| sli_sub = [i*sli_step+sli_total[0], len(sli_total)+sli[0]] | |
| current_sli = sli_sub | |
| else: | |
| sli_sub = [i*sli_step+sli_total[0], (i+1)*sli_step+sli_total[0]] | |
| current_sli = [sli_sub[0]-add_slice, sli_sub[1]+add_slice] | |
| print(f'recon {i+1}/{n_steps}: sli = [{sli_sub[0]}, {sli_sub[1]}] ... ') | |
| prj_norm = proj_normalize(dataset, current_sli, txm_normed_flag, binning, allow_list, bkg_level, denoise_level) | |
| prj_norm = wiener_denoise(prj_norm, wiener_param, denoise_flag) | |
| if i!=0 and i!=n_steps-1: | |
| prj_norm = prj_norm[:, add_slice//binning:sli_step//binning+add_slice//binning] | |
| rec_sub = tomopy.recon(prj_norm, theta, center=rot_cen, algorithm='gridrec') | |
| print("rec_sub.shape: ", rec_sub.shape) | |
| rec[i*sli_step // binning : i*sli_step // binning + rec_sub.shape[0]] = rec_sub | |
| bin_info = f'_bin{int(binning)}' | |
| fout = f'recon_scan_{str(scan_id)}{str(slice_info)}{str(bin_info)}' | |
| if zero_flag: | |
| rec[rec<0] = 0 | |
| fout_h5 = f'/tmp/{fout}.h5' | |
| with h5py.File(fout_h5, 'w') as hf: | |
| hf.create_dataset('img', data=np.array(rec, dtype=np.float32)) | |
| hf.create_dataset('scan_id', data=scan_id) | |
| hf.create_dataset('X_eng', data=eng) | |
| hf.create_dataset('rot_cen', data=rot_cen) | |
| hf.create_dataset('binning', data=binning) | |
| print(f'{fout} is saved.') | |
| del rec_sub | |
| #del img_tomo | |
| del prj_norm | |
| return rec | |
| def wiener_denoise(prj_norm, wiener_param, denoise_flag): | |
| import skimage.restoration as skr | |
| if not denoise_flag or not len(wiener_param): | |
| return prj_norm | |
| ss = prj_norm.shape | |
| psf = wiener_param['psf'] | |
| reg = wiener_param['reg'] | |
| balance = wiener_param['balance'] | |
| is_real = wiener_param['is_real'] | |
| clip = wiener_param['clip'] | |
| for j in range(ss[0]): | |
| prj_norm[j] = skr.wiener(prj_norm[j], psf=psf, reg=reg, balance=balance, is_real=is_real, clip=clip) | |
| return prj_norm | |
| def proj_normalize(dataset, sli, txm_normed_flag, binning, allow_list=[], bkg_level=0, denoise_level=9): | |
| #f = h5py.File(fn, 'r') | |
| order = [2, 1, 0] | |
| tmp = np.transpose(dataset.active_scalars, order) | |
| print("active_scalars.shape: ", dataset.active_scalars.shape) | |
| print("tmp.shape: ", tmp.shape) | |
| img_tomo = np.array(tmp[:, sli[0]:sli[1], :]) | |
| print("img_tomo.shape: ", img_tomo.shape) | |
| #img_dark = dataset.dark | |
| #img_bkg = dataset.white | |
| print("dark.shape: ", dataset.dark.shape) | |
| try: | |
| img_bkg = np.array(np.transpose(dataset.white, order)[:, sli[0]:sli[1]]) | |
| except: | |
| img_bkg = [] | |
| try: | |
| img_dark = np.array(np.transpose(dataset.dark, order)[:, sli[0]:sli[1]]) | |
| except: | |
| img_dark = [] | |
| if len(img_dark) == 0 or len(img_bkg) == 0 or txm_normed_flag == 1: | |
| prj = img_tomo | |
| else: | |
| prj = (img_tomo - img_dark) / (img_bkg - img_dark) | |
| s = prj.shape | |
| prj = bin_ndarray(prj, (s[0], int(s[1]/binning), int(s[2]/binning)), 'mean') | |
| prj_norm = -np.log(prj) | |
| prj_norm[np.isnan(prj_norm)] = 0 | |
| prj_norm[np.isinf(prj_norm)] = 0 | |
| prj_norm[prj_norm < 0] = 0 | |
| prj_norm = prj_norm[allow_list] | |
| prj_norm = tomopy.prep.stripe.remove_stripe_fw(prj_norm,level=denoise_level, wname='db5', sigma=1, pad=True) | |
| prj_norm -= bkg_level | |
| #f.close() | |
| #del img_tomo | |
| #del img_bkg | |
| #del img_dark | |
| #del prj | |
| return prj_norm | |
| def bin_ndarray(ndarray, new_shape=None, operation='mean'): | |
| """ | |
| Bins an ndarray in all axes based on the target shape, by summing or | |
| averaging. | |
| Number of output dimensions must match number of input dimensions and | |
| new axes must divide old ones. | |
| Example | |
| ------- | |
| >>> m = np.arange(0,100,1).reshape((10,10)) | |
| >>> n = bin_ndarray(m, new_shape=(5,5), operation='sum') | |
| >>> print(n) | |
| [[ 22 30 38 46 54] | |
| [102 110 118 126 134] | |
| [182 190 198 206 214] | |
| [262 270 278 286 294] | |
| [342 350 358 366 374]] | |
| """ | |
| if new_shape == None: | |
| s = np.array(ndarray.shape) | |
| s1 = np.int32(s/2) | |
| new_shape = tuple(s1) | |
| operation = operation.lower() | |
| if not operation in ['sum', 'mean']: | |
| raise ValueError("Operation not supported.") | |
| if ndarray.ndim != len(new_shape): | |
| raise ValueError("Shape mismatch: {} -> {}".format(ndarray.shape, | |
| new_shape)) | |
| compression_pairs = [(d, c//d) for d,c in zip(new_shape, | |
| ndarray.shape)] | |
| flattened = [l for p in compression_pairs for l in p] | |
| ndarray = ndarray.reshape(flattened) | |
| for i in range(len(new_shape)): | |
| op = getattr(ndarray, operation) | |
| ndarray = op(-1*(i+1)) | |
| return ndarray | |
| def show_image_slice(fn, sli=0): | |
| f=h5py.File(fn,'r') | |
| try: | |
| img = np.squeeze(np.array(f['img_tomo'][sli])) | |
| plt.figure() | |
| plt.imshow(img) | |
| except: | |
| try: | |
| img = np.squeeze(np.array(f['img_xanes'][sli])) | |
| plt.imshow(img) | |
| except: | |
| print('cannot display image') | |
| finally: | |
| f.close() | |
| def transform(dataset): | |
| # We don't use file names, but pass around the data set. It has a dark and a white attribute, along with the tilt_angles | |
| #print("We see a tomography data set: ", dataset.active_scalars.shape, " dark: ", dataset.dark, " white: ", dataset.white) | |
| result = recon(dataset, rot_cen=650, sli=[], binning=None, zero_flag=0, block_list=[], txm_normed_flag=0, read_full_memory=0) | |
| child = dataset.create_child_dataset() | |
| child.active_scalars = result | |
| return_values = {} | |
| return_values['reconstruction_fxi'] = child | |
| return return_values |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment