Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Atraxus committed Feb 12, 2025
1 parent 02bf34a commit 1efb357
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 183 deletions.
169 changes: 86 additions & 83 deletions Intern/rayx-core/src/Beamline/Beamline.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "Beamline.h"

#include <stack>
#include <stdexcept>
#include <sstream>

Expand All @@ -11,14 +12,14 @@ namespace RAYX {

// Move constructor
Group::Group(Group&& other) noexcept
: m_position(std::move(other.m_position)), m_orientation(std::move(other.m_orientation)), children(std::move(other.children)) {}
: m_position(std::move(other.m_position)), m_orientation(std::move(other.m_orientation)), m_children(std::move(other.m_children)) {}

// Move assignment operator
Group& Group::operator=(Group&& other) noexcept {
if (this != &other) {
m_position = std::move(other.m_position);
m_orientation = std::move(other.m_orientation);
children = std::move(other.children);
m_children = std::move(other.m_children);
}
return *this;
}
Expand All @@ -29,7 +30,7 @@ Group Group::clone() const {
copy.setPosition(m_position);
copy.setOrientation(m_orientation);
// For each child, clone it and add a new shared pointer.
for (const auto& child : children) {
for (const auto& child : m_children) {
std::visit(
[&copy](auto& ptr) {
using T = std::decay_t<decltype(ptr)>;
Expand All @@ -46,11 +47,21 @@ Group Group::clone() const {
return copy;
}

template <typename Callback>
void Group::traverse(Callback&& callback) const {
for (const auto& child : m_children) {
callback(child);
if (auto* groupPtr = std::get_if<std::unique_ptr<Group>>(&child)) {
if (*groupPtr) (*groupPtr)->traverse(std::forward<Callback>(callback));
}
}
}

// A Group is always a Group.
NodeType Group::getNodeType() const { return NodeType::Group; }

// Add a child node.
void Group::addChild(BeamlineNode&& child) { children.push_back(std::move(child)); }
void Group::addChild(BeamlineNode&& child) { m_children.push_back(std::move(child)); }

MaterialTables Group::calcMinimalMaterialTables() const {
auto elements = getElements();
Expand All @@ -70,100 +81,119 @@ void Group::accumulateLightSourcesWorldPositions(const Group& group, const glm::
glm::dvec4 currentPos = parentOri * group.getPosition() + parentPos;
glm::dmat4 currentOri = parentOri * group.getOrientation();

traverseGroup(group, [&](const BeamlineNode& node) {
if (std::holds_alternative<std::shared_ptr<DesignSource>>(node)) {
const auto& src = std::get<std::shared_ptr<DesignSource>>(node);
glm::dvec4 worldPos = currentOri * src->getPosition() + currentPos;
positions.push_back(worldPos);
} else if (std::holds_alternative<std::shared_ptr<Group>>(node)) {
const auto& childGroup = std::get<std::shared_ptr<Group>>(node);
accumulateLightSourcesWorldPositions(*childGroup, currentPos, currentOri, positions);
for (const auto& child : group) {
if (std::holds_alternative<std::unique_ptr<DesignSource>>(child)) {
auto& srcPtr = std::get<std::unique_ptr<DesignSource>>(child);
if (srcPtr) {
glm::dvec4 worldPos = currentOri * srcPtr->getPosition() + currentPos;
positions.push_back(worldPos);
}
} else if (std::holds_alternative<std::unique_ptr<Group>>(child)) {
auto& childGroupPtr = std::get<std::unique_ptr<Group>>(child);
if (childGroupPtr) {
accumulateLightSourcesWorldPositions(*childGroupPtr, currentPos, currentOri, positions);
}
}
// DesignElements are ignored.
});
}
}

std::vector<OpticalElement> Group::compileElements() const {
std::vector<OpticalElement> elements;

auto recurse = [&](auto& self, const Group& grp, const glm::dvec4& parentPos, const glm::dmat4& parentOri) -> void {
glm::dvec4 thisGroupPos = parentOri * grp.getPosition() + parentPos;
glm::dmat4 thisGroupOri = parentOri * grp.getOrientation();
for (const auto& child : grp.children) {
if (std::holds_alternative<std::shared_ptr<DesignElement>>(child)) {
const auto& de = std::get<std::shared_ptr<DesignElement>>(child);
elements.push_back(de->compile(thisGroupPos, thisGroupOri));
} else if (std::holds_alternative<std::shared_ptr<Group>>(child)) {
self(self, *std::get<std::shared_ptr<Group>>(child), thisGroupPos, thisGroupOri);
}
// Sources are ignored for optical element compilation.

for (const auto& child : grp.m_children) { // For each child...
if (std::holds_alternative<std::unique_ptr<DesignElement>>(child)) {
const auto& dePtr = std::get<std::unique_ptr<DesignElement>>(child);
if (dePtr) {
// Compile an OpticalElement from the DesignElement
elements.push_back(dePtr->compile(thisGroupPos, thisGroupOri));
}
} else if (std::holds_alternative<std::unique_ptr<Group>>(child)) {
const auto& groupPtr = std::get<std::unique_ptr<Group>>(child);
if (groupPtr) {
// Recurse into the child group
self(self, *groupPtr, thisGroupPos, thisGroupOri);
}
} // Ignore DesignSources
}
};

// Start recursion at this group
recurse(recurse, *this, glm::dvec4(0, 0, 0, 1), glm::dmat4(1.0));
return elements;
}

std::vector<Ray> Group::compileSources(int thread_count) const {
RAYX_PROFILE_FUNCTION_STDOUT();
std::vector<Ray> rays;

auto traverseWithTransforms = [&rays, thread_count](const Group& group, const glm::dvec4& parentPosition, const glm::dmat4& parentOrientation,
const auto& self) -> void {
// Compute the transform for this group
glm::dvec4 currentPosition = parentOrientation * group.getPosition() + parentPosition;
glm::dmat4 currentOrientation = parentOrientation * group.getOrientation();

for (const auto& child : group.children) {
if (std::holds_alternative<std::shared_ptr<DesignSource>>(child)) {
const auto& src = std::get<std::shared_ptr<DesignSource>>(child);
auto sourceRays = src->compile(thread_count, currentPosition, currentOrientation);
for (auto& ray : sourceRays) {
ray.m_sourceID = static_cast<uint32_t>(rays.size());
for (const auto& child : group.m_children) { // For each child...
if (std::holds_alternative<std::unique_ptr<DesignSource>>(child)) {
const auto& srcPtr = std::get<std::unique_ptr<DesignSource>>(child);
if (srcPtr) {
// Compile the rays for this source
auto sourceRays = srcPtr->compile(thread_count, currentPosition, currentOrientation);
for (auto& ray : sourceRays) {
ray.m_sourceID = static_cast<uint32_t>(rays.size());
}
rays.insert(rays.end(), sourceRays.begin(), sourceRays.end());
}
} else if (std::holds_alternative<std::unique_ptr<Group>>(child)) {
const auto& childGroupPtr = std::get<std::unique_ptr<Group>>(child);
if (childGroupPtr) {
self(*childGroupPtr, currentPosition, currentOrientation, self);
}
rays.insert(rays.end(), sourceRays.begin(), sourceRays.end());
} else if (std::holds_alternative<std::shared_ptr<Group>>(child)) {
self(*std::get<std::shared_ptr<Group>>(child), currentPosition, currentOrientation, self);
}
}
};

traverseWithTransforms(*this, glm::dvec4(0, 0, 0, 1), glm::dmat4(1), traverseWithTransforms);
return rays;
}

// Retrieve all DesignElements (deep)
std::vector<std::shared_ptr<DesignElement>> Group::getElements() const {
std::vector<std::shared_ptr<DesignElement>> elements;
traverseGroup(*this, [&elements](const BeamlineNode& node) {
if (std::holds_alternative<std::shared_ptr<DesignElement>>(node)) {
elements.push_back(std::get<std::shared_ptr<DesignElement>>(node));
std::vector<DesignElement*> Group::getElements() const {
std::vector<DesignElement*> elements;
for (const auto& node : m_children) {
if (std::holds_alternative<std::unique_ptr<DesignElement>>(node)) {
elements.push_back(std::get<std::unique_ptr<DesignElement>>(node).get());
}
});
}
return elements;
}

// Retrieve all DesignSources (deep)
std::vector<std::shared_ptr<DesignSource>> Group::getSources() const {
std::vector<std::shared_ptr<DesignSource>> sources;
traverseGroup(*this, [&sources](const BeamlineNode& node) {
if (std::holds_alternative<std::shared_ptr<DesignSource>>(node)) {
sources.push_back(std::get<std::shared_ptr<DesignSource>>(node));
std::vector<DesignSource*> Group::getSources() const {
std::vector<DesignSource*> sources;
for (const auto& node : m_children) {
if (std::holds_alternative<std::unique_ptr<DesignSource>>(node)) {
sources.push_back(std::get<std::unique_ptr<DesignSource>>(node).get());
}
});
}
return sources;
}

// Retrieve all Groups (deep)
std::vector<std::shared_ptr<Group>> Group::getGroups() const {
std::vector<std::shared_ptr<Group>> groups;
traverseGroup(*this, [&groups](const BeamlineNode& node) {
if (std::holds_alternative<std::shared_ptr<Group>>(node)) {
groups.push_back(std::get<std::shared_ptr<Group>>(node));
std::vector<Group*> Group::getGroups() const {
std::vector<Group*> groups;
for (const auto& node : m_children) {
if (std::holds_alternative<std::unique_ptr<Group>>(node)) {
groups.push_back(std::get<std::unique_ptr<Group>>(node).get());
}
});
}
return groups;
}

size_t Group::numElements() const {
size_t count = 0;
traverseGroup(*this, [&count](const BeamlineNode& node) {
if (std::holds_alternative<std::shared_ptr<DesignElement>>(node)) {
traverse([&count](const BeamlineNode& node) {
if (std::holds_alternative<std::unique_ptr<DesignElement>>(node)) {
++count;
}
});
Expand All @@ -172,39 +202,12 @@ size_t Group::numElements() const {

size_t Group::numSources() const {
size_t count = 0;
traverseGroup(*this, [&count](const BeamlineNode& node) {
if (std::holds_alternative<std::shared_ptr<DesignSource>>(node)) {
traverse([&count](const BeamlineNode& node) {
if (std::holds_alternative<std::unique_ptr<DesignSource>>(node)) {
++count;
}
});
return count;
}

// Non‑const overload.
template <typename Callback>
void traverseGroup(Group& group, Callback&& callback) {
// Iterate over the mutable children.
for (auto& child : group.getChildren()) {
callback(child);
// If the child is a Group, then recursively traverse it.
if (std::holds_alternative<std::shared_ptr<Group>>(child)) {
traverseGroup(*std::get<std::shared_ptr<Group>>(child), callback);
}
}
}

// Const overload.
template <typename Callback>
void traverseGroup(const Group& group, Callback&& callback) {
// Iterate over the children (as read‑only).
for (const auto& child : group.getChildren()) {
callback(child);
// Even in a const context, our variant is defined as holding std::shared_ptr<Group>.
// We can still call traverseGroup on the pointed-to Group.
if (std::holds_alternative<std::shared_ptr<Group>>(child)) {
traverseGroup(*std::get<std::shared_ptr<Group>>(child), callback);
}
}
}

} // namespace RAYX
51 changes: 29 additions & 22 deletions Intern/rayx-core/src/Beamline/Beamline.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
namespace RAYX {

class Group;
using BeamlineNode = std::variant<std::shared_ptr<DesignElement>, std::shared_ptr<DesignSource>, std::shared_ptr<Group>>;
using BeamlineNode = std::variant<std::unique_ptr<DesignElement>, std::unique_ptr<DesignSource>, std::unique_ptr<Group>>;
enum class NodeType { OpticalElement, LightSource, Group };

class RAYX_API Group {
Expand All @@ -29,25 +29,38 @@ class RAYX_API Group {
// Clone returns a deep copy of the group and its children.
Group clone() const;

template <typename Callback>
void traverse(Callback&& callback) const;

// Iterators for non-const access
auto begin() { return m_children.begin(); }
auto end() { return m_children.end(); }

// Iterators for const access
auto begin() const { return m_children.cbegin(); }
auto end() const { return m_children.cend(); }
auto cbegin() const { return m_children.cbegin(); }
auto cend() const { return m_children.cend(); }

NodeType getNodeType() const;
const std::vector<BeamlineNode>& getChildren() const { return children; }
std::vector<BeamlineNode>& getChildren() { return m_children; }

// Add a child (by move).
void addChild(BeamlineNode&& child);

// Other member functions.
MaterialTables calcMinimalMaterialTables() const;
static void accumulateLightSourcesWorldPositions(const Group& group, const glm::dvec4& parentPos, const glm::dmat4& parentOri,
std::vector<glm::dvec4>& positions);
std::vector<OpticalElement> compileElements() const;
std::vector<Ray> compileSources(int thread_count = 1) const;
// Getters returning raw pointers. This follows more the old way of handling beamlines and is to be used with care.
std::vector<DesignElement*> getElements() const;
std::vector<DesignSource*> getSources() const;
std::vector<Group*> getGroups() const;

// Getters returning smart pointers.
std::vector<std::shared_ptr<DesignElement>> getElements() const;
std::vector<std::shared_ptr<DesignSource>> getSources() const;
std::vector<std::shared_ptr<Group>> getGroups() const;
// Helper
size_t numElements() const;
size_t numSources() const;
MaterialTables calcMinimalMaterialTables() const;
std::vector<OpticalElement> compileElements() const;
std::vector<Ray> compileSources(int thread_count = 1) const;
static void accumulateLightSourcesWorldPositions(const Group& group, const glm::dvec4& parentPos, const glm::dmat4& parentOri,
std::vector<glm::dvec4>& positions);

// Getters & setters for transforms.
const glm::dvec4& getPosition() const { return m_position; }
Expand All @@ -58,26 +71,20 @@ class RAYX_API Group {
private:
glm::dvec4 m_position = glm::dvec4(0, 0, 0, 1);
glm::dmat4 m_orientation = glm::dmat4(1);
std::vector<BeamlineNode> children;
std::vector<BeamlineNode> m_children;
};

using Beamline = Group; // Conceptually, a Beamline is a Group

template <typename Callback>
void traverseGroup(Group& group, Callback&& callback);
template <typename Callback>
void traverseGroup(const Group& group, Callback&& callback);

// Utility function to determine node type.
inline NodeType getNodeType(const BeamlineNode& node) {
return std::visit(
[](auto&& element) -> NodeType {
using T = std::decay_t<decltype(element)>;
if constexpr (std::is_same_v<T, std::shared_ptr<DesignElement>>) {
if constexpr (std::is_same_v<T, std::unique_ptr<DesignElement>>) {
return NodeType::OpticalElement;
} else if constexpr (std::is_same_v<T, std::shared_ptr<DesignSource>>) {
} else if constexpr (std::is_same_v<T, std::unique_ptr<DesignSource>>) {
return NodeType::LightSource;
} else if constexpr (std::is_same_v<T, std::shared_ptr<Group>>) {
} else if constexpr (std::is_same_v<T, std::unique_ptr<Group>>) {
return NodeType::Group;
}
},
Expand Down
6 changes: 6 additions & 0 deletions Intern/rayx-core/src/Design/DesignElement.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,15 @@ struct RAYX_API DesignElement {
void setPosition(glm::dvec4 p);
glm::dvec4 getPosition() const;

void setWorldPosition(glm::dvec4 p);
glm::dvec4 getWorldPosition() const;

void setOrientation(glm::dmat4x4 o);
glm::dmat4x4 getOrientation() const;

void setWorldOrientation(glm::dmat4x4 o);
glm::dmat4x4 getWorldOrientation() const;

void setMisalignment(Misalignment m);
Misalignment getMisalignment() const;

Expand Down
Loading

0 comments on commit 1efb357

Please sign in to comment.