"""Code to load XENON-specific correction maps
Adapted from https://github.com/XENONnT/straxen/blob/master/straxen/itp_map.py
"""
import logging
import gzip
import json
import re
import numpy as np
from scipy.spatial import cKDTree
import flamedisx as fd
export, __all__ = fd.exporter()
[docs]@export
class InterpolateAndExtrapolate:
"""Linearly interpolate- and extrapolate using inverse-distance
weighted averaging between nearby points.
"""
def __init__(self, points, values, neighbours_to_use=None, array_valued=False):
"""
:param points: array (n_points, n_dims) of coordinates
:param values: array (n_points) of values
:param neighbours_to_use: Number of neighbouring points to use for
averaging. Default is 2 * dimensions of points.
"""
self.kdtree = cKDTree(points)
self.values = values.astype(np.float)
if neighbours_to_use is None:
neighbours_to_use = points.shape[1] * 2
self.neighbours_to_use = neighbours_to_use
self.array_valued = array_valued
if array_valued:
self.n_dim = self.values.shape[-1]
def __call__(self, points):
distances, indices = self.kdtree.query(points, self.neighbours_to_use)
result = np.ones(len(points)) * float('nan')
if self.array_valued:
result = np.repeat(result.reshape(-1, 1), self.n_dim, axis=1)
# If one of the coordinates is NaN, the neighbour-query fails.
# If we don't filter these out, it would result in an IndexError
# as the kdtree returns an invalid index if it can't find neighbours.
valid = (distances < float('inf')).max(axis=-1)
values = self.values[indices[valid]]
weights = 1 / np.clip(distances[valid], 1e-6, float('inf'))
if self.array_valued:
weights = np.repeat(weights, self.n_dim).reshape(values.shape)
result[valid] = np.average(values, weights=weights,
axis=-2 if self.array_valued else -1)
return result
[docs]@export
class InterpolatingMap:
"""Correction map that computes values using inverse-weighted distance
interpolation.
The map must be specified as a json translating to a dictionary like this:
'coordinate_system' : [[x1, y1], [x2, y2], [x3, y3], [x4, y4], ...],
'map' : [value1, value2, value3, value4, ...]
'another_map' : idem
'name': 'Nice file with maps',
'description': 'Say what the maps are, who you are, etc',
'timestamp': unix epoch seconds timestamp
with the straightforward generalization to 1d and 3d.
Alternatively, a grid coordinate system can be specified as follows:
'coordinate_system' : [['x', [x_min, x_max, n_x]], [['y', [y_min, y_max, n_y]]
Alternatively, an N-vector-valued map can be specified by an array with
last dimension N in 'map'.
The default map name is 'map', I'd recommend you use that.
For a 0d placeholder map, use
'points': [],
'map': 42,
etc
"""
data_field_names = ['timestamp', 'description', 'coordinate_system',
'name', 'irregular']
def __init__(self, data):
if isinstance(data, bytes):
data = gzip.decompress(data).decode()
if isinstance(data, (str, bytes)):
data = json.loads(data)
assert isinstance(data, dict), f"Expected dictionary data, got {type(data)}"
self.data = data
# Decompress / dequantize the map
# TODO: support multiple map names
if 'compressed' in self.data:
try:
import strax
except ImportError:
print("You must install strax to use compressed maps!\n")
raise
compressor, dtype, shape = self.data['compressed']
self.data['map'] = np.frombuffer(
strax.io.COMPRESSORS[compressor]['decompress'](self.data['map']),
dtype=dtype).reshape(*shape)
del self.data['compressed']
if 'quantized' in self.data:
self.data['map'] = self.data['quantized'] * self.data['map'].astype(np.float32)
del self.data['quantized']
cs = self.data['coordinate_system']
if not cs:
self.dimensions = 0
elif isinstance(cs[0], list) and isinstance(cs[0][0], str):
# Support for specifying coordinate system as a gridspec
grid = [np.linspace(left, right, points)
for _, (left, right, points) in cs]
cs = np.array(np.meshgrid(*grid, indexing='ij'))
cs = np.transpose(cs, np.roll(np.arange(len(grid)+1), -1))
cs = np.array(cs).reshape((-1, len(grid)))
self.dimensions = len(grid)
else:
self.dimensions = len(cs[0])
self.coordinate_system = cs
self.interpolators = {}
self.map_names = sorted([k for k in self.data.keys()
if k not in self.data_field_names])
log = logging.getLogger('InterpolatingMap')
log.debug('Map name: %s' % self.data['name'])
log.debug('Map description:\n ' +
re.sub(r'\n', r'\n ', self.data['description']))
log.debug("Map names found: %s" % self.map_names)
for map_name in self.map_names:
map_data = np.array(self.data[map_name])
array_valued = len(map_data.shape) == self.dimensions + 1
if self.dimensions == 0:
# 0 D -- placeholder maps which take no arguments
# and always return a single value
def itp_fun(positions, _data=map_data):
return np.array([_data])
else:
if array_valued:
map_data = map_data.reshape((-1, map_data.shape[-1]))
itp_fun = InterpolateAndExtrapolate(points=np.array(cs),
values=np.array(map_data),
array_valued=array_valued)
self.interpolators[map_name] = itp_fun
def __call__(self, positions, map_name='map'):
"""Returns the value of the map at the position given by coordinates
:param positions: array (n_dim) or (n_points, n_dim) of positions
:param map_name: Name of the map to use. Default is 'map'.
"""
return self.interpolators[map_name](positions)