Move wasm emission logic into specials.cljc, specials into program

This commit is contained in:
Jeremy Penner 2024-08-05 22:46:47 -04:00
parent 2a69d98b49
commit 7bf14eb508
4 changed files with 155 additions and 164 deletions

View file

@ -3,7 +3,7 @@
[tock.compiler.bind :refer [bind]] [tock.compiler.bind :refer [bind]]
[tock.compiler.specials :refer [specials]] [tock.compiler.specials :refer [specials]]
[tock.compiler.type :refer [typecheck]] [tock.compiler.type :refer [typecheck]]
[tock.compiler.wasm :refer [collect-definitions]])) [tock.compiler.wasm :refer [make-empty-program collect-definitions]]))
;; compiler structure: ;; compiler structure:
;; a quoted form is passed through a series of passes: ;; a quoted form is passed through a series of passes:
@ -39,16 +39,18 @@
;; 6. codegen pass ;; 6. codegen pass
;; * function expression trees are recursively walked to generate linear wasm bytecode ;; * function expression trees are recursively walked to generate linear wasm bytecode
(def empty-program tock.compiler.wasm/empty-program)
(def generate-wasm tock.compiler.wasm/generate-wasm) (def generate-wasm tock.compiler.wasm/generate-wasm)
(defn compile-form [program form] (defn compile-form [program form]
(-> form (let [specials (:specials program)]
(desugar specials) (-> form
(bind specials [(atom (:globals program))]) (desugar specials)
(typecheck specials) (bind specials [(atom (:globals program))])
(collect-definitions program))) (typecheck specials)
(collect-definitions program))))
(defn compile [forms] (defn compile
(-> (reduce compile-form empty-program forms) ([forms] (compile forms specials))
generate-wasm)) ([forms specials]
(-> (reduce compile-form (make-empty-program specials) forms)
generate-wasm)))

View file

@ -1,6 +1,7 @@
(ns tock.compiler.specials (ns tock.compiler.specials
(:require [meander.epsilon :as m] (:require [meander.epsilon :as m]
[meander.strategy.epsilon :as r] [meander.strategy.epsilon :as r]
[helins.wasm.bin :as op]
[tock.compiler.meander :refer [parse-type to-sym label m+]])) [tock.compiler.meander :refer [parse-type to-sym label m+]]))
;; no namespace - source symbol ;; no namespace - source symbol
@ -35,11 +36,16 @@
('i/if ?m) ('do ?m) ('i/if ?m) ('do ?m)
('i/if _ ?else) (label ?else "else block") ('i/if _ ?else) (label ?else "else block")
('i/if ?m ?cond ?body & ?more) ('l/if ?m (label ?cond "condition") ('i/if ?m ?cond ?body & ?more) ('l/if ?m (label ?cond "condition")
(label ?body "body") (label ?body "body")
('i/if {} & ?more))) ('i/if {} & ?more)))
:validate (r/rewrite ('if _) "if statement needs at least a condition and a body" :validate (r/rewrite ('if _) "if statement needs at least a condition and a body"
('if _ _) "if statement needs at least one body expression") ('if _ _) "if statement needs at least one body expression")
:typecheck (r/rewrite (_ ?cond ?l ?r) [[?l 'bool ?l ?l] [?r 'bool ?r ?r] ['void 'bool 'void 'void]])} :typecheck (r/rewrite (_ ?cond ?l ?r) [[?l 'bool ?l ?l] [?r 'bool ?r ?r] ['void 'bool 'void 'void]])
:emit (fn [form emit]
(m/match form
('l/if {:type ?type} ?cond ?l ?r)
(concat (emit ?cond)
[[op/if- [:wasm/valtype ?type] (emit ?l) (emit ?r)]])))}
'l/fn 'l/fn
{:desugar (r/rewrite {:desugar (r/rewrite
@ -50,7 +56,7 @@
(m/app to-sym '->) (m/app parse-type {:type !types})] & ?body) (m/app to-sym '->) (m/app parse-type {:type !types})] & ?body)
('l/fn (m/app merge ?m {:type ['fn . !types ...] ('l/fn (m/app merge ?m {:type ['fn . !types ...]
:params [!names ...]}) :params [!names ...]})
('do {} & ?body)) ('do {} & ?body))
('fn {& ?m} [(m/app (fn [params] [params (parse-type (meta params))]) [(m/pred symbol? !names) {:type !types}]) ...] & ?body) ('fn {& ?m} [(m/app (fn [params] [params (parse-type (meta params))]) [(m/pred symbol? !names) {:type !types}]) ...] & ?body)
('l/fn (m/app merge ?m {:type ['fn . !types ... 'void] ('l/fn (m/app merge ?m {:type ['fn . !types ... 'void]
@ -62,27 +68,71 @@
(into {} (map-indexed (fn [index [name type]] [name (list 'l/param {:type type :name name :index index})]) (into {} (map-indexed (fn [index [name type]] [name (list 'l/param {:type type :name name :index index})])
(m/rewrite form (m/rewrite form
('l/fn {:params [!names ...] :type ['fn . !types ... _]} _) ('l/fn {:params [!names ...] :type ['fn . !types ... _]} _)
[[!names !types] ...]))))} [[!names !types] ...]))))
:emit (fn [form emit]
(m/match form ('l/fn _ ?expr) (emit ?expr)))}
'+ {:desugar (left-binop-desugar '+) '+ {:desugar (left-binop-desugar '+)
:typecheck (combinator-typecheck numerical-type?)} :typecheck (combinator-typecheck numerical-type?)
:ops {['i64 'i64] [[op/i64-add]]
['i32 'i32] [[op/i32-add]]
['f64 'f64] [[op/f64-add]]}}
'- {:desugar (r/choice '- {:desugar (r/choice
(r/rewrite ('- ?m ?v) ('- ?m ('l/lit {:value 0 :type 'i64}) ?v)) (r/rewrite ('- ?m ?v) ('- ?m ('l/lit {:value 0 :type 'i64}) ?v))
(left-associative '-)) (left-associative '-))
:typecheck (combinator-typecheck numerical-type?)} :typecheck (combinator-typecheck numerical-type?)
:ops {['i64 'i64] [[op/i64-sub]]
['i32 'i32] [[op/i32-sub]]
['f64 'f64] [[op/f64-sub]]}}
'* {:desugar (left-binop-desugar '*) '* {:desugar (left-binop-desugar '*)
:typecheck (combinator-typecheck numerical-type?)} :typecheck (combinator-typecheck numerical-type?)
:ops {['i64 'i64] [[op/i64-mul]]
['i32 'i32] [[op/i32-mul]]
['f64 'f64] [[op/f64-mul]]}}
'/ {:desugar (left-binop-desugar '/) '/ {:desugar (left-binop-desugar '/)
:typecheck (combinator-typecheck numerical-type?)} :typecheck (combinator-typecheck numerical-type?)
'= {:typecheck (comparitor-typecheck equatable-type?)} :ops {['i64 'i64] [[op/i64-div_s]]
'not= {:typecheck (comparitor-typecheck equatable-type?)} ['i32 'i32] [[op/i32-div_s]]
'< {:typecheck (comparitor-typecheck ordered-type?)} ['f64 'f64] [[op/f64-div]]}}
'<= {:typecheck (comparitor-typecheck ordered-type?)} '= {:typecheck (comparitor-typecheck equatable-type?)
'> {:typecheck (comparitor-typecheck ordered-type?)} :ops {['i64 'i64] [[op/i64-eq]]
'>= {:typecheck (comparitor-typecheck ordered-type?)} ['i32 'i32] [[op/i32-eq]]
['f64 'f64] [[op/f64-eq]]
['fn 'fn] [[op/i32-eq]]
['bool 'bool] [[op/i32-eq]]}}
'not= {:typecheck (comparitor-typecheck equatable-type?)
:ops {['i64 'i64] [[op/i64-ne]]
['i32 'i32] [[op/i32-ne]]
['f64 'f64] [[op/f64-eq] [op/i32-eqz]]
['fn 'fn] [[op/i32-ne]]
['bool 'bool] [[op/i32-ne]]}}
'< {:typecheck (comparitor-typecheck ordered-type?)
:ops {['i64 'i64] [[op/i64-lt_s]]
['i32 'i32] [[op/i32-lt_s]]
['f64 'f64] [[op/f64-lt]]}}
'<= {:typecheck (comparitor-typecheck ordered-type?)
:ops {['i64 'i64] [[op/i64-le_s]]
['i32 'i32] [[op/i32-le_s]]
['f64 'f64] [[op/f64-le]]}}
'> {:typecheck (comparitor-typecheck ordered-type?)
:ops {['i64 'i64] [[op/i64-gt_s]]
['i32 'i32] [[op/i32-gt_s]]
['f64 'f64] [[op/f64-gt]]}}
'>= {:typecheck (comparitor-typecheck ordered-type?)
:ops {['i64 'i64] [[op/i64-ge_s]]
['i32 'i32] [[op/i32-ge_s]]
['f64 'f64] [[op/f64-ge]]}}
'and {:desugar (left-binop-desugar 'and) 'and {:desugar (left-binop-desugar 'and)
:typecheck (comparitor-typecheck logical-type?)} :typecheck (comparitor-typecheck logical-type?)
:emit (fn [form emit]
(m/match form
('and _ ?l ?r)
[[op/if- [:wasm/valtype 'i32] (emit ?l) (emit ?r) [[op/i32-const 0]]]]))}
'or {:desugar (left-binop-desugar 'or) 'or {:desugar (left-binop-desugar 'or)
:typecheck (comparitor-typecheck logical-type?)} :typecheck (comparitor-typecheck logical-type?)
:emit (fn [form emit]
(m/match form
('or _ ?l ?r)
[[op/if- [:wasm/valtype 'i32] (emit ?l) [[op/i32-const -1]] (emit ?r)]]))}
'def {:desugar (r/rewrite 'def {:desugar (r/rewrite
('def (m/pred symbol? ?name) ?expr) ('def {:name ?name} ?expr)) ('def (m/pred symbol? ?name) ?expr) ('def {:name ?name} ?expr))
:typecheck (r/rewrite (_ ?t) [[?t ?t]]) :typecheck (r/rewrite (_ ?t) [[?t ?t]])
@ -90,12 +140,59 @@
:mark-bound-subexpressions (r/rewrite ('def (m/and ?m {:name ?symbol}) ?expr) ('def ?m (m+ {:binding ?symbol} ?expr)))} :mark-bound-subexpressions (r/rewrite ('def (m/and ?m {:name ?symbol}) ?expr) ('def ?m (m+ {:binding ?symbol} ?expr)))}
'do {:typecheck (r/rewrite (_) [['void]] 'do {:typecheck (r/rewrite (_) [['void]]
(_ . !stmt ... ?last) [[?last . (m/app (constantly 'void) !stmt) ... ?last]]) (_ . !stmt ... ?last) [[?last . (m/app (constantly 'void) !stmt) ... ?last]])
:scope {}} :scope {}
:emit (fn [form emit]
(m/match form
('do _ . !exprs ...)
(mapcat emit !exprs)))}
'l/lookup {} 'l/lookup {}
'l/local {} 'l/local {}
'l/param {} 'l/param
'l/lit {} {:emit (fn [form _emit]
(m/match form
('l/param {:index ?id}) [[op/local-get ?id]]))}
'l/lit
{:emit (fn [form _emit]
(m/match form
('l/lit {:type 'i64 :value ?num}) [[op/i64-const ?num]]
('l/lit {:type 'i32 :value ?num}) [[op/i32-const ?num]]
('l/lit {:type 'bool :value ?b}) [[op/i32-const (if ?b -1 0)]]))}
'l/call 'l/call
{:typecheck (r/rewrite (_ (m/and ['fn . !param-types ... ?return-type] ?fn-type) . _ ...) {:typecheck (r/rewrite (_ (m/and ['fn . !param-types ... ?return-type] ?fn-type) . _ ...)
[[?return-type ?fn-type . !param-types ...]])} [[?return-type ?fn-type . !param-types ...]])
}) :emit-meta (fn [form {:keys [emit program]}]
(m/match form
('l/call _ ('l/funcref {:id ?id}) . !args ...)
(concat (mapcat emit !args) [[op/call ?id]])
('l/call _ (typed ?expr ?fntype) . !args ...)
(concat (mapcat emit !args)
[[op/call_indirect (get-in program [:fntypes ?fntype])]])))}
'l/cast
{:emit (fn [form emit]
(m/match form
('l/cast {:type ?type} (typed ?expr ?type))
(emit ?expr)
('l/cast {:type 'void} ?expr)
(concat (emit ?expr) [[op/drop]])
('l/cast {:type 'f64} ('l/lit {:type (m/or 'i32 'i64) :value ?val}))
[[op/f64-const ?val]]
('l/cast {:type 'i64} ('l/lit {:type 'i32 :value ?val}))
[[op/i32-const ?val]]
('l/cast {:type 'i64} (typed ?expr 'i32))
(concat (emit ?expr) [[op/i64-extend_i32_s]])
('l/cast {:type 'f64} (typed ?expr 'i32))
(concat (emit ?expr) [[op/f64-convert_i32_s]])
('l/cast {:type 'f64} (typed ?expr 'i64))
(concat (emit ?expr) [[op/f64-convert_i64_s]])))}
'l/funcref
{:emit (fn [form _emit]
(m/match form
('l/funcref {:id ?id}) [[op/i32-const ?id]]))}})

View file

@ -44,137 +44,28 @@
(r/attempt collect-function-types))) (r/attempt collect-function-types)))
collect-globals))) collect-globals)))
(def wasm-specials
{'l/if
{:emit (fn [form emit]
(m/match form
('l/if {:type ?type} ?cond ?l ?r)
(concat (emit ?cond)
[[op/if- [:wasm/valtype ?type] (emit ?l) (emit ?r)]])))}
'+ {:ops {['i64 'i64] [[op/i64-add]]
['i32 'i32] [[op/i32-add]]
['f64 'f64] [[op/f64-add]]}}
'- {:ops {['i64 'i64] [[op/i64-sub]]
['i32 'i32] [[op/i32-sub]]
['f64 'f64] [[op/f64-sub]]}}
'* {:ops {['i64 'i64] [[op/i64-mul]]
['i32 'i32] [[op/i32-mul]]
['f64 'f64] [[op/f64-mul]]}}
'/ {:ops {['i64 'i64] [[op/i64-div_s]]
['i32 'i32] [[op/i32-div_s]]
['f64 'f64] [[op/f64-div]]}}
'= {:ops {['i64 'i64] [[op/i64-eq]]
['i32 'i32] [[op/i32-eq]]
['f64 'f64] [[op/f64-eq]]
['fn 'fn] [[op/i32-eq]]
['bool 'bool] [[op/i32-eq]]}}
'not= {:ops {['i64 'i64] [[op/i64-ne]]
['i32 'i32] [[op/i32-ne]]
['f64 'f64] [[op/f64-eq] [op/i32-eqz]]
['fn 'fn] [[op/i32-ne]]
['bool 'bool] [[op/i32-ne]]}}
'< {:ops {['i64 'i64] [[op/i64-lt_s]]
['i32 'i32] [[op/i32-lt_s]]
['f64 'f64] [[op/f64-lt]]}}
'<= {:ops {['i64 'i64] [[op/i64-le_s]]
['i32 'i32] [[op/i32-le_s]]
['f64 'f64] [[op/f64-le]]}}
'> {:ops {['i64 'i64] [[op/i64-gt_s]]
['i32 'i32] [[op/i32-gt_s]]
['f64 'f64] [[op/f64-gt]]}}
'>= {:ops {['i64 'i64] [[op/i64-ge_s]]
['i32 'i32] [[op/i32-ge_s]]
['f64 'f64] [[op/f64-ge]]}}
'not {:ops {['bool] [[op/i32-eqz]]}}
'and
{:emit (fn [form emit]
(m/match form
('and _ ?l ?r)
[[op/if- [:wasm/valtype 'i32] (emit ?l) (emit ?r) [[op/i32-const 0]]]]))}
'or
{:emit (fn [form emit]
(m/match form
('or _ ?l ?r)
[[op/if- [:wasm/valtype 'i32] (emit ?l) [[op/i32-const -1]] (emit ?r)]]))}
'do
{:emit (fn [form emit]
(m/match form
('do _ . !exprs ...)
(mapcat emit !exprs)))}
'l/call
{:emit-meta (fn [form {:keys [emit program]}]
(m/match form
('l/call _ ('l/funcref {:id ?id}) . !args ...)
(concat (mapcat emit !args) [[op/call ?id]])
('l/call _ (typed ?expr ?fntype) . !args ...)
(concat (mapcat emit !args)
[[op/call_indirect (get-in program [:fntypes ?fntype])]])))}
'l/funcref
{:emit (fn [form _emit]
(m/match form
('l/funcref {:id ?id}) [[op/i32-const ?id]]))}
'l/lit
{:emit (fn [form _emit]
(m/match form
('l/lit {:type 'i64 :value ?num}) [[op/i64-const ?num]]
('l/lit {:type 'i32 :value ?num}) [[op/i32-const ?num]]
('l/lit {:type 'bool :value ?b}) [[op/i32-const (if ?b -1 0)]]))}
'l/param
{:emit (fn [form _emit]
(m/match form
('l/param {:index ?id}) [[op/local-get ?id]]))}
'l/cast
{:emit (fn [form emit]
(m/match form
('l/cast {:type ?type} (typed ?expr ?type))
(emit ?expr)
('l/cast {:type 'void} ?expr)
(concat (emit ?expr) [[op/drop]])
('l/cast {:type 'f64} ('l/lit {:type (m/or 'i32 'i64) :value ?val}))
[[op/f64-const ?val]]
('l/cast {:type 'i64} ('l/lit {:type 'i32 :value ?val}))
[[op/i32-const ?val]]
('l/cast {:type 'i64} (typed ?expr 'i32))
(concat (emit ?expr) [[op/i64-extend_i32_s]])
('l/cast {:type 'f64} (typed ?expr 'i32))
(concat (emit ?expr) [[op/f64-convert_i32_s]])
('l/cast {:type 'f64} (typed ?expr 'i64))
(concat (emit ?expr) [[op/f64-convert_i64_s]])))}
'l/fn
{:emit (fn [form emit]
(m/match form ('l/fn _ ?expr) (emit ?expr)))}})
(defn type-based-emit [form opmap emit] (defn type-based-emit [form opmap emit]
(let [subexprs (rest (rest form)) (let [subexprs (rest (rest form))
types (into [] (map #(get-meta % :type) subexprs)) types (into [] (map #(get-meta % :type) subexprs))
ops (get opmap types)] ops (get opmap types)]
(or (concat (mapcat emit subexprs) ops) [[:compile-error form "Unexpected types" types]]))) (or (concat (mapcat emit subexprs) ops) [[:compile-error form "Unexpected types" types]])))
(defn make-emitter [specials] (defn make-emitter [program]
(fn [form program] (fn emit [form]
((fn emit [form] (let [specials (:specials program)
(let [special (first form) special (first form)
emitter (get specials special)] emitter (get specials special)]
(try (try
(cond (cond
(contains? emitter :ops) (type-based-emit form (:ops emitter) emit) (contains? emitter :ops) (type-based-emit form (:ops emitter) emit)
(contains? emitter :emit) ((:emit emitter) form emit) (contains? emitter :emit) ((:emit emitter) form emit)
(contains? emitter :emit-meta) ((:emit-meta emitter) form {:emit emit :program program}) (contains? emitter :emit-meta) ((:emit-meta emitter) form {:emit emit :program program})
:else [[:compile-error form "No wasm emitter defined for special" special]]) :else [[:compile-error form "No wasm emitter defined for special" special]])
(catch :default e (catch :default e
[[:compile-error form "Error during emission:" e]])))) [[:compile-error form "Error during emission:" e]])))))
form)))
(def wasm-emit (make-emitter wasm-specials)) (defn make-empty-program [specials]
{:fndefs [] :fntypes {} :globals {} :specials specials})
(def empty-program {:fndefs [] :fntypes {} :globals {}})
(defn type-to-wasmtype [type] (defn type-to-wasmtype [type]
(m/match type (m/match type
@ -190,8 +81,9 @@
[(apply vector (map type-to-wasmtype !types)) [(apply vector (map type-to-wasmtype !types))
(if (= ?return-type 'void) [] [(type-to-wasmtype ?return-type)])])) (if (= ?return-type 'void) [] [(type-to-wasmtype ?return-type)])]))
(defn generate-wasm [{:keys [fndefs fntypes globals]}] (defn generate-wasm [{:keys [fndefs fntypes globals] :as program}]
(let [index-to-fntype (into {} (map (fn [[sym index]] [index sym]) fntypes)) (let [wasm-emit (make-emitter program)
index-to-fntype (into {} (map (fn [[sym index]] [index sym]) fntypes))
fntypes (into [] (map #(get index-to-fntype %) (range (count index-to-fntype))))] fntypes (into [] (map #(get index-to-fntype %) (range (count index-to-fntype))))]
(as-> (wasm/ctx) wasm (as-> (wasm/ctx) wasm
(reduce (fn [wasm fntype] (reduce (fn [wasm fntype]

View file

@ -1,7 +1,7 @@
(ns tock.experiment (ns tock.experiment
(:require [helins.wasm :as wasm] (:require [helins.wasm :as wasm]
[helins.binf :as binf] [helins.binf :as binf]
[tock.compiler :refer [compile compile-form empty-program]] [tock.compiler :refer [compile]]
[cljs.pprint :as pp])) [cljs.pprint :as pp]))