forked from Hyperion/HyperionOS
Potential windows case insensitive filesystem issue fix
This commit is contained in:
556
Src/Hyperion-kernel/lib/modules/hyperion/20_socket.kmod
Normal file
556
Src/Hyperion-kernel/lib/modules/hyperion/20_socket.kmod
Normal file
@@ -0,0 +1,556 @@
|
||||
-- :Minify:--
|
||||
-- Supports:
|
||||
-- AF_UNIX - local IPC via /var/run/*.sock paths
|
||||
-- AF_INET - network sockets with three backends:
|
||||
-- rednet://0.0.B.C or rednet+PROTO://0.0.B.C -> CC rednet (computer B*256+C)
|
||||
-- modem://0.0.B.C -> raw CC modem frames
|
||||
-- http://host/path or https://... -> HTTP via CC http API
|
||||
-- A.B.C.D (dotted quad, non-zero A) -> HTTP
|
||||
--
|
||||
-- Socket lifecycle:
|
||||
-- fd = syscall.socket(domain, socktype) -- "unix"/"inet", "stream"/"dgram"
|
||||
-- syscall.bind(fd, address) -- server: claim address
|
||||
-- syscall.listen(fd, backlog) -- server: mark as listening
|
||||
-- cfd = syscall.accept(fd) -- server: get connected client fd (blocking poll)
|
||||
-- syscall.connect(fd, address) -- client: connect to server
|
||||
-- syscall.send(fd, data) -- send bytes
|
||||
-- syscall.recv(fd, len) -- receive bytes (blocking poll, returns "" on nothing)
|
||||
-- syscall.sockshutdown(fd) -- half-close send side
|
||||
-- -- normal vfs.close(fd) closes the socket
|
||||
|
||||
local kernel = ...
|
||||
|
||||
local sockets = {}
|
||||
local unixSocks = {}
|
||||
local nextSockId = 1
|
||||
|
||||
local function allocSockId()
|
||||
local id = nextSockId
|
||||
nextSockId = nextSockId + 1
|
||||
return id
|
||||
end
|
||||
|
||||
local function parseAddress(addr)
|
||||
if not addr then error("EINVAL") end
|
||||
|
||||
if addr:sub(1,1) == "/" or addr:sub(1,5) == "unix:" then
|
||||
local path = addr:sub(1,5) == "unix:" and addr:sub(6) or addr
|
||||
return { backend="unix", path=path }
|
||||
end
|
||||
|
||||
local rproto, raddr = addr:match("^rednet%+?([^:/]*)://(.+)$")
|
||||
if raddr then
|
||||
local a,b,c,d = raddr:match("^(%d+)%.(%d+)%.(%d+)%.(%d+)$")
|
||||
if not a then error("EINVAL: bad rednet address " .. raddr) end
|
||||
local compId = tonumber(c)*256 + tonumber(d)
|
||||
return { backend="rednet", compId=compId,
|
||||
protocol=(rproto ~= "" and rproto or "hyperion") }
|
||||
end
|
||||
|
||||
local maddr = addr:match("^modem://(.+)$")
|
||||
if maddr then
|
||||
local a,b,c,d = maddr:match("^(%d+)%.(%d+)%.(%d+)%.(%d+)$")
|
||||
if not a then error("EINVAL: bad modem address " .. maddr) end
|
||||
local compId = tonumber(c)*256 + tonumber(d)
|
||||
local port = tonumber(maddr:match(":(%d+)$")) or 0
|
||||
return { backend="modem", compId=compId, port=port }
|
||||
end
|
||||
|
||||
local scheme, rest = addr:match("^(https?)://(.+)$")
|
||||
if scheme then
|
||||
return { backend=scheme, url=addr }
|
||||
end
|
||||
|
||||
local a,b,c,d = addr:match("^(%d+)%.(%d+)%.(%d+)%.(%d+)")
|
||||
if a and tonumber(a) ~= 0 then
|
||||
return { backend="http", url="http://" .. addr }
|
||||
end
|
||||
|
||||
error("EINVAL: unrecognised address format: " .. tostring(addr))
|
||||
end
|
||||
|
||||
local rednetOpen = false
|
||||
local function ensureRednet()
|
||||
if rednetOpen then return end
|
||||
local rn = kernel.apis and kernel.apis.rednet
|
||||
if not rn then error("ENODEV: no rednet API available") end
|
||||
local peripheral = kernel.apis.peripheral
|
||||
if peripheral then
|
||||
for _, name in ipairs(peripheral.getNames and peripheral.getNames() or {}) do
|
||||
if peripheral.getType(name) == "modem" then
|
||||
pcall(rn.open, name)
|
||||
end
|
||||
end
|
||||
end
|
||||
rednetOpen = true
|
||||
end
|
||||
|
||||
local function getModem()
|
||||
local peripheral = kernel.apis and kernel.apis.peripheral
|
||||
if not peripheral then error("ENODEV") end
|
||||
for _, name in ipairs(peripheral.getNames and peripheral.getNames() or {}) do
|
||||
if peripheral.getType(name) == "modem" then
|
||||
local m = peripheral.wrap(name)
|
||||
if m then return m, name end
|
||||
end
|
||||
end
|
||||
error("ENODEV: no modem peripheral found")
|
||||
end
|
||||
|
||||
local function pumpEvents()
|
||||
local ev = kernel.computer:getMachineEvent()
|
||||
while ev do
|
||||
if ev == "rednet_message" then
|
||||
for _, sock in pairs(sockets) do
|
||||
if sock.backend == "rednet" and sock.bound then
|
||||
if sock.address.protocol == tostring(select(4, table.unpack({ev}))) or
|
||||
sock.address.protocol == "hyperion" then
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
ev = kernel.computer:getMachineEvent()
|
||||
end
|
||||
end
|
||||
|
||||
local function pollEvent()
|
||||
local results = table.pack(kernel.computer:getMachineEvent())
|
||||
if results.n == 0 or results[1] == nil then return nil end
|
||||
return results
|
||||
end
|
||||
|
||||
local function dispatchEvent(ev)
|
||||
if not ev then return end
|
||||
local evtype = ev[1]
|
||||
|
||||
if evtype == "rednet_message" then
|
||||
local senderId = ev[2]
|
||||
local message = ev[3]
|
||||
local protocol = ev[4] or "hyperion"
|
||||
for _, sock in pairs(sockets) do
|
||||
if sock.backend == "rednet" and (sock.listening or sock.connected) then
|
||||
if sock.address and sock.address.protocol == protocol then
|
||||
table.insert(sock.rxbuf, { from=senderId, data=message })
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
elseif evtype == "modem_message" then
|
||||
local channel = ev[3]
|
||||
local msg = ev[5]
|
||||
local fromCh = ev[4]
|
||||
for _, sock in pairs(sockets) do
|
||||
if sock.backend == "modem" and sock.modemChannel == channel then
|
||||
table.insert(sock.rxbuf, { from=fromCh, data=msg })
|
||||
end
|
||||
end
|
||||
|
||||
elseif evtype == "http_success" then
|
||||
local url = ev[2]
|
||||
local handle = ev[3]
|
||||
for _, sock in pairs(sockets) do
|
||||
if sock.backend == "http" or sock.backend == "https" then
|
||||
if sock.pendingUrl == url then
|
||||
local body = handle.readAll and handle.readAll() or ""
|
||||
handle.close()
|
||||
table.insert(sock.rxbuf, { data=body, done=true })
|
||||
sock.pendingUrl = nil
|
||||
sock.connected = true
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
elseif evtype == "http_failure" then
|
||||
local url = ev[2]
|
||||
local err = ev[3]
|
||||
for _, sock in pairs(sockets) do
|
||||
if (sock.backend == "http" or sock.backend == "https") and
|
||||
sock.pendingUrl == url then
|
||||
sock.error = err
|
||||
sock.pendingUrl = nil
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function pumpAll()
|
||||
local ev = pollEvent()
|
||||
while ev do
|
||||
dispatchEvent(ev)
|
||||
ev = pollEvent()
|
||||
end
|
||||
end
|
||||
|
||||
local function newSocket(domain, socktype)
|
||||
local sock = {
|
||||
id = allocSockId(),
|
||||
domain = domain, -- "unix" | "inet"
|
||||
socktype = socktype, -- "stream" | "dgram"
|
||||
backend = nil,
|
||||
state = "idle", -- idle | bound | listening | connected | closed
|
||||
rxbuf = {},
|
||||
txbuf = {},
|
||||
backlog = {},
|
||||
address = nil,
|
||||
peer = nil,
|
||||
modemChannel = nil,
|
||||
modem = nil,
|
||||
pendingUrl = nil,
|
||||
bound = false,
|
||||
listening = false,
|
||||
connected = false,
|
||||
error = nil,
|
||||
}
|
||||
sockets[sock.id] = sock
|
||||
return sock
|
||||
end
|
||||
|
||||
local sockSend, sockClose
|
||||
|
||||
local function socketToFd(sock)
|
||||
return {
|
||||
isSocket = true,
|
||||
sockId = sock.id,
|
||||
mode = "rw",
|
||||
meta = { etype=0, owner=0, group=0, perms=0x1FF, cmeta="" },
|
||||
type = "socket",
|
||||
refcount = 1,
|
||||
handle = {
|
||||
read = function(count)
|
||||
pumpAll()
|
||||
if #sock.rxbuf == 0 then return "" end
|
||||
local item = table.remove(sock.rxbuf, 1)
|
||||
local data = type(item) == "table" and (item.data or "") or tostring(item)
|
||||
if count and #data > count then
|
||||
table.insert(sock.rxbuf, 1, { data=data:sub(count+1), from=item.from })
|
||||
data = data:sub(1, count)
|
||||
end
|
||||
return data
|
||||
end,
|
||||
write = function(data)
|
||||
if sock.state == "closed" then error("EBADF") end
|
||||
return sockSend(sock, data)
|
||||
end,
|
||||
close = function()
|
||||
sockClose(sock)
|
||||
end,
|
||||
}
|
||||
}
|
||||
end
|
||||
|
||||
sockSend = function(sock, data)
|
||||
if sock.backend == "unix" then
|
||||
local peer = sock.peer
|
||||
if not peer then error("ENOTCONN") end
|
||||
table.insert(peer.rxbuf, { data=data })
|
||||
return #data
|
||||
|
||||
elseif sock.backend == "rednet" then
|
||||
ensureRednet()
|
||||
local rn = kernel.apis.rednet
|
||||
rn.send(sock.address.compId, data, sock.address.protocol)
|
||||
return #data
|
||||
|
||||
elseif sock.backend == "modem" then
|
||||
local modem = sock.modem
|
||||
if not modem then error("ENOTCONN") end
|
||||
modem.transmit(sock.address.port, sock.modemChannel or 0, data)
|
||||
return #data
|
||||
|
||||
elseif sock.backend == "http" or sock.backend == "https" then
|
||||
local http = kernel.apis and kernel.apis.http
|
||||
if not http then error("ENODEV: no http API") end
|
||||
local url = sock.address.url
|
||||
local ok, err = pcall(http.request, url, data, {
|
||||
["Content-Type"] = "application/octet-stream"
|
||||
})
|
||||
if not ok then error("ENETDOWN: " .. tostring(err)) end
|
||||
sock.pendingUrl = url
|
||||
return #data
|
||||
end
|
||||
error("EPROTONOSUPPORT")
|
||||
end
|
||||
|
||||
sockClose = function(sock)
|
||||
if sock.state == "closed" then return end
|
||||
sock.state = "closed"
|
||||
|
||||
if sock.backend == "unix" then
|
||||
if sock.peer then
|
||||
sock.peer.peer = nil
|
||||
sock.peer.state = "closed"
|
||||
end
|
||||
if sock.bound and sock.address and sock.address.path then
|
||||
unixSocks[sock.address.path] = nil
|
||||
end
|
||||
|
||||
elseif sock.backend == "modem" and sock.modem and sock.modemChannel then
|
||||
pcall(sock.modem.close, sock.modemChannel)
|
||||
|
||||
elseif sock.backend == "rednet" then
|
||||
end
|
||||
|
||||
sockets[sock.id] = nil
|
||||
end
|
||||
|
||||
kernel.syscalls["socket"] = function(domain, socktype)
|
||||
domain = domain or "inet"
|
||||
socktype = socktype or "stream"
|
||||
if domain ~= "unix" and domain ~= "inet" then error("EAFNOSUPPORT") end
|
||||
if socktype ~= "stream" and socktype ~= "dgram" then error("EPROTOTYPE") end
|
||||
|
||||
local sock = newSocket(domain, socktype)
|
||||
local fdobj = socketToFd(sock)
|
||||
local fd = kernel.vfs.newfd(fdobj)
|
||||
return fd
|
||||
end
|
||||
|
||||
kernel.syscalls["bind"] = function(fd, address)
|
||||
local task = kernel.currentTask
|
||||
local fdobj = task.fd[fd]
|
||||
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
||||
local sock = sockets[fdobj.sockId]
|
||||
if not sock then error("EBADF") end
|
||||
if sock.bound then error("EINVAL") end
|
||||
|
||||
local parsed = parseAddress(address)
|
||||
|
||||
if parsed.backend == "unix" then
|
||||
local existing = unixSocks[parsed.path]
|
||||
if existing then
|
||||
if existing.state == "closed" then
|
||||
unixSocks[parsed.path] = nil
|
||||
else
|
||||
error("EADDRINUSE")
|
||||
end
|
||||
end
|
||||
sock.backend = "unix"
|
||||
sock.address = parsed
|
||||
sock.bound = true
|
||||
sock.state = "bound"
|
||||
unixSocks[parsed.path] = sock
|
||||
|
||||
elseif parsed.backend == "rednet" then
|
||||
ensureRednet()
|
||||
sock.backend = "rednet"
|
||||
sock.address = parsed
|
||||
sock.bound = true
|
||||
sock.state = "bound"
|
||||
|
||||
elseif parsed.backend == "modem" then
|
||||
local modem, side = getModem()
|
||||
sock.backend = "modem"
|
||||
sock.address = parsed
|
||||
sock.modem = modem
|
||||
sock.modemChannel = parsed.port
|
||||
sock.bound = true
|
||||
sock.state = "bound"
|
||||
modem.open(parsed.port)
|
||||
|
||||
else
|
||||
error("EOPNOTSUPP: cannot bind to " .. parsed.backend .. " address")
|
||||
end
|
||||
end
|
||||
|
||||
kernel.syscalls["listen"] = function(fd, backlog)
|
||||
local task = kernel.currentTask
|
||||
local fdobj = task.fd[fd]
|
||||
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
||||
local sock = sockets[fdobj.sockId]
|
||||
if not sock then error("EBADF") end
|
||||
if not sock.bound then error("EDESTADDRREQ") end
|
||||
sock.listening = true
|
||||
sock.state = "listening"
|
||||
sock.maxBacklog = backlog or 5
|
||||
end
|
||||
|
||||
kernel.syscalls["accept"] = function(fd)
|
||||
local task = kernel.currentTask
|
||||
local fdobj = task.fd[fd]
|
||||
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
||||
local sock = sockets[fdobj.sockId]
|
||||
if not sock then error("EBADF") end
|
||||
if not sock.listening then error("EINVAL") end
|
||||
|
||||
local deadline = kernel.computer:time() + 30000
|
||||
while #sock.backlog == 0 do
|
||||
pumpAll()
|
||||
if kernel.computer:time() > deadline then error("ETIMEDOUT") end
|
||||
coroutine.yield()
|
||||
end
|
||||
|
||||
local clientSock = table.remove(sock.backlog, 1)
|
||||
local cfdobj = socketToFd(clientSock)
|
||||
local newfd = kernel.vfs.newfd(cfdobj)
|
||||
return newfd
|
||||
end
|
||||
|
||||
kernel.syscalls["connect"] = function(fd, address)
|
||||
local task = kernel.currentTask
|
||||
local fdobj = task.fd[fd]
|
||||
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
||||
local sock = sockets[fdobj.sockId]
|
||||
if not sock then error("EBADF") end
|
||||
if sock.connected then error("EISCONN") end
|
||||
|
||||
local parsed = parseAddress(address)
|
||||
sock.address = parsed
|
||||
sock.backend = parsed.backend
|
||||
|
||||
if parsed.backend == "unix" then
|
||||
local server = unixSocks[parsed.path]
|
||||
if not server then error("ECONNREFUSED") end
|
||||
if not server.listening then error("ECONNREFUSED") end
|
||||
if #server.backlog >= (server.maxBacklog or 5) then error("ECONNREFUSED") end
|
||||
|
||||
local serverPeer = newSocket("unix", sock.socktype)
|
||||
serverPeer.backend = "unix"
|
||||
serverPeer.connected = true
|
||||
serverPeer.state = "connected"
|
||||
serverPeer.peer = sock
|
||||
|
||||
sock.peer = serverPeer
|
||||
sock.connected = true
|
||||
sock.state = "connected"
|
||||
|
||||
table.insert(server.backlog, serverPeer)
|
||||
|
||||
elseif parsed.backend == "rednet" then
|
||||
ensureRednet()
|
||||
sock.connected = true
|
||||
sock.state = "connected"
|
||||
|
||||
elseif parsed.backend == "modem" then
|
||||
local modem, side = getModem()
|
||||
local replyChannel = math.random(1024, 65534)
|
||||
sock.modem = modem
|
||||
sock.modemChannel = replyChannel
|
||||
sock.connected = true
|
||||
sock.state = "connected"
|
||||
modem.open(replyChannel)
|
||||
|
||||
elseif parsed.backend == "http" or parsed.backend == "https" then
|
||||
sock.connected = true
|
||||
sock.state = "connected"
|
||||
|
||||
else
|
||||
error("EAFNOSUPPORT")
|
||||
end
|
||||
end
|
||||
|
||||
kernel.syscalls["send"] = function(fd, data)
|
||||
local task = kernel.currentTask
|
||||
local fdobj = task.fd[fd]
|
||||
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
||||
local sock = sockets[fdobj.sockId]
|
||||
if not sock then error("EBADF") end
|
||||
return sockSend(sock, data)
|
||||
end
|
||||
|
||||
kernel.syscalls["recv"] = function(fd, maxlen, timeout_ms)
|
||||
local task = kernel.currentTask
|
||||
local fdobj = task.fd[fd]
|
||||
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
||||
local sock = sockets[fdobj.sockId]
|
||||
if not sock then error("EBADF") end
|
||||
|
||||
local deadline = kernel.computer:time() + (timeout_ms or 10000)
|
||||
while #sock.rxbuf == 0 do
|
||||
pumpAll()
|
||||
if #sock.rxbuf > 0 then break end
|
||||
if sock.state == "closed" or sock.error then
|
||||
if sock.error then error("ECONNRESET: " .. tostring(sock.error)) end
|
||||
return ""
|
||||
end
|
||||
if kernel.computer:time() > deadline then return "" end
|
||||
coroutine.yield()
|
||||
end
|
||||
|
||||
local item = table.remove(sock.rxbuf, 1)
|
||||
local data = type(item) == "table" and (item.data or "") or tostring(item)
|
||||
if maxlen and #data > maxlen then
|
||||
table.insert(sock.rxbuf, 1, { data=data:sub(maxlen+1), from=item and item.from })
|
||||
data = data:sub(1, maxlen)
|
||||
end
|
||||
return data
|
||||
end
|
||||
|
||||
kernel.syscalls["sockshutdown"] = function(fd)
|
||||
local task = kernel.currentTask
|
||||
local fdobj = task.fd[fd]
|
||||
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
||||
local sock = sockets[fdobj.sockId]
|
||||
if sock then sockClose(sock) end
|
||||
end
|
||||
|
||||
kernel.syscalls["getpeername"] = function(fd)
|
||||
local task = kernel.currentTask
|
||||
local fdobj = task.fd[fd]
|
||||
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
||||
local sock = sockets[fdobj.sockId]
|
||||
if not sock or not sock.connected then error("ENOTCONN") end
|
||||
if sock.address then return sock.address end
|
||||
return nil
|
||||
end
|
||||
|
||||
kernel.syscalls["getsockname"] = function(fd)
|
||||
local task = kernel.currentTask
|
||||
local fdobj = task.fd[fd]
|
||||
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
||||
local sock = sockets[fdobj.sockId]
|
||||
if not sock then error("EBADF") end
|
||||
return sock.address
|
||||
end
|
||||
|
||||
kernel.syscalls["httpget"] = function(url, headers)
|
||||
local http = kernel.apis and kernel.apis.http
|
||||
if not http then error("ENODEV: no http API") end
|
||||
|
||||
local ok, err = pcall(http.request, url, nil, headers)
|
||||
if not ok then error("ENETDOWN: " .. tostring(err)) end
|
||||
|
||||
local deadline = kernel.computer:time() + 15000
|
||||
while true do
|
||||
local ev = pollEvent()
|
||||
if ev then
|
||||
if ev[1] == "http_success" and ev[2] == url then
|
||||
local handle = ev[3]
|
||||
local body = handle.readAll and handle.readAll() or ""
|
||||
handle.close()
|
||||
return body
|
||||
elseif ev[1] == "http_failure" and ev[2] == url then
|
||||
error("ECONNREFUSED: " .. tostring(ev[3]))
|
||||
else
|
||||
dispatchEvent(ev)
|
||||
end
|
||||
end
|
||||
if kernel.computer:time() > deadline then error("ETIMEDOUT") end
|
||||
coroutine.yield()
|
||||
end
|
||||
end
|
||||
|
||||
kernel.syscalls["resolve"] = function(hostname)
|
||||
if hostname:match("^%d+%.%d+%.%d+%.%d+$") then return hostname end
|
||||
|
||||
local a,b,c,d = hostname:match("^(%d+)%.(%d+)%.(%d+)%.(%d+)$")
|
||||
if a and tonumber(a) == 0 and tonumber(b) == 0 then
|
||||
return hostname
|
||||
end
|
||||
|
||||
local http = kernel.apis and kernel.apis.http
|
||||
if not http then error("ENODEV: no http API for DNS") end
|
||||
|
||||
local url = "https://cloudflare-dns.com/dns-query?name=" .. hostname .. "&type=A"
|
||||
local body = kernel.syscalls["httpget"](url, {
|
||||
["Accept"] = "application/dns-json"
|
||||
})
|
||||
|
||||
local ip = body:match('"type":1[^}]*"data":"([%d%.]+)"')
|
||||
if not ip then error("ENOENT: could not resolve " .. hostname) end
|
||||
return ip
|
||||
end
|
||||
|
||||
kernel.sockets = sockets
|
||||
kernel.unixSockets = unixSocks
|
||||
|
||||
kernel.log("Loaded socket module")
|
||||
Reference in New Issue
Block a user