#!/usr/bin/env python3
"""
Barionet M44 TCP Command API Server
Production-grade TCP server with command handling, IO mapping, and state subscriptions.
"""

import json
import logging
import logging.handlers
import os
import re
import signal
import socket
import subprocess
import sys
import threading
import time
import struct
from typing import Optional, Set, Dict

# ---------------------------------------------------------------------------
# Constants & IO Address Definitions
# ---------------------------------------------------------------------------

CONFIG_FILE = "config.json"
DEFAULT_CONFIG_FILE = "default_config.json"

MAX_MSG_LEN = 256
STARTUP_DELAY_S = 3
IOMAPPING_RETRY_INTERVAL_S = 1
IOMAPPING_RETRY_MAX_S = 10

# UX8 detection addresses
UX8_DETECT_ADDRS = [60007, 60008, 60009, 60010]

# Built-in local relay addresses (for LocalIO subscription)
LOCAL_RELAY_ADDRS = [1, 2, 3, 4]
# Built-in local digital input addresses (for LocalIO subscription)
LOCAL_DI_ADDRS = [201, 202, 203, 204]

# All writable 1-bit address ranges (baseline, without UX8 context)
# Format: list of (start, end) inclusive
_WRITABLE_1BIT_BASE = [
    (1, 10),        # Relays 1-4, 5-8 (reserved), RS232 RTS, Virtual IO 10
    (43, 200),      # Virtual IO bits (incl. 101-108 digital outputs, 109-200 virtual)
    (210, 210),     # Virtual IO 210
    (243, 300),     # Virtual IO 243-300
    (309, 400),     # Virtual IO 309-400
    (1207, 1207),   # USB Control
]

# Writable multi-bit address ranges: (start, end, bit_size)
_WRITABLE_MULTIBIT = [
    (401, 410, 32),   # DI Counters + Virtual 32bit
    (443, 500, 32),   # Virtual 32bit registers
    (509, 510, 16),   # Virtual 16bit registers
    (543, 600, 16),   # Virtual IO 16bit registers
    (751, 1200, 16),  # Virtual IO 16bit registers
    (1208, 1211, 4),  # LED color/brightness
    (1212, 1244, 1),  # Pull-ups for UX8 (1-bit, but plain write only)
]

# UX8 relay/input address blocks
UX8_RELAY_BLOCKS = [
    (11, 18),   # UX8 #1 relays
    (19, 26),   # UX8 #2 relays
    (27, 34),   # UX8 #3 relays
    (35, 42),   # UX8 #4 relays
]
UX8_INPUT_BLOCKS = [
    (211, 218),  # UX8 #1 digital inputs
    (219, 226),  # UX8 #2 digital inputs
    (227, 234),  # UX8 #3 digital inputs
    (235, 242),  # UX8 #4 digital inputs
]
UX8_COUNTER_BLOCKS = [
    (411, 418),  # UX8 #1 DI counters
    (419, 426),  # UX8 #2 DI counters
    (427, 434),  # UX8 #3 DI counters
    (435, 442),  # UX8 #4 DI counters
]

# All readable addresses (complete table)
_READABLE_RANGES = [
    (1, 400),
    (401, 500),    # counters + virtual 32bit
    (501, 600),    # analog inputs (read-only) + virtual 16bit
    (601, 750),    # temperature + sensor addresses (read-only)
    (751, 1211),   # virtual + device info (read-only 1201-1206, 1203-1206)
    (1207, 1244),  # USB control + LED + pullups
    (60001, 60010),
]

# Addresses that are strictly read-only (no write allowed regardless of UX8)
_READ_ONLY_RANGES = [
    (201, 209),    # Digital inputs 201-204, 205-208 reserved, 209 RS232 CTS
    (501, 508),    # Analog inputs (read-only)
    (511, 542),    # UX8 analog inputs (read-only)
    (601, 750),    # Temperature sensors (read-only)
    (1201, 1206),  # Device info (read-only)
    (60001, 60010),# System info (read-only)
]

# Addresses where timed/toggle apply (writable 1-bit, excluding pullup multi-write 1212-1244)
_TIMED_TOGGLE_1BIT_BASE = [
    (1, 10),
    (43, 200),
    (210, 210),
    (243, 308),
    (309, 400),
    (1207, 1207),
]

# ---------------------------------------------------------------------------
# Logging setup
# ---------------------------------------------------------------------------

logger = logging.getLogger("tcp_api")
logger.setLevel(logging.DEBUG)

_console_handler = logging.StreamHandler(sys.stdout)
_console_handler.setFormatter(logging.Formatter(
    "%(asctime)s [%(levelname)s] %(message)s"
))
logger.addHandler(_console_handler)


def setup_syslog(address: str) -> None:
    """Configure syslog handler with LOCAL0 facility."""
    try:
        host, port_str = address.rsplit(":", 1)
        port = int(port_str)
        handler = logging.handlers.SysLogHandler(
            address=(host, port),
            facility=logging.handlers.SysLogHandler.LOG_LOCAL0
        )
        handler.setFormatter(logging.Formatter(
            "tcp_api: %(levelname)s %(message)s"
        ))
        logger.addHandler(handler)
        logger.info("Syslog configured: %s:%d", host, port)
    except Exception as exc:
        logger.error("Failed to configure syslog (%s): %s", address, exc)


# ---------------------------------------------------------------------------
# Config loader
# ---------------------------------------------------------------------------

def load_config() -> dict:
    for fname in (CONFIG_FILE, DEFAULT_CONFIG_FILE):
        if os.path.exists(fname):
            try:
                with open(fname, "r") as fh:
                    data = json.load(fh)
                logger.info("Loaded config from %s", fname)
                return data.get("AppParam", {})
            except Exception as exc:
                logger.error("Error reading %s: %s", fname, exc)
    logger.warning("No config file found; using built-in defaults")
    return {}


# ---------------------------------------------------------------------------
# IO Address helpers
# ---------------------------------------------------------------------------

def _in_ranges(addr: int, ranges) -> bool:
    return any(lo <= addr <= hi for lo, hi in ranges)


def _addr_bit_size(addr: int, ux8_count: int) -> Optional[int]:
    """Return bit size for a given address, or None if address is unknown."""
    # 1-bit writable base
    if _in_ranges(addr, _WRITABLE_1BIT_BASE):
        return 1
    # UX8 relays
    for i, (lo, hi) in enumerate(UX8_RELAY_BLOCKS):
        if lo <= addr <= hi:
            return 1
    # UX8 digital inputs (read-only when UX8 present, 1-bit virtual when not)
    for i, (lo, hi) in enumerate(UX8_INPUT_BLOCKS):
        if lo <= addr <= hi:
            return 1
    # UX8 counters
    for i, (lo, hi) in enumerate(UX8_COUNTER_BLOCKS):
        if lo <= addr <= hi:
            return 32
    # Multi-bit writable
    for lo, hi, bits in _WRITABLE_MULTIBIT:
        if lo <= addr <= hi:
            return bits
    # Read-only ranges (still need bit size for getio)
    if 201 <= addr <= 209:
        return 1
    if 501 <= addr <= 542:
        return 16
    if 601 <= addr <= 650:
        return 16
    if 651 <= addr <= 750:
        return 32
    if 1201 <= addr <= 1203:
        return 16
    if addr == 1204:
        return 32
    if 1205 <= addr <= 1206:
        return 16
    if 60001 <= addr <= 60006:
        return 16
    if 60007 <= addr <= 60010:
        return 1
    return None


def _is_readable(addr: int, ux8_count: int) -> bool:
    """Check if address is readable."""
    return _addr_bit_size(addr, ux8_count) is not None


def _is_writable(addr: int, ux8_count: int) -> bool:
    """Check if address is writable, considering UX8 state."""
    # Strictly read-only ranges
    if _in_ranges(addr, _READ_ONLY_RANGES):
        return False
    # UX8 relay blocks
    for i, (lo, hi) in enumerate(UX8_RELAY_BLOCKS):
        if lo <= addr <= hi:
            return True  # always writable (UX8 relay or virtual IO)
    # UX8 digital input blocks: writable only if NO UX8 for that slot
    for i, (lo, hi) in enumerate(UX8_INPUT_BLOCKS):
        if lo <= addr <= hi:
            return (i + 1) > ux8_count  # virtual if no UX8
    # UX8 counter blocks
    for i, (lo, hi) in enumerate(UX8_COUNTER_BLOCKS):
        if lo <= addr <= hi:
            return True  # always writable (UX8 counter or virtual register)
    # Pull-up addresses for UX8: only if UX8 connected
    if 1212 <= addr <= 1244:
        slot = (addr - 1212) // 8  # 0-indexed UX8 slot
        return (slot + 1) <= ux8_count
    # Base writable 1-bit
    if _in_ranges(addr, _WRITABLE_1BIT_BASE):
        return True
    # Multi-bit writable (excluding 501-508 and 511-542 analog inputs which are read-only)
    for lo, hi, _ in _WRITABLE_MULTIBIT:
        if lo <= addr <= hi:
            if (501 <= addr <= 508) or (511 <= addr <= 542):
                return False
            return True
    return False


def _is_1bit_timed_toggle(addr: int, ux8_count: int) -> bool:
    """Return True if this address supports timed/toggle setio functions."""
    if _in_ranges(addr, _TIMED_TOGGLE_1BIT_BASE):
        return True
    # UX8 relay blocks (always 1-bit)
    for lo, hi in UX8_RELAY_BLOCKS:
        if lo <= addr <= hi:
            return True
    # UX8 digital input blocks as virtual IOs (when no UX8)
    for i, (lo, hi) in enumerate(UX8_INPUT_BLOCKS):
        if lo <= addr <= hi and (i + 1) > ux8_count:
            return True
    return False


def _max_value_for_addr(addr: int, ux8_count: int) -> int:
    bits = _addr_bit_size(addr, ux8_count) or 1
    return (2 ** bits) - 1


def _get_all_writable_1bit_addrs(ux8_count: int) -> list:
    """Return sorted list of all writable 1-bit addresses."""
    addrs = set()
    for lo, hi in _WRITABLE_1BIT_BASE:
        addrs.update(range(lo, hi + 1))
    # UX8 relays always writable
    for lo, hi in UX8_RELAY_BLOCKS:
        addrs.update(range(lo, hi + 1))
    # UX8 digital inputs: writable only if virtual (no UX8)
    for i, (lo, hi) in enumerate(UX8_INPUT_BLOCKS):
        if (i + 1) > ux8_count:
            addrs.update(range(lo, hi + 1))
    addrs -= set(range(5, 9))     # exclude addresses reserved for future use from initialization        
    addrs -= set(range(101, 201)) # exclude addresses reserved for future use from initialization
    addrs.discard(1207)           # exclude USB power from initialization
    return sorted(addrs)


# ---------------------------------------------------------------------------
# IoMapping wrapper with retry
# ---------------------------------------------------------------------------

class IoMappingService:
    """Thread-safe wrapper around IoMapping with startup retry."""

    def __init__(self):
        self._service = None
        self._lock = threading.Lock()
        self._ready = False

    def wait_until_ready(self, timeout_s: int = IOMAPPING_RETRY_MAX_S) -> bool:
        """Block until IoMapping is available or timeout."""
        deadline = time.time() + timeout_s
        attempt = 0
        while time.time() < deadline:
            attempt += 1
            try:
                from iomapping import IoMapping  # type: ignore
                svc = IoMapping()
                alive = svc.is_alive()
                if alive:
                    with self._lock:
                        self._service = svc
                        self._ready = True
                    logger.info("IoMapping service ready (attempt %d)", attempt)
                    return True
                else:
                    logger.warning("IoMapping not alive yet (attempt %d), retrying...", attempt)
            except Exception as exc:
                logger.warning("IoMapping init error (attempt %d): %s", attempt, exc)
            time.sleep(IOMAPPING_RETRY_INTERVAL_S)
        logger.error("IoMapping service not available after %ds; continuing without it", timeout_s)
        return False

    def read_value(self, addr: int) -> Optional[int]:
        with self._lock:
            if not self._ready or self._service is None:
                return None
            try:
                return self._service.read_value(addr)
            except Exception as exc:
                logger.error("Error reading address %d: %s", addr, exc)
                return None

    def write_value(self, addr: int, value: int) -> bool:
        with self._lock:
            if not self._ready or self._service is None:
                return False
            try:
                self._service.write_value(addr, value)
                return True
            except Exception as exc:
                logger.error("Error writing address %d value %d: %s", addr, value, exc)
                return False

    def read_values(self, addrs: list) -> Optional[list]:
        with self._lock:
            if not self._ready or self._service is None:
                return None
            try:
                return self._service.read_values(addrs)
            except Exception as exc:
                logger.error("Error reading addresses %s: %s", addrs, exc)
                return None

    def enable_notifications(self, addr: int, callback) -> bool:
        with self._lock:
            if not self._ready or self._service is None:
                return False
            try:
                self._service.enable_notifications(addr, callback)
                return True
            except Exception as exc:
                logger.error("Error enabling notifications for %d: %s", addr, exc)
                return False

    def enable_notifications_range(self, addrs: list, callback) -> bool:
        with self._lock:
            if not self._ready or self._service is None:
                return False
            try:
                self._service.enable_notifications_range(addrs, callback)
                return True
            except Exception as exc:
                logger.error("Error enabling notifications for range %s: %s", addrs, exc)
                return False

    def disable_notifications(self, addr: int) -> None:
        with self._lock:
            if not self._ready or self._service is None:
                return
            try:
                self._service.disable_notifications(addr)
            except Exception as exc:
                logger.error("Error disabling notifications for %d: %s", addr, exc)

    def disable_notifications_range(self, addrs: list) -> None:
        with self._lock:
            if not self._ready or self._service is None:
                return
            try:
                self._service.disable_notifications_range(addrs)
            except Exception as exc:
                logger.error("Error disabling notifications for range %s: %s", addrs, exc)

    def run_once(self) -> None:
        """Call service.run(False) to process pending notifications without blocking."""
        with self._lock:
            if not self._ready or self._service is None:
                return
            try:
                self._service.run(False)
            except Exception as exc:
                logger.error("IoMapping run error: %s", exc)


# ---------------------------------------------------------------------------
# Timer manager for timed setio
# ---------------------------------------------------------------------------

class TimerManager:
    """Manages per-address timed output resets."""

    def __init__(self, io_svc: IoMappingService):
        self._io = io_svc
        self._timers: Dict[int, threading.Timer] = {}
        self._lock = threading.Lock()

    def schedule(self, addr: int, delay_s: float) -> None:
        """Schedule addr to be set to 0 after delay_s seconds."""
        with self._lock:
            existing = self._timers.get(addr)
            if existing is not None:
                existing.cancel()
            t = threading.Timer(delay_s, self._reset_addr, args=(addr,))
            t.daemon = True
            self._timers[addr] = t
            t.start()

    def cancel(self, addr: int) -> None:
        """Cancel any active timer for addr."""
        with self._lock:
            existing = self._timers.pop(addr, None)
            if existing is not None:
                existing.cancel()

    def cancel_all_and_reset(self) -> None:
        """Cancel all timers and write 0 to each address (graceful shutdown)."""
        with self._lock:
            addrs = list(self._timers.keys())
            for t in self._timers.values():
                t.cancel()
            self._timers.clear()
        for addr in addrs:
            logger.info("Graceful shutdown: resetting timed address %d to 0", addr)
            self._io.write_value(addr, 0)

    def _reset_addr(self, addr: int) -> None:
        with self._lock:
            self._timers.pop(addr, None)
        logger.debug("Timer expired: writing 0 to address %d", addr)
        self._io.write_value(addr, 0)


# ---------------------------------------------------------------------------
# TCP Subscription / Notification manager
# ---------------------------------------------------------------------------

class SubscriptionManager:
    """Manages statechange notification subscriptions and delivery to TCP peer."""

    def __init__(self, io_svc: IoMappingService, ux8_count: int):
        self._io = io_svc
        self._ux8_count = ux8_count
        self._lock = threading.Lock()
        # Set of addresses currently subscribed for statechange
        self._subscribed: Set[int] = set()
        # Set of relay addresses recently written by setio (suppress statechange)
        self._setio_written: Set[int] = set()
        # TCP send callback (set when connection is active)
        self._send_cb = None
        # Config
        self._initial_sub = "None"
        self._add_sub = "None"

    def configure(self, initial_sub: str, add_sub: str) -> None:
        self._initial_sub = initial_sub
        self._add_sub = add_sub

    def set_send_callback(self, cb) -> None:
        with self._lock:
            self._send_cb = cb

    def clear_send_callback(self) -> None:
        with self._lock:
            self._send_cb = None

    def on_connection_established(self) -> None:
        """Called when TCP peer connects. Sends initial IO dump and sets up subscriptions."""
        if self._initial_sub == "LocalIO":
            self._setup_local_io_subscriptions()
            self._send_initial_dump()

    def on_connection_closed(self) -> None:
        """Called when TCP peer disconnects. Remove dynamic subscriptions."""
        self._teardown_dynamic_subscriptions()
        with self._lock:
            self._setio_written.clear()

    def notify_setio_write(self, addr: int) -> None:
        """Mark that setio just wrote to this relay address."""
        with self._lock:
            self._setio_written.add(addr)

    def clear_setio_flag(self, addr: int) -> None:
        """Clear the setio-written flag for an address."""
        with self._lock:
            self._setio_written.discard(addr)

    def on_getio_or_setio(self, addr: int) -> None:
        """Subscribe addr for statechange if 'With getio/setio' mode."""
        if self._add_sub != "With getio/setio":
            return
        # Only subscribe valid notification addresses
        if not self._is_notifiable(addr):
            return
        with self._lock:
            if addr in self._subscribed:
                return
            self._subscribed.add(addr)
        self._io.enable_notifications(addr, self._on_state_change)

    def _is_notifiable(self, addr: int) -> bool:
        """Check if address is eligible for statechange notifications."""
        # Built-in DI and relays
        if 1 <= addr <= 4 or 201 <= addr <= 204:
            return True
        # UX8 relays 11-42
        if 11 <= addr <= 42:
            return True
        # UX8 digital inputs 211-242 (virtual if no UX8)
        if 211 <= addr <= 242:
            return True
        # Virtual IOs (1-bit writable)
        if _in_ranges(addr, _TIMED_TOGGLE_1BIT_BASE):
            return True
        return False

    def _setup_local_io_subscriptions(self) -> None:
        """Subscribe LocalIO addresses for statechange."""
        addrs = LOCAL_RELAY_ADDRS + LOCAL_DI_ADDRS
        with self._lock:
            new_addrs = [a for a in addrs if a not in self._subscribed]
            self._subscribed.update(new_addrs)
        if new_addrs:
            self._io.enable_notifications_range(new_addrs, self._on_state_change)

    def _teardown_dynamic_subscriptions(self) -> None:
        with self._lock:
            addrs = list(self._subscribed)
            self._subscribed.clear()
        if addrs:
            self._io.disable_notifications_range(addrs)

    def _send_initial_dump(self) -> None:
        """Send initial state dump for LocalIO addresses."""
        all_addrs = LOCAL_RELAY_ADDRS + LOCAL_DI_ADDRS
        values = self._io.read_values(all_addrs)
        if values is None:
            logger.error("Failed to read initial IO state for dump")
            return
        msg = ""
        for addr, val in zip(all_addrs, values):
            msg += f"statechange,{addr},{val}\r"
        self._send(msg)

    def _on_state_change(self, addr: int, value: int) -> None:
        """IoMapping callback for state changes."""
        if self._initial_sub == "LocalIO" and self._add_sub == "None":
            if addr not in LOCAL_DI_ADDRS and addr not in LOCAL_RELAY_ADDRS:
                return

        # For all addresses: suppress if recently written by setio
        with self._lock:
            if addr in self._setio_written:
                self._setio_written.discard(addr)
                return

        self._send(f"statechange,{addr},{value}\r")

    def _send(self, msg: str) -> None:
        with self._lock:
            cb = self._send_cb
        if cb is not None:
            try:
                cb(msg)
            except Exception as exc:
                logger.error("Error sending statechange notification: %s", exc)


# ---------------------------------------------------------------------------
# Command handlers
# ---------------------------------------------------------------------------

def _run_system_cmd(cmd: list, timeout: int = 5) -> Optional[str]:
    """Run a system command and return stdout, or None on error."""
    try:
        result = subprocess.run(
            cmd, capture_output=True, text=True, timeout=timeout
        )
        if result.returncode != 0:
            logger.error("Command %s failed (rc=%d): %s", cmd, result.returncode, result.stderr)
            return None
        return result.stdout.strip()
    except subprocess.TimeoutExpired:
        logger.error("Command %s timed out", cmd)
        return None
    except Exception as exc:
        logger.error("Command %s error: %s", cmd, exc)
        return None


def cmd_version() -> str:
    """Handle 'version' command."""
    spi_out = _run_system_cmd(["qiba-spi-get-info"])
    ver_out = _run_system_cmd(["cat", "/barix/info/VERSION"])

    if spi_out is None or ver_out is None:
        return "command failed\r"

    try:
        data = json.loads(spi_out)
        product_name = data["HW_DEVICE"]["Product_Name"]
        image_name = data["IMAGE"]["Name"]
    except (json.JSONDecodeError, KeyError) as exc:
        logger.error("Failed to parse qiba-spi-get-info output: %s", exc)
        return "version not available\r"

    fw_version = ver_out.strip()
    return f"version,{product_name} {image_name} {fw_version}\r"


def cmd_c65535() -> str:
    """Handle 'c=65535' command."""
    out = _run_system_cmd(["uci", "show", "network.eth0.dhcpname"])
    if out is None:
        return "command failed\r"
    # Expected format: network.eth0.dhcpname='BarionetM44'
    match = re.search(r"network\.eth0\.dhcpname='?([^'\n]+)'?", out)
    if not match:
        logger.error("Unexpected uci output: %s", out)
        return "command failed\r"
    hostname = match.group(1).strip().strip("'")
    return f"<BARIONET><n>{hostname}</n></BARIONET>\r"


def cmd_getio(args: str, io_svc: IoMappingService, ux8_count: int,
              sub_mgr: Optional["SubscriptionManager"]) -> str:
    """Handle 'getio,A' command."""
    parts = args.split(",")
    if len(parts) != 2:
        return "cmderr\r"
    try:
        addr = int(parts[1])
    except ValueError:
        return "cmderr\r"

    if not _is_readable(addr, ux8_count):
        return "cmderr\r"

    value = io_svc.read_value(addr)
    if value is None:
        return "cmderr\r"

    if sub_mgr is not None:
        sub_mgr.on_getio_or_setio(addr)

    return f"state,{addr},{value}\r"


def cmd_setio(args: str, io_svc: IoMappingService, ux8_count: int,
              timer_mgr: TimerManager,
              sub_mgr: Optional["SubscriptionManager"]) -> str:
    """Handle 'setio,A,V' command."""
    parts = args.split(",")
    if len(parts) != 3:
        return "cmderr\r"
    try:
        addr = int(parts[1])
        val = int(parts[2])
    except ValueError:
        return "cmderr\r"

    if not _is_writable(addr, ux8_count):
        return "cmderr\r"

    is_1bit = _is_1bit_timed_toggle(addr, ux8_count)
    max_val = _max_value_for_addr(addr, ux8_count)

    # For 1-bit addresses: handle special values
    if is_1bit:
        if val == 999:
            current = io_svc.read_value(addr)
            if current is None:
                return "cmderr\r"
            new_val = 1 if current == 0 else 0
            if sub_mgr:
                sub_mgr.on_getio_or_setio(addr)    # subscribe first
                sub_mgr.notify_setio_write(addr)
            if not io_svc.write_value(addr, new_val):
                return "cmderr\r"
            return f"state,{addr},{new_val}\r"

        elif (2 <= val <= 998) or (1000 <= val <= 9999):
            # Timed function: write 1 now, reset after val/10 seconds
            delay_s = val / 10.0
            timer_mgr.cancel(addr)  # cancel any existing timer first
            if sub_mgr:
                sub_mgr.on_getio_or_setio(addr)    # 1. subscribe FIRST
                sub_mgr.notify_setio_write(addr)   # 2. flag set — suppress the write-1 callback
            if not io_svc.write_value(addr, 1):    # 3. write 1 — callback suppressed by flag
                return "cmderr\r"
            timer_mgr.schedule(addr, delay_s)      # 4. schedule reset — its callback will NOT be suppressed
            return f"state,{addr},1\r"

        elif val == 0:
            timer_mgr.cancel(addr)
            if sub_mgr:
                sub_mgr.on_getio_or_setio(addr)    # subscribe first
                sub_mgr.notify_setio_write(addr)
            if not io_svc.write_value(addr, 0):
                return "cmderr\r"
            return f"state,{addr},0\r"

        elif val == 1:
            if sub_mgr:
                sub_mgr.on_getio_or_setio(addr)    # subscribe first
                sub_mgr.notify_setio_write(addr)
            if not io_svc.write_value(addr, 1):
                return "cmderr\r"
            return f"state,{addr},1\r"

        else:
            # val >= 10000 or other invalid
            return "cmderr\r"

    else:
        # Multi-bit address: plain write, validate range
        if val < 0 or val > max_val:
            return "cmderr\r"
        if not io_svc.write_value(addr, val):
            return "cmderr\r"
        if sub_mgr:
            sub_mgr.on_getio_or_setio(addr)
        return f"state,{addr},{val}\r"


def cmd_iolist(io_svc: IoMappingService) -> str:
    """Handle 'iolist' command."""
    # Read addresses 60006, 60004, 60005, 60003, 60002
    addrs = [60006, 60004, 60005, 60003, 60002]
    values = io_svc.read_values(addrs)
    if values is None:
        values = [0, 0, 0, 0, 0]

    v60006, v60004, v60005, v60003, v60002 = values

    # Count 1-wire devices
    w1_count = 0
    try:
        w1_path = "/sys/bus/w1/devices"
        if os.path.isdir(w1_path):
            entries = os.listdir(w1_path)
            # Exclude the 'w1_bus_master' entries
            w1_count = sum(1 for e in entries if not e.startswith("w1_bus_master"))
    except Exception as exc:
        logger.error("Error counting 1-wire devices: %s", exc)

    return f"io,{v60006},{v60004},{v60005},{v60003},0,{v60002},{w1_count}\r"


# ---------------------------------------------------------------------------
# Command dispatcher
# ---------------------------------------------------------------------------

VALID_COMMANDS = {"version", "c=65535", "getio", "setio", "iolist"}


def dispatch_single_command(raw_cmd: str, io_svc: IoMappingService,
                            ux8_count: int, timer_mgr: TimerManager,
                            sub_mgr: Optional[SubscriptionManager]) -> str:
    """Dispatch a single command string and return the response (without trailing \\r for concatenation)."""
    cmd = raw_cmd.strip()

    if not cmd:
        return "cmderr"

    # Must be lowercase
    if cmd != cmd.lower():
        return "cmderr"

    # Exact matches
    if cmd == "version":
        return cmd_version().rstrip("\r")
    if cmd == "c=65535":
        return cmd_c65535().rstrip("\r")
    if cmd == "iolist":
        return cmd_iolist(io_svc).rstrip("\r")

    # Prefix matches
    if cmd.startswith("getio,"):
        return cmd_getio(cmd, io_svc, ux8_count, sub_mgr).rstrip("\r")
    if cmd.startswith("setio,"):
        return cmd_setio(cmd, io_svc, ux8_count, timer_mgr, sub_mgr).rstrip("\r")

    return "cmderr"


def process_message(message: str, io_svc: IoMappingService, ux8_count: int,
                    timer_mgr: TimerManager, sub_mgr: Optional[SubscriptionManager],
                    password: str) -> str:
    """
    Process a full received message (possibly concatenated with &).
    Returns full response string with trailing \\r.
    """
    msg = message.strip()
    logger.info("Processing message: %r", msg)

    # Password handling
    if password:
        if msg.startswith("a="):
            amp_idx = msg.find("&")
            if amp_idx == -1:
                logger.warning("Password-protected message missing command part")
                return "operation not allowed\r"
            provided_pw = msg[2:amp_idx]
            if provided_pw != password:
                logger.warning("Invalid password provided")
                return "operation not allowed\r"
            msg = msg[amp_idx + 1:]  # strip password prefix
        else:
            logger.warning("Message received without password prefix")
            return "operation not allowed\r"

    # Split on & for command concatenation
    cmds = msg.split("&")
    responses = []
    for cmd in cmds:
        resp = dispatch_single_command(cmd.strip(), io_svc, ux8_count, timer_mgr, sub_mgr)
        responses.append(resp)

    return "&".join(responses) + "\r"


# ---------------------------------------------------------------------------
# IP allow-list validation
# ---------------------------------------------------------------------------

_IP_LIST_PATTERN = re.compile(
    r'^$|^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}'
    r'(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)'
    r'(\s*,\s*((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}'
    r'(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))*$'
)


def parse_allowed_ips(raw: str) -> Optional[Set[str]]:
    """Parse and validate the allowed_ips config. Returns None if no filter."""
    raw = raw.strip()
    if not raw:
        return None
    if not _IP_LIST_PATTERN.match(raw):
        logger.error("Invalid allowed_ips format: %r — no IP filter applied", raw)
        return None
    return {ip.strip() for ip in raw.split(",") if ip.strip()}


def is_ip_allowed(client_ip: str, allowed: Optional[Set[str]]) -> bool:
    if allowed is None:
        return True
    return client_ip in allowed


# ---------------------------------------------------------------------------
# TCP Server
# ---------------------------------------------------------------------------

class TCPServer:
    def __init__(self, config: dict, io_svc: IoMappingService,
                 ux8_count: int, timer_mgr: TimerManager,
                 sub_mgr: SubscriptionManager):
        self._port = int(config.get("TCP_port", 12301))
        self._timeout = int(config.get("TCP_timeout", 10))
        self._password = config.get("password", "")
        allowed_raw = config.get("allowed_ips", "")
        self._allowed_ips = parse_allowed_ips(allowed_raw)
        self._io = io_svc
        self._ux8_count = ux8_count
        self._timer_mgr = timer_mgr
        self._sub_mgr = sub_mgr
        self._server_sock: Optional[socket.socket] = None
        self._active_conn_lock = threading.Lock()
        self._active_conn: Optional[socket.socket] = None
        self._running = False
        self._thread: Optional[threading.Thread] = None

    def start(self) -> None:
        self._running = True
        self._thread = threading.Thread(target=self._accept_loop, daemon=True, name="TCPServer")
        self._thread.start()
        logger.info("TCP server started on port %d", self._port)

    def stop(self) -> None:
        self._running = False
        if self._server_sock:
            try:
                self._server_sock.close()
            except Exception:
                pass
        with self._active_conn_lock:
            if self._active_conn:
                try:
                    self._active_conn.close()
                except Exception:
                    pass
        if self._thread:
            self._thread.join(timeout=5)

    def _accept_loop(self) -> None:
        try:
            self._server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self._server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            self._server_sock.bind(("0.0.0.0", self._port))
            self._server_sock.listen(1)
            self._server_sock.settimeout(1.0)
        except Exception as exc:
            logger.critical("Failed to bind TCP server on port %d: %s", self._port, exc)
            return

        while self._running:
            try:
                conn, addr = self._server_sock.accept()
            except socket.timeout:
                continue
            except OSError:
                break
            except Exception as exc:
                logger.error("Accept error: %s", exc)
                continue

            client_ip = addr[0]

            # IP allow-list check
            if not is_ip_allowed(client_ip, self._allowed_ips):
                logger.warning("Connection from non-allowed IP %s — dropping", client_ip)
                conn.close()
                continue

            # Single-connection enforcement
            with self._active_conn_lock:
                if self._active_conn is not None:
                    logger.warning("Connection attempt from %s refused: only one connection at a time is supported",client_ip)
                    self._reject_connection(conn, client_ip)
                    continue
                self._active_conn = conn

            logger.info("TCP connection established from %s", client_ip)
            t = threading.Thread(
                target=self._client_thread,
                args=(conn, client_ip),
                daemon=True
            )
            t.start()

    def _send_to_client(self, sock: socket.socket, data: str) -> bool:
        try:
            sock.sendall(data.encode("utf-8", errors="replace"))
            return True
        except Exception as exc:
            logger.error("Send error: %s", exc)
            return False

    def _client_thread(self, conn: socket.socket, client_ip: str) -> None:
        try:
            self._handle_client(conn, client_ip)
        finally:
            with self._active_conn_lock:
                self._active_conn = None

    def _handle_client(self, conn: socket.socket, client_ip: str) -> None:
        if self._timeout > 0:
            conn.settimeout(float(self._timeout))
        else:
            conn.settimeout(None)

        # Register send callback for subscriptions
        def send_cb(msg: str):
            self._send_to_client(conn, msg)

        self._sub_mgr.set_send_callback(send_cb)
        self._sub_mgr.on_connection_established()

        buffer = bytearray()
        try:
            while True:
                try:
                    chunk = conn.recv(256)
                except socket.timeout:
                    logger.info("TCP connection from %s timed out (inactivity)", client_ip)
                    break
                except Exception as exc:
                    logger.error("Recv error from %s: %s", client_ip, exc)
                    break

                if not chunk:
                    logger.info("TCP connection closed by peer %s", client_ip)
                    break

                # Filter invalid bytes, keep only printable ASCII + terminators
                for b in chunk:
                    if b in (0x0A, 0x0D, 0x00):
                        # Terminator — process buffer
                        if buffer:
                            self._process_buffer(buffer, conn, client_ip)
                            buffer.clear()
                    elif 0x20 <= b <= 0x7E:
                        if len(buffer) < MAX_MSG_LEN:
                            buffer.append(b)
                        else:
                            logger.warning("Message from %s exceeded max length (%d bytes); truncating",
                                           client_ip, MAX_MSG_LEN)
                            self._process_buffer(buffer, conn, client_ip)
                            buffer.clear()
                    # else: silently discard invalid byte

        finally:
            logger.info("Closing connection from %s", client_ip)
            self._sub_mgr.on_connection_closed()
            self._sub_mgr.clear_send_callback()
            try:
                conn.close()
            except Exception:
                pass

    def _process_buffer(self, buf: bytearray, conn: socket.socket, client_ip: str) -> None:
        try:
            msg = buf.decode("ascii", errors="replace").strip()
        except Exception as exc:
            logger.error("Decode error: %s", exc)
            self._send_to_client(conn, "cmderr\r")
            return

        if not msg:
            return

        logger.info("Received from %s: %r", client_ip, msg)

        response = process_message(
            msg, self._io, self._ux8_count,
            self._timer_mgr, self._sub_mgr, self._password
        )

        logger.info("Response to %s: %r", client_ip, response)
        self._send_to_client(conn, response)

    def _reject_connection(self, conn: socket.socket, client_ip: str) -> None:
        """Immediately shut down a refused connection."""
        try:
            conn.setsockopt(
                socket.SOL_SOCKET,
                socket.SO_LINGER,
                struct.pack('ii', 1, 0)
            )
        except Exception as exc:
            logger.error("Error setting SO_LINGER on rejected connection from %s: %s", client_ip, exc)
        try:
            conn.shutdown(socket.SHUT_RDWR)
        except Exception:
            pass
        try:
            conn.close()
        except Exception:
            pass
        logger.warning("Rejected connection from %s with RST: connection already active", client_ip)


# ---------------------------------------------------------------------------
# UX8 detection with retry
# ---------------------------------------------------------------------------

def detect_ux8(io_svc: IoMappingService) -> int:
    """
    Detect how many UX8 units are connected.
    Returns count 0-4 with retry mechanism.
    """
    deadline = time.time() + IOMAPPING_RETRY_MAX_S
    attempt = 0

    while time.time() < deadline:
        attempt += 1
        try:
            count = 0
            success = True
            for i, addr in enumerate(UX8_DETECT_ADDRS):
                val = io_svc.read_value(addr)
                if val is None:
                    logger.warning("Error reading UX8 detection address %d (attempt %d)", addr, attempt)
                    success = False
                    break
                if val == 1:
                    count += 1
                else:
                    # UX8 units are contiguous — if slot i is 0, stop counting
                    break
            if success:
                logger.info("UX8 detection complete: %d unit(s) found", count)
                return count
        except Exception as exc:
            logger.warning("UX8 detection error (attempt %d): %s", attempt, exc)

        time.sleep(IOMAPPING_RETRY_INTERVAL_S)

    logger.error("UX8 detection failed after %ds; assuming 0 units", IOMAPPING_RETRY_MAX_S)
    return 0


# ---------------------------------------------------------------------------
# Initialize all writable 1-bit addresses to 0
# ---------------------------------------------------------------------------

def initialize_outputs(io_svc: IoMappingService, ux8_count: int) -> None:
    """Write 0 to all writable 1-bit addresses at startup."""
    addrs = _get_all_writable_1bit_addrs(ux8_count)
    logger.info("Initializing %d writable 1-bit addresses to 0", len(addrs))
    errors = 0
    for addr in addrs:
        if not io_svc.write_value(addr, 0):
            errors += 1
    if errors:
        logger.warning("Failed to initialize %d address(es) to 0", errors)
    else:
        logger.info("All writable 1-bit addresses initialized to 0")


# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------

_EXIT = False
_server: Optional[TCPServer] = None
_timer_mgr: Optional[TimerManager] = None


def _signal_handler(signum, frame):
    global _EXIT
    logger.info("Signal %d received — initiating graceful shutdown", signum)
    _EXIT = True


def main():
    global _EXIT, _server, _timer_mgr

    signal.signal(signal.SIGTERM, _signal_handler)
    signal.signal(signal.SIGINT, _signal_handler)

    # Load config
    config = load_config()

    # Configure syslog if enabled
    if config.get("enable_syslog", False):
        syslog_addr = config.get("syslog_address", "").strip()
        if syslog_addr:
            setup_syslog(syslog_addr)
        else:
            logger.warning("enable_syslog=true but syslog_address is empty — syslog not configured")

    if not config.get("TCP_api_enable", True):
        logger.info("TCP API disabled in config; exiting")
        return

    # Startup delay
    logger.info("Startup delay: waiting %ds for system IO to be ready...", STARTUP_DELAY_S)
    time.sleep(STARTUP_DELAY_S)

    # Initialize IoMapping with retry
    io_svc = IoMappingService()
    iomapping_ready = io_svc.wait_until_ready()

    if not iomapping_ready:
        logger.error("IoMapping unavailable — some commands may fail")

    # Detect UX8 units
    ux8_count = 0
    if iomapping_ready:
        ux8_count = detect_ux8(io_svc)
    else:
        logger.warning("Skipping UX8 detection (IoMapping not ready)")

    # Initialize all writable 1-bit addresses to 0
    if iomapping_ready:
        initialize_outputs(io_svc, ux8_count)

    # Set up timer manager
    _timer_mgr = TimerManager(io_svc)

    # Set up subscription manager
    initial_sub = config.get("TCP_initial_state_subscriptions", "None")
    add_sub = config.get("TCP_add_IO_state_subscriptions", "None")
    sub_mgr = SubscriptionManager(io_svc, ux8_count)
    sub_mgr.configure(initial_sub, add_sub)

    # Start TCP server in its own thread
    _server = TCPServer(config, io_svc, ux8_count, _timer_mgr, sub_mgr)
    _server.start()

    logger.info(
        "TCP Command API running | port=%d | ux8=%d | initial_sub=%s | add_sub=%s",
        config.get("TCP_port", 12301), ux8_count, initial_sub, add_sub
    )

    # Main loop: drive IoMapping notifications
    try:
        while not _EXIT:
            io_svc.run_once()
            time.sleep(0.01)  # 10ms polling interval
    except Exception as exc:
        logger.error("Main loop error: %s", exc)
    finally:
        logger.info("Shutting down...")

        # Graceful shutdown: reset all active timers
        if _timer_mgr:
            _timer_mgr.cancel_all_and_reset()

        # Stop TCP server
        if _server:
            _server.stop()

        logger.info("Shutdown complete")


if __name__ == "__main__":
    main()
