import argparse
import glob
import logging
import time
from os import makedirs as os_makedirs
import warnings
import numpy as np
from qsonic import QsonicException
import qsonic.catalog
import qsonic.io
from qsonic.masks import BALMask
from qsonic.mpi_utils import mpi_parse, mpi_fnc_bcast, MPISaver
from qsonic.picca_continuum import VarLSSFitter
from qsonic.spectrum import add_wave_region_parser
[docs]
def get_parser(add_help=True):
"""Constructs the parser needed for the script.
Arguments
---------
add_help: bool, default: True
Add help to parser.
Returns
-------
parser: argparse.ArgumentParser
"""
parser = argparse.ArgumentParser(
add_help=add_help,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
epilog='Note: Not every option is used in this script.')
iogroup = parser.add_argument_group(
'Input/output parameters and selections')
iogroup.add_argument(
"--input-dir", '-i', required=True,
help="Input directory.")
iogroup.add_argument(
"--outdir", '-o', required=True,
help="Output directory to save files.")
iogroup.add_argument("--fbase", help="Basename", default="qsonic-eta-fits")
iogroup.add_argument(
"--catalog",
help="Catalog filename to enable catalog related removals.")
iogroup.add_argument(
"--mock-analysis", action="store_true",
help="Input folder is mock. Uses nside=16")
iogroup.add_argument(
"--keep-surveys", nargs='*',
help="Surveys to keep. Empty keeps all.")
iogroup.add_argument(
"--remove-bal-qsos", action="store_true",
help="Removes BAL sightlines in the catalog option.")
iogroup.add_argument(
"--remove-targetid-list",
help="Text file with TARGETIDs to exclude from analysis.")
vargroup = parser.add_argument_group(
'Variance fitting parameters')
vargroup.add_argument(
"--nvarbins", help="Number of variance bins (logarithmically spaced).",
default=100, type=int)
vargroup.add_argument(
"--var-use-cov", action="store_true",
help="Use covariance in varlss-eta fitting.")
vargroup.add_argument(
"--nwbins", default=None, type=int,
help="Number of wavelength bins. None creates bins with 120 A spacing")
vargroup.add_argument(
"--var1", help="Lower variance bin.", default=1e-4, type=float)
vargroup.add_argument(
"--var2", help="Upper variance bin.", default=20., type=float)
vargroup.add_argument(
"--min-snr", help="Minimum SNR of the forest.",
default=0, type=float)
vargroup.add_argument(
"--max-snr", help="Maximum SNR of the forest.",
default=100, type=float)
parser = add_wave_region_parser(parser)
return parser
[docs]
def mpi_set_targetid_list_to_remove(args, comm=None, mpi_rank=0):
""" Return a ndarray of TARGETIDs to remove from the sample.
Can be used without MPI by passing ``comm=None`` (which is the default.)
Arguments
---------
args: argparse.Namespace
Options passed to script.
comm: MPI.COMM_WORLD or None, default: None
Communication object broadcast data.
mpi_rank: int, default: 0
Rank of the MPI process.
Returns
-------
ids_to_remove: :external+numpy:py:class:`ndarray <numpy.ndarray>`
TARGETIDs to remove from the sample.
Raises
------
QsonicException
If error occurs while reading ``args.remove_targetid_list`` or if
``--remove_bal_qsos`` is passed but the input catalog is missing BAL
columns .
"""
ids_to_remove = np.array([], dtype=int)
if args.remove_targetid_list:
ids_to_remove = mpi_fnc_bcast(
np.loadtxt, comm, mpi_rank,
"Error while reading remove_targetid_list.",
args.remove_targetid_list, dtype=int)
if args.catalog:
catalog = qsonic.catalog.mpi_read_quasar_catalog(
args.catalog, comm, mpi_rank, args.mock_analysis)
else:
catalog = None
if catalog is not None and args.remove_bal_qsos:
logging.info("Checking BAL mask.")
BALMask.check_catalog(catalog)
sel_ai = (catalog['VMIN_CIV_450'] > 0) & (catalog['VMAX_CIV_450'] > 0)
sel_bi = (
(catalog['VMIN_CIV_2000'] > 0) & (catalog['VMAX_CIV_2000'] > 0))
sel_bal = np.any(sel_ai, axis=1) | np.any(sel_bi, axis=1)
bal_targetids = catalog['TARGETID'][sel_bal]
ids_to_remove = np.concatenate((ids_to_remove, bal_targetids))
logging.info(f"Removing {bal_targetids.size} BAL sightlines")
if catalog is not None and args.keep_surveys:
sel_survey = np.isin(catalog['SURVEY'], args.keep_surveys)
remove_sur_tids = catalog['TARGETID'][~sel_survey]
logging.info(f"Removing {remove_sur_tids.size} non survey sightlines")
ids_to_remove = np.concatenate((ids_to_remove, remove_sur_tids))
return ids_to_remove
[docs]
def mpi_read_all_deltas(args, comm=None, mpi_rank=0, mpi_size=1):
start_time = time.time()
logging.info("Reading deltas.")
all_delta_files = mpi_fnc_bcast(
glob.glob, comm, mpi_rank,
f"Delta files are not found in {args.input_dir}.",
f"{args.input_dir}/delta-*.fits*")
ndelta_all = len(all_delta_files)
logging.info(f"There are {ndelta_all} delta files.")
if mpi_size > ndelta_all:
warnings.warn(
"There are more MPI processes then number of delta files.")
nfiles_per_rank = max(1, ndelta_all // mpi_size)
i1 = nfiles_per_rank * mpi_rank
i2 = min(ndelta_all, i1 + nfiles_per_rank)
files_this_rank = all_delta_files[i1:i2]
deltas_list = [qsonic.io.read_deltas(fname) for fname in files_this_rank]
etime = (time.time() - start_time) / 60 # min
logging.info(f"Rank{mpi_rank} read {i2-i1} deltas in {etime:.1f} mins.")
return deltas_list
[docs]
def mpi_stack_fluxes(args, comm, deltas_list):
dwave = deltas_list[0].header['DELTA_LAMBDA']
nwaveobs = int((args.wave2 - args.wave1) / dwave) + 1
waveobs = np.linspace(args.wave1, args.wave2, nwaveobs)
stacked_flux = np.zeros(nwaveobs)
weights = np.zeros(nwaveobs)
for delta in deltas_list:
flux = 1 + delta.delta
idx = np.round((delta.wave - args.wave1) / dwave).astype(int)
w = (idx >= 0) & (idx < nwaveobs)
stacked_flux[idx[w]] += flux[w] * delta.weight[w]
weights[idx[w]] += delta.weight[w]
# Save stacked_flux to buffer, then used stacked_flux to store reduced
# weights. Place them properly in the end.
buf = np.zeros(nwaveobs)
comm.Allreduce(stacked_flux, buf)
stacked_flux *= 0
comm.Allreduce(weights, stacked_flux)
weights = stacked_flux
stacked_flux = buf
w = weights > 0
stacked_flux[w] /= weights[w]
stacked_flux[~w] = 0
return waveobs, stacked_flux
[docs]
def mpi_run_all(comm, mpi_rank, mpi_size):
args = mpi_parse(get_parser(), comm, mpi_rank)
if mpi_rank == 0:
os_makedirs(args.outdir, exist_ok=True)
varfitter = VarLSSFitter(
args.wave1, args.wave2, args.nwbins,
args.var1, args.var2, args.nvarbins,
use_cov=args.var_use_cov, comm=comm)
ids_to_remove = mpi_set_targetid_list_to_remove(args, comm, mpi_rank)
def _is_kept(delta):
return (
(delta.targetid not in ids_to_remove)
and (delta.mean_snr > args.min_snr)
and (delta.mean_snr < args.max_snr)
)
deltas_list = mpi_read_all_deltas(args, comm, mpi_rank, mpi_size)
# Flatten this list of lists and remove quasars
deltas_list = [x for alist in deltas_list for x in alist if _is_kept(x)]
for delta in deltas_list:
varfitter.add(delta.wave, delta.delta, delta.ivar)
logging.info("Fitting var_lss and eta")
fit_results = np.ones((varfitter.nwbins, 2))
fit_results[:, 0] = 0.1
fit_results, std_results = varfitter.fit(fit_results)
# Save variance stats to file
logging.info("Saving variance stats to files")
suffix = f"snr{args.min_snr:.1f}-{args.max_snr:.1f}"
tmpfilename = f"{args.outdir}/{args.fbase}-{suffix}-variance-stats.fits"
mpi_saver = MPISaver(tmpfilename, mpi_rank)
varfitter.write(mpi_saver, args.min_snr, args.max_snr)
logging.info(f"Variance stats saved in {tmpfilename}.")
# Save fits results as well
mpi_saver.write([
varfitter.waveobs, fit_results[:, 0], std_results[:, 0],
fit_results[:, 1], std_results[:, 1]],
names=['lambda', 'var_lss', 'e_var_lss', 'eta', 'e_eta'],
extname="VAR_FUNC"
)
waveobs, stacked_flux = mpi_stack_fluxes(args, comm, deltas_list)
mpi_saver.write(
[waveobs, stacked_flux],
names=["lambda", "stacked_flux"],
extname="STACKED_FLUX")
mpi_saver.close()
[docs]
def main():
from mpi4py import MPI
comm = MPI.COMM_WORLD
mpi_rank = comm.Get_rank()
mpi_size = comm.Get_size()
logging.basicConfig(
format='%(asctime)s - %(levelname)s: %(message)s',
datefmt='%Y/%m/%d %I:%M:%S %p',
level=logging.DEBUG if mpi_rank == 0 else logging.CRITICAL)
try:
mpi_run_all(comm, mpi_rank, mpi_size)
except QsonicException as e:
logging.exception(e)
exit(1)
except Exception as e:
logging.critical(
f"Unexpected error on Rank{mpi_rank}: {e}. Abort.", exc_info=True)
comm.Abort()