# #########################################################################
# Copyright (c) 2022, UChicago Argonne, LLC. All rights reserved. #
# #
# Copyright 2022. UChicago Argonne, LLC. This software was produced #
# under U.S. Government contract DE-AC02-06CH11357 for Argonne National #
# Laboratory (ANL), which is operated by UChicago Argonne, LLC for the #
# U.S. Department of Energy. The U.S. Government has rights to use, #
# reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR #
# UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR #
# ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is #
# modified to produce derivative works, such modified software should #
# be clearly marked, so as not to confuse it with the version available #
# from ANL. #
# #
# Additionally, redistribution and use in source and binary forms, with #
# or without modification, are permitted provided that the following #
# conditions are met: #
# #
# * Redistributions of source code must retain the above copyright #
# notice, this list of conditions and the following disclaimer. #
# #
# * Redistributions in binary form must reproduce the above copyright #
# notice, this list of conditions and the following disclaimer in #
# the documentation and/or other materials provided with the #
# distribution. #
# #
# * Neither the name of UChicago Argonne, LLC, Argonne National #
# Laboratory, ANL, the U.S. Government, nor the names of its #
# contributors may be used to endorse or promote products derived #
# from this software without specific prior written permission. #
# #
# THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS #
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT #
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS #
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago #
# Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, #
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, #
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; #
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT #
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN #
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE #
# POSSIBILITY OF SUCH DAMAGE. #
# #########################################################################
import os
import h5py
import dxchange
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from tile import log
from tile import fileio
__all__ = ['stitching']
def write_meta(file_name, fid):
try: # trying to copy meta
import meta
mp = meta.read_meta.Hdf5MetadataReader(file_name)
meta_dict = mp.readMetadata()
mp.close()
with h5py.File(file_name, 'r') as f:
print(f" *** copy meta data from {file_name}")#
for key, value in meta_dict.items():
print(key, value)
if key.find('exchange') != 1:
dset = fid.create_dataset(
key, data=value[0], dtype=f[key].dtype, shape=(1,))
if value[1] is not None:
s = value[1]
utf8_type = h5py.string_dtype('utf-8', len(s)+1)
dset.attrs['units'] = np.array(
s.encode("utf-8"), dtype=utf8_type)
except:
log.error('write_meta() error: Skip copying meta')
pass
[docs]
def stitching(args):
"""Stitching projection tiles in horizontal direction"""
log.info('Run stitching')
# read files grid and retrieve data sizes
meta_dict, grid, data_shape, data_type, _, _ = fileio.tile(args)
data_type='float32'
# check if flip is needed for having tile[0,0] as the left one and at sample_x=0
sample_x = args.sample_x
if args.reverse_step=='True':
step = -1
else:
step = 1
x_shifts = np.fromstring(args.x_shifts[1:-1], sep=',', dtype='int')
log.info(f'Relative shifts {x_shifts}')
if args.end_proj == -1:
args.end_proj = data_shape[0]
# total size in x direction, multiple of 4 for faster ffts in reconstruction
size = int(np.ceil(
(data_shape[2]+np.sum(np.sum(x_shifts)))/16)*16)
tile_path = os.path.join(args.folder_name, 'tile')
if not os.path.exists(tile_path):
os.makedirs(tile_path)
tile_file_name = os.path.join(tile_path, args.tile_file_name)
theta = np.zeros(1,dtype='float32')
for itile in range(grid.shape[1]):
with h5py.File(grid[0, itile], 'r') as fid:
if(len(fid['/exchange/theta'][:])>len(theta)):
theta = fid['/exchange/theta'][:]
# load external flat field basis if provided
use_flats_basis = bool(args.flats_file)
if use_flats_basis:
with h5py.File(args.flats_file, 'r') as fb:
basis_flats = fb['/exchange/data_white'][:, :, ::step].astype('float32')
log.info(f'Loaded flat basis: {basis_flats.shape[0]} frames from {args.flats_file}')
# design matrix: mean horizontal profile of each basis frame, shape (H, n_basis)
basis_profs = np.mean(basis_flats, axis=2).T.astype('float64') # (H, n_basis)
# pre-compute per-tile flat/dark correction arrays (read once, reused for all chunks)
tile_dark = []
if args.flat_linear == 'True':
tile_flat_p0 = []
tile_flat_p1 = []
else:
tile_flat = []
for itile in range(grid.shape[1]):
if args.reverse_grid=='True':
iitile=grid.shape[1]-itile-1
else:
iitile=itile
with h5py.File(grid[0, ::-step][iitile],'r') as fidin:
flat = fidin['/exchange/data_white'][:]
dark = fidin['/exchange/data_dark'][:]
tile_dark.append(np.mean(dark[:, :, ::step], axis=0))
if args.flat_linear == 'True':
n = flat.shape[0]
tile_flat_p0.append(np.mean(flat[:n//2, :, ::step], axis=0))
tile_flat_p1.append(np.mean(flat[n//2:, :, ::step], axis=0))
else:
tile_flat.append(np.mean(flat[:, :, ::step], axis=0))
os.system(f'rm -rf {tile_file_name}')
with h5py.File(tile_file_name, 'w') as fid:
# flat/dark correction applied per tile before stitching; store 1 and 0 as placeholders
data_all = fid.create_dataset('/exchange/data', (args.end_proj-args.start_proj,
data_shape[1], size), dtype=data_type, chunks=(1, data_shape[1], size))
fid.create_dataset('/exchange/data_white',
data=np.ones((1, data_shape[1], size), dtype=data_type))
fid.create_dataset('/exchange/data_dark',
data=np.zeros((1, data_shape[1], size), dtype=data_type))
fid.create_dataset('/exchange/theta', data=theta[args.start_proj:args.end_proj])
write_meta(grid[0, itile],fid)
def process_chunk(ichunk):
st_chunk = args.start_proj + ichunk * args.nproj_per_chunk
end_chunk = min(st_chunk + args.nproj_per_chunk, args.end_proj)
chunk_len = end_chunk - st_chunk
log.info(f'Stitching projections {st_chunk} - {end_chunk}')
chunk_buf = np.zeros((chunk_len, data_shape[1], size), dtype=data_type)
ref_overlap_mean = None
for itile in range(grid.shape[1]):
if args.reverse_grid == 'True':
iitile = grid.shape[1] - itile - 1
else:
iitile = itile
with h5py.File(grid[0, ::-step][iitile], 'r') as fidin:
uids = fidin['/defaults/NDArrayUniqueId'][:]
hdf_location = fidin['/defaults/HDF5FrameLocation']
proj_ids = uids[hdf_location[:] == b'/exchange/data'] - 1
proj_ids = proj_ids[(proj_ids >= st_chunk) * (proj_ids < end_chunk)]
if len(proj_ids) != end_chunk - st_chunk:
log.warning('There are missing projection in the current tile, setting them to 0')
data = fidin['/exchange/data'][proj_ids]
st = np.sum(x_shifts[:itile + 1])
end = min(st + data_shape[2], size)
vv = np.ones(data_shape[2])
if itile < grid.shape[1] - 1:
v = np.linspace(1, 0, data_shape[2] - x_shifts[itile + 1], endpoint=False)
v = v**5 * (126 - 420*v + 540*v**2 - 315*v**3 + 70*v**4)
vv[x_shifts[itile + 1]:] = v
if itile > 0:
v = np.linspace(1, 0, data_shape[2] - x_shifts[itile], endpoint=False)
v = v**5 * (126 - 420*v + 540*v**2 - 315*v**3 + 70*v**4)
vv[:data_shape[2] - x_shifts[itile]] = 1 - v
# correct each tile before stitching using that tile's flat/dark
data_f = data[:, :, ::step].copy()
dark_mean = tile_dark[itile]
if use_flats_basis:
from scipy.optimize import nnls
for li in range(len(proj_ids)):
proj_prof = np.mean(data_f[li], axis=1).astype('float64')
w, _ = nnls(basis_profs, proj_prof)
#log.info(f'proj {proj_ids[li]:4d} tile {itile} coeffs: {np.round(w, 4).tolist()}')
flat_i = np.einsum('k,khw->hw', w, basis_flats)
data_f[li] = (data_f[li] - dark_mean) / (flat_i - dark_mean+1e-3)
elif args.flat_linear == 'True':
for li, gi in enumerate(proj_ids):
t = gi / max(data_shape[0] - 1, 1)
flat_i = (1 - t) * tile_flat_p0[itile] + t * tile_flat_p1[itile]
data_f[li] = (data_f[li] - dark_mean) / (flat_i - dark_mean+ 1e-3)
else:
data_f = (data_f - dark_mean) / (tile_flat[itile] - dark_mean+ 1e-3)
np.nan_to_num(data_f, nan=1.0, posinf=1.0, neginf=1.0, copy=False)
if args.zinger_level > 0:
from scipy.ndimage import median_filter
kernel = (min(5, data_f.shape[0]), 1, 1)
med = median_filter(data_f, size=kernel)
mask = data_f > med * (1 + args.zinger_level)
data_f[mask] = med[mask]
# intensity scale calibration using overlap with previous tile (per projection)
if itile > 0 and ref_overlap_mean is not None:
overlap_cols = data_shape[2] - x_shifts[itile]
cur_means = np.mean(data_f[:, :, :overlap_cols], axis=(1, 2)) # (n_proj,)
ref = ref_overlap_mean[proj_ids - st_chunk]
valid = cur_means > 1e-6
scales = np.where(valid, ref / np.where(valid, cur_means, 1.0), 1.0)
data_f *= scales[:, np.newaxis, np.newaxis]
if itile < grid.shape[1] - 1:
ref_overlap_mean = np.zeros(chunk_len, dtype='float64')
ref_overlap_mean[proj_ids - st_chunk] = np.mean(
data_f[:, :, x_shifts[itile + 1]:], axis=(1, 2))
chunk_buf[proj_ids - st_chunk, :, st:end] += data_f[:, :, :end - st] * vv[:end - st]
if itile == grid.shape[1] - 1:
chunk_buf[:, :, end:] = np.tile(chunk_buf[:, :, end - 1:end], (1, 1, size - end))
np.nan_to_num(chunk_buf, nan=1.0, posinf=1.0, neginf=1.0, copy=False)
return st_chunk, end_chunk, chunk_buf
n_chunks = int(np.ceil((args.end_proj - args.start_proj) / args.nproj_per_chunk))
pending = {}
next_write = args.start_proj
with ThreadPoolExecutor(max_workers=args.max_workers) as pool:
futures = {pool.submit(process_chunk, i): i for i in range(n_chunks)}
for fut in as_completed(futures):
st_chunk, end_chunk, chunk_buf = fut.result()
pending[st_chunk] = (end_chunk, chunk_buf)
while next_write in pending:
ep, buf = pending.pop(next_write)
data_all[next_write - args.start_proj:ep - args.start_proj] = buf
next_write = ep
log.info(f'Output file {tile_file_name}')
log.info(f'Reconstruct {tile_file_name} with tomocupy:')
log.info(f'tomocupy recon --file-name {tile_file_name} --rotation-axis <found rotation axis> --reconstruction-type full --file-type double_fov --remove-stripe-method fw --binning <select binning> --nsino-per-chunk 2 ')
log.info(f'Reconstruct {tile_file_name} with tomopy:')
log.info(f'tomopy recon --file-name {tile_file_name} --rotation-axis <found rotation axis> --reconstruction-type full --file-type double_fov --remove-stripe-method fw --binning <select binning> --nsino-per-chunk 8 --rotation-axis-auto manual')