import logging
|
import math
|
import struct
|
import usb.util
|
|
def to_hex(val):
|
return ' '.join([bytes([i]).hex() for i in val])
|
|
class ControlVaultCommunicator:
|
def __init__(self, device, spi_master=0x01, spi_slave=0x00, spi_crc=0x00):
|
self.logger = logging.getLogger(__name__)
|
self.device = device
|
self.bulk_in, self.bulk_out = self._find_endpoints()
|
|
self.spi_master = spi_master
|
self.spi_slave = spi_slave
|
self.spi_crc = spi_crc
|
self.spi_slave_prefix = struct.pack('>BB', self.spi_slave, self.spi_crc)
|
|
def ctrl_transfer(self, *args, **kwargs):
|
self.logger.debug('Control: {} {}'.format(args, kwargs))
|
return self.device.ctrl_transfer(*args, **kwargs)
|
|
def write(self, *args, **kwargs):
|
return self.bulk_out.write(*args, **kwargs)
|
|
def read(self, *args, **kwargs):
|
return self.bulk_in.read(*args, **kwargs)
|
|
def send_packet(self, payload):
|
length = len(payload)
|
packet = struct.pack('>BBH', self.spi_master, self.spi_crc, length) + payload
|
self.logger.debug('Put: {}'.format(to_hex(packet)))
|
self.write(packet)
|
|
def recv_packet(self):
|
packet = self.read(64, timeout=5000).tobytes()
|
tag = packet[0:2]
|
if tag != self.spi_slave_prefix:
|
raise Exception('Unknown tag: {}'.format(tag.hex()))
|
length = packet[2:4]
|
length = struct.unpack('>H', length)[0]
|
|
for i in range(0, math.ceil(length/64) - 1):
|
packet += self.read(64).tobytes()
|
|
self.logger.debug('Got: {}'.format(to_hex(packet)))
|
return packet[4:]
|
|
def talk(self, exchange):
|
for packet in exchange:
|
self.send_packet(bytes.fromhex(packet))
|
data = self.recv_packet()
|
|
if data[1] == 0x61:
|
packet = self.recv_packet()
|
|
def _find_endpoints(self):
|
self.logger.debug('Enumerating interfaces...')
|
configuration = self.device.get_active_configuration()
|
bcm_interface = None
|
for interface in configuration:
|
if interface.bInterfaceClass == 0xff:
|
if bcm_interface is not None:
|
raise Exception('More than one vendor-specific interface found!')
|
bcm_interface = interface
|
if bcm_interface is None:
|
raise Exception('Cannot find vendor-specific interface')
|
self.logger.debug('Interface found: {}'.format(bcm_interface._str()))
|
|
self.logger.debug('Enumerating endpoints...')
|
bulk_in = None
|
bulk_out = None
|
for endpoint in bcm_interface:
|
if endpoint.bmAttributes & usb.util._ENDPOINT_TRANSFER_TYPE_MASK == usb.util.ENDPOINT_TYPE_BULK:
|
if endpoint.bEndpointAddress & usb.util._ENDPOINT_DIR_MASK == usb.util.ENDPOINT_IN:
|
if bulk_in is not None:
|
raise Exception('More than one BULK IN endpoint found!')
|
bulk_in = endpoint
|
self.logger.debug('BULK IN found: {}'.format(bulk_in._str()))
|
if endpoint.bEndpointAddress & usb.util._ENDPOINT_DIR_MASK == usb.util.ENDPOINT_OUT:
|
if bulk_out is not None:
|
raise Exception('More than one BULK OUT endpoint found!')
|
bulk_out = endpoint
|
self.logger.debug('BULK OUT found: {}'.format(bulk_out._str()))
|
|
if bulk_in is None:
|
raise Exception('BULK IN endpoint not found!')
|
if bulk_out is None:
|
raise Exception('BULK OUT endpoint not found!')
|
|
self.logger.debug('Endpoint discovery successful.')
|
return bulk_in, bulk_out
|