Wasm generation rewrite, lots of bugfixes

This commit is contained in:
Jeremy Penner 2024-08-05 22:30:05 -04:00
parent f7eee74e2d
commit 2a69d98b49
9 changed files with 277 additions and 135 deletions

View file

@ -3,8 +3,7 @@
[tock.compiler.bind :refer [bind]]
[tock.compiler.specials :refer [specials]]
[tock.compiler.type :refer [typecheck]]
;; [tock.compiler.wasm :refer [ctx-to-wasm]]
[meander.epsilon :as m]))
[tock.compiler.wasm :refer [collect-definitions]]))
;; compiler structure:
;; a quoted form is passed through a series of passes:
@ -30,32 +29,26 @@
;; * binding expressions must be checked against the metadata in scope before typechecking happens - perhaps l/lookup replacement
;; should happen here??
;; 5. type-lowering pass
;; * static allocation of globals
;; * all function definitions are lifted to a top-level collection, and replaced
;; with references
;; * allocation of locals
;; * static allocation of globals
;; * memory-based stack frame allocation (so structs can be passed by reference)
;; * tock types are converted into wasm types
;; * struct access is converted into pointer arithmetic
;; * tock types are converted into wasm types
;; 6. codegen pass
;; * function expression trees are recursively walked to generate linear wasm bytecode
(def empty-program tock.compiler.wasm/empty-program)
(def generate-wasm tock.compiler.wasm/generate-wasm)
;; (defmulti compile-toplevel (fn [form _ctx] (form-dispatch form)))
;; (defmethod compile-toplevel :default [form _ctx]
;; (throw (compile-error form "Unrecognized form")))
;; (defmethod compile-toplevel 'defn [form ctx]
;; (m/match form
;; (_ (m/pred simple-symbol? ?name) . !fndef ...)
;; (let [funcref (typecheck-expr (apply list 'fn !fndef) ctx)
;; exported-funcs (lookup ctx ::u/exported-funcs)]
;; (bind! ctx ?name {:funcref funcref})
;; (swap! exported-funcs assoc (name ?name) (last funcref)))))
(defn new-ctx [] [(atom {:function-count 0})])
(defn compile [form]
(let [ctx (new-ctx)]
(defn compile-form [program form]
(-> form
(desugar specials)
(bind specials ctx)
(typecheck specials))))
(bind specials [(atom (:globals program))])
(typecheck specials)
(collect-definitions program)))
(defn compile [forms]
(-> (reduce compile-form empty-program forms)
generate-wasm))

View file

@ -1,7 +1,7 @@
(ns tock.compiler.bind
(:require [meander.epsilon :as m]
[tock.compiler.util :refer [get-special new-scope bind! lookup]]
[tock.compiler.meander :refer [bottom-up all-subexpressions m+ merge-metafield] :include-macros true]))
[tock.compiler.meander :refer [bottom-up all-subexpressions m+ merge-metafield]]))
(defn decorate-ctx [specials ctx form]
(let [special (get-special specials form)
@ -12,8 +12,7 @@
marker (or (:mark-bound-subexpressions special) identity)
form (marker form)
form (if new-bindings (merge-metafield form {:bindings bindings}) form)
add-ctx? (or (= (first form) 'l/lookup) (:binding (second form)) new-bindings)
form (if add-ctx? (merge-metafield form {:ctx ctx}) form)]
form (merge-metafield form {:ctx ctx})]
(doseq [[symbol binding] bindings]
(bind! ctx symbol binding))
((all-subexpressions (partial decorate-ctx specials ctx)) form)))

View file

@ -2,7 +2,7 @@
(:require [meander.epsilon :as m]
[meander.strategy.epsilon :as r]
[tock.compiler.specials :refer [specials]]
[tock.compiler.meander :refer [parse-type to-sym label] :include-macros true]))
[tock.compiler.meander :refer [parse-type to-sym label]]))
(def leaf-pass
(r/pipe

View file

@ -1,4 +1,5 @@
(ns tock.compiler.meander
#?(:cljs (:require-macros [tock.compiler.meander]))
(:require [meander.epsilon :as m]
[meander.strategy.epsilon :as r]
[meander.syntax.epsilon :as r.syntax]))
@ -13,7 +14,7 @@
[meta-pattern pattern]
(case (::r.syntax/phase &env)
:meander/substitute `(m/app merge-metafield ~pattern ~meta-pattern)
:meander/match `(m/and (_ ~meta-pattern . _ ...) ~pattern)
:meander/match `(m/and (~'_ ~meta-pattern ~'. ~'_ ~'...) ~pattern)
&form))
(m/defsyntax label [form label]
@ -50,3 +51,41 @@
(defn bottom-up [s]
(fn rec [t]
((r/pipe (all-subexpressions rec) s) t)))
(defn reduce-all-subexpressions [s]
(fn [[t ctx]]
(let [subexprs (rest (rest t))
[reduced-subexprs reduced-ctx]
(reduce (fn [[new-subexprs ctx] subexpr]
(let [[subexpr2 ctx2] (s [subexpr ctx])]
[(conj new-subexprs subexpr2) ctx2]))
[[] ctx]
subexprs)]
[(apply list (first t) (second t) reduced-subexprs) reduced-ctx])))
(defn reduce-td [s]
(fn rec [t]
((r/pipe s (reduce-all-subexpressions rec)) t)))
(defn reduce-bu [s]
(fn rec [t]
((r/pipe (reduce-all-subexpressions rec) s) t)))
(defn join-ctx [s] (fn [t ctx] (s [t ctx])))
(defn tree-reducer [s]
(fn [t-ctx]
(let [[_ ctx] t-ctx
new-t (s t-ctx)]
(if (r/fail? new-t) new-t [new-t ctx]))))
(defn ctx-reducer [s]
(fn [t-ctx]
(let [[t _] t-ctx
new-ctx (s t-ctx)]
(if (r/fail? new-ctx) new-ctx [t new-ctx]))))
(defn rewrite-map [m]
(fn [form]
(let [special (first form)
rewriter (get m special)]
(if rewriter (rewriter form) form))))

View file

@ -1,8 +1,7 @@
(ns tock.compiler.specials
(:require [meander.epsilon :as m]
[meander.strategy.epsilon :as r]
[tock.compiler.meander :refer [parse-type to-sym label m+] :include-macros true]
[tock.compiler.util :refer [get-meta]]))
[tock.compiler.meander :refer [parse-type to-sym label m+]]))
;; no namespace - source symbol
;; l/sym - "lowered" form - special form not directly writable from source
@ -18,10 +17,10 @@
(defn left-binop-desugar [symbol]
(r/choice (left-associative symbol) (simple-identity symbol)))
(defn equatable-type? [typesym] (contains? #{'f64 'i32 'bool} typesym))
(defn ordered-type? [typesym] (= typesym 'f64))
(defn equatable-type? [typesym] (contains? #{'f64 'i64 'i32 'bool} typesym))
(defn ordered-type? [typesym] (contains? #{'f64 'i64 'i32} typesym))
(defn logical-type? [typesym] (= typesym 'bool))
(defn numerical-type? [typesym] (= typesym 'f64))
(defn numerical-type? [typesym] (contains? #{'f64 'i64 'i32} typesym))
(defn combinator-typecheck [valid?]
(r/rewrite (_ (m/pred valid? ?l) (m/pred valid? ?r)) [[?l ?l ?l] [?r ?r ?r]]))

View file

@ -1,7 +1,7 @@
(ns tock.compiler.type
(:require [meander.epsilon :as m]
[meander.strategy.epsilon :as r]
[tock.compiler.meander :refer [bottom-up m+ merge-metafield] :include-macros true]
[tock.compiler.meander :refer [bottom-up m+ merge-metafield]]
[tock.compiler.util :refer [lookup get-meta get-special lookup update-binding!]]))
; typechecking happens bottom-up. by the time a node is called to be typechecked, the system has verified that all of the children
@ -19,14 +19,16 @@
; the metadata, or that the form is a binding lookup and the type can be read from context
(def default-typechecker
(r/rewrite
({:ctx ?ctx :name ?name}) [[(m/app #(get-meta (lookup %1 %2) :type) ?ctx ?name)]]
({:ctx (m/some ?ctx) :name (m/some ?name)}) [[(m/app #(get-meta (lookup %1 %2) :type) ?ctx ?name)]]
({:type ?type}) [[?type]]))
(defn coerce [expr to-type]
(m/rewrite [(get-meta expr :type) to-type]
[?t ?t] ~expr
[(m/or 'i32 'i64) 'f64] ('l/cast {:type 'f64} ~expr)
['i32 'i64] ('l/cast {:type 'i64} ~expr)
[_ 'void] ('l/cast {:type 'void} ~expr)
_ (m+ {:type-mismatch ~to-type} ~expr)))
?coercion (m+ {:type-mismatch ?coercion} ~expr)))
(defn coerce-form [form typing]
(m/rewrite [form typing]
@ -37,9 +39,9 @@
(let [special (get-special specials form)
typechecker (or (:typecheck special) default-typechecker)
input (m/rewrite form (_ ?m . (m/and (_ {:type !subtype} . _ ...) _) ...) (?m . !subtype ...))
_ (print (first form) input)
typings (typechecker input)
rewrites (map #(coerce-form form %) typings)
;; _ (print (first form) typings input)
valid-rewrites (filter #(nil? (get-meta % :type-mismatch)) rewrites)
rewrite (or (first valid-rewrites) (first rewrites))]

View file

@ -22,5 +22,8 @@
(defn update-binding! [ctx key f & rest]
(apply swap! (first (filter #(contains? @% key) ctx)) update key f rest))
(defn update-root! [ctx key f & rest]
(apply swap! (peek ctx) update key f rest))
(defn get-special [specials form] (get specials (first form)))
(defn get-meta [form key] (get (second form) key))

View file

@ -3,104 +3,210 @@
[helins.wasm.ir :as ir]
[helins.wasm.bin :as op]
[helins.binf.string :as binf.string]
[tock.compiler.util :refer [compile-error form-dispatch lookup] :as u]
[meander.epsilon :as m]))
[tock.compiler.meander :refer [typed m+ join-ctx reduce-bu ctx-reducer]]
[tock.compiler.util :refer [get-meta]]
[meander.epsilon :as m]
[meander.strategy.epsilon :as r]))
(defn expr-type [form] (m/match form (_ {:type ?type} & _) ?type))
;; Wasm preprocessing:
;; functions need to be assigned numeric ids that correspond to their wasm index.
;; function _types_ also need to be assigned numeric IDs corresponding to their wasm index.
;; replace function definitions with 'l/funcref nodes.
(def operator-assembly
{['= 'f64] [[op/f64-eq]]
['= 'i32] [[op/i32-eq]]
['= 'bool] [[op/i32-eq]]
['not= 'f64] [[op/f64-eq] [op/i32-eqz]]
['not= 'i32] [[op/i32-ne]]
['not= 'bool] [[op/i32-ne]]
['< 'f64] [[op/f64-lt]]
['<= 'f64] [[op/f64-le]]
['> 'f64] [[op/f64-gt]]
['>= 'f64] [[op/f64-ge]]
['< 'i32] [[op/i32-lt_s]]
['<= 'i32] [[op/i32-le_s]]
['> 'i32] [[op/i32-gt_s]]
['>= 'i32] [[op/i32-ge_s]]
['and 'bool] [[op/i32-and]]
['or 'bool] [[op/i32-or]]
['+ 'f64] [[op/f64-add]]
['+ 'i32] [[op/i32-add]]
['- 'f64] [[op/f64-sub]]
['- 'i32] [[op/i32-sub]]
['* 'f64] [[op/f64-mul]]
['* 'i32] [[op/i32-mul]]
['/ 'f64] [[op/f64-div]]
['/ 'i32] [[op/i32-div_s]]})
;; local definitions also need to be hoisted
;; 'bool is 'i32 internally, but it's simpler to support it directly in the emitter than rewrite it
(def lift-functions
(r/rewrite
[(m/and ('l/fn {:type ?type} . _ ...) ?fn) {:fndefs ?funcs & ?r}]
[('l/funcref {:id (m/app count ?funcs)
:type ?type})
{:fndefs (m/app conj ?funcs (m+ {:id (m/app count ?funcs)} ?fn)) & ?r}]))
(def collect-function-types
(ctx-reducer
(r/rewrite
[(typed ?form (m/and ['fn . !params ...] ?type))
{:fntypes (m/and ?types (m/not {?type _index})) & ?r}]
{:fntypes {?type (m/app count ?types) & ?types} & ?r})))
(def collect-globals
(r/rewrite
[('def {:name ?name} ?expr) {:globals ?g & ?r}]
{:globals {?name ?expr & ?g} & ?r}))
(def collect-definitions
(join-ctx
(r/pipe
(reduce-bu
(r/pipe
(r/attempt lift-functions)
(r/attempt collect-function-types)))
collect-globals)))
(def wasm-specials
{'l/if
{:emit (fn [form emit]
(m/match form
('l/if {:type ?type} ?cond ?l ?r)
(concat (emit ?cond)
[[op/if- [:wasm/valtype ?type] (emit ?l) (emit ?r)]])))}
'+ {:ops {['i64 'i64] [[op/i64-add]]
['i32 'i32] [[op/i32-add]]
['f64 'f64] [[op/f64-add]]}}
'- {:ops {['i64 'i64] [[op/i64-sub]]
['i32 'i32] [[op/i32-sub]]
['f64 'f64] [[op/f64-sub]]}}
'* {:ops {['i64 'i64] [[op/i64-mul]]
['i32 'i32] [[op/i32-mul]]
['f64 'f64] [[op/f64-mul]]}}
'/ {:ops {['i64 'i64] [[op/i64-div_s]]
['i32 'i32] [[op/i32-div_s]]
['f64 'f64] [[op/f64-div]]}}
'= {:ops {['i64 'i64] [[op/i64-eq]]
['i32 'i32] [[op/i32-eq]]
['f64 'f64] [[op/f64-eq]]
['fn 'fn] [[op/i32-eq]]
['bool 'bool] [[op/i32-eq]]}}
'not= {:ops {['i64 'i64] [[op/i64-ne]]
['i32 'i32] [[op/i32-ne]]
['f64 'f64] [[op/f64-eq] [op/i32-eqz]]
['fn 'fn] [[op/i32-ne]]
['bool 'bool] [[op/i32-ne]]}}
'< {:ops {['i64 'i64] [[op/i64-lt_s]]
['i32 'i32] [[op/i32-lt_s]]
['f64 'f64] [[op/f64-lt]]}}
'<= {:ops {['i64 'i64] [[op/i64-le_s]]
['i32 'i32] [[op/i32-le_s]]
['f64 'f64] [[op/f64-le]]}}
'> {:ops {['i64 'i64] [[op/i64-gt_s]]
['i32 'i32] [[op/i32-gt_s]]
['f64 'f64] [[op/f64-gt]]}}
'>= {:ops {['i64 'i64] [[op/i64-ge_s]]
['i32 'i32] [[op/i32-ge_s]]
['f64 'f64] [[op/f64-ge]]}}
'not {:ops {['bool] [[op/i32-eqz]]}}
'and
{:emit (fn [form emit]
(m/match form
('and _ ?l ?r)
[[op/if- [:wasm/valtype 'i32] (emit ?l) (emit ?r) [[op/i32-const 0]]]]))}
'or
{:emit (fn [form emit]
(m/match form
('or _ ?l ?r)
[[op/if- [:wasm/valtype 'i32] (emit ?l) [[op/i32-const -1]] (emit ?r)]]))}
'do
{:emit (fn [form emit]
(m/match form
('do _ . !exprs ...)
(mapcat emit !exprs)))}
'l/call
{:emit-meta (fn [form {:keys [emit program]}]
(m/match form
('l/call _ ('l/funcref {:id ?id}) . !args ...)
(concat (mapcat emit !args) [[op/call ?id]])
('l/call _ (typed ?expr ?fntype) . !args ...)
(concat (mapcat emit !args)
[[op/call_indirect (get-in program [:fntypes ?fntype])]])))}
'l/funcref
{:emit (fn [form _emit]
(m/match form
('l/funcref {:id ?id}) [[op/i32-const ?id]]))}
'l/lit
{:emit (fn [form _emit]
(m/match form
('l/lit {:type 'i64 :value ?num}) [[op/i64-const ?num]]
('l/lit {:type 'i32 :value ?num}) [[op/i32-const ?num]]
('l/lit {:type 'bool :value ?b}) [[op/i32-const (if ?b -1 0)]]))}
'l/param
{:emit (fn [form _emit]
(m/match form
('l/param {:index ?id}) [[op/local-get ?id]]))}
'l/cast
{:emit (fn [form emit]
(m/match form
('l/cast {:type ?type} (typed ?expr ?type))
(emit ?expr)
('l/cast {:type 'void} ?expr)
(concat (emit ?expr) [[op/drop]])
('l/cast {:type 'f64} ('l/lit {:type (m/or 'i32 'i64) :value ?val}))
[[op/f64-const ?val]]
('l/cast {:type 'i64} ('l/lit {:type 'i32 :value ?val}))
[[op/i32-const ?val]]
('l/cast {:type 'i64} (typed ?expr 'i32))
(concat (emit ?expr) [[op/i64-extend_i32_s]])
('l/cast {:type 'f64} (typed ?expr 'i32))
(concat (emit ?expr) [[op/f64-convert_i32_s]])
('l/cast {:type 'f64} (typed ?expr 'i64))
(concat (emit ?expr) [[op/f64-convert_i64_s]])))}
'l/fn
{:emit (fn [form emit]
(m/match form ('l/fn _ ?expr) (emit ?expr)))}})
(defn type-based-emit [form opmap emit]
(let [subexprs (rest (rest form))
types (into [] (map #(get-meta % :type) subexprs))
ops (get opmap types)]
(or (concat (mapcat emit subexprs) ops) [[:compile-error form "Unexpected types" types]])))
(defn make-emitter [specials]
(fn [form program]
((fn emit [form]
(let [special (first form)
emitter (get specials special)]
(try
(cond
(contains? emitter :ops) (type-based-emit form (:ops emitter) emit)
(contains? emitter :emit) ((:emit emitter) form emit)
(contains? emitter :emit-meta) ((:emit-meta emitter) form {:emit emit :program program})
:else [[:compile-error form "No wasm emitter defined for special" special]])
(catch :default e
[[:compile-error form "Error during emission:" e]]))))
form)))
(def wasm-emit (make-emitter wasm-specials))
(def empty-program {:fndefs [] :fntypes {} :globals {}})
(defn type-to-wasmtype [type]
(m/match type
'f64 op/numtype-f64
'bool op/numtype-i32
'i32 op/numtype-i32
'i64 op/numtype-i64
'void op/blocktype-nil
['fn . _ ...] op/numtype-i32))
_ op/numtype-i32))
(defn wasm-function-signature [type]
(m/match type
['fn . !types ... ?return-type]
[(apply vector (map type-to-wasmtype !types))
(if (= ?return-type 'void) [] [(type-to-wasmtype ?return-type)])]))
(defmulti emit-code (fn [form] (form-dispatch form)))
(defmethod emit-code :default [form]
(m/match form
((m/and (funcref ?func-id) (m/app meta {:type ?type})) . !params ...)
(concat (mapcat emit-code !params) [[op/call ?func-id]])
(?op . !params ...)
(if-let [ops (get operator-assembly [?op (expr-type form)])]
(concat (mapcat emit-code !params) ops)
(throw (compile-error form ["Don't know how to compile" ?op])))))
(defmethod emit-code `u/local [form]
(m/match form (_ ?local-id) [[op/local-get ?local-id]]))
(defmethod emit-code `u/lit [form]
(m/match form (m/and (_ ?lit) (m/app meta {:type ?type}))
(cond
(= ?type 'i32) [[op/i32-const ?lit]]
(= ?type 'i64) [[op/i64-const ?lit]]
(= ?type 'f64) [[op/f64-const ?lit]]
(= ?type 'bool) [[op/i32-const (if ?lit 1 0)]]
:else (throw (compile-error form "Invalid literal")))))
(defmethod emit-code `u/call-func [form]
(m/match form
(_ ?funcref . !args ...)
(concat (mapcat emit-code !args) [[op/call ?funcref]])))
(defmethod emit-code `u/cast-void [form]
(m/match form
(_ ?expr)
(concat (emit-code ?expr) [[op/drop]])))
(defmethod emit-code 'do [form]
(mapcat #(emit-code %) (rest form)))
(defmethod emit-code 'if [form]
(m/match form
(_ ?cond ?when-true ?when-false)
(concat (emit-code ?cond) [[op/if- [:wasm/valtype (type-to-wasmtype (expr-type form))] (emit-code ?when-true) (emit-code ?when-false)]])))
(defn ctx-to-wasm [ctx]
(let [funcs (deref (lookup ctx ::u/module-funcs))
exported-funcs (deref (lookup ctx ::u/exported-funcs))]
(pr "generating" funcs exported-funcs)
(defn generate-wasm [{:keys [fndefs fntypes globals]}]
(let [index-to-fntype (into {} (map (fn [[sym index]] [index sym]) fntypes))
fntypes (into [] (map #(get index-to-fntype %) (range (count index-to-fntype))))]
(as-> (wasm/ctx) wasm
(reduce (fn [wasm i]
(let [{:keys [body type]} (get funcs i)]
(reduce (fn [wasm fntype]
(ir/assoc-type wasm (ir/type-signature {} (wasm-function-signature fntype))))
wasm fntypes)
(reduce (fn [wasm funcdef]
(let [id (get-meta funcdef :id)]
(-> wasm
(ir/assoc-type (ir/type-signature {} (wasm-function-signature type)))
(ir/assoc-func (ir/func {} i))
(assoc-in [:wasm/codesec i] (ir/func' {} [] (apply vector (emit-code body)))))))
wasm
(range (count funcs)))
(reduce (fn [wasm [name funcid]]
(assoc-in wasm [:wasm/exportsec :wasm.export/func funcid] [(ir/export' {} (binf.string/encode name))]))
wasm
exported-funcs))))
(ir/assoc-func (ir/func {} id))
(assoc-in [:wasm/codesec id]
(ir/func' {} [] (wasm-emit funcdef))))))
wasm fndefs)
(reduce (fn [wasm [name form]]
(m/match form
('l/funcref {:id ?id})
(assoc-in wasm [:wasm/exportsec :wasm.export/func ?id] [(ir/export' {} (binf.string/encode name))])
_ wasm))
wasm globals))))

View file

@ -1,16 +1,17 @@
(ns tock.experiment
(:require [helins.wasm :as wasm]
[helins.binf :as binf]
[tock.compiler :refer [compile]]))
[tock.compiler :refer [compile compile-form empty-program]]
[cljs.pprint :as pp]))
;; ; https://github.com/kalai-transpiler/kalai
;; (def test-wasm
;; (compile
;; '[(defn add [^f64 left ^f64 right -> f64] (+ left right))
;; (defn double [^f64 val -> f64] (* val 2))
;; (defn add_double [^f64 left ^f64 right -> f64] (double (add left right)))]))
(def test-wasm
(compile
'[(fn add [^f64 left ^f64 right -> f64] (+ left right))
(fn double [^f64 val -> f64] (* val 2))
(fn add_double [^f64 left ^f64 right -> f64] (double (add left right)))]))
(defn decompile-url [url]
(-> (js/fetch url)
@ -26,9 +27,9 @@
#_{:clj-kondo/ignore [:clojure-lsp/unused-public-var]}
(defn main []
(js/console.log (compile `(fn add [^f64 left ^f64 right -> f64] (+ left right)))))
;; (js/console.log test-wasm)
;; (-> (instantiate-wasm test-wasm #js {})
;; (.then #(js/console.log (-> % (.-instance) (.-exports) (.add-double 2 3)))))
(js/console.log test-wasm)
(-> (instantiate-wasm test-wasm #js {})
(.then #(js/console.log (-> % (.-instance) (.-exports) (.add-double 2 3)))))
;; (-> (decompile-url "release.wasm")
;; (.then #(js/console.log (-> % :wasm/exportsec :wasm.export/func (get 0))))))
;; (.then #(js/console.log (-> % :wasm/exportsec :wasm.export/func (get 0)))))
)