import numpy as np
import threading
import logging
import re
from . import csi
[docs]
class BacklogFilter(object):
"""
Base class for CSI backlog filters.
Subclasses implement :meth:`matches` to decide whether a clustered CSI frame
should be admitted to the backlog.
"""
[docs]
def matches(self, clustered_csi):
raise NotImplementedError("BacklogFilter subclasses must implement matches()")
[docs]
class MacFilter(BacklogFilter):
"""
Backlog filter that matches source MAC addresses against a regular
expression.
:param filter_regex: Regular expression applied to the source MAC string
"""
def __init__(self, filter_regex):
self.filter_regex = filter_regex
self._compiled_regex = re.compile(filter_regex)
[docs]
def matches(self, clustered_csi):
return self._compiled_regex.match(clustered_csi.get_source_mac()) is not None
[docs]
class Exclude11bFilter(BacklogFilter):
"""
Backlog filter that drops 802.11b packets, which do not carry CSI.
"""
[docs]
def matches(self, clustered_csi):
return not clustered_csi.is_11b()
[docs]
class CSIBacklog(object):
"""
CSI backlog class. Stores CSI data in a ringbuffer for processing when needed.
:param pool: CSI pool object to collect CSI data from
:param fields: List of fields to store (default: all), e.g., ["lltf", "ht40", "rssi", "rx_gain", "fft_gain", "cfo", "lltf_8bit_mode", "timestamp", "host_timestamp", "mac", "radar_tx_timestamp", "radar_tx_index", "radar_tx_power", "radar_tx_rfswitch_state"]
:param calibrate: Apply calibration to CSI data (default: True)
:param cb_predicate: A function that defines the conditions under which clustered CSI is regarded as completed and thus added to the backlog.
See :meth:`espargos.pool.Pool.add_csi_callback` for more details.
:param size: Size of the ringbuffer (default: 100)
"""
DATA_FORMATS = {
"lltf": {
"shape": (csi.LEGACY_COEFFICIENTS_PER_CHANNEL,),
"per_antenna": True,
"dtype": np.complex64,
},
"ht20": {
"shape": (csi.HT_COEFFICIENTS_PER_CHANNEL,),
"per_antenna": True,
"dtype": np.complex64,
},
"ht40": {
"shape": (csi.HT_COEFFICIENTS_PER_CHANNEL + csi.HT40_GAP_SUBCARRIERS + csi.HT_COEFFICIENTS_PER_CHANNEL,),
"per_antenna": True,
"dtype": np.complex64,
},
"he20": {
"shape": (csi.HE20_COEFFICIENTS_PER_CHANNEL,),
"per_antenna": True,
"dtype": np.complex64,
},
"rssi": {"shape": (), "per_antenna": True, "dtype": np.float32},
"rx_gain": {"shape": (), "per_antenna": True, "dtype": np.float32},
"fft_gain": {"shape": (), "per_antenna": True, "dtype": np.float32},
"cfo": {"shape": (), "per_antenna": True, "dtype": np.float32},
"lltf_8bit_mode": {"shape": (), "per_antenna": True, "dtype": np.bool_},
"rfswitch_state": {"shape": (), "per_antenna": True, "dtype": np.uint8},
"timestamp": {"shape": (), "per_antenna": True, "dtype": np.float64},
"host_timestamp": {"shape": (), "per_antenna": False, "dtype": np.float64},
"mac": {"shape": (6,), "per_antenna": False, "dtype": np.uint8},
"radar_tx_timestamp": {"shape": (), "per_antenna": False, "dtype": np.float64},
"radar_tx_index": {"shape": (), "per_antenna": False, "dtype": np.int16},
"radar_tx_power": {"shape": (), "per_antenna": False, "dtype": np.int16},
"radar_tx_rfswitch_state": {"shape": (), "per_antenna": False, "dtype": np.uint8},
}
def __init__(self, pool, fields=None, calibrate=True, cb_predicate=None, size=100):
self.logger = logging.getLogger("pyespargos.backlog")
self.pool = pool
self.calibrate = calibrate
self.storage_mutex = threading.Lock()
self.storage = None
self.head = 0
self.latest = None
self.filllevel = 0
self.filter_mutex = threading.Lock()
self.filters = []
self._initialize_storage(
size=size,
fields=set(self.DATA_FORMATS.keys()) if fields is None else set(fields),
)
self.running = True
self.pool.add_csi_callback(self._on_new_csi, cb_predicate=cb_predicate)
self.callbacks = []
def _initialize_storage(self, size=None, fields=None):
"""
Initialize or reinitialize storage arrays.
If storage already exists, old data will be preserved where applicable.
:param size: New size of the ringbuffer (default: keep current size)
:param fields: New set of fields to store (default: keep current fields)
"""
with self.storage_mutex:
# Back up old data if storage exists
old_storage = None
if hasattr(self, "storage") and self.storage:
old_storage = dict()
for key in self.fields:
old_storage[key] = np.copy(self._read(key))
# Update size and fields
if size is not None:
self.size = size
if fields is not None:
self.fields = set(fields)
# Create new storage
self.storage = dict()
for key, meta in self.DATA_FORMATS.items():
if key not in self.fields:
continue
shape = meta["shape"]
dtype = meta["dtype"]
if meta["per_antenna"]:
full_shape = (self.size,) + self.pool.get_shape() + shape
else:
full_shape = (self.size,) + shape
if np.issubdtype(dtype, np.bool_):
self.storage[key] = np.zeros(full_shape, dtype=dtype)
elif np.issubdtype(dtype, np.unsignedinteger):
self.storage[key] = np.zeros(full_shape, dtype=dtype)
elif np.issubdtype(dtype, np.signedinteger):
self.storage[key] = np.full(full_shape, fill_value=-1, dtype=dtype)
else:
self.storage[key] = np.full(full_shape, fill_value=np.nan, dtype=dtype)
# Reset ringbuffer state
self.head = 0
self.latest = None
self.filllevel = 0
# Re-insert old data if available
if old_storage is not None and len(old_storage) > 0:
num_entries = old_storage[next(iter(old_storage))].shape[0]
for i in range(num_entries):
for key in old_storage.keys():
if key in self.fields:
self.storage[key][self.head] = old_storage[key][i]
self.latest = self.head
self.head = (self.head + 1) % self.size
self.filllevel = min(self.filllevel + 1, self.size)
def _on_new_csi(self, clustered_csi):
with self.filter_mutex:
filters = tuple(self.filters)
for backlog_filter in filters:
if not backlog_filter.matches(clustered_csi):
return
with self.storage_mutex:
# Store timestamp
sensor_timestamps = clustered_csi.get_sensor_timestamps()
if "timestamp" in self.fields:
self.storage["timestamp"][self.head] = sensor_timestamps
# Store host timestamp
if "host_timestamp" in self.fields:
self.storage["host_timestamp"][self.head] = clustered_csi.get_host_timestamp()
# Store LLTF CSI if applicable
if "lltf" in self.fields:
if clustered_csi.has_lltf():
csi_lltf = clustered_csi.deserialize_csi_lltf()
if self.calibrate:
assert self.pool.get_calibration() is not None
csi_lltf = self.pool.get_calibration().apply_lltf(csi_lltf)
self.storage["lltf"][self.head] = csi_lltf
else:
self.storage["lltf"][self.head] = np.nan
self.logger.debug("Received non-LLTF frame even though LLTF is enabled")
# Store HT40 CSI if applicable
if "ht40" in self.fields:
if clustered_csi.has_ht40ltf():
csi_ht40 = clustered_csi.deserialize_csi_ht40ltf()
if self.calibrate:
assert self.pool.get_calibration() is not None
csi_ht40 = self.pool.get_calibration().apply_ht40(csi_ht40)
self.storage["ht40"][self.head] = csi_ht40
else:
self.storage["ht40"][self.head] = np.nan
self.logger.debug("Received non-HT40 frame even though HT40 is enabled")
# Store HT20 CSI if applicable
if "ht20" in self.fields:
if clustered_csi.has_ht20ltf():
csi_ht20 = clustered_csi.deserialize_csi_ht20ltf()
if self.calibrate:
assert self.pool.get_calibration() is not None
csi_ht20 = self.pool.get_calibration().apply_ht20(csi_ht20)
self.storage["ht20"][self.head] = csi_ht20
else:
self.storage["ht20"][self.head] = np.nan
self.logger.debug("Received non-HT20 frame even though HT20 is enabled")
# Store HE20 CSI if applicable
if "he20" in self.fields:
if clustered_csi.has_he20ltf():
csi_he20 = clustered_csi.deserialize_csi_he20ltf()
if self.calibrate:
assert self.pool.get_calibration() is not None
csi_he20 = self.pool.get_calibration().apply_he20(csi_he20)
self.storage["he20"][self.head] = csi_he20
else:
self.storage["he20"][self.head] = np.nan
self.logger.debug("Received non-HE20 frame even though HE20 is enabled")
# Store RSSI
if "rssi" in self.fields:
self.storage["rssi"][self.head] = clustered_csi.get_rssi()
# Store gain metadata
if "rx_gain" in self.fields:
self.storage["rx_gain"][self.head] = clustered_csi.get_rx_gain()
if "fft_gain" in self.fields:
self.storage["fft_gain"][self.head] = clustered_csi.get_fft_gain()
# Store CFO
if "cfo" in self.fields:
self.storage["cfo"][self.head] = clustered_csi.get_cfo()
# Store LLTF bit depth metadata
if "lltf_8bit_mode" in self.fields:
self.storage["lltf_8bit_mode"][self.head] = clustered_csi.get_lltf_8bit_mode()
# Store RF switch states
if "rfswitch_state" in self.fields:
self.storage["rfswitch_state"][self.head] = clustered_csi.get_rfswitch_state()
# Store MAC address. mac_str is a hex string without colons, e.g. "00:11:22:33:44:55" -> "001122334455"
mac_str = clustered_csi.get_source_mac()
mac = np.asarray([int(mac_str[i : i + 2], 16) for i in range(0, len(mac_str), 2)])
assert mac.shape == (6,)
if "mac" in self.fields:
self.storage["mac"][self.head] = mac
# Store radar TX metadata if present. These are packet-wide fields:
# the TX timestamp is sensor-local, and tx_index is flattened over the pool layout.
if "radar_tx_timestamp" in self.fields:
self.storage["radar_tx_timestamp"][self.head] = np.nan
if "radar_tx_index" in self.fields:
self.storage["radar_tx_index"][self.head] = -1
if "radar_tx_power" in self.fields:
self.storage["radar_tx_power"][self.head] = -1
if "radar_tx_rfswitch_state" in self.fields:
self.storage["radar_tx_rfswitch_state"][self.head] = 0
if clustered_csi.has_radar_tx_report():
radar_tx_report = clustered_csi.get_radar_tx_info()
if "radar_tx_timestamp" in self.fields:
self.storage["radar_tx_timestamp"][self.head] = radar_tx_report.get_hardware_tx_timestamp_ns() / 1e9
if "radar_tx_index" in self.fields:
self.storage["radar_tx_index"][self.head] = clustered_csi.get_radar_tx_index()
if "radar_tx_power" in self.fields:
self.storage["radar_tx_power"][self.head] = radar_tx_report.tx_power
if "radar_tx_rfswitch_state" in self.fields:
self.storage["radar_tx_rfswitch_state"][self.head] = radar_tx_report.rfswitch_state
# Advance ringbuffer head
self.latest = self.head
self.head = (self.head + 1) % self.size
self.filllevel = min(self.filllevel + 1, self.size)
for cb in self.callbacks:
cb()
[docs]
def add_update_callback(self, cb):
"""Add a callback that is called when new CSI data is added to the backlog"""
self.callbacks.append(cb)
def _read(self, key):
if self.filllevel == 0:
return np.empty((0,) + self.storage[key].shape[1:], dtype=self.storage[key].dtype)
return np.roll(self.storage[key], -self.head, axis=0)[-self.filllevel :]
[docs]
def get(self, key):
"""
Retrieve data from the ringbuffer
:param key: Key of the data to retrieve (e.g., "lltf", "ht40", "rssi", etc.)
:return: Data corresponding to the key, oldest first
"""
if not key in self.fields:
raise ValueError(f"Requested key '{key}' not in backlog fields")
self.storage_mutex.acquire()
retval = np.copy(self._read(key))
self.storage_mutex.release()
return retval
[docs]
def get_multiple(self, keys):
"""
Retrieve multiple data fields from the ringbuffer.
You must use get_multiple to ensure consistency of data across multiple keys.
:param keys: List of keys of the data to retrieve (e.g., ["lltf", "ht40", "rssi"], etc.)
:return: Tuple of data arrays corresponding to the keys (in same order), contents are oldest first
"""
for key in keys:
if not (key in self.fields):
raise ValueError(f"Requested key '{key}' not in backlog fields")
self.storage_mutex.acquire()
retval = []
for key in keys:
retval.append(np.copy(self._read(key)))
self.storage_mutex.release()
return tuple(retval)
[docs]
def clear(self):
"""
Clear all stored CSI datapoints from the ringbuffer.
This is useful after changing calibration or other receiver settings:
entries already in the backlog were stored with the old interpretation
and must not be mixed with freshly received CSI.
"""
with self.storage_mutex:
self.head = 0
self.latest = None
self.filllevel = 0
[docs]
def count_valid_datapoints(self, key: str, allow_incomplete: bool = False) -> int:
"""
Count datapoints for which a particular stored field (key) is valid
(finite and nonzero).
The first axis of every backlog field is the datapoint axis. If
``allow_incomplete`` is false, only datapoints with all finite values are
counted. If true, datapoints with at least one finite value are counted.
Non-floating fields are always finite, so each stored entry counts.
:param key: Key of the data to inspect.
:param allow_incomplete: Whether partially valid datapoints count.
"""
if not key in self.fields:
raise ValueError(f"Requested key '{key}' not in backlog fields")
with self.storage_mutex:
data = self._read(key)
if data.size == 0:
return 0
finite = np.isfinite(data.reshape(data.shape[0], -1))
valid = np.any(finite, axis=1) if allow_incomplete else np.all(finite, axis=1)
return int(np.count_nonzero(valid))
[docs]
def get_latest(self, key):
"""
Retrieve the latest value for a key in the ringbuffer.
:param key: Key of the data to retrieve
:return: Latest value, or None if no data is available
"""
if self.latest is None:
return None
assert key in self.fields
latest_value = self.storage[key][self.latest]
return np.copy(latest_value)
[docs]
def nonempty(self):
"""
Check if the backlog is nonempty
:return: True if the backlog is nonempty
"""
return self.latest is not None
[docs]
def start(self):
"""
Start the CSI backlog thread, must be called before using the backlog
"""
self.thread = threading.Thread(target=self.__run)
self.thread.start()
self.logger.info(f"Started CSI backlog thread")
[docs]
def stop(self):
"""
Stop the CSI backlog thread
"""
self.running = False
self.thread.join()
[docs]
def add_filter(self, backlog_filter):
"""
Add a filter to the backlog.
:param backlog_filter: Instance of :class:`BacklogFilter`
"""
if not isinstance(backlog_filter, BacklogFilter):
raise TypeError("backlog_filter must be an instance of BacklogFilter")
with self.filter_mutex:
if backlog_filter not in self.filters:
self.filters.append(backlog_filter)
[docs]
def remove_filter(self, backlog_filter):
"""
Remove a previously added filter from the backlog.
:param backlog_filter: Instance of :class:`BacklogFilter`
"""
with self.filter_mutex:
if backlog_filter in self.filters:
self.filters.remove(backlog_filter)
[docs]
def clear_filters(self):
"""
Remove all filters from the backlog.
"""
with self.filter_mutex:
self.filters.clear()
[docs]
def get_filters(self):
"""
Get the list of currently active backlog filters.
:return: List of :class:`BacklogFilter` instances
"""
with self.filter_mutex:
return list(self.filters)
[docs]
def get_size(self):
"""
Get the size of the backlog ringbuffer
:return: Size of the backlog ringbuffer
"""
return self.size
[docs]
def set_size(self, new_size):
"""
Resize the backlog ringbuffer.
If there are existing entries, they will be preserved up to the new size.
:param new_size: New size of the backlog ringbuffer
"""
self._initialize_storage(size=new_size)
[docs]
def set_fields(self, new_fields):
"""
Set the fields to be stored in the backlog.
Existing data will be preserved for fields that are still present.
:param new_fields: New list of fields to store
"""
self._initialize_storage(fields=new_fields)
[docs]
def get_fields(self):
"""
Get the list of fields currently stored in the backlog.
:return: List of fields currently stored in the backlog
"""
return self.fields
def __run(self):
"""
CSI backlog thread main loop, do not call directly.
This function runs in a separate thread and continuously processes CSI data from the pool.
"""
while self.running:
self.pool.run()