Skip to content

Commit

Permalink
Polymorphism. May god have mercy
Browse files Browse the repository at this point in the history
  • Loading branch information
Atraxus committed Feb 14, 2025
1 parent baf74f7 commit cb40102
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 133 deletions.
121 changes: 48 additions & 73 deletions Intern/rayx-core/src/Beamline/Beamline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,12 @@ Group& Group::operator=(Group&& other) noexcept {
}

// Deep-copy (clone) implementation.
Group Group::clone() const {
Group copy;
copy.setPosition(m_position);
copy.setOrientation(m_orientation);
// For each child, clone it and add a new shared pointer.
std::unique_ptr<BeamlineNode> Group::clone() const {
auto copy = std::make_unique<Group>();
copy->setPosition(m_position);
copy->setOrientation(m_orientation);
for (const auto& child : m_children) {
std::visit(
[&copy](auto& ptr) {
using T = std::decay_t<decltype(ptr)>;
if constexpr (std::is_same_v<T, std::shared_ptr<DesignElement>>) {
copy.addChild(BeamlineNode(std::make_shared<DesignElement>(ptr->clone())));
} else if constexpr (std::is_same_v<T, std::shared_ptr<DesignSource>>) {
copy.addChild(BeamlineNode(std::make_shared<DesignSource>(ptr->clone())));
} else if constexpr (std::is_same_v<T, std::shared_ptr<Group>>) {
copy.addChild(BeamlineNode(std::make_shared<Group>(ptr->clone())));
}
},
child->data);
copy->addChild(child->clone());
}
return copy;
}
Expand All @@ -51,14 +39,18 @@ template <typename Callback>
void Group::traverse(Callback&& callback) const {
for (const auto& child : m_children) {
callback(*child);
if (const auto* groupPtr = std::get_if<Group>(&(child->data))) {
if (groupPtr) groupPtr->traverse(std::forward<Callback>(callback));
if (child->isGroup()) {
auto groupPtr = static_cast<Group*>(child.get());
groupPtr->traverse(std::forward<Callback>(callback));
}
}
}

// Add a child node.
void Group::addChild(BeamlineNode&& child) { m_children.push_back(std::make_unique<BeamlineNode>(std::move(child))); }
void Group::addChild(std::unique_ptr<BeamlineNode> child) {
if (!child.get()) return;
m_children.push_back(std::move(child));
}

MaterialTables Group::calcMinimalMaterialTables() const {
auto elements = getElements();
Expand All @@ -79,16 +71,14 @@ void Group::accumulateLightSourcesWorldPositions(const Group& group, const glm::
glm::dmat4 currentOri = parentOri * group.getOrientation();

for (const auto& child : group) {
if (std::holds_alternative<DesignSource>(child->data)) {
const auto* srcPtr = &std::get<DesignSource>(child->data);
if (srcPtr) {
glm::dvec4 worldPos = currentOri * srcPtr->getPosition() + currentPos;
positions.push_back(worldPos);
}
} else if (std::holds_alternative<Group>(child->data)) {
const auto* childGroupPtr = &std::get<Group>(child->data);
if (childGroupPtr) {
accumulateLightSourcesWorldPositions(*childGroupPtr, currentPos, currentOri, positions);
if (child->isSource()) {
DesignSource* srcPtr = dynamic_cast<DesignSource*>(child.get());
glm::dvec4 worldPos = currentOri * srcPtr->getPosition() + currentPos;
positions.push_back(worldPos);
} else if (child->isGroup()) {
Group* grpPtr = dynamic_cast<Group*>(child.get());
if (grpPtr) {
accumulateLightSourcesWorldPositions(*grpPtr, currentPos, currentOri, positions);
}
}
}
Expand All @@ -101,19 +91,13 @@ std::vector<OpticalElement> Group::compileElements() const {
glm::dvec4 thisGroupPos = parentOri * grp.getPosition() + parentPos;
glm::dmat4 thisGroupOri = parentOri * grp.getOrientation();

for (const auto& child : grp.m_children) { // For each child...
if (std::holds_alternative<DesignElement>(child->data)) {
const auto* dePtr = &std::get<DesignElement>(child->data);
if (dePtr) {
// Compile an OpticalElement from the DesignElement
elements.push_back(dePtr->compile(thisGroupPos, thisGroupOri));
}
} else if (std::holds_alternative<Group>(child->data)) {
const auto* groupPtr = &std::get<Group>(child->data);
if (groupPtr) {
// Recurse into the child group
self(self, *groupPtr, thisGroupPos, thisGroupOri);
}
for (const auto& child : grp.m_children) {
if (child->isElement()) {
const auto* dePtr = static_cast<DesignElement*>(child.get());
elements.push_back(dePtr->compile(thisGroupPos, thisGroupOri));
} else if (child->isGroup()) {
const auto* groupPtr = static_cast<Group*>(child.get());
self(self, *groupPtr, thisGroupPos, thisGroupOri);
} // Ignore DesignSources
}
};
Expand All @@ -134,21 +118,16 @@ std::vector<Ray> Group::compileSources(int thread_count) const {
glm::dmat4 currentOrientation = parentOrientation * group.getOrientation();

for (const auto& child : group.m_children) { // For each child...
if (std::holds_alternative<DesignSource>(child->data)) {
const auto* srcPtr = &std::get<DesignSource>(child->data);
if (srcPtr) {
// Compile the rays for this source
auto sourceRays = srcPtr->compile(thread_count, currentPosition, currentOrientation);
for (auto& ray : sourceRays) {
ray.m_sourceID = static_cast<uint32_t>(rays.size());
}
rays.insert(rays.end(), sourceRays.begin(), sourceRays.end());
}
} else if (std::holds_alternative<Group>(child->data)) {
const auto* childGroupPtr = &std::get<Group>(child->data);
if (childGroupPtr) {
self(*childGroupPtr, currentPosition, currentOrientation, self);
if (child->isSource()) {
const auto* srcPtr = static_cast<DesignSource*>(child.get());
auto sourceRays = srcPtr->compile(thread_count, currentPosition, currentOrientation);
for (auto& ray : sourceRays) {
ray.m_sourceID = static_cast<uint32_t>(rays.size());
}
rays.insert(rays.end(), sourceRays.begin(), sourceRays.end());
} else if (child->isGroup()) {
const auto* childGroupPtr = static_cast<Group*>(child.get());
self(*childGroupPtr, currentPosition, currentOrientation, self);
}
}
};
Expand All @@ -160,8 +139,11 @@ std::vector<Ray> Group::compileSources(int thread_count) const {
std::vector<DesignElement*> Group::getElements() const {
std::vector<DesignElement*> elements;
for (const auto& node : m_children) {
if (std::holds_alternative<DesignElement>(node->data)) {
elements.push_back(&std::get<DesignElement>(node->data));
if (node->isElement()) {
elements.push_back(static_cast<DesignElement*>(node.get()));
} else if (node->isGroup()) {
auto nodeElements = static_cast<Group*>(node.get())->getElements();
elements.insert(elements.end(), nodeElements.begin(), nodeElements.end());
}
}
return elements;
Expand All @@ -170,27 +152,20 @@ std::vector<DesignElement*> Group::getElements() const {
std::vector<DesignSource*> Group::getSources() const {
std::vector<DesignSource*> sources;
for (const auto& node : m_children) {
if (std::holds_alternative<DesignSource>(node->data)) {
sources.push_back(&std::get<DesignSource>(node->data));
if (node->isSource()) {
sources.push_back(static_cast<DesignSource*>(node.get()));
} else if (node->isGroup()) {
auto nodeElements = static_cast<Group*>(node.get())->getSources();
sources.insert(sources.end(), nodeElements.begin(), nodeElements.end());
}
}
return sources;
}

std::vector<Group*> Group::getGroups() const {
std::vector<Group*> groups;
for (const auto& node : m_children) {
if (std::holds_alternative<Group>(node->data)) {
groups.push_back(&std::get<Group>(node->data));
}
}
return groups;
}

size_t Group::numElements() const {
size_t count = 0;
traverse([&count](const BeamlineNode& node) {
if (std::holds_alternative<DesignElement>(node.data)) {
if (node.isElement()) {
++count;
}
});
Expand All @@ -200,7 +175,7 @@ size_t Group::numElements() const {
size_t Group::numSources() const {
size_t count = 0;
traverse([&count](const BeamlineNode& node) {
if (std::holds_alternative<DesignSource>(node.data)) {
if (node.isSource()) {
++count;
}
});
Expand Down
37 changes: 7 additions & 30 deletions Intern/rayx-core/src/Beamline/Beamline.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@
#include "Core.h"
#include "Design/DesignElement.h"
#include "Design/DesignSource.h"
#include "Node.h"

namespace RAYX {

struct BeamlineNode;
enum class NodeType { OpticalElement, LightSource, Group };

class RAYX_API Group {
class RAYX_API Group : public BeamlineNode {
public:
Group() = default;
~Group() = default;
Expand All @@ -26,7 +24,9 @@ class RAYX_API Group {
Group& operator=(const Group&) = delete;

// Clone returns a deep copy of the group and its children.
Group clone() const;
std::unique_ptr<BeamlineNode> clone() const override;

bool isGroup() const override { return true; }

template <typename Callback>
void traverse(Callback&& callback) const;
Expand All @@ -44,12 +44,11 @@ class RAYX_API Group {
const std::vector<std::unique_ptr<BeamlineNode>>& getChildren() const { return m_children; }

// Add a child (by move).
void addChild(BeamlineNode&& child);
void addChild(std::unique_ptr<BeamlineNode> child);

// Getters returning raw pointers. This follows more the old way of handling beamlines and is to be used with care.
std::vector<DesignElement*> getElements() const;
std::vector<DesignSource*> getSources() const;
std::vector<Group*> getGroups() const;

// Helper
size_t numElements() const;
Expand All @@ -69,32 +68,10 @@ class RAYX_API Group {
private:
glm::dvec4 m_position = glm::dvec4(0, 0, 0, 1);
glm::dmat4 m_orientation = glm::dmat4(1);
// m_children vec is not checked for nullptrs anywhere, because with the current implementation it won't have any
std::vector<std::unique_ptr<BeamlineNode>> m_children;
};

using Beamline = Group; // Conceptually, a Beamline is a Group
using NodeData = std::variant<Group, DesignElement, DesignSource>;

struct BeamlineNode {
BeamlineNode(NodeData&& other) : data(std::move(other)), parent(nullptr) {}
NodeData data; // The variant storing the node-specific fields
BeamlineNode* parent = nullptr; // or a parent index if you want
// ? Store node type
};

inline NodeType getNodeType(const BeamlineNode& node) {
return std::visit(
[](auto&& element) -> NodeType {
using T = std::decay_t<decltype(element)>;
if constexpr (std::is_same_v<T, DesignElement>) {
return NodeType::OpticalElement;
} else if constexpr (std::is_same_v<T, DesignSource>) {
return NodeType::LightSource;
} else if constexpr (std::is_same_v<T, Group>) {
return NodeType::Group;
}
},
node.data);
}

} // namespace RAYX
20 changes: 20 additions & 0 deletions Intern/rayx-core/src/Beamline/Node.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include <memory>

#include "Core.h"

namespace RAYX {

class RAYX_API BeamlineNode {
public:
virtual std::unique_ptr<BeamlineNode> clone() const = 0;

virtual bool isGroup() const { return false; }
virtual bool isElement() const { return false; }
virtual bool isSource() const { return false; }

private:
BeamlineNode* parent;
};
} // namespace RAYX
6 changes: 3 additions & 3 deletions Intern/rayx-core/src/Data/Importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ void addBeamlineObjectFromXML(rapidxml::xml_node<>* node, Group& group, std::fil

// TODO: could likely be made nicer
if (isSource) {
group.addChild(BeamlineNode(std::move(*ds)));
group.addChild(std::move(ds));
} else {
parseElement(parser, de.get());
group.addChild(BeamlineNode(std::move(*de)));
group.addChild(std::move(de));
}
}

Expand All @@ -138,7 +138,7 @@ void handleObjectCollection(rapidxml::xml_node<>* collection, Group& group, cons
handleObjectCollection(object, *nestedGroup, filename);

// Add the group to the beamline.
group.addChild(BeamlineNode(std::move(*nestedGroup)));
group.addChild(std::move(nestedGroup));
} else if (strcmp(object->name(), "param") != 0) {
RAYX_EXIT << "received weird object->name(): " << object->name();
}
Expand Down
24 changes: 13 additions & 11 deletions Intern/rayx-core/src/Design/DesignElement.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "DesignElement.h"

#include <iostream>
#include <memory>

#include "Beamline/Objects/Objects.h"
#include "Debug/Debug.h"
Expand All @@ -15,10 +16,10 @@ DesignElement& DesignElement::operator=(DesignElement&& other) noexcept {
return *this;
}

DesignElement DesignElement::clone() const {
std::unique_ptr<BeamlineNode> DesignElement::clone() const {
DesignElement clone;
clone.m_elementParameters = m_elementParameters.clone();
return clone;
return std::make_unique<DesignElement>(std::move(clone));
}

OpticalElement DesignElement::compile(const glm::dvec4& parentPos, const glm::dmat4& parentOri) const {
Expand All @@ -27,21 +28,22 @@ OpticalElement DesignElement::compile(const glm::dvec4& parentPos, const glm::dm
glm::dmat4 worldOri = parentOri * getOrientation();

// Then produce the final OpticalElement with these "world" coords
DesignElement de = clone();
de.setPosition(worldPos);
de.setOrientation(worldOri);
std::unique_ptr<BeamlineNode> de = clone();
DesignElement* dePtr = dynamic_cast<DesignElement*>(de.get());
dePtr->setPosition(worldPos);
dePtr->setOrientation(worldOri);

if (getType() == ElementType::ExpertsMirror) {
return makeElement(de, serializeMirror(), makeQuadric(de));
return makeElement(*dePtr, serializeMirror(), makeQuadric(*dePtr));
} else {
Surface surface = makeSurface(de);
Behaviour behavior = makeBehaviour(de);
Surface surface = makeSurface(*dePtr);
Behaviour behavior = makeBehaviour(*dePtr);
if (getType() == ElementType::Slit) {
return makeElement(de, behavior, surface, {}, DesignPlane::XY);
return makeElement(*dePtr, behavior, surface, {}, DesignPlane::XY);
} else if (getType() == ElementType::ImagePlane) {
return makeElement(de, behavior, surface, serializeUnlimited(), DesignPlane::XY);
return makeElement(*dePtr, behavior, surface, serializeUnlimited(), DesignPlane::XY);
} else {
return makeElement(de, behavior, surface);
return makeElement(*dePtr, behavior, surface);
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions Intern/rayx-core/src/Design/DesignElement.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#pragma once

#include "Element/Element.h"
#include "Beamline/Node.h"
#include "Value.h"

namespace RAYX {

struct RAYX_API DesignElement {
struct RAYX_API DesignElement : public BeamlineNode {
DesignElement() = default;
~DesignElement() = default;
// Delete copy constructor because shallow copies of DesignMap lead to unexpected behavior
Expand All @@ -15,13 +16,14 @@ struct RAYX_API DesignElement {
DesignElement(DesignElement&& other) noexcept;
DesignElement& operator=(DesignElement&& other) noexcept;
// Allow intentional copies
DesignElement clone() const;
std::unique_ptr<BeamlineNode> clone() const override;

DesignMap m_elementParameters;
OpticalElement compile(const glm::dvec4& groupPosition, const glm::dmat4& groupOrientation) const;

void setName(std::string s);
void setType(ElementType s);
bool isElement() const override { return true; }

std::string getName() const;
ElementType getType() const;
Expand Down
Loading

0 comments on commit cb40102

Please sign in to comment.