diff --git a/src/core/hle/service/soc_u.cpp b/src/core/hle/service/soc_u.cpp index d3e5d4bca..a7404085b 100644 --- a/src/core/hle/service/soc_u.cpp +++ b/src/core/hle/service/soc_u.cpp @@ -373,14 +373,18 @@ static void Bind(Service::Interface* self) { u32* cmd_buffer = Kernel::GetCommandBuffer(); u32 socket_handle = cmd_buffer[1]; u32 len = cmd_buffer[2]; - CTRSockAddr* ctr_sock_addr = reinterpret_cast(Memory::GetPointer(cmd_buffer[6])); - if (ctr_sock_addr == nullptr) { + // Virtual address of the sock_addr structure + VAddr sock_addr_addr = cmd_buffer[6]; + if (!Memory::IsValidVirtualAddress(sock_addr_addr)) { cmd_buffer[1] = -1; // TODO(Subv): Correct code return; } - sockaddr sock_addr = CTRSockAddr::ToPlatform(*ctr_sock_addr); + CTRSockAddr ctr_sock_addr; + Memory::ReadBlock(sock_addr_addr, reinterpret_cast(&ctr_sock_addr), sizeof(CTRSockAddr)); + + sockaddr sock_addr = CTRSockAddr::ToPlatform(ctr_sock_addr); int res = ::bind(socket_handle, &sock_addr, std::max(sizeof(sock_addr), len)); @@ -496,7 +500,7 @@ static void Accept(Service::Interface* self) { result = TranslateError(GET_ERRNO); } else { CTRSockAddr ctr_addr = CTRSockAddr::FromPlatform(addr); - Memory::WriteBlock(cmd_buffer[0x104 >> 2], (const u8*)&ctr_addr, max_addr_len); + Memory::WriteBlock(cmd_buffer[0x104 >> 2], &ctr_addr, sizeof(ctr_addr)); } cmd_buffer[0] = IPC::MakeHeader(4, 2, 2); @@ -547,20 +551,31 @@ static void SendTo(Service::Interface* self) { u32 flags = cmd_buffer[3]; u32 addr_len = cmd_buffer[4]; - u8* input_buff = Memory::GetPointer(cmd_buffer[8]); - CTRSockAddr* ctr_dest_addr = reinterpret_cast(Memory::GetPointer(cmd_buffer[10])); - - if (ctr_dest_addr == nullptr) { + VAddr input_buff_address = cmd_buffer[8]; + if (!Memory::IsValidVirtualAddress(input_buff_address)) { cmd_buffer[1] = -1; // TODO(Subv): Find the right error code return; } + // Memory address of the dest_addr structure + VAddr dest_addr_addr = cmd_buffer[10]; + if (!Memory::IsValidVirtualAddress(dest_addr_addr)) { + cmd_buffer[1] = -1; // TODO(Subv): Find the right error code + return; + } + + std::vector input_buff(len); + Memory::ReadBlock(input_buff_address, input_buff.data(), input_buff.size()); + + CTRSockAddr ctr_dest_addr; + Memory::ReadBlock(dest_addr_addr, reinterpret_cast(&ctr_dest_addr), sizeof(ctr_dest_addr)); + int ret = -1; if (addr_len > 0) { - sockaddr dest_addr = CTRSockAddr::ToPlatform(*ctr_dest_addr); - ret = ::sendto(socket_handle, (const char*)input_buff, len, flags, &dest_addr, sizeof(dest_addr)); + sockaddr dest_addr = CTRSockAddr::ToPlatform(ctr_dest_addr); + ret = ::sendto(socket_handle, reinterpret_cast(input_buff.data()), len, flags, &dest_addr, sizeof(dest_addr)); } else { - ret = ::sendto(socket_handle, (const char*)input_buff, len, flags, nullptr, 0); + ret = ::sendto(socket_handle, reinterpret_cast(input_buff.data()), len, flags, nullptr, 0); } int result = 0; @@ -591,14 +606,24 @@ static void RecvFrom(Service::Interface* self) { std::memcpy(&buffer_parameters, &cmd_buffer[64], sizeof(buffer_parameters)); - u8* output_buff = Memory::GetPointer(buffer_parameters.output_buffer_addr); + if (!Memory::IsValidVirtualAddress(buffer_parameters.output_buffer_addr)) { + cmd_buffer[1] = -1; // TODO(Subv): Find the right error code + return; + } + + if (!Memory::IsValidVirtualAddress(buffer_parameters.output_src_address_buffer)) { + cmd_buffer[1] = -1; // TODO(Subv): Find the right error code + return; + } + + std::vector output_buff(len); sockaddr src_addr; socklen_t src_addr_len = sizeof(src_addr); - int ret = ::recvfrom(socket_handle, (char*)output_buff, len, flags, &src_addr, &src_addr_len); + int ret = ::recvfrom(socket_handle, reinterpret_cast(output_buff.data()), len, flags, &src_addr, &src_addr_len); if (ret >= 0 && buffer_parameters.output_src_address_buffer != 0 && src_addr_len > 0) { - CTRSockAddr* ctr_src_addr = reinterpret_cast(Memory::GetPointer(buffer_parameters.output_src_address_buffer)); - *ctr_src_addr = CTRSockAddr::FromPlatform(src_addr); + CTRSockAddr ctr_src_addr = CTRSockAddr::FromPlatform(src_addr); + Memory::WriteBlock(buffer_parameters.output_src_address_buffer, reinterpret_cast(&ctr_src_addr), sizeof(ctr_src_addr)); } int result = 0; @@ -606,6 +631,9 @@ static void RecvFrom(Service::Interface* self) { if (ret == SOCKET_ERROR_VALUE) { result = TranslateError(GET_ERRNO); total_received = 0; + } else { + // Write only the data we received to avoid overwriting parts of the buffer with zeros + Memory::WriteBlock(buffer_parameters.output_buffer_addr, output_buff.data(), total_received); } cmd_buffer[1] = result; @@ -617,18 +645,28 @@ static void Poll(Service::Interface* self) { u32* cmd_buffer = Kernel::GetCommandBuffer(); u32 nfds = cmd_buffer[1]; int timeout = cmd_buffer[2]; - CTRPollFD* input_fds = reinterpret_cast(Memory::GetPointer(cmd_buffer[6])); - CTRPollFD* output_fds = reinterpret_cast(Memory::GetPointer(cmd_buffer[0x104 >> 2])); + + VAddr input_fds_addr = cmd_buffer[6]; + VAddr output_fds_addr = cmd_buffer[0x104 >> 2]; + if (!Memory::IsValidVirtualAddress(input_fds_addr) || !Memory::IsValidVirtualAddress(output_fds_addr)) { + cmd_buffer[1] = -1; // TODO(Subv): Find correct error code. + return; + } + + std::vector ctr_fds(nfds); + Memory::ReadBlock(input_fds_addr, reinterpret_cast(ctr_fds.data()), nfds * sizeof(CTRPollFD)); // The 3ds_pollfd and the pollfd structures may be different (Windows/Linux have different sizes) // so we have to copy the data std::vector platform_pollfd(nfds); - std::transform(input_fds, input_fds + nfds, platform_pollfd.begin(), CTRPollFD::ToPlatform); + std::transform(ctr_fds.begin(), ctr_fds.end(), platform_pollfd.begin(), CTRPollFD::ToPlatform); const int ret = ::poll(platform_pollfd.data(), nfds, timeout); // Now update the output pollfd structure - std::transform(platform_pollfd.begin(), platform_pollfd.end(), output_fds, CTRPollFD::FromPlatform); + std::transform(platform_pollfd.begin(), platform_pollfd.end(), ctr_fds.begin(), CTRPollFD::FromPlatform); + + Memory::WriteBlock(output_fds_addr, reinterpret_cast(ctr_fds.data()), nfds * sizeof(CTRPollFD)); int result = 0; if (ret == SOCKET_ERROR_VALUE) @@ -643,14 +681,16 @@ static void GetSockName(Service::Interface* self) { u32 socket_handle = cmd_buffer[1]; socklen_t ctr_len = cmd_buffer[2]; - CTRSockAddr* ctr_dest_addr = reinterpret_cast(Memory::GetPointer(cmd_buffer[0x104 >> 2])); + // Memory address of the ctr_dest_addr structure + VAddr ctr_dest_addr_addr = cmd_buffer[0x104 >> 2]; sockaddr dest_addr; socklen_t dest_addr_len = sizeof(dest_addr); int ret = ::getsockname(socket_handle, &dest_addr, &dest_addr_len); - if (ctr_dest_addr != nullptr) { - *ctr_dest_addr = CTRSockAddr::FromPlatform(dest_addr); + if (ctr_dest_addr_addr != 0 && Memory::IsValidVirtualAddress(ctr_dest_addr_addr)) { + CTRSockAddr ctr_dest_addr = CTRSockAddr::FromPlatform(dest_addr); + Memory::WriteBlock(ctr_dest_addr_addr, reinterpret_cast(&ctr_dest_addr), sizeof(ctr_dest_addr)); } else { cmd_buffer[1] = -1; // TODO(Subv): Verify error return; @@ -682,14 +722,16 @@ static void GetPeerName(Service::Interface* self) { u32 socket_handle = cmd_buffer[1]; socklen_t len = cmd_buffer[2]; - CTRSockAddr* ctr_dest_addr = reinterpret_cast(Memory::GetPointer(cmd_buffer[0x104 >> 2])); + // Memory address of the ctr_dest_addr structure + VAddr ctr_dest_addr_addr = cmd_buffer[0x104 >> 2]; sockaddr dest_addr; socklen_t dest_addr_len = sizeof(dest_addr); int ret = ::getpeername(socket_handle, &dest_addr, &dest_addr_len); - if (ctr_dest_addr != nullptr) { - *ctr_dest_addr = CTRSockAddr::FromPlatform(dest_addr); + if (ctr_dest_addr_addr != 0 && Memory::IsValidVirtualAddress(ctr_dest_addr_addr)) { + CTRSockAddr ctr_dest_addr = CTRSockAddr::FromPlatform(dest_addr); + Memory::WriteBlock(ctr_dest_addr_addr, reinterpret_cast(&ctr_dest_addr), sizeof(ctr_dest_addr)); } else { cmd_buffer[1] = -1; return; @@ -711,13 +753,17 @@ static void Connect(Service::Interface* self) { u32 socket_handle = cmd_buffer[1]; socklen_t len = cmd_buffer[2]; - CTRSockAddr* ctr_input_addr = reinterpret_cast(Memory::GetPointer(cmd_buffer[6])); - if (ctr_input_addr == nullptr) { + // Memory address of the ctr_input_addr structure + VAddr ctr_input_addr_addr = cmd_buffer[6]; + if (!Memory::IsValidVirtualAddress(ctr_input_addr_addr)) { cmd_buffer[1] = -1; // TODO(Subv): Verify error return; } - sockaddr input_addr = CTRSockAddr::ToPlatform(*ctr_input_addr); + CTRSockAddr ctr_input_addr; + Memory::ReadBlock(ctr_input_addr_addr, reinterpret_cast(&ctr_input_addr), sizeof(ctr_input_addr)); + + sockaddr input_addr = CTRSockAddr::ToPlatform(ctr_input_addr); int ret = ::connect(socket_handle, &input_addr, sizeof(input_addr)); int result = 0; if (ret != 0)