diff --git a/dlclivegui/gui/main_window.py b/dlclivegui/gui/main_window.py index 894fc4d..915af1d 100644 --- a/dlclivegui/gui/main_window.py +++ b/dlclivegui/gui/main_window.py @@ -445,7 +445,7 @@ def _build_dlc_group(self) -> QGroupBox: # Processor selection processor_path_layout = QHBoxLayout() self.processor_folder_edit = QLineEdit() - self.processor_folder_edit.setText(default_processors_dir()) + self.processor_folder_edit.setText(self._settings_store.get_processor_folder(default=default_processors_dir())) processor_path_layout.addWidget(self.processor_folder_edit) self.browse_processor_folder_button = QPushButton("Browse...") @@ -1081,10 +1081,11 @@ def _action_browse_directory(self) -> None: def _action_browse_processor_folder(self) -> None: """Browse for processor folder.""" - current_path = self.processor_folder_edit.text() or default_processors_dir() + current_path = self.processor_folder_edit.text().strip() or default_processors_dir() directory = QFileDialog.getExistingDirectory(self, "Select processor folder", current_path) if directory: self.processor_folder_edit.setText(directory) + self._settings_store.set_processor_folder(directory) self._refresh_processors() def _action_open_recording_folder(self) -> None: @@ -1138,10 +1139,17 @@ def _refresh_processors(self) -> None: self.processor_combo.addItem("No Processor", None) selected_folder = self.processor_folder_edit.text().strip() - if Path(selected_folder).exists(): - self._scanned_processors = scan_processor_folder(selected_folder) + selected_path = Path(selected_folder).expanduser() if selected_folder else None + + if selected_path is not None and selected_path.is_dir(): + resolved_folder = str(selected_path.resolve()) + self._settings_store.set_processor_folder(resolved_folder) + self._scanned_processors = scan_processor_folder(resolved_folder) + source_text = resolved_folder else: self._scanned_processors = scan_processor_package("dlclivegui.processors") + source_text = "package dlclivegui.processors" + self._processor_keys = list(self._scanned_processors.keys()) for key in self._processor_keys: @@ -1150,9 +1158,7 @@ def _refresh_processors(self) -> None: self.processor_combo.addItem(display_name, key) self.processor_combo.update_shrink_width() - self.statusBar().showMessage( - f"Found {len(self._processor_keys)} processor(s) in package dlclivegui.processors", 3000 - ) + self.statusBar().showMessage(f"Found {len(self._processor_keys)} processor(s) in {source_text}", 3000) # ------------------------------------------------------------------ # Recording path preview and session name persistence @@ -1728,24 +1734,28 @@ def _update_inference_buttons(self) -> None: def _update_dlc_controls_enabled(self) -> None: """Enable/disable DLC settings based on inference state.""" allow_changes = not self._dlc_active - processor_controls = allow_changes and self._processor_control_enabled() widgets = [ self.model_path_edit, self.browse_model_button, self.dlc_camera_combo, - # self.additional_options_edit, ] + processor_widgets = [ self.processor_folder_edit, self.browse_processor_folder_button, self.refresh_processors_button, self.processor_combo, ] + for widget in widgets: widget.setEnabled(allow_changes) + for widget in processor_widgets: - widget.setEnabled(processor_controls) + widget.setEnabled(allow_changes) + + if hasattr(self, "allow_processor_ctrl_checkbox"): + self.allow_processor_ctrl_checkbox.setEnabled(allow_changes) def _update_camera_controls_enabled(self) -> None: multi_cam_recording = self._rec_manager.is_active @@ -2151,6 +2161,9 @@ def closeEvent(self, event: QCloseEvent) -> None: # pragma: no cover - GUI beha # Remember model path on exit self._model_path_store.save_if_valid(self.model_path_edit.text().strip()) + # Remember processor folder on exit + if hasattr(self, "processor_folder_edit"): + self._settings_store.set_processor_folder(self.processor_folder_edit.text().strip()) # Close the window super().closeEvent(event) diff --git a/dlclivegui/processors/PLUGIN_SYSTEM.md b/dlclivegui/processors/PLUGIN_SYSTEM.md index 9e975e0..e6a1436 100644 --- a/dlclivegui/processors/PLUGIN_SYSTEM.md +++ b/dlclivegui/processors/PLUGIN_SYSTEM.md @@ -16,7 +16,8 @@ Processors are Python classes (typically subclasses of `dlclive.Processor`) that ### Useful files -- `dlclivegui/processors/dlc_processor_socket.py` — Example socket-based processor base class + examples +- `dlclivegui/processors/dlc_processor_socket.py` — Example socket-based processor base class +- `dlclivegui/processors/examples.py` — Example processor implementations (e.g., One-Euro filter) - `dlclivegui/processors/processor_utils.py` — Scanning + instantiation helpers used by the GUI --- @@ -204,12 +205,7 @@ The built-in `BaseProcessorSocket` (in `dlc_processor_socket.py`) demonstrates a ```python from dlclive import Processor - -PROCESSOR_REGISTRY = {} - -def register_processor(cls): - PROCESSOR_REGISTRY[getattr(cls, "PROCESSOR_ID", cls.__name__)] = cls - return cls +from dlclivegui.processors import register_processor, PROCESSOR_REGISTRY @register_processor class MyNewProcessor(Processor): diff --git a/dlclivegui/processors/__init__.py b/dlclivegui/processors/__init__.py new file mode 100644 index 0000000..8e77171 --- /dev/null +++ b/dlclivegui/processors/__init__.py @@ -0,0 +1,3 @@ +from .registry import PROCESSOR_REGISTRY, register_processor + +__all__ = ["register_processor", "PROCESSOR_REGISTRY"] diff --git a/dlclivegui/processors/dlc_processor_socket.py b/dlclivegui/processors/dlc_processor_socket.py index 8ded010..594512c 100644 --- a/dlclivegui/processors/dlc_processor_socket.py +++ b/dlclivegui/processors/dlc_processor_socket.py @@ -7,14 +7,17 @@ import sys import time from collections import deque -from math import acos, atan2, copysign, degrees, pi, sqrt from multiprocessing.connection import Client, Listener from pathlib import Path from threading import Event, Thread import numpy as np import pandas as pd -from dlclive import Processor # type: ignore + +try: + from dlclive.processor import Processor # type: ignore +except ImportError: + Processor = object # Fallback for type checking if dlclive is not installed logger = logging.getLogger("dlc_processor_socket") @@ -24,59 +27,6 @@ _handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")) logger.addHandler(_handler) -# Registry for GUI discovery -PROCESSOR_REGISTRY = {} - - -def register_processor(cls): - registry_key = getattr(cls, "PROCESSOR_ID", cls.__name__) - if registry_key in PROCESSOR_REGISTRY: - raise ValueError( - f"Duplicate processor registration key '{registry_key}': " - f"{PROCESSOR_REGISTRY[registry_key].__name__} vs {cls.__name__}" - ) - PROCESSOR_REGISTRY[registry_key] = cls - return cls - - -class OneEuroFilter: # pragma: no cover - def __init__(self, t0, x0, dx0=None, min_cutoff=1.0, beta=0.0, d_cutoff=1.0): - self.min_cutoff = min_cutoff - self.beta = beta - self.d_cutoff = d_cutoff - self.x_prev = x0 - if dx0 is None: - dx0 = np.zeros_like(x0) - self.dx_prev = dx0 - self.t_prev = t0 - - @staticmethod - def smoothing_factor(t_e, cutoff): - r = 2 * pi * cutoff * t_e - return r / (r + 1) - - @staticmethod - def exponential_smoothing(alpha, x, x_prev): - return alpha * x + (1 - alpha) * x_prev - - def __call__(self, t, x): - t_e = t - self.t_prev - if t_e <= 0: - return x - a_d = self.smoothing_factor(t_e, self.d_cutoff) - dx = (x - self.x_prev) / t_e - dx_hat = self.exponential_smoothing(a_d, dx, self.dx_prev) - - cutoff = self.min_cutoff + self.beta * abs(dx_hat) - a = self.smoothing_factor(t_e, cutoff) - x_hat = self.exponential_smoothing(a, x, self.x_prev) - - self.x_prev = x_hat - self.dx_prev = dx_hat - self.t_prev = t - - return x_hat - # pragma: cover class BaseProcessorSocket(Processor): @@ -474,375 +424,3 @@ def get_data(self): if self.dlc_cfg is not None: save_dict["dlc_cfg"] = self.dlc_cfg return save_dict - - -@register_processor -class ExampleProcessorSocketCalculateMousePose(BaseProcessorSocket): # pragma: no cover - """ - DLC Processor with pose calculations (center, heading, head angle) and optional filtering. - - Calculates: - - center: Weighted average of head keypoints - - heading: Body orientation (degrees) - - head_angle: Head rotation relative to body (radians) - - Broadcasts: [timestamp, center_x, center_y, heading, head_angle] - """ - - PROCESSOR_NAME = "Example Experiment Pose Processor" - PROCESSOR_DESCRIPTION = "Calculates mouse center, heading, and head angle with optional One-Euro filtering" - PROCESSOR_PARAMS = { - "bind": { - "type": "tuple", - "default": ("127.0.0.1", 6000), - "description": "Server address (host, port)", - }, - "authkey": { - "type": "bytes", - "default": b"secret password", - "description": "Authentication key for clients", - }, - "use_perf_counter": { - "type": "bool", - "default": False, - "description": "Use time.perf_counter() instead of time.time()", - }, - "use_filter": { - "type": "bool", - "default": False, - "description": "Apply One-Euro filter to calculated values", - }, - "filter_kwargs": { - "type": "dict", - "default": {"min_cutoff": 1.0, "beta": 0.02, "d_cutoff": 1.0}, - "description": "One-Euro filter parameters (min_cutoff, beta, d_cutoff)", - }, - "save_original": { - "type": "bool", - "default": False, - "description": "Save raw pose arrays for analysis", - }, - } - - def __init__( - self, - bind=("127.0.0.1", 6000), - authkey=b"secret password", - use_perf_counter=False, - use_filter=False, - filter_kwargs: dict | None = None, - save_original=False, - ): - super().__init__( - bind=bind, - authkey=authkey, - use_perf_counter=use_perf_counter, - save_original=save_original, - ) - - self.center_x = deque() - self.center_y = deque() - self.heading_direction = deque() - self.head_angle = deque() - - self.use_filter = use_filter - self.filter_kwargs = filter_kwargs if filter_kwargs is not None else {} - self.filters = None - - def _clear_data_queues(self): - super()._clear_data_queues() - self.center_x.clear() - self.center_y.clear() - self.heading_direction.clear() - self.head_angle.clear() - - def _initialize_filters(self, vals): - t0 = self.timing_func() - self.filters = { - "center_x": OneEuroFilter(t0, vals[0], **self.filter_kwargs), - "center_y": OneEuroFilter(t0, vals[1], **self.filter_kwargs), - "heading": OneEuroFilter(t0, vals[2], **self.filter_kwargs), - "head_angle": OneEuroFilter(t0, vals[3], **self.filter_kwargs), - } - logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}") - - def process(self, pose, **kwargs): - # Extract keypoints and confidence - xy = pose[:, :2] - conf = pose[:, 2] - - # Calculate weighted center from head keypoints - head_xy = xy[[0, 1, 2, 3, 4, 5, 6, 26], :] - head_conf = conf[[0, 1, 2, 3, 4, 5, 6, 26]] - center = np.average(head_xy, axis=0, weights=head_conf) - - # Calculate body axis (tail_base -> neck) - body_axis = xy[7] - xy[13] - body_axis /= sqrt(np.sum(body_axis**2)) - - # Calculate head axis (neck -> nose) - head_axis = xy[0] - xy[7] - head_axis /= sqrt(np.sum(head_axis**2)) - - # Calculate head angle relative to body - cross = body_axis[0] * head_axis[1] - head_axis[0] * body_axis[1] - sign = copysign(1, cross) # Positive when looking left - sign = copysign(1, cross) - try: - head_angle = acos(body_axis @ head_axis) * sign - except ValueError: - head_angle = 0 - - # Calculate heading (body orientation) - heading = degrees(atan2(body_axis[1], body_axis[0])) - - # Raw values (heading unwrapped for filtering) - vals = [center[0], center[1], heading, head_angle] - - # Apply filtering if enabled - curr_time = self.timing_func() - if self.use_filter: - if self.filters is None: - self._initialize_filters(vals) - - vals = [ - self.filters["center_x"](curr_time, vals[0]), - self.filters["center_y"](curr_time, vals[1]), - self.filters["heading"](curr_time, vals[2]), - self.filters["head_angle"](curr_time, vals[3]), - ] - - # Wrap heading to [0, 360) after filtering - vals[2] = vals[2] % 360 - # Update step counter - self.curr_step = self.curr_step + 1 - - # Store processed data (only if recording) - if self.recording: - if self.save_original and self.original_pose is not None: - self.original_pose.append(pose.copy()) - self.center_x.append(vals[0]) - self.center_y.append(vals[1]) - self.heading_direction.append(vals[2]) - self.head_angle.append(vals[3]) - self.time_stamp.append(curr_time) - self.step.append(self.curr_step) - self.frame_time.append(kwargs.get("frame_time", -1)) - if "pose_time" in kwargs: - self.pose_time.append(kwargs["pose_time"]) - - payload = [curr_time, vals[0], vals[1], vals[2], vals[3]] - self.broadcast(payload) - return pose - - def get_data(self): - save_dict = super().get_data() - save_dict["x_pos"] = np.array(self.center_x) - save_dict["y_pos"] = np.array(self.center_y) - save_dict["heading_direction"] = np.array(self.heading_direction) - save_dict["head_angle"] = np.array(self.head_angle) - save_dict["use_filter"] = self.use_filter - save_dict["filter_kwargs"] = self.filter_kwargs - return save_dict - - -@register_processor -class ExampleProcessorSocketFilterKeypoints(BaseProcessorSocket): # pragma: no cover - PROCESSOR_NAME = "Mouse Pose with less keypoints" - PROCESSOR_DESCRIPTION = "Calculates mouse center, heading, and head angle with optional One-Euro filtering" - PROCESSOR_PARAMS = { - "bind": { - "type": "tuple", - "default": ("127.0.0.1", 6000), - "description": "Server address (host, port)", - }, - "authkey": { - "type": "bytes", - "default": b"secret password", - "description": "Authentication key for clients", - }, - "use_perf_counter": { - "type": "bool", - "default": False, - "description": "Use time.perf_counter() instead of time.time()", - }, - "use_filter": { - "type": "bool", - "default": False, - "description": "Apply One-Euro filter to calculated values", - }, - "filter_kwargs": { - "type": "dict", - "default": {"min_cutoff": 1.0, "beta": 0.02, "d_cutoff": 1.0}, - "description": "One-Euro filter parameters (min_cutoff, beta, d_cutoff)", - }, - "save_original": { - "type": "bool", - "default": True, - "description": "Save raw pose arrays for analysis", - }, - } - - def __init__( - self, - bind=("127.0.0.1", 6000), - authkey=b"secret password", - use_perf_counter=False, - use_filter=False, - filter_kwargs: dict | None = None, - save_original=True, - p_cutoff=0.4, - ): - super().__init__( - bind=bind, - authkey=authkey, - use_perf_counter=use_perf_counter, - save_original=save_original, - ) - - self.center_x = deque() - self.center_y = deque() - self.heading_direction = deque() - self.head_angle = deque() - - self.p_cutoff = p_cutoff - - self.use_filter = use_filter - self.filter_kwargs = filter_kwargs if filter_kwargs is not None else {} - self.filters = None - - def _clear_data_queues(self): - super()._clear_data_queues() - self.center_x.clear() - self.center_y.clear() - self.heading_direction.clear() - self.head_angle.clear() - - def _initialize_filters(self, vals): - t0 = self.timing_func() - self.filters = { - "center_x": OneEuroFilter(t0, vals[0], **self.filter_kwargs), - "center_y": OneEuroFilter(t0, vals[1], **self.filter_kwargs), - "heading": OneEuroFilter(t0, vals[2], **self.filter_kwargs), - "head_angle": OneEuroFilter(t0, vals[3], **self.filter_kwargs), - } - logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}") - - def process(self, pose, **kwargs): - # Extract keypoints and confidence - xy = pose[:, :2] - conf = pose[:, 2] - - # Calculate weighted center from head keypoints - head_xy = xy[[0, 1, 2, 3, 5, 6, 7], :] - head_conf = conf[[0, 1, 2, 3, 5, 6, 7]] - # set low confidence keypoints to zero weight - head_conf = np.where(head_conf < self.p_cutoff, 0, head_conf) - try: - center = np.average(head_xy, axis=0, weights=head_conf) - except ZeroDivisionError: - # If all keypoints have zero weight, return without processing - return pose - - neck = np.average(xy[[2, 3, 6, 7], :], axis=0, weights=conf[[2, 3, 6, 7]]) - - # Calculate body axis (tail_base -> neck) - body_axis = neck - xy[9] - body_axis /= sqrt(np.sum(body_axis**2)) - - # Calculate head axis (neck -> nose) - head_axis = xy[0] - neck - head_axis /= sqrt(np.sum(head_axis**2)) - - # Calculate head angle relative to body - cross = body_axis[0] * head_axis[1] - head_axis[0] * body_axis[1] - sign = copysign(1, cross) # Positive when looking left - sign = copysign(1, cross) - try: - head_angle = acos(body_axis @ head_axis) * sign - except ValueError: - head_angle = 0 - - # Calculate heading (body orientation) - heading = degrees(atan2(body_axis[1], body_axis[0])) - vals = [center[0], center[1], heading, head_angle] - - curr_time = self.timing_func() - if self.use_filter: - if self.filters is None: - self._initialize_filters(vals) - - vals = [ - self.filters["center_x"](curr_time, vals[0]), - self.filters["center_y"](curr_time, vals[1]), - self.filters["heading"](curr_time, vals[2]), - self.filters["head_angle"](curr_time, vals[3]), - ] - - # Wrap heading to [0, 360) after filtering - vals[2] = vals[2] % 360 - # Update step counter - self.curr_step = self.curr_step + 1 - - # Store processed data (only if recording) - if self.recording: - if self.save_original and self.original_pose is not None: - self.original_pose.append(pose.copy()) - self.center_x.append(vals[0]) - self.center_y.append(vals[1]) - self.heading_direction.append(vals[2]) - self.head_angle.append(vals[3]) - self.time_stamp.append(curr_time) - self.step.append(self.curr_step) - self.frame_time.append(kwargs.get("frame_time", -1)) - if "pose_time" in kwargs: - self.pose_time.append(kwargs["pose_time"]) - - payload = [curr_time, vals[0], vals[1], vals[2], vals[3]] - self.broadcast(payload) - return pose - - def get_data(self): - save_dict = super().get_data() - save_dict["x_pos"] = np.array(self.center_x) - save_dict["y_pos"] = np.array(self.center_y) - save_dict["heading_direction"] = np.array(self.heading_direction) - save_dict["head_angle"] = np.array(self.head_angle) - save_dict["use_filter"] = self.use_filter - save_dict["filter_kwargs"] = self.filter_kwargs - return save_dict - - -def get_available_processors(): - """ - Get list of available processor classes. - - Returns: - dict: Dictionary mapping registry keys to processor info. - """ - return { - name: { - "class": cls, - "name": getattr(cls, "PROCESSOR_NAME", name), - "description": getattr(cls, "PROCESSOR_DESCRIPTION", ""), - "params": getattr(cls, "PROCESSOR_PARAMS", {}), - } - for name, cls in PROCESSOR_REGISTRY.items() - } - - -def instantiate_processor(class_name, **kwargs): - """ - Instantiate a processor by class name with given parameters. - - Args: - class_name: Registry key (e.g., "MyProcessorSocket") - **kwargs: Constructor kwargs - - Raises: - ValueError: If class_name is not in registry - """ - if class_name not in PROCESSOR_REGISTRY: - available = ", ".join(PROCESSOR_REGISTRY.keys()) - raise ValueError(f"Unknown processor '{class_name}'. Available: {available}") - return PROCESSOR_REGISTRY[class_name](**kwargs) diff --git a/dlclivegui/processors/examples.py b/dlclivegui/processors/examples.py new file mode 100644 index 0000000..7ed7691 --- /dev/null +++ b/dlclivegui/processors/examples.py @@ -0,0 +1,391 @@ +from __future__ import annotations + +import logging +from collections import deque +from math import acos, atan2, copysign, degrees, pi, sqrt + +import numpy as np + +from dlclivegui.processors import register_processor +from dlclivegui.processors.dlc_processor_socket import BaseProcessorSocket + +logger = logging.getLogger(__name__) + + +class OneEuroFilter: # pragma: no cover + def __init__(self, t0, x0, dx0=None, min_cutoff=1.0, beta=0.0, d_cutoff=1.0): + self.min_cutoff = min_cutoff + self.beta = beta + self.d_cutoff = d_cutoff + self.x_prev = x0 + if dx0 is None: + dx0 = np.zeros_like(x0) + self.dx_prev = dx0 + self.t_prev = t0 + + @staticmethod + def smoothing_factor(t_e, cutoff): + r = 2 * pi * cutoff * t_e + return r / (r + 1) + + @staticmethod + def exponential_smoothing(alpha, x, x_prev): + return alpha * x + (1 - alpha) * x_prev + + def __call__(self, t, x): + t_e = t - self.t_prev + if t_e <= 0: + return x + a_d = self.smoothing_factor(t_e, self.d_cutoff) + dx = (x - self.x_prev) / t_e + dx_hat = self.exponential_smoothing(a_d, dx, self.dx_prev) + + cutoff = self.min_cutoff + self.beta * abs(dx_hat) + a = self.smoothing_factor(t_e, cutoff) + x_hat = self.exponential_smoothing(a, x, self.x_prev) + + self.x_prev = x_hat + self.dx_prev = dx_hat + self.t_prev = t + + return x_hat + + +@register_processor +class ExampleProcessorSocketCalculateMousePose(BaseProcessorSocket): # pragma: no cover + """ + DLC Processor with pose calculations (center, heading, head angle) and optional filtering. + + Calculates: + - center: Weighted average of head keypoints + - heading: Body orientation (degrees) + - head_angle: Head rotation relative to body (radians) + + Broadcasts: [timestamp, center_x, center_y, heading, head_angle] + """ + + PROCESSOR_NAME = "Example Experiment Pose Processor" + PROCESSOR_DESCRIPTION = "Calculates mouse center, heading, and head angle with optional One-Euro filtering" + PROCESSOR_PARAMS = { + "bind": { + "type": "tuple", + "default": ("127.0.0.1", 6000), + "description": "Server address (host, port)", + }, + "authkey": { + "type": "bytes", + "default": b"secret password", + "description": "Authentication key for clients", + }, + "use_perf_counter": { + "type": "bool", + "default": False, + "description": "Use time.perf_counter() instead of time.time()", + }, + "use_filter": { + "type": "bool", + "default": False, + "description": "Apply One-Euro filter to calculated values", + }, + "filter_kwargs": { + "type": "dict", + "default": {"min_cutoff": 1.0, "beta": 0.02, "d_cutoff": 1.0}, + "description": "One-Euro filter parameters (min_cutoff, beta, d_cutoff)", + }, + "save_original": { + "type": "bool", + "default": False, + "description": "Save raw pose arrays for analysis", + }, + } + + def __init__( + self, + bind=("127.0.0.1", 6000), + authkey=b"secret password", + use_perf_counter=False, + use_filter=False, + filter_kwargs: dict | None = None, + save_original=False, + ): + super().__init__( + bind=bind, + authkey=authkey, + use_perf_counter=use_perf_counter, + save_original=save_original, + ) + + self.center_x = deque() + self.center_y = deque() + self.heading_direction = deque() + self.head_angle = deque() + + self.use_filter = use_filter + self.filter_kwargs = filter_kwargs if filter_kwargs is not None else {} + self.filters = None + + def _clear_data_queues(self): + super()._clear_data_queues() + self.center_x.clear() + self.center_y.clear() + self.heading_direction.clear() + self.head_angle.clear() + + def _initialize_filters(self, vals): + t0 = self.timing_func() + self.filters = { + "center_x": OneEuroFilter(t0, vals[0], **self.filter_kwargs), + "center_y": OneEuroFilter(t0, vals[1], **self.filter_kwargs), + "heading": OneEuroFilter(t0, vals[2], **self.filter_kwargs), + "head_angle": OneEuroFilter(t0, vals[3], **self.filter_kwargs), + } + logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}") + + def process(self, pose, **kwargs): + # Extract keypoints and confidence + xy = pose[:, :2] + conf = pose[:, 2] + + # Calculate weighted center from head keypoints + head_xy = xy[[0, 1, 2, 3, 4, 5, 6, 26], :] + head_conf = conf[[0, 1, 2, 3, 4, 5, 6, 26]] + try: + center = np.average(head_xy, axis=0, weights=head_conf) + except ZeroDivisionError: + center = np.zeros(2) + + # Calculate body axis (tail_base -> neck) + body_axis = xy[7] - xy[13] + body_axis /= sqrt(np.sum(body_axis**2)) + + # Calculate head axis (neck -> nose) + head_axis = xy[0] - xy[7] + head_axis /= sqrt(np.sum(head_axis**2)) + + # Calculate head angle relative to body + cross = body_axis[0] * head_axis[1] - head_axis[0] * body_axis[1] + sign = copysign(1, cross) # Positive when looking left + + try: + head_angle = acos(body_axis @ head_axis) * sign + except ValueError: + head_angle = 0 + + # Calculate heading (body orientation) + heading = degrees(atan2(body_axis[1], body_axis[0])) + + # Raw values (heading unwrapped for filtering) + vals = [center[0], center[1], heading, head_angle] + + # Apply filtering if enabled + curr_time = self.timing_func() + if self.use_filter: + if self.filters is None: + self._initialize_filters(vals) + + vals = [ + self.filters["center_x"](curr_time, vals[0]), + self.filters["center_y"](curr_time, vals[1]), + self.filters["heading"](curr_time, vals[2]), + self.filters["head_angle"](curr_time, vals[3]), + ] + + # Wrap heading to [0, 360) after filtering + vals[2] = vals[2] % 360 + # Update step counter + self.curr_step = self.curr_step + 1 + + # Store processed data (only if recording) + if self.recording: + if self.save_original and self.original_pose is not None: + self.original_pose.append(pose.copy()) + self.center_x.append(vals[0]) + self.center_y.append(vals[1]) + self.heading_direction.append(vals[2]) + self.head_angle.append(vals[3]) + self.time_stamp.append(curr_time) + self.step.append(self.curr_step) + self.frame_time.append(kwargs.get("frame_time", -1)) + if "pose_time" in kwargs: + self.pose_time.append(kwargs["pose_time"]) + + payload = [curr_time, vals[0], vals[1], vals[2], vals[3]] + self.broadcast(payload) + return pose + + def get_data(self): + save_dict = super().get_data() + save_dict["x_pos"] = np.array(self.center_x) + save_dict["y_pos"] = np.array(self.center_y) + save_dict["heading_direction"] = np.array(self.heading_direction) + save_dict["head_angle"] = np.array(self.head_angle) + save_dict["use_filter"] = self.use_filter + save_dict["filter_kwargs"] = self.filter_kwargs + return save_dict + + +@register_processor +class ExampleProcessorSocketFilterKeypoints(BaseProcessorSocket): # pragma: no cover + PROCESSOR_NAME = "Mouse Pose with less keypoints" + PROCESSOR_DESCRIPTION = "Calculates mouse center, heading, and head angle with optional One-Euro filtering" + PROCESSOR_PARAMS = { + "bind": { + "type": "tuple", + "default": ("127.0.0.1", 6000), + "description": "Server address (host, port)", + }, + "authkey": { + "type": "bytes", + "default": b"secret password", + "description": "Authentication key for clients", + }, + "use_perf_counter": { + "type": "bool", + "default": False, + "description": "Use time.perf_counter() instead of time.time()", + }, + "use_filter": { + "type": "bool", + "default": False, + "description": "Apply One-Euro filter to calculated values", + }, + "filter_kwargs": { + "type": "dict", + "default": {"min_cutoff": 1.0, "beta": 0.02, "d_cutoff": 1.0}, + "description": "One-Euro filter parameters (min_cutoff, beta, d_cutoff)", + }, + "save_original": { + "type": "bool", + "default": True, + "description": "Save raw pose arrays for analysis", + }, + } + + def __init__( + self, + bind=("127.0.0.1", 6000), + authkey=b"secret password", + use_perf_counter=False, + use_filter=False, + filter_kwargs: dict | None = None, + save_original=True, + p_cutoff=0.4, + ): + super().__init__( + bind=bind, + authkey=authkey, + use_perf_counter=use_perf_counter, + save_original=save_original, + ) + + self.center_x = deque() + self.center_y = deque() + self.heading_direction = deque() + self.head_angle = deque() + + self.p_cutoff = p_cutoff + + self.use_filter = use_filter + self.filter_kwargs = filter_kwargs if filter_kwargs is not None else {} + self.filters = None + + def _clear_data_queues(self): + super()._clear_data_queues() + self.center_x.clear() + self.center_y.clear() + self.heading_direction.clear() + self.head_angle.clear() + + def _initialize_filters(self, vals): + t0 = self.timing_func() + self.filters = { + "center_x": OneEuroFilter(t0, vals[0], **self.filter_kwargs), + "center_y": OneEuroFilter(t0, vals[1], **self.filter_kwargs), + "heading": OneEuroFilter(t0, vals[2], **self.filter_kwargs), + "head_angle": OneEuroFilter(t0, vals[3], **self.filter_kwargs), + } + logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}") + + def process(self, pose, **kwargs): + # Extract keypoints and confidence + xy = pose[:, :2] + conf = pose[:, 2] + + # Calculate weighted center from head keypoints + head_xy = xy[[0, 1, 2, 3, 5, 6, 7], :] + head_conf = conf[[0, 1, 2, 3, 5, 6, 7]] + # set low confidence keypoints to zero weight + head_conf = np.where(head_conf < self.p_cutoff, 0, head_conf) + try: + center = np.average(head_xy, axis=0, weights=head_conf) + except ZeroDivisionError: + # If all keypoints have zero weight, return without processing + return pose + + neck = np.average(xy[[2, 3, 6, 7], :], axis=0, weights=conf[[2, 3, 6, 7]]) + + # Calculate body axis (tail_base -> neck) + body_axis = neck - xy[9] + body_axis /= sqrt(np.sum(body_axis**2)) + + # Calculate head axis (neck -> nose) + head_axis = xy[0] - neck + head_axis /= sqrt(np.sum(head_axis**2)) + + # Calculate head angle relative to body + cross = body_axis[0] * head_axis[1] - head_axis[0] * body_axis[1] + sign = copysign(1, cross) # Positive when looking left + + try: + head_angle = acos(body_axis @ head_axis) * sign + except ValueError: + head_angle = 0 + + # Calculate heading (body orientation) + heading = degrees(atan2(body_axis[1], body_axis[0])) + vals = [center[0], center[1], heading, head_angle] + + curr_time = self.timing_func() + if self.use_filter: + if self.filters is None: + self._initialize_filters(vals) + + vals = [ + self.filters["center_x"](curr_time, vals[0]), + self.filters["center_y"](curr_time, vals[1]), + self.filters["heading"](curr_time, vals[2]), + self.filters["head_angle"](curr_time, vals[3]), + ] + + # Wrap heading to [0, 360) after filtering + vals[2] = vals[2] % 360 + # Update step counter + self.curr_step = self.curr_step + 1 + + # Store processed data (only if recording) + if self.recording: + if self.save_original and self.original_pose is not None: + self.original_pose.append(pose.copy()) + self.center_x.append(vals[0]) + self.center_y.append(vals[1]) + self.heading_direction.append(vals[2]) + self.head_angle.append(vals[3]) + self.time_stamp.append(curr_time) + self.step.append(self.curr_step) + self.frame_time.append(kwargs.get("frame_time", -1)) + if "pose_time" in kwargs: + self.pose_time.append(kwargs["pose_time"]) + + payload = [curr_time, vals[0], vals[1], vals[2], vals[3]] + self.broadcast(payload) + return pose + + def get_data(self): + save_dict = super().get_data() + save_dict["x_pos"] = np.array(self.center_x) + save_dict["y_pos"] = np.array(self.center_y) + save_dict["heading_direction"] = np.array(self.heading_direction) + save_dict["head_angle"] = np.array(self.head_angle) + save_dict["use_filter"] = self.use_filter + save_dict["filter_kwargs"] = self.filter_kwargs + return save_dict diff --git a/dlclivegui/processors/processor_utils.py b/dlclivegui/processors/processor_utils.py index b32445c..467792b 100644 --- a/dlclivegui/processors/processor_utils.py +++ b/dlclivegui/processors/processor_utils.py @@ -17,6 +17,65 @@ def default_processors_dir() -> str: return str(path) +def _processor_base_class(): + from dlclive.processor import Processor + + return Processor + + +def _is_processor_subclass(obj, *, include_base: bool = False) -> bool: + """Return True for dlclive.Processor subclasses, including indirect subclasses.""" + if not inspect.isclass(obj): + return False + + try: + processor_base = _processor_base_class() + except Exception: + logger.exception("Could not import dlclive.Processor") + return False + + try: + if obj is processor_base: + return bool(include_base) + return issubclass(obj, processor_base) + except Exception: + logger.exception(f"Error checking if {obj} is a subclass of dlclive.Processor") + return False + + +def _processor_info_from_class(cls, fallback_name: str) -> dict: + return { + "class": cls, + "name": getattr(cls, "PROCESSOR_NAME", fallback_name), + "description": getattr(cls, "PROCESSOR_DESCRIPTION", ""), + "params": getattr(cls, "PROCESSOR_PARAMS", {}), + } + + +def discover_processor_classes(module, *, only_defined_in_module: bool = True) -> dict[str, dict]: + """Discover dlclive.Processor subclasses in a module. + + Includes indirect subclasses of Processor. + + Args: + module: Imported Python module. + only_defined_in_module: If True, ignore Processor subclasses imported + from other modules to avoid duplicate registry entries. + """ + processors: dict[str, dict] = {} + + for name, obj in inspect.getmembers(module, inspect.isclass): + if only_defined_in_module and getattr(obj, "__module__", None) != module.__name__: + continue + + if not _is_processor_subclass(obj): + continue + + processors[name] = _processor_info_from_class(obj, name) + + return processors + + def scan_processor_folder(folder_path): all_processors = {} folder = Path(folder_path) @@ -39,11 +98,9 @@ def scan_processor_folder(folder_path): return all_processors -def scan_processor_package(package_name: str = "dlclivegui.processors") -> dict[str | dict]: +def scan_processor_package(package_name: str = "dlclivegui.processors") -> dict[str, dict]: """ Discover and load processor classes from a package namespace. - Returns a dict keyed as 'module.py::ClassName' with the same - structure you use today. """ all_processors: dict[str, dict] = {} @@ -59,28 +116,16 @@ def scan_processor_package(package_name: str = "dlclivegui.processors") -> dict[ continue try: mod = import_module(mod_name) + # Skip dlc_processor_socket.py as it's the base class and registry + if mod.__name__.endswith("dlc_processor_socket"): + continue # Prefer module-level registry function if present if hasattr(mod, "get_available_processors"): processors = mod.get_available_processors() else: # Fallback: scan for dlclive.Processor subclasses - from dlclive import Processor - - processors = {} - for attr_name in dir(mod): - obj = getattr(mod, attr_name) - try: - if isinstance(obj, type) and obj is not Processor and issubclass(obj, Processor): - processors[attr_name] = { - "class": obj, - "name": getattr(obj, "PROCESSOR_NAME", attr_name), - "description": getattr(obj, "PROCESSOR_DESCRIPTION", ""), - "params": getattr(obj, "PROCESSOR_PARAMS", {}), - } - except Exception: - # Non-class or weird metaclass; ignore - pass + processors = discover_processor_classes(mod) # Normalize into your “file::class” shape module_file = mod.__name__.split(".")[-1] + ".py" @@ -131,26 +176,7 @@ def load_processors_from_file(file_path: str | Path): return processors # Fallback path: discover subclasses of dlclive.Processor - from dlclive import Processor - - processors: dict[str, dict] = {} - for name, obj in inspect.getmembers(module, inspect.isclass): - if obj is Processor: - continue - # Guard: module might define other classes; only include Processor subclasses - try: - if issubclass(obj, Processor): - processors[name] = { - "class": obj, - "name": getattr(obj, "PROCESSOR_NAME", name), - "description": getattr(obj, "PROCESSOR_DESCRIPTION", ""), - "params": getattr(obj, "PROCESSOR_PARAMS", {}), - } - except Exception: - # Some "classes" can fail issubclass checks; ignore safely - continue - - return processors + return discover_processor_classes(module) except Exception: # Full traceback helps a ton when a plugin fails to import diff --git a/dlclivegui/processors/registry.py b/dlclivegui/processors/registry.py new file mode 100644 index 0000000..2889297 --- /dev/null +++ b/dlclivegui/processors/registry.py @@ -0,0 +1,53 @@ +import logging + +logger = logging.getLogger(__name__) + +# Registry for GUI discovery +PROCESSOR_REGISTRY = {} + + +def register_processor(cls): + registry_key = getattr(cls, "PROCESSOR_ID", cls.__name__) + if registry_key in PROCESSOR_REGISTRY: + msg = ( + f"Duplicate processor registration key '{registry_key}': " + f"{PROCESSOR_REGISTRY[registry_key].__name__} vs {cls.__name__}" + ) + logger.warning(msg) + PROCESSOR_REGISTRY[registry_key] = cls + return cls + + +def get_available_processors(): + """ + Get list of available processor classes. + + Returns: + dict: Dictionary mapping registry keys to processor info. + """ + return { + name: { + "class": cls, + "name": getattr(cls, "PROCESSOR_NAME", name), + "description": getattr(cls, "PROCESSOR_DESCRIPTION", ""), + "params": getattr(cls, "PROCESSOR_PARAMS", {}), + } + for name, cls in PROCESSOR_REGISTRY.items() + } + + +def instantiate_processor(class_name, **kwargs): + """ + Instantiate a processor by class name with given parameters. + + Args: + class_name: Registry key (e.g., "MyProcessorSocket") + **kwargs: Constructor kwargs + + Raises: + ValueError: If class_name is not in registry + """ + if class_name not in PROCESSOR_REGISTRY: + available = ", ".join(PROCESSOR_REGISTRY.keys()) + raise ValueError(f"Unknown processor '{class_name}'. Available: {available}") + return PROCESSOR_REGISTRY[class_name](**kwargs) diff --git a/dlclivegui/temp/engine.py b/dlclivegui/temp/engine.py index a6bb225..85c4755 100644 --- a/dlclivegui/temp/engine.py +++ b/dlclivegui/temp/engine.py @@ -6,7 +6,7 @@ # or if we update dlclive.Engine to have these methods and use that instead of a separate enum here. # The latter would be more cohesive but also creates a dependency from utils to dlclive, # pending release of dlclive -class Engine(Enum): +class Engine(str, Enum): TENSORFLOW = "tensorflow" PYTORCH = "pytorch" @@ -26,6 +26,12 @@ def is_tensorflow_model_dir_path(model_path: str | Path) -> bool: @classmethod def from_model_type(cls, model_type: str) -> "Engine": + if not isinstance(model_type, str): + try: + model_type = getattr(model_type, "value", str(model_type)) + except Exception as e: + raise ValueError(f"Could not convert model_type to string: {model_type}") from e + if model_type.lower() == "pytorch": return cls.PYTORCH elif model_type.lower() in ("tensorflow", "base", "tensorrt", "lite"): diff --git a/dlclivegui/utils/settings_store.py b/dlclivegui/utils/settings_store.py index a0c5677..0107afb 100644 --- a/dlclivegui/utils/settings_store.py +++ b/dlclivegui/utils/settings_store.py @@ -57,6 +57,42 @@ def get_fast_encoding(self, default: bool = False) -> bool: return value return str(value).strip().lower() in {"1", "true", "yes", "on"} + def get_processor_folder(self, default: str = "") -> str: + """ + Return the persisted processor folder if it still exists and is a directory. + Otherwise return default. + """ + value = self._s.value("dlc/processor_folder", default) + value = str(value).strip() if value is not None else "" + + if not value: + return default + + try: + path = Path(value).expanduser() + if path.is_dir(): + return str(path.resolve()) + except Exception: + logger.debug("Persisted processor folder is invalid: %s", value, exc_info=True) + + return default + + def set_processor_folder(self, folder: str) -> None: + """ + Persist processor folder only if it exists and is a directory. + Invalid folders are ignored. + """ + folder = str(folder).strip() if folder is not None else "" + if not folder: + return + + try: + path = Path(folder).expanduser() + if path.is_dir(): + self._s.setValue("dlc/processor_folder", str(path.resolve())) + except Exception: + logger.debug("Failed to persist processor folder: %s", folder, exc_info=True) + def set_fast_encoding(self, enabled: bool) -> None: self._s.setValue("recording/fast_encoding", bool(enabled)) diff --git a/tests/custom_processors/test_base_processor.py b/tests/custom_processors/test_base_processor.py index d38749b..94dabab 100644 --- a/tests/custom_processors/test_base_processor.py +++ b/tests/custom_processors/test_base_processor.py @@ -13,15 +13,21 @@ def _mock_dlclive(monkeypatch): - """Provide a dummy dlclive.Processor so the module can import in tests.""" - fake = types.ModuleType("dlclive") - class Processor: def __init__(self, *args, **kwargs): pass - fake.Processor = Processor - monkeypatch.setitem(sys.modules, "dlclive", fake) + def process(self, pose, **kwargs): + return pose + + dlclive_mod = types.ModuleType("dlclive") + processor_mod = types.ModuleType("dlclive.processor") + + dlclive_mod.Processor = Processor + processor_mod.Processor = Processor + + monkeypatch.setitem(sys.modules, "dlclive", dlclive_mod) + monkeypatch.setitem(sys.modules, "dlclive.processor", processor_mod) @pytest.fixture @@ -37,6 +43,19 @@ def socket_mod(monkeypatch): return importlib.import_module(mod_name) +@pytest.fixture +def example_processor_mod(monkeypatch): + """ + Import the example processor module with dlclive mocked. + Adjust module name if your file lives elsewhere. + """ + _mock_dlclive(monkeypatch) + mod_name = "dlclivegui.processors.examples" + if mod_name in sys.modules: + del sys.modules[mod_name] + return importlib.import_module(mod_name) + + def _module_data_dir(socket_mod) -> Path: """Compute the data/ directory where save() writes artifacts.""" return Path(socket_mod.__file__).parent.parent.parent / "data" @@ -233,12 +252,14 @@ def test_save_ignores_pre_recording_original_pose_frames(socket_mod): ("ExampleProcessorSocketFilterKeypoints", 10), ], ) -def test_subclass_save_ignores_pre_recording_original_pose_frames(socket_mod, class_name, n_keypoints): +def test_subclass_save_ignores_pre_recording_original_pose_frames( + socket_mod, example_processor_mod, class_name, n_keypoints +): """ Concrete processors must keep original_pose aligned with recorded metadata even when process() is called before recording starts. """ - processor_class = getattr(socket_mod, class_name) + processor_class = getattr(example_processor_mod, class_name) proc = processor_class(bind=("127.0.0.1", 0), save_original=True) try: