diff --git a/filter.lua b/filter.lua index 492096e..4abbaa7 100644 --- a/filter.lua +++ b/filter.lua @@ -13,6 +13,57 @@ local function add_edge(u, v) table.insert(adj[u], v) end +--------------------------------------------------------- +-- Topological sort (DFS-based, with cycle detection) +--------------------------------------------------------- +local visited = {} -- "perm" mark: node completely processed +local in_stack = {} -- "temp" mark: node currently in recursion stack +local order = {} -- reverse topological order +local cycle = nil -- store cycle if found + +local function _dfs(u, path) + visited[u] = true + in_stack[u] = true + table.insert(path, u) + + for _, v in ipairs(adj[u] or {}) do + if not visited[v] then + _dfs(v, path) + if cycle then return end + elseif in_stack[v] then + -- found a cycle + cycle = {} + -- extract the cycle part from path + for i = #path, 1, -1 do + table.insert(cycle, 1, path[i]) + if path[i] == v then break end + end + table.insert(cycle, v) -- close the cycle + return + end + end + + in_stack[u] = false + table.remove(path) + table.insert(order, 1, u) +end + +local function topo_sort() + for u, _ in pairs(adj) do + if not visited[u] then + _dfs(u, {}) + if cycle then return nil, cycle end + end + end + + -- build rank map + local rank = {} + for i, u in ipairs(order) do + rank[u] = i + end + return rank +end + local function collect_labels(blk) if blk.identifier and blk.identifier ~= "" then label_map[blk.identifier] = blk:clone() @@ -100,6 +151,7 @@ return { dfs(blk, {}) end + -- sanity check show("edges:") for u, vs in pairs(adj) do for _, v in ipairs(vs) do @@ -107,10 +159,22 @@ return { end end - -- replace - for i, blk in ipairs(doc.blocks) do - doc.blocks[i] = replace(blk) + -- topological sort + local rank, cycle = topo_sort() + if cycle then + error("Cycle detected:" .. table.concat(cycle, " -> ") .. "\n") end + + -- sanity check + show("ranks:") + for k, v in pairs(rank) do + show(k .. "--" .. v) + end + + -- replace + -- for i, blk in ipairs(doc.blocks) do + -- doc.blocks[i] = replace(blk) + -- end return doc end } diff --git a/output.html b/output.html index cbdddac..308a0a0 100644 --- a/output.html +++ b/output.html @@ -4,12 +4,7 @@
thm2 - inthm2 (need thm1)
test thm2
-i need theorem 1
-This line will be ignored
test thm1
diff --git a/test.lua b/test.lua new file mode 100644 index 0000000..ad0de98 --- /dev/null +++ b/test.lua @@ -0,0 +1,79 @@ +-- adjacency list: adj[u] = {v1, v2, ...} +local adj = {} + +local function add_edge(u, v) + adj[u] = adj[u] or {} + table.insert(adj[u], v) +end + +-- Example graph +add_edge("A", "B") +add_edge("B", "C") +add_edge("C", "D") +-- Uncomment next line to introduce a cycle: +add_edge("A", "D") + +--------------------------------------------------------- +-- Topological sort (DFS-based, with cycle detection) +--------------------------------------------------------- +local visited = {} -- "perm" mark: node completely processed +local in_stack = {} -- "temp" mark: node currently in recursion stack +local order = {} -- reverse topological order +local cycle = nil -- store cycle if found + +local function dfs(u, path) + visited[u] = true + in_stack[u] = true + table.insert(path, u) + + for _, v in ipairs(adj[u] or {}) do + if not visited[v] then + dfs(v, path) + if cycle then return end + elseif in_stack[v] then + -- found a cycle + cycle = {} + -- extract the cycle part from path + for i = #path, 1, -1 do + table.insert(cycle, 1, path[i]) + if path[i] == v then break end + end + table.insert(cycle, v) -- close the cycle + return + end + end + + in_stack[u] = false + table.remove(path) + table.insert(order, 1, u) +end + +local function topo_sort() + for u, _ in pairs(adj) do + if not visited[u] then + dfs(u, {}) + if cycle then return nil, cycle end + end + end + + -- build rank map + local rank = {} + for i, u in ipairs(order) do + rank[u] = i + end + return rank +end + +--------------------------------------------------------- +-- Run it +--------------------------------------------------------- +local rank, cycle = topo_sort() +if cycle then + print("Cycle detected:") + print(table.concat(cycle, " -> ")) +else + print("Topological order ranks:") + for k, v in pairs(rank) do + print(k, v) + end +end