557 lines
18 KiB
Plaintext
557 lines
18 KiB
Plaintext
--:Minify:--
|
|
-- Supports:
|
|
-- AF_UNIX - local IPC via /var/run/*.sock paths
|
|
-- AF_INET - network sockets with three backends:
|
|
-- rednet://0.0.B.C or rednet+PROTO://0.0.B.C -> CC rednet (computer B*256+C)
|
|
-- modem://0.0.B.C -> raw CC modem frames
|
|
-- http://host/path or https://... -> HTTP via CC http API
|
|
-- A.B.C.D (dotted quad, non-zero A) -> HTTP
|
|
--
|
|
-- Socket lifecycle:
|
|
-- fd = syscall.socket(domain, socktype) -- "unix"/"inet", "stream"/"dgram"
|
|
-- syscall.bind(fd, address) -- server: claim address
|
|
-- syscall.listen(fd, backlog) -- server: mark as listening
|
|
-- cfd = syscall.accept(fd) -- server: get connected client fd (blocking poll)
|
|
-- syscall.connect(fd, address) -- client: connect to server
|
|
-- syscall.send(fd, data) -- send bytes
|
|
-- syscall.recv(fd, len) -- receive bytes (blocking poll, returns "" on nothing)
|
|
-- syscall.sockshutdown(fd) -- half-close send side
|
|
-- -- normal vfs.close(fd) closes the socket
|
|
|
|
local kernel = ...
|
|
|
|
local sockets = {}
|
|
local unixSocks = {}
|
|
local nextSockId = 1
|
|
|
|
local function allocSockId()
|
|
local id = nextSockId
|
|
nextSockId = nextSockId + 1
|
|
return id
|
|
end
|
|
|
|
local function parseAddress(addr)
|
|
if not addr then error("EINVAL") end
|
|
|
|
if addr:sub(1,1) == "/" or addr:sub(1,5) == "unix:" then
|
|
local path = addr:sub(1,5) == "unix:" and addr:sub(6) or addr
|
|
return { backend="unix", path=path }
|
|
end
|
|
|
|
local rproto, raddr = addr:match("^rednet%+?([^:/]*)://(.+)$")
|
|
if raddr then
|
|
local a,b,c,d = raddr:match("^(%d+)%.(%d+)%.(%d+)%.(%d+)$")
|
|
if not a then error("EINVAL: bad rednet address " .. raddr) end
|
|
local compId = tonumber(c)*256 + tonumber(d)
|
|
return { backend="rednet", compId=compId,
|
|
protocol=(rproto ~= "" and rproto or "hyperion") }
|
|
end
|
|
|
|
local maddr = addr:match("^modem://(.+)$")
|
|
if maddr then
|
|
local a,b,c,d = maddr:match("^(%d+)%.(%d+)%.(%d+)%.(%d+)$")
|
|
if not a then error("EINVAL: bad modem address " .. maddr) end
|
|
local compId = tonumber(c)*256 + tonumber(d)
|
|
local port = tonumber(maddr:match(":(%d+)$")) or 0
|
|
return { backend="modem", compId=compId, port=port }
|
|
end
|
|
|
|
local scheme, rest = addr:match("^(https?)://(.+)$")
|
|
if scheme then
|
|
return { backend=scheme, url=addr }
|
|
end
|
|
|
|
local a,b,c,d = addr:match("^(%d+)%.(%d+)%.(%d+)%.(%d+)")
|
|
if a and tonumber(a) ~= 0 then
|
|
return { backend="http", url="http://" .. addr }
|
|
end
|
|
|
|
error("EINVAL: unrecognised address format: " .. tostring(addr))
|
|
end
|
|
|
|
local rednetOpen = false
|
|
local function ensureRednet()
|
|
if rednetOpen then return end
|
|
local rn = kernel.apis and kernel.apis.rednet
|
|
if not rn then error("ENODEV: no rednet API available") end
|
|
local peripheral = kernel.apis.peripheral
|
|
if peripheral then
|
|
for _, name in ipairs(peripheral.getNames and peripheral.getNames() or {}) do
|
|
if peripheral.getType(name) == "modem" then
|
|
pcall(rn.open, name)
|
|
end
|
|
end
|
|
end
|
|
rednetOpen = true
|
|
end
|
|
|
|
local function getModem()
|
|
local peripheral = kernel.apis and kernel.apis.peripheral
|
|
if not peripheral then error("ENODEV") end
|
|
for _, name in ipairs(peripheral.getNames and peripheral.getNames() or {}) do
|
|
if peripheral.getType(name) == "modem" then
|
|
local m = peripheral.wrap(name)
|
|
if m then return m, name end
|
|
end
|
|
end
|
|
error("ENODEV: no modem peripheral found")
|
|
end
|
|
|
|
local function pumpEvents()
|
|
local ev = kernel.computer:getMachineEvent()
|
|
while ev do
|
|
if ev == "rednet_message" then
|
|
for _, sock in pairs(sockets) do
|
|
if sock.backend == "rednet" and sock.bound then
|
|
if sock.address.protocol == tostring(select(4, table.unpack({ev}))) or
|
|
sock.address.protocol == "hyperion" then
|
|
end
|
|
end
|
|
end
|
|
end
|
|
ev = kernel.computer:getMachineEvent()
|
|
end
|
|
end
|
|
|
|
local function pollEvent()
|
|
local results = table.pack(kernel.computer:getMachineEvent())
|
|
if results.n == 0 or results[1] == nil then return nil end
|
|
return results
|
|
end
|
|
|
|
local function dispatchEvent(ev)
|
|
if not ev then return end
|
|
local evtype = ev[1]
|
|
|
|
if evtype == "rednet_message" then
|
|
local senderId = ev[2]
|
|
local message = ev[3]
|
|
local protocol = ev[4] or "hyperion"
|
|
for _, sock in pairs(sockets) do
|
|
if sock.backend == "rednet" and (sock.listening or sock.connected) then
|
|
if sock.address and sock.address.protocol == protocol then
|
|
table.insert(sock.rxbuf, { from=senderId, data=message })
|
|
end
|
|
end
|
|
end
|
|
|
|
elseif evtype == "modem_message" then
|
|
local channel = ev[3]
|
|
local msg = ev[5]
|
|
local fromCh = ev[4]
|
|
for _, sock in pairs(sockets) do
|
|
if sock.backend == "modem" and sock.modemChannel == channel then
|
|
table.insert(sock.rxbuf, { from=fromCh, data=msg })
|
|
end
|
|
end
|
|
|
|
elseif evtype == "http_success" then
|
|
local url = ev[2]
|
|
local handle = ev[3]
|
|
for _, sock in pairs(sockets) do
|
|
if sock.backend == "http" or sock.backend == "https" then
|
|
if sock.pendingUrl == url then
|
|
local body = handle.readAll and handle.readAll() or ""
|
|
handle.close()
|
|
table.insert(sock.rxbuf, { data=body, done=true })
|
|
sock.pendingUrl = nil
|
|
sock.connected = true
|
|
end
|
|
end
|
|
end
|
|
|
|
elseif evtype == "http_failure" then
|
|
local url = ev[2]
|
|
local err = ev[3]
|
|
for _, sock in pairs(sockets) do
|
|
if (sock.backend == "http" or sock.backend == "https") and
|
|
sock.pendingUrl == url then
|
|
sock.error = err
|
|
sock.pendingUrl = nil
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
local function pumpAll()
|
|
local ev = pollEvent()
|
|
while ev do
|
|
dispatchEvent(ev)
|
|
ev = pollEvent()
|
|
end
|
|
end
|
|
|
|
local function newSocket(domain, socktype)
|
|
local sock = {
|
|
id = allocSockId(),
|
|
domain = domain, -- "unix" | "inet"
|
|
socktype = socktype, -- "stream" | "dgram"
|
|
backend = nil,
|
|
state = "idle", -- idle | bound | listening | connected | closed
|
|
rxbuf = {},
|
|
txbuf = {},
|
|
backlog = {},
|
|
address = nil,
|
|
peer = nil,
|
|
modemChannel = nil,
|
|
modem = nil,
|
|
pendingUrl = nil,
|
|
bound = false,
|
|
listening = false,
|
|
connected = false,
|
|
error = nil,
|
|
}
|
|
sockets[sock.id] = sock
|
|
return sock
|
|
end
|
|
|
|
local sockSend, sockClose
|
|
|
|
local function socketToFd(sock)
|
|
return {
|
|
isSocket = true,
|
|
sockId = sock.id,
|
|
mode = "rw",
|
|
meta = { etype=0, owner=0, group=0, perms=0x1FF, cmeta="" },
|
|
type = "socket",
|
|
refcount = 1,
|
|
handle = {
|
|
read = function(count)
|
|
pumpAll()
|
|
if #sock.rxbuf == 0 then return "" end
|
|
local item = table.remove(sock.rxbuf, 1)
|
|
local data = type(item) == "table" and (item.data or "") or tostring(item)
|
|
if count and #data > count then
|
|
table.insert(sock.rxbuf, 1, { data=data:sub(count+1), from=item.from })
|
|
data = data:sub(1, count)
|
|
end
|
|
return data
|
|
end,
|
|
write = function(data)
|
|
if sock.state == "closed" then error("EBADF") end
|
|
return sockSend(sock, data)
|
|
end,
|
|
close = function()
|
|
sockClose(sock)
|
|
end,
|
|
}
|
|
}
|
|
end
|
|
|
|
sockSend = function(sock, data)
|
|
if sock.backend == "unix" then
|
|
local peer = sock.peer
|
|
if not peer then error("ENOTCONN") end
|
|
table.insert(peer.rxbuf, { data=data })
|
|
return #data
|
|
|
|
elseif sock.backend == "rednet" then
|
|
ensureRednet()
|
|
local rn = kernel.apis.rednet
|
|
rn.send(sock.address.compId, data, sock.address.protocol)
|
|
return #data
|
|
|
|
elseif sock.backend == "modem" then
|
|
local modem = sock.modem
|
|
if not modem then error("ENOTCONN") end
|
|
modem.transmit(sock.address.port, sock.modemChannel or 0, data)
|
|
return #data
|
|
|
|
elseif sock.backend == "http" or sock.backend == "https" then
|
|
local http = kernel.apis and kernel.apis.http
|
|
if not http then error("ENODEV: no http API") end
|
|
local url = sock.address.url
|
|
local ok, err = pcall(http.request, url, data, {
|
|
["Content-Type"] = "application/octet-stream"
|
|
})
|
|
if not ok then error("ENETDOWN: " .. tostring(err)) end
|
|
sock.pendingUrl = url
|
|
return #data
|
|
end
|
|
error("EPROTONOSUPPORT")
|
|
end
|
|
|
|
sockClose = function(sock)
|
|
if sock.state == "closed" then return end
|
|
sock.state = "closed"
|
|
|
|
if sock.backend == "unix" then
|
|
if sock.peer then
|
|
sock.peer.peer = nil
|
|
sock.peer.state = "closed"
|
|
end
|
|
if sock.bound and sock.address and sock.address.path then
|
|
unixSocks[sock.address.path] = nil
|
|
end
|
|
|
|
elseif sock.backend == "modem" and sock.modem and sock.modemChannel then
|
|
pcall(sock.modem.close, sock.modemChannel)
|
|
|
|
elseif sock.backend == "rednet" then
|
|
end
|
|
|
|
sockets[sock.id] = nil
|
|
end
|
|
|
|
kernel.syscalls["socket"] = function(domain, socktype)
|
|
domain = domain or "inet"
|
|
socktype = socktype or "stream"
|
|
if domain ~= "unix" and domain ~= "inet" then error("EAFNOSUPPORT") end
|
|
if socktype ~= "stream" and socktype ~= "dgram" then error("EPROTOTYPE") end
|
|
|
|
local sock = newSocket(domain, socktype)
|
|
local fdobj = socketToFd(sock)
|
|
local fd = kernel.vfs.newfd(fdobj)
|
|
return fd
|
|
end
|
|
|
|
kernel.syscalls["bind"] = function(fd, address)
|
|
local task = kernel.currentTask
|
|
local fdobj = task.fd[fd]
|
|
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
|
local sock = sockets[fdobj.sockId]
|
|
if not sock then error("EBADF") end
|
|
if sock.bound then error("EINVAL") end
|
|
|
|
local parsed = parseAddress(address)
|
|
|
|
if parsed.backend == "unix" then
|
|
local existing = unixSocks[parsed.path]
|
|
if existing then
|
|
if existing.state == "closed" then
|
|
unixSocks[parsed.path] = nil
|
|
else
|
|
error("EADDRINUSE")
|
|
end
|
|
end
|
|
sock.backend = "unix"
|
|
sock.address = parsed
|
|
sock.bound = true
|
|
sock.state = "bound"
|
|
unixSocks[parsed.path] = sock
|
|
|
|
elseif parsed.backend == "rednet" then
|
|
ensureRednet()
|
|
sock.backend = "rednet"
|
|
sock.address = parsed
|
|
sock.bound = true
|
|
sock.state = "bound"
|
|
|
|
elseif parsed.backend == "modem" then
|
|
local modem, side = getModem()
|
|
sock.backend = "modem"
|
|
sock.address = parsed
|
|
sock.modem = modem
|
|
sock.modemChannel = parsed.port
|
|
sock.bound = true
|
|
sock.state = "bound"
|
|
modem.open(parsed.port)
|
|
|
|
else
|
|
error("EOPNOTSUPP: cannot bind to " .. parsed.backend .. " address")
|
|
end
|
|
end
|
|
|
|
kernel.syscalls["listen"] = function(fd, backlog)
|
|
local task = kernel.currentTask
|
|
local fdobj = task.fd[fd]
|
|
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
|
local sock = sockets[fdobj.sockId]
|
|
if not sock then error("EBADF") end
|
|
if not sock.bound then error("EDESTADDRREQ") end
|
|
sock.listening = true
|
|
sock.state = "listening"
|
|
sock.maxBacklog = backlog or 5
|
|
end
|
|
|
|
kernel.syscalls["accept"] = function(fd)
|
|
local task = kernel.currentTask
|
|
local fdobj = task.fd[fd]
|
|
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
|
local sock = sockets[fdobj.sockId]
|
|
if not sock then error("EBADF") end
|
|
if not sock.listening then error("EINVAL") end
|
|
|
|
local deadline = kernel.computer:time() + 30000
|
|
while #sock.backlog == 0 do
|
|
pumpAll()
|
|
if kernel.computer:time() > deadline then error("ETIMEDOUT") end
|
|
coroutine.yield()
|
|
end
|
|
|
|
local clientSock = table.remove(sock.backlog, 1)
|
|
local cfdobj = socketToFd(clientSock)
|
|
local newfd = kernel.vfs.newfd(cfdobj)
|
|
return newfd
|
|
end
|
|
|
|
kernel.syscalls["connect"] = function(fd, address)
|
|
local task = kernel.currentTask
|
|
local fdobj = task.fd[fd]
|
|
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
|
local sock = sockets[fdobj.sockId]
|
|
if not sock then error("EBADF") end
|
|
if sock.connected then error("EISCONN") end
|
|
|
|
local parsed = parseAddress(address)
|
|
sock.address = parsed
|
|
sock.backend = parsed.backend
|
|
|
|
if parsed.backend == "unix" then
|
|
local server = unixSocks[parsed.path]
|
|
if not server then error("ECONNREFUSED") end
|
|
if not server.listening then error("ECONNREFUSED") end
|
|
if #server.backlog >= (server.maxBacklog or 5) then error("ECONNREFUSED") end
|
|
|
|
local serverPeer = newSocket("unix", sock.socktype)
|
|
serverPeer.backend = "unix"
|
|
serverPeer.connected = true
|
|
serverPeer.state = "connected"
|
|
serverPeer.peer = sock
|
|
|
|
sock.peer = serverPeer
|
|
sock.connected = true
|
|
sock.state = "connected"
|
|
|
|
table.insert(server.backlog, serverPeer)
|
|
|
|
elseif parsed.backend == "rednet" then
|
|
ensureRednet()
|
|
sock.connected = true
|
|
sock.state = "connected"
|
|
|
|
elseif parsed.backend == "modem" then
|
|
local modem, side = getModem()
|
|
local replyChannel = math.random(1024, 65534)
|
|
sock.modem = modem
|
|
sock.modemChannel = replyChannel
|
|
sock.connected = true
|
|
sock.state = "connected"
|
|
modem.open(replyChannel)
|
|
|
|
elseif parsed.backend == "http" or parsed.backend == "https" then
|
|
sock.connected = true
|
|
sock.state = "connected"
|
|
|
|
else
|
|
error("EAFNOSUPPORT")
|
|
end
|
|
end
|
|
|
|
kernel.syscalls["send"] = function(fd, data)
|
|
local task = kernel.currentTask
|
|
local fdobj = task.fd[fd]
|
|
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
|
local sock = sockets[fdobj.sockId]
|
|
if not sock then error("EBADF") end
|
|
return sockSend(sock, data)
|
|
end
|
|
|
|
kernel.syscalls["recv"] = function(fd, maxlen, timeout_ms)
|
|
local task = kernel.currentTask
|
|
local fdobj = task.fd[fd]
|
|
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
|
local sock = sockets[fdobj.sockId]
|
|
if not sock then error("EBADF") end
|
|
|
|
local deadline = kernel.computer:time() + (timeout_ms or 10000)
|
|
while #sock.rxbuf == 0 do
|
|
pumpAll()
|
|
if #sock.rxbuf > 0 then break end
|
|
if sock.state == "closed" or sock.error then
|
|
if sock.error then error("ECONNRESET: " .. tostring(sock.error)) end
|
|
return ""
|
|
end
|
|
if kernel.computer:time() > deadline then return "" end
|
|
coroutine.yield()
|
|
end
|
|
|
|
local item = table.remove(sock.rxbuf, 1)
|
|
local data = type(item) == "table" and (item.data or "") or tostring(item)
|
|
if maxlen and #data > maxlen then
|
|
table.insert(sock.rxbuf, 1, { data=data:sub(maxlen+1), from=item and item.from })
|
|
data = data:sub(1, maxlen)
|
|
end
|
|
return data
|
|
end
|
|
|
|
kernel.syscalls["sockshutdown"] = function(fd)
|
|
local task = kernel.currentTask
|
|
local fdobj = task.fd[fd]
|
|
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
|
local sock = sockets[fdobj.sockId]
|
|
if sock then sockClose(sock) end
|
|
end
|
|
|
|
kernel.syscalls["getpeername"] = function(fd)
|
|
local task = kernel.currentTask
|
|
local fdobj = task.fd[fd]
|
|
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
|
local sock = sockets[fdobj.sockId]
|
|
if not sock or not sock.connected then error("ENOTCONN") end
|
|
if sock.address then return sock.address end
|
|
return nil
|
|
end
|
|
|
|
kernel.syscalls["getsockname"] = function(fd)
|
|
local task = kernel.currentTask
|
|
local fdobj = task.fd[fd]
|
|
if not fdobj or not fdobj.isSocket then error("ENOTSOCK") end
|
|
local sock = sockets[fdobj.sockId]
|
|
if not sock then error("EBADF") end
|
|
return sock.address
|
|
end
|
|
|
|
kernel.syscalls["httpget"] = function(url, headers)
|
|
local http = kernel.apis and kernel.apis.http
|
|
if not http then error("ENODEV: no http API") end
|
|
|
|
local ok, err = pcall(http.request, url, nil, headers)
|
|
if not ok then error("ENETDOWN: " .. tostring(err)) end
|
|
|
|
local deadline = kernel.computer:time() + 15000
|
|
while true do
|
|
local ev = pollEvent()
|
|
if ev then
|
|
if ev[1] == "http_success" and ev[2] == url then
|
|
local handle = ev[3]
|
|
local body = handle.readAll and handle.readAll() or ""
|
|
handle.close()
|
|
return body
|
|
elseif ev[1] == "http_failure" and ev[2] == url then
|
|
error("ECONNREFUSED: " .. tostring(ev[3]))
|
|
else
|
|
dispatchEvent(ev)
|
|
end
|
|
end
|
|
if kernel.computer:time() > deadline then error("ETIMEDOUT") end
|
|
coroutine.yield()
|
|
end
|
|
end
|
|
|
|
kernel.syscalls["resolve"] = function(hostname)
|
|
if hostname:match("^%d+%.%d+%.%d+%.%d+$") then return hostname end
|
|
|
|
local a,b,c,d = hostname:match("^(%d+)%.(%d+)%.(%d+)%.(%d+)$")
|
|
if a and tonumber(a) == 0 and tonumber(b) == 0 then
|
|
return hostname
|
|
end
|
|
|
|
local http = kernel.apis and kernel.apis.http
|
|
if not http then error("ENODEV: no http API for DNS") end
|
|
|
|
local url = "https://cloudflare-dns.com/dns-query?name=" .. hostname .. "&type=A"
|
|
local body = kernel.syscalls["httpget"](url, {
|
|
["Accept"] = "application/dns-json"
|
|
})
|
|
|
|
local ip = body:match('"type":1[^}]*"data":"([%d%.]+)"')
|
|
if not ip then error("ENOENT: could not resolve " .. hostname) end
|
|
return ip
|
|
end
|
|
|
|
kernel.sockets = sockets
|
|
kernel.unixSockets = unixSocks
|
|
|
|
kernel.log("Loaded socket module")
|