tma: impl helper services, cleanup hostside packets

This commit is contained in:
Michael Scire 2018-11-06 20:20:07 -08:00
parent 46001263f8
commit 2572ae8378
10 changed files with 252 additions and 47 deletions

View File

@ -18,8 +18,6 @@ def main(argc, argv):
print 'Waiting for connection...'
c.wait_connected()
print 'Connected!'
while True:
c.send_packet('AAAAAAAA')
return 0
if __name__ == '__main__':

View File

@ -0,0 +1,116 @@
# Copyright (c) 2018 Atmosphere-NX
#
# This program is free software; you can redistribute it and/or modify it
# under the terms and conditions of the GNU General Public License,
# version 2, as published by the Free Software Foundation.
#
# This program is distributed in the hope it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import zlib
import ServiceId
from struct import unpack as up, pack as pk
HEADER_SIZE = 0x28
def crc32(s):
return zlib.crc32(s) & 0xFFFFFFFF
class Packet():
def __init__(self):
self.service = 0
self.task = 0
self.cmd = 0
self.continuation = 0
self.version = 0
self.body_len = 0
self.body = ''
self.offset = 0
def load_header(self, header):
assert len(header) == HEADER_SIZE
self.service, self.task, self.cmd, self.continuation, self.version, self.body_len, \
_, self.body_chk, self.hdr_chk = up('<IIHBBI16sII', header)
if crc32(header[:-4]) != self.hdr_chk:
raise ValueError('Invalid header checksum in received packet!')
def load_body(self, body):
assert len(body) == self.body_len
if crc32(body) != self.body_chk:
raise ValueError('Invalid body checksum in received packet!')
self.body = body
def get_data(self):
assert len(self.body) == self.body_len and self.body_len <= 0xE000
self.body_chk = crc32(self.body)
hdr = pk('<IIHBBIIIIII', self.service, self.task, self.cmd, self.continuation, self.version, self.body_len, 0, 0, 0, 0, self.body_chk)
self.hdr_chk = crc32(hdr)
hdr += pk('<I', self.hdr_chk)
return hdr + self.body
def set_service(self, srv):
if type(srv) is str:
self.service = ServiceId.hash(srv)
else:
self.service = srv
return self
def set_task(self, t):
self.task = t
return self
def set_cmd(self, x):
self.cmd = x
return self
def set_continuation(self, c):
self.continuation = c
return self
def set_version(self, v):
self.version = v
return self
def reset_offset(self):
self.offset = 0
return self
def write_str(self, s):
self.body += s
self.body_len += len(s)
return self
def write_u8(self, x):
self.body += pk('<B', x & 0xFF)
self.body_len += 1
return self
def write_u16(self, x):
self.body += pk('<H', x & 0xFFFF)
self.body_len += 2
return self
def write_u32(self, x):
self.body += pk('<I', x & 0xFFFFFFFF)
self.body_len += 4
return self
def write_u64(self, x):
self.body += pk('<Q', x & 0xFFFFFFFFFFFFFFFF)
self.body_len += 8
return self
def read_str(self):
s = ''
while self.body[self.offset] != '\x00' and self.offset < self.body_len:
s += self.body[self.offset]
self.offset += 1
def read_u8(self):
x, = up('<B', self.body[self.offset:self.offset+1])
self.offset += 1
return x
def read_u16(self):
x, = up('<H', self.body[self.offset:self.offset+2])
self.offset += 2
return x
def read_u32(self):
x, = up('<I', self.body[self.offset:self.offset+4])
self.offset += 4
return x
def read_u64(self):
x, = up('<Q', self.body[self.offset:self.offset+8])
self.offset += 8
return x
def read_struct(self, format, sz):
x = up(format, self.body[self.offset:self.offset+sz])
self.offset += sz
return x

View File

@ -0,0 +1,27 @@
# Copyright (c) 2018 Atmosphere-NX
#
# This program is free software; you can redistribute it and/or modify it
# under the terms and conditions of the GNU General Public License,
# version 2, as published by the Free Software Foundation.
#
# This program is distributed in the hope it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
def hash(s):
h = ord(s[0]) & 0xFFFFFFFF
for c in s:
h = ((1000003 * h) ^ ord(c)) & 0xFFFFFFFF
h ^= len(s)
return h
USB_QUERY_TARGET = hash("USBQueryTarget")
USB_SEND_HOST_INFO = hash("USBSendHostInfo")
USB_CONNECT = hash("USBConnect")
USB_DISCONNECT = hash("USBDisconnect")

View File

@ -15,6 +15,8 @@ from UsbInterface import UsbInterface
from threading import Thread, Condition
from collections import deque
import time
import ServiceId
from Packet import Packet
class UsbConnection(UsbInterface):
# Auto connect thread func.
@ -25,12 +27,6 @@ class UsbConnection(UsbInterface):
except ValueError as e:
continue
def recv_thread(connection):
if connection.is_connected():
try:
# If we've previously been connected, PyUSB will read garbage...
connection.recv_packet()
except ValueError:
pass
while connection.is_connected():
try:
connection.recv_packet()
@ -65,6 +61,7 @@ class UsbConnection(UsbInterface):
self.conn_thrd.start()
return self
def __exit__(self, type, value, traceback):
self.disconnect()
time.sleep(1)
print 'Closing!'
time.sleep(1)
@ -80,24 +77,43 @@ class UsbConnection(UsbInterface):
self.conn_lock.acquire()
assert not self.connected
self.intf = intf
self.connected = True
self.conn_lock.notify()
self.conn_lock.release()
self.recv_thrd = Thread(target=UsbConnection.recv_thread, args=(self,))
self.send_thrd = Thread(target=UsbConnection.send_thread, args=(self,))
self.recv_thrd.daemon = True
self.send_thrd.daemon = True
self.recv_thrd.start()
self.send_thrd.start()
try:
# Perform Query + Connection handshake
self.intf.send_packet(Packet().set_service(ServiceId.USB_QUERY_TARGET))
query_resp = self.intf.read_packet()
print 'Found Switch, Protocol version 0x%x' % query_resp.read_u32()
self.intf.send_packet(Packet().set_service(ServiceId.USB_SEND_HOST_INFO).write_u32(0).write_u32(0))
self.intf.send_packet(Packet().set_service(ServiceId.USB_CONNECT))
resp = self.intf.read_packet()
# Spawn threads
self.recv_thrd = Thread(target=UsbConnection.recv_thread, args=(self,))
self.send_thrd = Thread(target=UsbConnection.send_thread, args=(self,))
self.recv_thrd.daemon = True
self.send_thrd.daemon = True
self.recv_thrd.start()
self.send_thrd.start()
self.connected = True
finally:
# Finish connection.
self.conn_lock.notify()
self.conn_lock.release()
def disconnect(self):
self.conn_lock.acquire()
if self.connected:
self.connected = False
self.intf.send_packet(Packet().set_service(ServiceId.USB_DISCONNECT))
self.conn_lock.release()
def recv_packet(self):
hdr, body = self.intf.read_packet()
print('Got Packet: %s' % body.encode('hex'))
packet = self.intf.read_packet()
assert type(packet) is Packet
dat = packet.read_u64()
print('Got Packet: %08x' % dat)
def send_packet(self, packet):
assert type(packet) is Packet
self.send_lock.acquire()
if len(self.send_queue) == 0x40:
self.send_lock.wait()

View File

@ -11,11 +11,8 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import usb, zlib
from struct import unpack as up, pack as pk
def crc32(s):
return zlib.crc32(s) & 0xFFFFFFFF
import usb
import Packet
class UsbInterface():
def __init__(self):
@ -50,20 +47,16 @@ class UsbInterface():
def blocking_write(self, data):
self.ep_out.write(data, 0xFFFFFFFFFFFFFFFF)
def read_packet(self):
hdr = self.blocking_read(0x28)
_, _, _, body_size, _, _, _, _, body_chk, hdr_chk = up('<IIIIIIIIII', hdr)
if crc32(hdr[:-4]) != hdr_chk:
raise ValueError('Invalid header checksum in received packet!')
body = self.blocking_read(body_size)
if len(body) != body_size:
raise ValueError('Failed to receive packet body!')
elif crc32(body) != body_chk:
raise ValueError('Invalid body checksum in received packet!')
return (hdr, body)
def send_packet(self, body):
hdr = pk('<IIIIIIIII', 0, 0, 0, len(body), 0, 0, 0, 0, crc32(body))
hdr += pk('<I', crc32(hdr))
self.blocking_write(hdr)
self.blocking_write(body)
packet = Packet.Packet()
hdr = self.blocking_read(Packet.HEADER_SIZE)
packet.load_header(hdr)
if packet.body_len:
packet.load_body(self.blocking_read(packet.body_len))
return packet
def send_packet(self, packet):
data = packet.get_data()
self.blocking_write(data[:Packet.HEADER_SIZE])
if (len(data) > Packet.HEADER_SIZE):
self.blocking_write(data[Packet.HEADER_SIZE:])

View File

@ -190,7 +190,7 @@ class TmaPacket {
}
template<typename T>
TmaConnResult Read(const T &t) {
TmaConnResult Read(T &t) {
return Read(&t, sizeof(T));
}

View File

@ -34,5 +34,13 @@ static constexpr u32 HashServiceName(const char *name) {
enum class TmaService : u32 {
Invalid = 0,
/* Special nodes, for facilitating connection over USB. */
UsbQueryTarget = HashServiceName("USBQueryTarget"),
UsbSendHostInfo = HashServiceName("USBSendHostInfo"),
UsbConnect = HashServiceName("USBConnect"),
UsbDisconnect = HashServiceName("USBDisconnect"),
TestService = HashServiceName("AtmosphereTestService"), /* Temporary service, will be used to debug communications. */
};

View File

@ -76,18 +76,54 @@ void TmaUsbConnection::RecvThreadFunc(void *arg) {
this_ptr->SetConnected(true);
while (res == TmaConnResult::Success) {
if (!this_ptr->IsConnected()) {
break;
}
TmaPacket *packet = this_ptr->AllocateRecvPacket();
if (packet == nullptr) { std::abort(); }
res = TmaUsbComms::ReceivePacket(packet);
if (res == TmaConnResult::Success) {
TmaPacket *send_packet = this_ptr->AllocateSendPacket();
send_packet->Write<u64>(i++);
this_ptr->send_queue.Send(reinterpret_cast<uintptr_t>(send_packet));
switch (packet->GetServiceId()) {
case TmaService::UsbQueryTarget: {
this_ptr->SetConnected(false);
res = this_ptr->SendQueryReply(packet);
if (!this_ptr->has_woken_up) {
/* TODO: Cancel background work. */
}
}
break;
case TmaService::UsbSendHostInfo: {
struct {
u32 version;
u32 sleeping;
} host_info;
packet->Read<decltype(host_info)>(host_info);
if (!this_ptr->has_woken_up || !host_info.sleeping) {
/* TODO: Cancel background work. */
}
}
break;
case TmaService::UsbConnect: {
res = this_ptr->SendQueryReply(packet);
if (res == TmaConnResult::Success) {
this_ptr->SetConnected(true);
this_ptr->OnConnectionEvent(ConnectionEvent::Connected);
}
}
break;
case TmaService::UsbDisconnect: {
this_ptr->SetConnected(false);
this_ptr->OnDisconnected();
/* TODO: Cancel background work. */
}
break;
default:
break;
}
this_ptr->FreePacket(packet);
} else {
this_ptr->FreePacket(packet);
@ -153,3 +189,13 @@ TmaConnResult TmaUsbConnection::SendPacket(TmaPacket *packet) {
return TmaConnResult::Disconnected;
}
}
TmaConnResult TmaUsbConnection::SendQueryReply(TmaPacket *packet) {
packet->ClearOffset();
struct {
u32 version;
} target_info;
target_info.version = 0;
packet->Write<decltype(target_info)>(target_info);
return TmaUsbComms::SendPacket(packet);
}

View File

@ -29,6 +29,7 @@ class TmaUsbConnection : public TmaConnection {
static void SendThreadFunc(void *arg);
static void RecvThreadFunc(void *arg);
static void OnUsbStateChange(void *this_ptr, u32 state);
TmaConnResult SendQueryReply(TmaPacket *packet);
void ClearSendQueue();
void StartThreads();
void StopThreads();

View File

@ -436,7 +436,7 @@ TmaConnResult TmaUsbComms::SendPacket(TmaPacket *packet) {
res = TmaConnResult::GeneralFailure;
}
if (res == TmaConnResult::Success) {
if (res == TmaConnResult::Success && 0 < body_len) {
/* Copy body to send buffer. */
packet->CopyBodyTo(g_send_data_buf);