177 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Lua
		
	
	
	
	
	
			
		
		
	
	
			177 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Lua
		
	
	
	
	
	
-- helper functions for debugging
 | 
						|
local function show(s)
 | 
						|
    io.stderr:write("[Debug] " .. s .. "\n")
 | 
						|
end
 | 
						|
 | 
						|
------------------------------------------------
 | 
						|
local adj = {} -- graph
 | 
						|
local include_nodes = {}
 | 
						|
local function add_edge(u, v)
 | 
						|
    if not adj[u] then
 | 
						|
        adj[u] = {}
 | 
						|
    end
 | 
						|
    table.insert(adj[u], v)
 | 
						|
end
 | 
						|
 | 
						|
---------------------------------------------------------
 | 
						|
-- Topological sort (DFS-based, with cycle detection) -- chatgpt
 | 
						|
---------------------------------------------------------
 | 
						|
local visited = {}  -- "perm" mark: node completely processed
 | 
						|
local in_stack = {} -- "temp" mark: node currently in recursion stack
 | 
						|
local order = {}    -- reverse topological order
 | 
						|
local cycle = nil   -- store cycle if found
 | 
						|
 | 
						|
local function _dfs(u, path)
 | 
						|
    visited[u] = true
 | 
						|
    in_stack[u] = true
 | 
						|
    table.insert(path, u)
 | 
						|
 | 
						|
    for _, v in ipairs(adj[u] or {}) do
 | 
						|
        if not visited[v] then
 | 
						|
            _dfs(v, path)
 | 
						|
            if cycle then return end
 | 
						|
        elseif in_stack[v] then
 | 
						|
            -- found a cycle
 | 
						|
            cycle = {}
 | 
						|
            -- extract the cycle part from path
 | 
						|
            for i = #path, 1, -1 do
 | 
						|
                table.insert(cycle, 1, path[i])
 | 
						|
                if path[i] == v then break end
 | 
						|
            end
 | 
						|
            table.insert(cycle, v) -- close the cycle
 | 
						|
            return
 | 
						|
        end
 | 
						|
    end
 | 
						|
 | 
						|
    in_stack[u] = false
 | 
						|
    table.remove(path)
 | 
						|
    table.insert(order, 1, u)
 | 
						|
end
 | 
						|
 | 
						|
local function topo_sort()
 | 
						|
    local rank = {}
 | 
						|
    for u, _ in pairs(adj) do
 | 
						|
        if not visited[u] then
 | 
						|
            _dfs(u, {})
 | 
						|
            if cycle then return rank, cycle end
 | 
						|
        end
 | 
						|
    end
 | 
						|
 | 
						|
    -- build rank map
 | 
						|
    for i, u in ipairs(order) do
 | 
						|
        rank[u] = i
 | 
						|
    end
 | 
						|
    return rank
 | 
						|
end
 | 
						|
 | 
						|
local function words(s)
 | 
						|
    local res = {}
 | 
						|
    for part in s:gmatch("[^,]+") do               -- split by commas
 | 
						|
        local trimmed = part:match("^%s*(.-)%s*$") -- remove leading/trailing spaces
 | 
						|
        if trimmed ~= "" then
 | 
						|
            table.insert(res, trimmed)
 | 
						|
        end
 | 
						|
    end
 | 
						|
    return res
 | 
						|
end
 | 
						|
 | 
						|
local function dfs(blk, stack) -- depth first search on a top level blk
 | 
						|
    -- look for 2 types of AST node: divs with include attr and divs with labels
 | 
						|
    local labelled = false
 | 
						|
    local include = false
 | 
						|
    if blk.attributes and blk.attributes["include"] then -- this must be a leaf node
 | 
						|
        include = true
 | 
						|
        table.insert(include_nodes, blk.identifier)
 | 
						|
        -- labels in include may appears later in the dfs than this include-node
 | 
						|
        -- but we assume every label will be there and build the graph now
 | 
						|
        -- This is a directed bipartite grpah.
 | 
						|
        -- one side for labeled nodes and one side for include-nodes
 | 
						|
        for _, l in ipairs(words(blk.attributes["include"])) do
 | 
						|
            -- insert edges
 | 
						|
            -- what's the identifier of this include-node?...
 | 
						|
            -- well... you must write a label for each include-node...
 | 
						|
            -- this can be done using another filter
 | 
						|
            add_edge(blk.identifier, l)
 | 
						|
        end
 | 
						|
        -- insert more edges
 | 
						|
        for _, l in ipairs(stack) do
 | 
						|
            add_edge(l, blk.identifier)
 | 
						|
        end
 | 
						|
    elseif blk.identifier and blk.identifier ~= "" then
 | 
						|
        -- collect labelled nodes & maintain the stack
 | 
						|
        labelled = true
 | 
						|
        table.insert(stack, blk.identifier)
 | 
						|
    end
 | 
						|
 | 
						|
    -- recurse into child blocks
 | 
						|
    -- type matters. see https://hackage-content.haskell.org/package/pandoc-types-1.23.1/docs/Text-Pandoc-Definition.html#t:Block
 | 
						|
    -- fortunately, we only need to recurse on divs.
 | 
						|
    if not include and blk.t == 'Div' then
 | 
						|
        for _, inner in ipairs(blk.content) do
 | 
						|
            dfs(inner, stack)
 | 
						|
        end
 | 
						|
    end
 | 
						|
    -- pop
 | 
						|
    if labelled then
 | 
						|
        table.remove(stack)
 | 
						|
    end
 | 
						|
end
 | 
						|
 | 
						|
-- read doc and return the concatenation of target nodes for an include-node
 | 
						|
-- Things become messy here since the include list is ordered.
 | 
						|
-- one has to again transverse the AST to collect labelled nodes and then
 | 
						|
-- do the concatenation...
 | 
						|
local function collect_node(doc, l)
 | 
						|
    local blocks = {}
 | 
						|
    local targets = words(l.attributes["include"])
 | 
						|
    local nodes = {}
 | 
						|
    doc:walk {
 | 
						|
        Div = function(div)
 | 
						|
            if div.identifier and div.identifier ~= "" then
 | 
						|
                nodes[div.identifier] = div:clone()
 | 
						|
            end
 | 
						|
            return nil
 | 
						|
        end
 | 
						|
    }
 | 
						|
    for _, t in ipairs(targets) do
 | 
						|
        if nodes[t] then
 | 
						|
            local cl = nodes[t]:clone()
 | 
						|
            cl.identifier = ""
 | 
						|
            table.insert(blocks, cl)
 | 
						|
            -- show("insert [" .. nodes[t].identifier .. ']\n')
 | 
						|
        else
 | 
						|
            io.stderr:write("Cannot find AST node with label " .. t .. "\n")
 | 
						|
        end
 | 
						|
    end
 | 
						|
    return pandoc.Div(blocks, l.attr)
 | 
						|
end
 | 
						|
 | 
						|
return {
 | 
						|
    -- traverse = 'topdown',
 | 
						|
    Pandoc = function(doc)
 | 
						|
        -- collect labels & build the graph
 | 
						|
        for _, blk in ipairs(doc.blocks) do
 | 
						|
            dfs(blk, {})
 | 
						|
        end
 | 
						|
 | 
						|
        -- topological sort
 | 
						|
        local rank, cycle = topo_sort()
 | 
						|
        if cycle then
 | 
						|
            error("Cycle detected:" .. table.concat(cycle, " -> ") .. "\n")
 | 
						|
        end
 | 
						|
 | 
						|
        -- replace
 | 
						|
        table.sort(include_nodes, function(x, y)
 | 
						|
            return rank[x] > rank[y]
 | 
						|
        end)
 | 
						|
        for _, v in ipairs(include_nodes) do
 | 
						|
            doc = doc:walk { Div = function(div)
 | 
						|
                if div.identifier and div.identifier == v then
 | 
						|
                    return collect_node(doc, div)
 | 
						|
                end
 | 
						|
            end }
 | 
						|
        end
 | 
						|
        return doc
 | 
						|
    end
 | 
						|
}
 |