local socket = require "socket" local serpent = require "serpent" local bencode = require "bencode" local load = loadstring or load local timeout = 0.001 local d = os.getenv("DEBUG") and print or function(_) end local serpent_pp = function(p) return function(x) local serpent_opts = {maxlevel=8,maxnum=64,nocode=true} p(serpent.block(x, serpent_opts)) end end local sessions = {} local response_for = function(old_msg, msg) -- certain implementations break when the ns field is empty; see -- https://gitlab.com/technomancy/jeejah/issues/5 msg.session, msg.id, msg.ns = old_msg.session, old_msg.id, ">" return msg end local send = function(conn, msg) d("Sending", bencode.encode(msg)) conn:send(bencode.encode(msg)) end local write_for = function(conn, msg) return function(...) send(conn, response_for(msg, {out=table.concat({...}, "\t")})) end end local print_for = function(write) return function(...) local args = {...} for i,x in ipairs(args) do args[i] = tostring(x) end table.insert(args, "\n") write(table.concat(args, " ")) end end local read_for = function(conn, msg) return function() send(conn, response_for(msg, {status={"need-input"}})) while(not sessions[msg.session].input) do coroutine.yield() d("yielded") end local input = sessions[msg.session].input sessions[msg.session].input = nil return input end end local sandbox_for = function(write, provided_sandbox) local sandbox = { io = { write = write }, print = print_for(write), } for k,v in pairs(provided_sandbox) do sandbox[k] = v end return sandbox end -- for stuff that's shared between eval and load_file local execute_chunk = function(session, chunk, pp) local old_write, old_print, old_read = io.write, print, io.read if(session.sandbox) then setfenv(chunk, session.sandbox) pp = pp or serpent_pp(session.sandbox.print) else _G.print = print_for(session.write) _G.io.write, _G.io.read = session.write, session.read pp = pp or serpent_pp(_G.print) end local trace, err local result = {xpcall(chunk, function(e) trace = debug.traceback() err = e end)} _G.print, _G.io.write, _G.io.read = old_print, old_write, old_read if(result[1]) then local res, i = pp(result[2]), 3 while i <= #result do res = res .. ', ' .. pp(result[i]) i = i + 1 end return res else return nil, (err or "Unknown error") .. "\n" .. trace end end local eval = function(session, code, pp) local chunk, err = load("return " .. code, "*socket*") if(err and not chunk) then -- statement, not expression chunk, err = load(code, "*socket*") if(not chunk) then return nil, "Compilation error: " .. (err or "unknown") end end return execute_chunk(session, chunk, pp) end local load_file = function(session, file, loader) local chunk, err = (loader or loadfile)(file) if(not chunk) then return nil, "Compilation error in " .. file ": ".. (err or "unknown") end return execute_chunk(session, chunk) end local register_session = function(conn, msg, provided_sandbox) local id = tostring(math.random(999999999)) local write = write_for(conn, msg) local sandbox = provided_sandbox and sandbox_for(write, provided_sandbox) sessions[id] = { conn = conn, write = write, print = print_for(write), sandbox = sandbox, coros = {}, id = id} return response_for(msg, {["new-session"]=id, status={"done"}}) end local unregister_session = function(msg) sessions[msg.session] = nil return response_for(msg, {status={"done"}}) end local describe = function(msg, handlers) local ops = { "clone", "close", "describe", "eval", "load-file", "ls-sessions", "complete", "stdin", "interrupt" } for op in handlers do table.insert(ops, op) end return response_for(msg, {ops=ops, status={"done"}}) end local session_for = function(conn, msg, sandbox) local s = sessions[msg.session] or register_session(conn, msg, sandbox) s.write = write_for(conn, msg) s.read = read_for(conn, msg) return s end local complete = function(msg, sandbox) local clone = function(t) local n = {} for k,v in pairs(t) do n[k] = v end return n end local top_ctx = clone(sandbox or _G) for k,v in pairs(msg.libs or {}) do top_ctx[k] = require(v:sub(2,-2)) end local function cpl_for(input_parts, ctx) if type(ctx) ~= "table" then return {} end if #input_parts == 0 and ctx ~= top_ctx then return ctx elseif #input_parts == 1 then local matches = {} for k in pairs(ctx) do if k:find('^' .. input_parts[1]) then table.insert(matches, k) end end return matches else local token1 = table.remove(input_parts, 1) return cpl_for(input_parts, ctx[token1]) end end local input_parts = {} for i in string.gmatch(msg.input, "([^.%s]+)") do table.insert(input_parts, i) end return response_for(msg, {completions = cpl_for(input_parts, top_ctx)}) end -- see https://github.com/clojure/tools.nrepl/blob/master/doc/ops.md local handle = function(conn, handlers, sandbox, msg) if(handlers and handlers[msg.op]) then d("Custom op:", msg.op) handlers[msg.op](conn, msg, session_for(conn, msg, sandbox), send, response_for) elseif(msg.op == "clone") then d("New session.") send(conn, register_session(conn, msg, sandbox)) elseif(msg.op == "describe") then d("Describe.") send(conn, describe(msg, handlers)) elseif(msg.op == "eval") then d("Evaluating", msg.code) local value, err = eval(session_for(conn, msg, sandbox), msg.code, msg.pp) d("Got", value, err) -- monroe bug means you have to send done status separately send(conn, response_for(msg, {value=value, ex=err})) send(conn, response_for(msg, {status={"done"}})) elseif(msg.op == "load-file") then d("Loading file", msg.file) local value, err = load_file(session_for(conn, msg, sandbox), msg.file, msg.loader) d("Got", value, err) send(conn, response_for(msg, {value=value, ex=err, status={"done"}})) elseif(msg.op == "ls-sessions") then d("List sessions") local session_ids = {} for id in pairs(sessions) do table.insert(session_ids, id) end send(conn, response_for(msg, {sessions=session_ids, status={"done"}})) elseif(msg.op == "complete") then d("Complete", msg.input) local session_sandbox = session_for(conn, msg, sandbox).sandbox send(conn, complete(msg, session_sandbox)) elseif(msg.op == "stdin") then d("Stdin", serpent.block(msg)) sessions[msg.session].input = msg.stdin send(conn, response_for(msg, {status={"done"}})) return elseif(msg.op ~= "interrupt") then -- silently ignore interrupt send(conn, response_for(msg, {status={"unknown-op"}})) print(" | Unknown op", serpent.block(msg)) end end local handler_coros = {} local function receive(conn, partial) local s, err = conn:receive(1) -- wow this is primitive -- iterate backwards so we can safely remove for i=#handler_coros, 1, -1 do local ok, err2 = coroutine.resume(handler_coros[i]) if(coroutine.status(handler_coros[i]) ~= "suspended") then if(not ok) then print(" | Handler error", err2) end table.remove(handler_coros, i) end end if(s) then return receive(conn, (partial or "") .. s) elseif(err == "timeout" and partial == nil) then coroutine.yield() return receive(conn) elseif(err == "timeout") then return partial else return nil, err end end local function client_loop(conn, sandbox, handlers, middleware, partial) local input, r_err = receive(conn, partial) if(input) then local decoded, d_err = bencode.decode(input) if decoded and d_err < #input then partial = input:sub(d_err + 1) else partial = nil end coroutine.yield() if(decoded and decoded.op == "close") then d("End session.") return send(conn, unregister_session(decoded)) elseif(decoded and decoded.op ~= "close") then -- If we don't spin up a coroutine here, we can't io.read, because -- that requires waiting for a response from the client. But most -- messages don't need to stick around. local coro = coroutine.create(handle) if(middleware) then middleware(function(msg) local ok, err = coroutine.resume(coro, conn, handlers, sandbox, msg) if(not ok) then print(" | Handler error", err) end end, decoded) else local ok, err = coroutine.resume(coro, conn, handlers, sandbox, decoded) if(not ok) then print(" | Handler error", err) end end if(coroutine.status(coro) == "suspended") then table.insert(handler_coros, coro) end else print(" | Decoding error:", d_err) end return client_loop(conn, sandbox, handlers, middleware, partial) else return r_err end end local connections = {} local function loop(server, sandbox, handlers, middleware, foreground) socket.sleep(timeout) local conn, err = server:accept() local stop = (not foreground) and (coroutine.yield() == "stop") if(conn) then conn:settimeout(timeout) d("Connected.") local coro = coroutine.create(function() local _, h_err = pcall(client_loop, conn, sandbox, handlers, middleware) if(h_err ~= "closed") then print("Connection closed: " .. h_err) end end) table.insert(connections, coro) return loop(server, sandbox, handlers, middleware, foreground) else if(err ~= "timeout") then print(" | Socket error: " .. err) end for _,c in ipairs(connections) do coroutine.resume(c) end if(stop or err == "closed") then server:close() print("Server stopped.") else return loop(server, sandbox, handlers, middleware, foreground) end end end return { -- Start an nrepl socket server on the given port. For opts you can pass a -- table with foreground=true to run in the foreground, debug=true for -- verbose logging, and sandbox={...} to evaluate all code in a sandbox. You -- can also give an opts.handlers table keying ops to handler functions which -- take the socket, the decoded message, and the optional sandbox table. start = function(port, opts) port = port or 7888 opts = opts or {} opts.handlers = opts.handlers or {} -- host should always be localhost on a PC, but not always on a micro local server = assert(socket.bind(opts.host or "localhost", port)) if(opts.debug) then d = print end if(opts.timeout) then timeout = tonumber(opts.timeout) end if(opts.fennel) then local fenneleval = require("jeejah.fenneleval") opts.handlers.eval = fenneleval opts.handlers.stdin = fenneleval end assert(not opts.sandbox or setfenv, "Can't use sandbox on 5.2+") server:settimeout(timeout) print("Server started on port " .. port .. "...") if opts.foreground then return loop(server, opts.sandbox, opts.handlers, opts.middleware, opts.foreground) else return coroutine.create(function() loop(server, opts.sandbox, opts.handlers, opts.middleware) end) end end, -- Pass in the coroutine from jeejah.start to this function to stop it. stop = function(coro) coroutine.resume(coro, "stop") end, broadcast = function(msg) for _,session in pairs(sessions) do send(session.conn, msg) end end, }