Source code for flamedisx.lxe_blocks.detection

import typing as ty

import numpy as np
from scipy import stats
import tensorflow as tf
import tensorflow_probability as tfp

import flamedisx as fd
export, __all__ = fd.exporter()
o = tf.newaxis


class DetectPhotonsOrElectrons(fd.Block):
    """Common code for DetectPhotons and DetectElectrons"""

    model_attributes = ('check_efficiencies',)

    # Whether to check if all events have a positive detection efficiency.
    # As with check_acceptances in MakeFinalSignals, you may have to
    # turn this off, depending on your application.
    check_efficiencies = True

    quanta_name: str

    # Prevent pycharm warnings:
    source: fd.Source
    gimme: ty.Callable
    gimme_numpy: ty.Callable

    def _compute(self, data_tensor, ptensor,
                 quanta_produced, quanta_detected):
        p = self.gimme(self.quanta_name + '_detection_eff',
                       data_tensor=data_tensor, ptensor=ptensor)[:, o, o]

        if self.quanta_name == 'photon':
            # Note *= doesn't work, p will get reshaped
            p = p * self.gimme('penning_quenching_eff',
                               bonus_arg=quanta_produced,
                               data_tensor=data_tensor, ptensor=ptensor)

        result = tfp.distributions.Binomial(
                total_count=quanta_produced,
                probs=tf.cast(p, dtype=fd.float_type())
            ).prob(quanta_detected)
        acceptance = self.gimme(self.quanta_name + '_acceptance',
                                bonus_arg=quanta_detected,
                                data_tensor=data_tensor, ptensor=ptensor)
        return result * acceptance

    def _simulate(self, d):
        p = self.gimme_numpy(self.quanta_name + '_detection_eff')

        if self.quanta_name == 'photon':
            p *= self.gimme_numpy(
                'penning_quenching_eff', d['photons_produced'].values)

        d[self.quanta_name + 's_detected'] = stats.binom.rvs(
            n=d[self.quanta_name + 's_produced'],
            p=p)
        d['p_accepted'] *= self.gimme_numpy(
            self.quanta_name + '_acceptance',
            d[self.quanta_name + 's_detected'].values)

    def _annotate(self, d):
        # Get efficiency
        eff = self.gimme_numpy(self.quanta_name + '_detection_eff')
        if self.quanta_name == 'photon':
            eff *= self.gimme_numpy('penning_quenching_eff',
                                    d['photons_detected_mle'].values / eff)

        # Check for bad efficiencies
        if self.check_efficiencies and np.any(eff <= 0):
            raise ValueError(f"Found event with nonpositive {self.quanta_name} "
                             "detection efficiency: did you apply and "
                             "configure your cuts correctly?")

        # Estimate produced quanta
        n_prod_mle = d[self.quanta_name + 's_produced_mle'] = \
            d[self.quanta_name + 's_detected_mle'] / eff

        # Estimating the spread in number of produced quanta is tricky since
        # the number of detected quanta is itself uncertain.
        # TODO: where did this derivation come from again?
        q = (1 - eff) / eff
        _std = (q + (q ** 2 + 4 * n_prod_mle * q) ** 0.5) / 2

        for bound, sign, intify in (('min', -1, np.floor),
                                    ('max', +1, np.ceil)):
            d[self.quanta_name + 's_produced_' + bound] = intify(
                n_prod_mle + sign * self.source.max_sigma * _std
            ).clip(0, None).astype(np.int)


[docs]@export class DetectPhotons(DetectPhotonsOrElectrons): dimensions = ('photons_produced', 'photons_detected') extra_dimensions = () special_model_functions = ('photon_acceptance', 'penning_quenching_eff') model_functions = ('photon_detection_eff',) + special_model_functions photon_detection_eff = 0.1
[docs] def photon_acceptance(self, photons_detected, min_photons=3): return tf.where( photons_detected < min_photons, tf.zeros_like(photons_detected, dtype=fd.float_type()), tf.ones_like(photons_detected, dtype=fd.float_type()))
quanta_name = 'photon'
[docs] @staticmethod def penning_quenching_eff(nph): return 1. + 0. * nph
[docs] def _compute(self, data_tensor, ptensor, photons_produced, photons_detected): return super()._compute(quanta_produced=photons_produced, quanta_detected=photons_detected, data_tensor=data_tensor, ptensor=ptensor)
[docs]@export class DetectElectrons(DetectPhotonsOrElectrons): dimensions = ('electrons_produced', 'electrons_detected') extra_dimensions = () special_model_functions = ('electron_acceptance',) model_functions = ('electron_detection_eff',) + special_model_functions
[docs] @staticmethod def electron_detection_eff(drift_time, *, elife=452e3, extraction_eff=0.96): return extraction_eff * tf.exp(-drift_time / elife)
electron_acceptance = 1. quanta_name = 'electron'
[docs] def _compute(self, data_tensor, ptensor, electrons_produced, electrons_detected): return super()._compute(quanta_produced=electrons_produced, quanta_detected=electrons_detected, data_tensor=data_tensor, ptensor=ptensor)