Skip to main content

GraphDsl.h File

Lightweight DSL helpers for wiring graphs. More...

Included Headers

#include "graph/Graph.h" #include <stdexcept> #include <string> #include <utility> #include <vector>

Namespaces Index

namespacesimaai
namespaceneat
namespacegraph
namespacedsl

Classes Index

structPortRef
structNodeRef

Description

Lightweight DSL helpers for wiring graphs.

File Listing

The file content with the documentation metadata removed is:

1
6#pragma once
7
8#include "graph/Graph.h"
9
10#include <stdexcept>
11#include <string>
12#include <utility>
13#include <vector>
14
16
17struct PortRef {
18 Graph* g = nullptr;
19 NodeId node = kInvalidNode;
20 PortId port = kInvalidPort;
21 bool is_output = true;
22};
23
24struct NodeRef {
25 Graph* g = nullptr;
26 NodeId id = kInvalidNode;
27
28 operator NodeId() const {
29 return id;
30 }
31
32 PortRef out() const;
33 PortRef out(const std::string& name) const;
34 PortRef in() const;
35 PortRef in(const std::string& name) const;
36 PortRef operator[](const std::string& name) const;
37 PortRef operator[](const char* name) const {
38 return (*this)[std::string(name)];
39 }
40};
41
42inline NodeRef ref(Graph& g, NodeId id) {
43 return NodeRef{&g, id};
44}
45
46inline NodeRef add(Graph& g, Graph::NodePtr node) {
47 return NodeRef{&g, g.add(std::move(node))};
48}
49
50inline bool has_port(const std::vector<PortDesc>& ports, const std::string& name) {
51 for (const auto& p : ports) {
52 if (p.name == name)
53 return true;
54 }
55 return false;
56}
57
58inline void ensure_graph(const Graph* g, const char* what) {
59 if (!g)
60 throw std::runtime_error(std::string("GraphDsl: null graph in ") + what);
61}
62
63inline const std::shared_ptr<Node>& get_node(const Graph* g, NodeId id, const char* what) {
64 if (!g)
65 throw std::runtime_error(std::string("GraphDsl: null graph in ") + what);
66 return g->node(id);
67}
68
69inline PortRef NodeRef::out() const {
70 const auto& node = get_node(g, id, "out()");
71 const auto ports = node->output_ports();
72 if (ports.size() != 1) {
73 throw std::runtime_error("GraphDsl: out() requires exactly one output port");
74 }
75 return out(ports.front().name);
76}
77
78inline PortRef NodeRef::out(const std::string& name) const {
79 const auto& node = get_node(g, id, "out(name)");
80 const auto ports = node->output_ports();
81 if (!has_port(ports, name)) {
82 throw std::runtime_error("GraphDsl: unknown output port: " + name);
83 }
84 return PortRef{g, id, g->intern_port(name), true};
85}
86
87inline PortRef NodeRef::in() const {
88 const auto& node = get_node(g, id, "in()");
89 const auto ports = node->input_ports();
90 if (ports.size() != 1) {
91 throw std::runtime_error("GraphDsl: in() requires exactly one input port");
92 }
93 return in(ports.front().name);
94}
95
96inline PortRef NodeRef::in(const std::string& name) const {
97 const auto& node = get_node(g, id, "in(name)");
98 const auto ports = node->input_ports();
99 if (!has_port(ports, name)) {
100 throw std::runtime_error("GraphDsl: unknown input port: " + name);
101 }
102 return PortRef{g, id, g->intern_port(name), false};
103}
104
105inline PortRef NodeRef::operator[](const std::string& name) const {
106 const auto& node = get_node(g, id, "operator[]");
107 const auto in_ports = node->input_ports();
108 const auto out_ports = node->output_ports();
109 const bool has_in = has_port(in_ports, name);
110 const bool has_out = has_port(out_ports, name);
111
112 if (has_out && !has_in)
113 return out(name);
114 if (has_in && !has_out)
115 return in(name);
116
117 if (!has_in && !has_out) {
118 throw std::runtime_error("GraphDsl: unknown port: " + name);
119 }
120 throw std::runtime_error("GraphDsl: ambiguous port name (use .in or .out): " + name);
121}
122
123inline void connect_ports(const PortRef& from, const PortRef& to) {
124 ensure_graph(from.g, "connect_ports(from)");
125 ensure_graph(to.g, "connect_ports(to)");
126 if (from.g != to.g) {
127 throw std::runtime_error("GraphDsl: cannot connect ports from different graphs");
128 }
129 if (!from.is_output) {
130 throw std::runtime_error("GraphDsl: left side is not an output port");
131 }
132 if (to.is_output) {
133 throw std::runtime_error("GraphDsl: right side is not an input port");
134 }
135 from.g->connect(from.node, to.node, from.g->port_name(from.port), to.g->port_name(to.port));
136}
137
138inline NodeRef operator>>(const NodeRef& from, const NodeRef& to) {
139 connect_ports(from.out(), to.in());
140 return to;
141}
142
143inline NodeRef operator>>(const PortRef& from, const NodeRef& to) {
144 connect_ports(from, to.in());
145 return to;
146}
147
148inline NodeRef operator>>(const NodeRef& from, const PortRef& to) {
149 connect_ports(from.out(), to);
150 return NodeRef{to.g, to.node};
151}
152
153inline NodeRef operator>>(const PortRef& from, const PortRef& to) {
154 connect_ports(from, to);
155 return NodeRef{to.g, to.node};
156}
157
158} // namespace simaai::neat::graph::dsl

Generated via doxygen2docusaurus 2.0.0 by Doxygen 1.9.1.