Compare commits

...

13 Commits

Author SHA1 Message Date
Anuken
63b92e6dd6 Better caching 2024-01-28 12:54:16 -05:00
Anuken
3be67aa52a No main executor call 2024-01-28 11:31:03 -05:00
Anuken
a73b783d98 SRV support 2024-01-27 17:46:52 -05:00
Anuken
db4c861dde AAAA record support 2024-01-27 12:50:26 -05:00
Anuken
c7a35ae789 IPv6 address support 2024-01-27 11:15:58 -05:00
Anuken
3e5ad07e8c Merge branch 'master' of https://github.com/Anuken/Mindustry into async-ping
 Conflicts:
	core/src/mindustry/net/ArcNetProvider.java
	gradle.properties
2024-01-27 10:11:09 -05:00
Anuken
584b22300d a few comments 2023-05-09 16:45:46 -04:00
Anuken
6eb049c419 Merge branch 'master' of https://github.com/Anuken/Mindustry into async-ping 2023-05-09 16:45:42 -04:00
Anuken
6be497878c cleanup 2023-02-26 15:44:51 -05:00
Anuken
8e9b409b63 cleanup 2023-02-26 15:41:49 -05:00
Anuken
b496e8457c it works minus SRV 2023-02-26 15:32:35 -05:00
Anuken
41b87b9345 static 2023-02-26 12:06:06 -05:00
Anuken
84e52bdee3 async UDP ping 2023-02-26 11:40:53 -05:00
6 changed files with 636 additions and 35 deletions

View File

@@ -0,0 +1,231 @@
package mindustry.net;
/** Utility class for parsing IPv4/IPv6. Taken from IPAddressUtil in the JDK. */
public class Addresses{
/** @return the IPv4 or IPv6 address of the string, or null if the address is not valid. */
public static byte[] getAddress(String src){
byte[] ipv4 = getIpv4(src);
if(ipv4 != null) return ipv4;
return getIpv6(src);
}
/** @return the IPv4 address of the string, or null if the address is not valid. Exotic formats (such as decimal) are allowed. */
public static byte[] getIpv4(String src){
byte[] res = new byte[4];
long tmpValue = 0L;
int currByte = 0;
boolean newOctet = true;
int len = src.length();
if(len != 0 && len <= 15){
for(int i = 0; i < len; ++i){
char c = src.charAt(i);
if(c == '.'){
if(newOctet || tmpValue < 0L || tmpValue > 255L || currByte == 3){
return null;
}
res[currByte++] = (byte)((int)(tmpValue & 255L));
tmpValue = 0L;
newOctet = true;
}else{
int digit = digit(c, 10);
if(digit < 0){
return null;
}
tmpValue *= 10L;
tmpValue += (long)digit;
newOctet = false;
}
}
if(!newOctet && tmpValue >= 0L && tmpValue < 1L << (4 - currByte) * 8){
switch(currByte){
case 0:
res[0] = (byte)((int)(tmpValue >> 24 & 255L));
case 1:
res[1] = (byte)((int)(tmpValue >> 16 & 255L));
case 2:
res[2] = (byte)((int)(tmpValue >> 8 & 255L));
case 3:
res[3] = (byte)((int)(tmpValue >> 0 & 255L));
default:
return res;
}
}else{
return null;
}
}else{
return null;
}
}
/** @return the IPv6 address of the string, or null if the address is not valid. */
public static byte[] getIpv6(String src){
if(src.length() < 2){
return null;
}else{
char[] srcb = src.toCharArray();
byte[] dst = new byte[16];
int srcb_length = srcb.length;
int pc = src.indexOf(37);
if(pc == srcb_length - 1){
return null;
}else{
if(pc != -1){
srcb_length = pc;
}
int colonp = -1;
int i = 0;
int j = 0;
if(srcb[i] == ':'){
++i;
if(srcb[i] != ':'){
return null;
}
}
int curtok = i;
boolean saw_xdigit = false;
int val = 0;
while(true){
int n;
while(i < srcb_length){
char ch = srcb[i++];
n = digit(ch, 16);
if(n != -1){
val <<= 4;
val |= n;
if(val > 65535){
return null;
}
saw_xdigit = true;
}else{
if(ch != ':'){
if(ch == '.' && j + 4 <= 16){
String ia4 = src.substring(curtok, srcb_length);
int dot_count = 0;
for(int index = 0; (index = ia4.indexOf(46, index)) != -1; ++index){
++dot_count;
}
if(dot_count != 3){
return null;
}
byte[] v4addr = getIpv4(ia4);
if(v4addr == null){
return null;
}
for(int k = 0; k < 4; ++k){
dst[j++] = v4addr[k];
}
saw_xdigit = false;
break;
}
return null;
}
curtok = i;
if(!saw_xdigit){
if(colonp != -1){
return null;
}
colonp = j;
}else{
if(i == srcb_length){
return null;
}
if(j + 2 > 16){
return null;
}
dst[j++] = (byte)(val >> 8 & 255);
dst[j++] = (byte)(val & 255);
saw_xdigit = false;
val = 0;
}
}
}
if(saw_xdigit){
if(j + 2 > 16){
return null;
}
dst[j++] = (byte)(val >> 8 & 255);
dst[j++] = (byte)(val & 255);
}
if(colonp != -1){
n = j - colonp;
if(j == 16){
return null;
}
for(i = 1; i <= n; ++i){
dst[16 - i] = dst[colonp + n - i];
dst[colonp + n - i] = 0;
}
j = 16;
}
if(j != 16){
return null;
}
byte[] newdst = convertFromIPv4MappedAddress(dst);
if(newdst != null){
return newdst;
}
return dst;
}
}
}
}
private static boolean isIPv4MappedAddress(byte[] addr){
if(addr.length < 16){
return false;
}else{
return addr[0] == 0 && addr[1] == 0 && addr[2] == 0 && addr[3] == 0 && addr[4] == 0 && addr[5] == 0 && addr[6] == 0 && addr[7] == 0 && addr[8] == 0 && addr[9] == 0 && addr[10] == -1 && addr[11] == -1;
}
}
private static byte[] convertFromIPv4MappedAddress(byte[] addr){
if(isIPv4MappedAddress(addr)){
byte[] newAddr = new byte[4];
System.arraycopy(addr, 12, newAddr, 0, 4);
return newAddr;
}else{
return null;
}
}
private static int digit(char ch, int radix){
return parseAsciiDigit(ch, radix);
}
private static int parseAsciiDigit(char c, int radix){
if(radix == 16){
char c1 = Character.toLowerCase(c);
return c1 >= 'a' && c1 <= 'f' ? c1 - 97 + 10 : parseAsciiDigit(c1, 10);
}else{
int val = c - 48;
return val >= 0 && val < radix ? val : -1;
}
}
}

View File

@@ -5,7 +5,6 @@ import arc.func.*;
import arc.math.*; import arc.math.*;
import arc.net.*; import arc.net.*;
import arc.net.FrameworkMessage.*; import arc.net.FrameworkMessage.*;
import arc.net.Server.*;
import arc.net.dns.*; import arc.net.dns.*;
import arc.struct.*; import arc.struct.*;
import arc.util.*; import arc.util.*;
@@ -165,7 +164,7 @@ public class ArcNetProvider implements NetProvider{
} }
@Override @Override
public @Nullable ServerConnectFilter getConnectFilter(){ public @Nullable Server.ServerConnectFilter getConnectFilter(){
return server.getConnectFilter(); return server.getConnectFilter();
} }
@@ -225,39 +224,47 @@ public class ArcNetProvider implements NetProvider{
@Override @Override
public void pingHost(String address, int port, Cons<Host> valid, Cons<Exception> invalid){ public void pingHost(String address, int port, Cons<Host> valid, Cons<Exception> invalid){
try{ //TODO: main executor or not?
var host = pingHostImpl(address, port); //mainExecutor.submit(() -> {
Core.app.post(() -> valid.get(host));
}catch(IOException e){ pingHostImpl(address, port, host -> Core.app.post(() -> valid.get(host)), e -> {
if(port == Vars.port){ //raw IP addresses can't have SRV records, so don't bother checking
for(var record : ArcDns.getSrvRecords("_mindustry._tcp." + address)){ if(port == Vars.port && Addresses.getAddress(address) == null){
try{ Dns.resolveSrv("_mindustry._tcp." + address, records -> {
var host = pingHostImpl(record.target, record.port); records.sort();
Core.app.post(() -> valid.get(host)); pingRecords(records, 0, host1 -> Core.app.post(() -> valid.get(host1)), srvError -> Core.app.post(() -> invalid.get(e)));
return; }, srvError -> Core.app.post(() -> invalid.get(e)));
}catch(IOException ignored){ }else{
} Core.app.post(() -> invalid.get(e));
}
} }
Core.app.post(() -> invalid.get(e)); });
}
//});
} }
private Host pingHostImpl(String address, int port) throws IOException{ private void pingRecords(Seq<SRVRecord> records, int index, Cons<Host> valid, Cons<Exception> invalid){
try(DatagramSocket socket = new DatagramSocket()){ if(index >= records.size){
long time = Time.millis(); invalid.get(new UnknownHostException());
return;
socket.send(new DatagramPacket(new byte[]{-2, 1}, 2, InetAddress.getByName(address), port));
socket.setSoTimeout(2000);
DatagramPacket packet = packetSupplier.get();
socket.receive(packet);
ByteBuffer buffer = ByteBuffer.wrap(packet.getData());
Host host = NetworkIO.readServerData((int)Time.timeSinceMillis(time), packet.getAddress().getHostAddress(), buffer);
host.port = port;
return host;
} }
var record = records.get(index);
pingHostImpl(record.target, record.port, valid, error -> pingRecords(records, index + 1, valid, invalid));
}
private void pingHostImpl(String address, int port, Cons<Host> valid, Cons<Exception> error){
long time = Time.millis();
Dns.resolveAddress(address, inetaddr -> {
var socket = new InetSocketAddress(inetaddr, port);
AsyncUdp.send(socket, 2000, 512, ByteBuffer.wrap(new byte[]{-2, 1}), data -> {
Host host = NetworkIO.readServerData((int)Time.timeSinceMillis(time), socket.getAddress().getHostAddress(), data);
host.port = port;
valid.get(host);
}, error);
}, error);
} }
@Override @Override

View File

@@ -0,0 +1,158 @@
package mindustry.net;
import arc.func.*;
import arc.util.*;
import java.io.*;
import java.net.*;
import java.nio.*;
import java.nio.channels.*;
import java.util.*;
import java.util.concurrent.*;
public class AsyncUdp{
static Selector selector;
static DelayQueue<Request> removals = new DelayQueue<>();
static TaskQueue tasks = new TaskQueue();
static int emptySelects;
static{
try{
selector = Selector.open();
//handle requests and tasks
Threads.daemon("AsyncUDP", () -> {
while(true){
try{
long startTime = Time.millis();
int selected = selector.select(0);
tasks.run();
if(selected == 0){
//prevent hogging the CPU due to empty selects as per Kryonet implementation
if(emptySelects++ >= 100){
emptySelects = 0;
long elapsedTime = System.currentTimeMillis() - startTime;
if(elapsedTime < 25) Threads.sleep(25 - elapsedTime);
}
continue;
}
var keys = selector.selectedKeys();
for(Iterator<SelectionKey> iter = keys.iterator(); iter.hasNext(); ){
var key = iter.next();
iter.remove();
if(key.isReadable() && key.isValid()){
var request = (Request)key.attachment();
try{
var channel = (DatagramChannel)key.channel();
var buffer = ByteBuffer.allocate(request.bufferSize);
channel.receive(buffer);
buffer.position(0);
buffer.limit(buffer.capacity());
request.received.get(buffer);
request.close();
}catch(Exception error){
request.fail(error);
}
}
}
}catch(Throwable e){
//should not happen
Log.err(e);
}
}
});
//remove requests with the delay queue
Threads.daemon("AsyncUDP-Delay", () -> {
while(true){
try{
var request = removals.take();
tasks.post(() -> request.fail(new TimeoutException()));
selector.wakeup();
}catch(InterruptedException ignored){}
}
});
}catch(IOException e){
throw new ArcRuntimeException(e);
}
}
public static void send(InetSocketAddress address, int timeout, int bufferSize, ByteBuffer data, Cons<ByteBuffer> received, Cons<Exception> failed){
tasks.post(() -> {
try{
DatagramChannel channel = selector.provider().openDatagramChannel();
channel.configureBlocking(false);
channel.connect(address);
channel.send(data, address);
SelectionKey key = channel.register(selector, SelectionKey.OP_READ);
Request req = new Request(address, timeout, bufferSize, data, channel, key, received, failed);
key.attach(req);
removals.offer(req);
}catch(Exception e){
failed.get(e);
}
});
selector.wakeup();
}
static class Request implements Delayed{
final InetSocketAddress address;
final long timeout, connectStartMs;
final int bufferSize;
final ByteBuffer data;
final Cons<ByteBuffer> received;
final Cons<Exception> failed;
final DatagramChannel channel;
final SelectionKey key;
boolean closed = false;
public Request(InetSocketAddress address, long timeout, int bufferSize, ByteBuffer data, DatagramChannel channel, SelectionKey key, Cons<ByteBuffer> received, Cons<Exception> failed){
this.address = address;
this.timeout = timeout;
this.bufferSize = bufferSize;
this.data = data;
this.received = received;
this.failed = failed;
this.channel = channel;
this.key = key;
this.connectStartMs = Time.millis();
}
void close(){
try{
closed = true;
key.cancel();
channel.close();
}catch(Exception close){
close.printStackTrace();
}
}
void fail(Exception error){
if(!closed){
failed.get(error);
close();
}
}
@Override
public long getDelay(TimeUnit unit){
return unit.convert(timeout - Time.timeSinceMillis(connectStartMs), TimeUnit.MILLISECONDS);
}
@Override
public int compareTo(Delayed o){
return Long.compare(getDelay(TimeUnit.MILLISECONDS), o.getDelay(TimeUnit.MILLISECONDS));
}
}
}

View File

@@ -0,0 +1,207 @@
package mindustry.net;
import arc.func.*;
import arc.math.*;
import arc.net.dns.*;
import arc.struct.*;
import arc.util.*;
import java.net.*;
import java.nio.*;
import java.nio.channels.*;
import java.util.concurrent.*;
public class Dns{
private static final int aRecord = 1, aaaaRecord = 28, srvRecord = 33;
private static IntMap<ObjectMap<String, Seq<?>>> cache = new IntMap<>(); //TODO remove this cache?
private static ConcurrentHashMap<String, InetAddress> domainToIp = new ConcurrentHashMap<>();
static <T> void resolve(int type, String domain, Func<ByteBuffer, T> reader, Cons<Seq<T>> result, Cons<Exception> error){
ObjectMap<String, Seq<?>> map;
synchronized(cache){
map = cache.get(type, ObjectMap::new);
//TODO timeout
if(map.containsKey(domain)){
result.get((Seq<T>)map.get(domain));
return;
}
}
send(ArcDns.getNameservers(), 0, type, domain, reader, records -> {
synchronized(cache){
//cache the records
map.put(domain, records);
}
result.get(records);
}, error);
}
static void resolveSrv(String domain, Cons<Seq<SRVRecord>> result, Cons<Exception> error){
resolve(srvRecord, domain, bytes -> {
int priority = bytes.getShort() & 0xFFFF;
int weight = bytes.getShort() & 0xFFFF;
int port = bytes.getShort() & 0xFFFF;
int len;
StringBuilder builder = new StringBuilder();
while((len = bytes.get()) != 0){
for(int j = 0; j < len; j++) builder.append((char)bytes.get());
builder.append('.');
}
builder.delete(builder.length() - 1, builder.length());
return new SRVRecord(0, priority, weight, port, builder.toString());
}, records -> {
if(records.size > 0){
result.get(records);
}else{
//no SRV records, just call it an error
error.get(new UnknownHostException());
}
}, error);
}
static void resolveAddress(String domain, Cons<InetAddress> result, Cons<Exception> error){
//since parsing the address may be slow, check the cache first.
var cachedIp = domainToIp.get(domain);
if(cachedIp != null){
try{
result.get(cachedIp);
}catch(Exception e){
error.get(e);
}
return;
}
//attempt to resolve ipv4 or ipv6 address
byte[] rawAddress = Addresses.getAddress(domain);
if(rawAddress != null){
try{
var address = InetAddress.getByAddress(domain, rawAddress);
domainToIp.put(domain, address);
result.get(address);
}catch(Exception e){
error.get(e);
}
return;
}
resolve(aRecord, domain, bytes -> {
byte[] address = new byte[4];
bytes.get(address);
return address;
}, addresses -> {
try{
if(addresses.size > 0){
var address = InetAddress.getByAddress(addresses.get(0));
domainToIp.put(domain, address);
result.get(address);
}else{
//there are no records found - try AAAA instead.
resolve(aaaaRecord, domain, bytes -> {
byte[] address = new byte[16];
bytes.get(address);
return address;
}, addresses2 -> {
try{
if(addresses2.size > 0){
var address = InetAddress.getByAddress(addresses2.get(0));
domainToIp.put(domain, address);
result.get(address);
}else{
//there are no records found
error.get(new UnresolvedAddressException());
}
}catch(UnknownHostException unknown){
error.get(unknown);
}
}, error);
}
}catch(UnknownHostException unknown){
error.get(unknown);
}
}, error);
}
static <T> void send(Seq<InetSocketAddress> addresses, int index, int type, String domain, Func<ByteBuffer, T> reader, Cons<Seq<T>> recordResult, Cons<Exception> error){
short id = (short)new Rand().nextInt(Short.MAX_VALUE);
ByteBuffer buffer = ByteBuffer.allocate(512);
buffer.putShort(id); // Id
buffer.putShort((short) 0x0100); // Flags (recursion enabled)
buffer.putShort((short) 1); // Questions
buffer.putShort((short) 0); // Answers
buffer.putShort((short) 0); // Authority
buffer.putShort((short) 0); // Additional
// Domain
for(String part : domain.split("\\.")){
buffer.put((byte) part.length());
buffer.put(part.getBytes(Strings.utf8));
}
buffer.put((byte) 0);
buffer.putShort((short) type); // Type
buffer.putShort((short) 1); // Class (Internet)
buffer.flip();
AsyncUdp.send(addresses.get(index), 2000, 512, buffer, result -> {
short responseId = result.getShort();
if(responseId != id) {
throw new ArcRuntimeException("Invalid response ID");
}
result.getShort();
result.getShort();
int answers = result.getShort() & 0xFFFF;
result.getShort();
result.getShort();
byte len;
while((len = result.get()) != 0) {
result.position(result.position() + len);
}
result.getShort();
result.getShort();
var records = new Seq<T>(answers);
for(int i = 0; i < answers; i++) {
result.getShort(); // OFFSET
int answerType = result.getShort() & 0xFFFF; // Type
result.getShort(); // Class
result.getInt(); // TTL
int length = result.getShort() & 0xFFFF; // Data length
// Optionally CNAME results will be returned with the A results, skip those
if(answerType != type){
result.position(result.position() + length);
continue;
}
int position = result.position();
records.add(reader.get(result));
result.position(position + length);
}
recordResult.get(records);
}, e -> {
if(index >= addresses.size - 1){
error.get(e);
}else{
send(addresses, index + 1, type, domain, reader, recordResult, error);
}
});
}
}

View File

@@ -15,7 +15,6 @@ import net.jpountz.lz4.*;
import java.io.*; import java.io.*;
import java.nio.*; import java.nio.*;
import java.nio.channels.*; import java.nio.channels.*;
import java.util.concurrent.*;
import static arc.util.Log.*; import static arc.util.Log.*;
import static mindustry.Vars.*; import static mindustry.Vars.*;
@@ -35,7 +34,6 @@ public class Net{
private final ObjectMap<Class<?>, Cons> clientListeners = new ObjectMap<>(); private final ObjectMap<Class<?>, Cons> clientListeners = new ObjectMap<>();
private final ObjectMap<Class<?>, Cons2<NetConnection, Object>> serverListeners = new ObjectMap<>(); private final ObjectMap<Class<?>, Cons2<NetConnection, Object>> serverListeners = new ObjectMap<>();
private final IntMap<StreamBuilder> streams = new IntMap<>(); private final IntMap<StreamBuilder> streams = new IntMap<>();
private final ExecutorService pingExecutor = OS.isWindows && !OS.is64Bit ? Threads.boundedExecutor("Ping Servers", 5) : Threads.unboundedExecutor();
private final NetProvider provider; private final NetProvider provider;
@@ -340,7 +338,7 @@ public class Net{
* If the port is the default mindustry port, SRV records are checked too. * If the port is the default mindustry port, SRV records are checked too.
*/ */
public void pingHost(String address, int port, Cons<Host> valid, Cons<Exception> failed){ public void pingHost(String address, int port, Cons<Host> valid, Cons<Exception> failed){
pingExecutor.submit(() -> provider.pingHost(address, port, valid, failed)); provider.pingHost(address, port, valid, failed);
} }
/** /**

View File

@@ -25,4 +25,4 @@ org.gradle.caching=true
#used for slow jitpack builds; TODO see if this actually works #used for slow jitpack builds; TODO see if this actually works
org.gradle.internal.http.socketTimeout=100000 org.gradle.internal.http.socketTimeout=100000
org.gradle.internal.http.connectionTimeout=100000 org.gradle.internal.http.connectionTimeout=100000
archash=137d14855f archash=008e73aa33