mirror of
https://github.com/ryujinx-mirror/ryujinx.git
synced 2024-12-13 17:01:07 +01:00
248 lines
7.2 KiB
C#
248 lines
7.2 KiB
C#
|
using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
|
|||
|
using Ryujinx.HLE.HOS.Services.Ssl.Types;
|
|||
|
using System;
|
|||
|
using System.IO;
|
|||
|
using System.Net.Security;
|
|||
|
using System.Net.Sockets;
|
|||
|
using System.Security.Authentication;
|
|||
|
|
|||
|
namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
|
|||
|
{
|
|||
|
class SslManagedSocketConnection : ISslConnectionBase
|
|||
|
{
|
|||
|
public int SocketFd { get; }
|
|||
|
|
|||
|
public ISocket Socket { get; }
|
|||
|
|
|||
|
private BsdContext _bsdContext;
|
|||
|
private SslVersion _sslVersion;
|
|||
|
private SslStream _stream;
|
|||
|
private bool _isBlockingSocket;
|
|||
|
private int _previousReadTimeout;
|
|||
|
|
|||
|
public SslManagedSocketConnection(BsdContext bsdContext, SslVersion sslVersion, int socketFd, ISocket socket)
|
|||
|
{
|
|||
|
_bsdContext = bsdContext;
|
|||
|
_sslVersion = sslVersion;
|
|||
|
|
|||
|
SocketFd = socketFd;
|
|||
|
Socket = socket;
|
|||
|
}
|
|||
|
|
|||
|
private void StartSslOperation()
|
|||
|
{
|
|||
|
// Save blocking state
|
|||
|
_isBlockingSocket = Socket.Blocking;
|
|||
|
|
|||
|
// Force blocking for SslStream
|
|||
|
Socket.Blocking = true;
|
|||
|
}
|
|||
|
|
|||
|
private void EndSslOperation()
|
|||
|
{
|
|||
|
// Restore blocking state
|
|||
|
Socket.Blocking = _isBlockingSocket;
|
|||
|
}
|
|||
|
|
|||
|
private void StartSslReadOperation()
|
|||
|
{
|
|||
|
StartSslOperation();
|
|||
|
|
|||
|
if (!_isBlockingSocket)
|
|||
|
{
|
|||
|
_previousReadTimeout = _stream.ReadTimeout;
|
|||
|
|
|||
|
_stream.ReadTimeout = 1;
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
private void EndSslReadOperation()
|
|||
|
{
|
|||
|
if (!_isBlockingSocket)
|
|||
|
{
|
|||
|
_stream.ReadTimeout = _previousReadTimeout;
|
|||
|
}
|
|||
|
|
|||
|
EndSslOperation();
|
|||
|
}
|
|||
|
|
|||
|
private static SslProtocols TranslateSslVersion(SslVersion version)
|
|||
|
{
|
|||
|
switch (version & SslVersion.VersionMask)
|
|||
|
{
|
|||
|
case SslVersion.Auto:
|
|||
|
return SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13;
|
|||
|
case SslVersion.TlsV10:
|
|||
|
return SslProtocols.Tls;
|
|||
|
case SslVersion.TlsV11:
|
|||
|
return SslProtocols.Tls11;
|
|||
|
case SslVersion.TlsV12:
|
|||
|
return SslProtocols.Tls12;
|
|||
|
case SslVersion.TlsV13:
|
|||
|
return SslProtocols.Tls13;
|
|||
|
default:
|
|||
|
throw new NotImplementedException(version.ToString());
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
public ResultCode Handshake(string hostName)
|
|||
|
{
|
|||
|
StartSslOperation();
|
|||
|
_stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null);
|
|||
|
_stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false);
|
|||
|
EndSslOperation();
|
|||
|
|
|||
|
return ResultCode.Success;
|
|||
|
}
|
|||
|
|
|||
|
public ResultCode Peek(out int peekCount, Memory<byte> buffer)
|
|||
|
{
|
|||
|
// NOTE: We cannot support that on .NET SSL API.
|
|||
|
// As Nintendo's curl implementation detail check if a connection is alive via Peek, we just return that it would block to let it know that it's alive.
|
|||
|
peekCount = -1;
|
|||
|
|
|||
|
return ResultCode.WouldBlock;
|
|||
|
}
|
|||
|
|
|||
|
public int Pending()
|
|||
|
{
|
|||
|
// Unsupported
|
|||
|
return 0;
|
|||
|
}
|
|||
|
|
|||
|
private static bool TryTranslateWinSockError(bool isBlocking, WsaError error, out ResultCode resultCode)
|
|||
|
{
|
|||
|
switch (error)
|
|||
|
{
|
|||
|
case WsaError.WSAETIMEDOUT:
|
|||
|
resultCode = isBlocking ? ResultCode.Timeout : ResultCode.WouldBlock;
|
|||
|
return true;
|
|||
|
case WsaError.WSAECONNABORTED:
|
|||
|
resultCode = ResultCode.ConnectionAbort;
|
|||
|
return true;
|
|||
|
case WsaError.WSAECONNRESET:
|
|||
|
resultCode = ResultCode.ConnectionReset;
|
|||
|
return true;
|
|||
|
default:
|
|||
|
resultCode = ResultCode.Success;
|
|||
|
return false;
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
public ResultCode Read(out int readCount, Memory<byte> buffer)
|
|||
|
{
|
|||
|
if (!Socket.Poll(0, SelectMode.SelectRead))
|
|||
|
{
|
|||
|
readCount = -1;
|
|||
|
|
|||
|
return ResultCode.WouldBlock;
|
|||
|
}
|
|||
|
|
|||
|
StartSslReadOperation();
|
|||
|
|
|||
|
try
|
|||
|
{
|
|||
|
readCount = _stream.Read(buffer.Span);
|
|||
|
}
|
|||
|
catch (IOException exception)
|
|||
|
{
|
|||
|
readCount = -1;
|
|||
|
|
|||
|
if (exception.InnerException is SocketException socketException)
|
|||
|
{
|
|||
|
WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
|
|||
|
|
|||
|
if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
|
|||
|
{
|
|||
|
return result;
|
|||
|
}
|
|||
|
else
|
|||
|
{
|
|||
|
throw socketException;
|
|||
|
}
|
|||
|
}
|
|||
|
else
|
|||
|
{
|
|||
|
throw exception;
|
|||
|
}
|
|||
|
}
|
|||
|
finally
|
|||
|
{
|
|||
|
EndSslReadOperation();
|
|||
|
}
|
|||
|
|
|||
|
return ResultCode.Success;
|
|||
|
}
|
|||
|
|
|||
|
public ResultCode Write(out int writtenCount, ReadOnlyMemory<byte> buffer)
|
|||
|
{
|
|||
|
if (!Socket.Poll(0, SelectMode.SelectWrite))
|
|||
|
{
|
|||
|
writtenCount = 0;
|
|||
|
|
|||
|
return ResultCode.WouldBlock;
|
|||
|
}
|
|||
|
|
|||
|
StartSslOperation();
|
|||
|
|
|||
|
try
|
|||
|
{
|
|||
|
_stream.Write(buffer.Span);
|
|||
|
}
|
|||
|
catch (IOException exception)
|
|||
|
{
|
|||
|
writtenCount = -1;
|
|||
|
|
|||
|
if (exception.InnerException is SocketException socketException)
|
|||
|
{
|
|||
|
WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
|
|||
|
|
|||
|
if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
|
|||
|
{
|
|||
|
return result;
|
|||
|
}
|
|||
|
else
|
|||
|
{
|
|||
|
throw socketException;
|
|||
|
}
|
|||
|
}
|
|||
|
else
|
|||
|
{
|
|||
|
throw exception;
|
|||
|
}
|
|||
|
}
|
|||
|
finally
|
|||
|
{
|
|||
|
EndSslOperation();
|
|||
|
}
|
|||
|
|
|||
|
// .NET API doesn't provide the size written, assume all written.
|
|||
|
writtenCount = buffer.Length;
|
|||
|
|
|||
|
return ResultCode.Success;
|
|||
|
}
|
|||
|
|
|||
|
public ResultCode GetServerCertificate(string hostname, Span<byte> certificates, out uint storageSize, out uint certificateCount)
|
|||
|
{
|
|||
|
byte[] rawCertData = _stream.RemoteCertificate.GetRawCertData();
|
|||
|
|
|||
|
storageSize = (uint)rawCertData.Length;
|
|||
|
certificateCount = 1;
|
|||
|
|
|||
|
if (rawCertData.Length > certificates.Length)
|
|||
|
{
|
|||
|
return ResultCode.CertBufferTooSmall;
|
|||
|
}
|
|||
|
|
|||
|
rawCertData.CopyTo(certificates);
|
|||
|
|
|||
|
return ResultCode.Success;
|
|||
|
}
|
|||
|
|
|||
|
public void Dispose()
|
|||
|
{
|
|||
|
_bsdContext.CloseFileDescriptor(SocketFd);
|
|||
|
}
|
|||
|
}
|
|||
|
}
|