edtris/vendor/jeejah/jeejah.lua

358 lines
12 KiB
Lua
Raw Permalink Normal View History

2021-02-07 21:56:19 +00:00
local socket = require "socket"
local serpent = require "serpent"
local bencode = require "bencode"
local load = loadstring or load
local timeout = 0.001
local d = os.getenv("DEBUG") and print or function(_) end
local serpent_pp = function(p) return function(x)
local serpent_opts = {maxlevel=8,maxnum=64,nocode=true}
p(serpent.block(x, serpent_opts)) end
end
local sessions = {}
local response_for = function(old_msg, msg)
-- certain implementations break when the ns field is empty; see
-- https://gitlab.com/technomancy/jeejah/issues/5
msg.session, msg.id, msg.ns = old_msg.session, old_msg.id, ">"
return msg
end
local send = function(conn, msg)
d("Sending", bencode.encode(msg))
conn:send(bencode.encode(msg))
end
local write_for = function(conn, msg)
return function(...)
send(conn, response_for(msg, {out=table.concat({...}, "\t")}))
end
end
local print_for = function(write)
return function(...)
local args = {...}
for i,x in ipairs(args) do args[i] = tostring(x) end
table.insert(args, "\n")
write(table.concat(args, " "))
end
end
local read_for = function(conn, msg)
return function()
send(conn, response_for(msg, {status={"need-input"}}))
while(not sessions[msg.session].input) do
coroutine.yield()
d("yielded")
end
local input = sessions[msg.session].input
sessions[msg.session].input = nil
return input
end
end
local sandbox_for = function(write, provided_sandbox)
local sandbox = { io = { write = write },
print = print_for(write), }
for k,v in pairs(provided_sandbox) do
sandbox[k] = v
end
return sandbox
end
-- for stuff that's shared between eval and load_file
local execute_chunk = function(session, chunk, pp)
local old_write, old_print, old_read = io.write, print, io.read
if(session.sandbox) then
setfenv(chunk, session.sandbox)
pp = pp or serpent_pp(session.sandbox.print)
else
_G.print = print_for(session.write)
_G.io.write, _G.io.read = session.write, session.read
pp = pp or serpent_pp(_G.print)
end
local trace, err
local result = {xpcall(chunk, function(e)
trace = debug.traceback()
err = e end)}
_G.print, _G.io.write, _G.io.read = old_print, old_write, old_read
if(result[1]) then
local res, i = pp(result[2]), 3
while i <= #result do
res = res .. ', ' .. pp(result[i])
i = i + 1
end
return res
else
return nil, (err or "Unknown error") .. "\n" .. trace
end
end
local eval = function(session, code, pp)
local chunk, err = load("return " .. code, "*socket*")
if(err and not chunk) then -- statement, not expression
chunk, err = load(code, "*socket*")
if(not chunk) then
return nil, "Compilation error: " .. (err or "unknown")
end
end
return execute_chunk(session, chunk, pp)
end
local load_file = function(session, file, loader)
local chunk, err = (loader or loadfile)(file)
if(not chunk) then
return nil, "Compilation error in " .. file ": ".. (err or "unknown")
end
return execute_chunk(session, chunk)
end
local register_session = function(conn, msg, provided_sandbox)
local id = tostring(math.random(999999999))
local write = write_for(conn, msg)
local sandbox = provided_sandbox and sandbox_for(write, provided_sandbox)
sessions[id] = { conn = conn, write = write, print = print_for(write),
sandbox = sandbox, coros = {}, id = id}
return response_for(msg, {["new-session"]=id, status={"done"}})
end
local unregister_session = function(msg)
sessions[msg.session] = nil
return response_for(msg, {status={"done"}})
end
local describe = function(msg, handlers)
local ops = { "clone", "close", "describe", "eval", "load-file",
"ls-sessions", "complete", "stdin", "interrupt" }
for op in handlers do table.insert(ops, op) end
return response_for(msg, {ops=ops, status={"done"}})
end
local session_for = function(conn, msg, sandbox)
local s = sessions[msg.session] or register_session(conn, msg, sandbox)
s.write = write_for(conn, msg)
s.read = read_for(conn, msg)
return s
end
local complete = function(msg, sandbox)
local clone = function(t)
local n = {} for k,v in pairs(t) do n[k] = v end return n
end
local top_ctx = clone(sandbox or _G)
for k,v in pairs(msg.libs or {}) do
top_ctx[k] = require(v:sub(2,-2))
end
local function cpl_for(input_parts, ctx)
if type(ctx) ~= "table" then return {} end
if #input_parts == 0 and ctx ~= top_ctx then
return ctx
elseif #input_parts == 1 then
local matches = {}
for k in pairs(ctx) do
if k:find('^' .. input_parts[1]) then
table.insert(matches, k)
end
end
return matches
else
local token1 = table.remove(input_parts, 1)
return cpl_for(input_parts, ctx[token1])
end
end
local input_parts = {}
for i in string.gmatch(msg.input, "([^.%s]+)") do
table.insert(input_parts, i)
end
return response_for(msg, {completions = cpl_for(input_parts, top_ctx)})
end
-- see https://github.com/clojure/tools.nrepl/blob/master/doc/ops.md
local handle = function(conn, handlers, sandbox, msg)
if(handlers and handlers[msg.op]) then
d("Custom op:", msg.op)
handlers[msg.op](conn, msg, session_for(conn, msg, sandbox),
send, response_for)
elseif(msg.op == "clone") then
d("New session.")
send(conn, register_session(conn, msg, sandbox))
elseif(msg.op == "describe") then
d("Describe.")
send(conn, describe(msg, handlers))
elseif(msg.op == "eval") then
d("Evaluating", msg.code)
local value, err = eval(session_for(conn, msg, sandbox), msg.code, msg.pp)
d("Got", value, err)
-- monroe bug means you have to send done status separately
send(conn, response_for(msg, {value=value, ex=err}))
send(conn, response_for(msg, {status={"done"}}))
elseif(msg.op == "load-file") then
d("Loading file", msg.file)
local value, err = load_file(session_for(conn, msg, sandbox),
msg.file, msg.loader)
d("Got", value, err)
send(conn, response_for(msg, {value=value, ex=err, status={"done"}}))
elseif(msg.op == "ls-sessions") then
d("List sessions")
local session_ids = {}
for id in pairs(sessions) do table.insert(session_ids, id) end
send(conn, response_for(msg, {sessions=session_ids, status={"done"}}))
elseif(msg.op == "complete") then
d("Complete", msg.input)
local session_sandbox = session_for(conn, msg, sandbox).sandbox
send(conn, complete(msg, session_sandbox))
elseif(msg.op == "stdin") then
d("Stdin", serpent.block(msg))
sessions[msg.session].input = msg.stdin
send(conn, response_for(msg, {status={"done"}}))
return
elseif(msg.op ~= "interrupt") then -- silently ignore interrupt
send(conn, response_for(msg, {status={"unknown-op"}}))
print(" | Unknown op", serpent.block(msg))
end
end
local handler_coros = {}
local function receive(conn, partial)
local s, err = conn:receive(1) -- wow this is primitive
-- iterate backwards so we can safely remove
for i=#handler_coros, 1, -1 do
local ok, err2 = coroutine.resume(handler_coros[i])
if(coroutine.status(handler_coros[i]) ~= "suspended") then
if(not ok) then print(" | Handler error", err2) end
table.remove(handler_coros, i)
end
end
if(s) then
return receive(conn, (partial or "") .. s)
elseif(err == "timeout" and partial == nil) then
coroutine.yield()
return receive(conn)
elseif(err == "timeout") then
return partial
else
return nil, err
end
end
local function client_loop(conn, sandbox, handlers, middleware, partial)
local input, r_err = receive(conn, partial)
if(input) then
local decoded, d_err = bencode.decode(input)
if decoded and d_err < #input then
partial = input:sub(d_err + 1)
else
partial = nil
end
coroutine.yield()
if(decoded and decoded.op == "close") then
d("End session.")
return send(conn, unregister_session(decoded))
elseif(decoded and decoded.op ~= "close") then
-- If we don't spin up a coroutine here, we can't io.read, because
-- that requires waiting for a response from the client. But most
-- messages don't need to stick around.
local coro = coroutine.create(handle)
if(middleware) then
middleware(function(msg)
local ok, err = coroutine.resume(coro, conn, handlers,
sandbox, msg)
if(not ok) then print(" | Handler error", err) end
end, decoded)
else
local ok, err = coroutine.resume(coro, conn, handlers,
sandbox, decoded)
if(not ok) then print(" | Handler error", err) end
end
if(coroutine.status(coro) == "suspended") then
table.insert(handler_coros, coro)
end
else
print(" | Decoding error:", d_err)
end
return client_loop(conn, sandbox, handlers, middleware, partial)
else
return r_err
end
end
local connections = {}
local function loop(server, sandbox, handlers, middleware, foreground)
socket.sleep(timeout)
local conn, err = server:accept()
local stop = (not foreground) and (coroutine.yield() == "stop")
if(conn) then
conn:settimeout(timeout)
d("Connected.")
local coro = coroutine.create(function()
local _, h_err = pcall(client_loop, conn, sandbox, handlers, middleware)
if(h_err ~= "closed") then print("Connection closed: " .. h_err) end
end)
table.insert(connections, coro)
return loop(server, sandbox, handlers, middleware, foreground)
else
if(err ~= "timeout") then print(" | Socket error: " .. err) end
for _,c in ipairs(connections) do coroutine.resume(c) end
if(stop or err == "closed") then
server:close()
print("Server stopped.")
else
return loop(server, sandbox, handlers, middleware, foreground)
end
end
end
return {
-- Start an nrepl socket server on the given port. For opts you can pass a
-- table with foreground=true to run in the foreground, debug=true for
-- verbose logging, and sandbox={...} to evaluate all code in a sandbox. You
-- can also give an opts.handlers table keying ops to handler functions which
-- take the socket, the decoded message, and the optional sandbox table.
start = function(port, opts)
port = port or 7888
opts = opts or {}
opts.handlers = opts.handlers or {}
-- host should always be localhost on a PC, but not always on a micro
local server = assert(socket.bind(opts.host or "localhost", port))
if(opts.debug) then d = print end
if(opts.timeout) then timeout = tonumber(opts.timeout) end
if(opts.fennel) then
local fenneleval = require("jeejah.fenneleval")
opts.handlers.eval = fenneleval
opts.handlers.stdin = fenneleval
end
assert(not opts.sandbox or setfenv, "Can't use sandbox on 5.2+")
server:settimeout(timeout)
print("Server started on port " .. port .. "...")
if opts.foreground then
return loop(server, opts.sandbox, opts.handlers,
opts.middleware, opts.foreground)
else
return coroutine.create(function()
loop(server, opts.sandbox, opts.handlers, opts.middleware)
end)
end
end,
-- Pass in the coroutine from jeejah.start to this function to stop it.
stop = function(coro)
coroutine.resume(coro, "stop")
end,
broadcast = function(msg)
for _,session in pairs(sessions) do
send(session.conn, msg)
end
end,
}