Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Atraxus committed Feb 7, 2025
1 parent 02bf34a commit 07283bd
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 131 deletions.
126 changes: 74 additions & 52 deletions Intern/rayx-core/src/Beamline/Beamline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,100 +70,122 @@ void Group::accumulateLightSourcesWorldPositions(const Group& group, const glm::
glm::dvec4 currentPos = parentOri * group.getPosition() + parentPos;
glm::dmat4 currentOri = parentOri * group.getOrientation();

// Assuming you have a free function (or lambda) that traverses the group tree.
traverseGroup(group, [&](const BeamlineNode& node) {
if (std::holds_alternative<std::shared_ptr<DesignSource>>(node)) {
const auto& src = std::get<std::shared_ptr<DesignSource>>(node);
glm::dvec4 worldPos = currentOri * src->getPosition() + currentPos;
positions.push_back(worldPos);
} else if (std::holds_alternative<std::shared_ptr<Group>>(node)) {
const auto& childGroup = std::get<std::shared_ptr<Group>>(node);
accumulateLightSourcesWorldPositions(*childGroup, currentPos, currentOri, positions);
if (std::holds_alternative<std::unique_ptr<DesignSource>>(node)) {
// Extract the unique_ptr, then access the raw pointer
auto& srcPtr = std::get<std::unique_ptr<DesignSource>>(node);
if (srcPtr) {
glm::dvec4 worldPos = currentOri * srcPtr->getPosition() + currentPos;
positions.push_back(worldPos);
}
} else if (std::holds_alternative<std::unique_ptr<Group>>(node)) {
auto& childGroupPtr = std::get<std::unique_ptr<Group>>(node);
if (childGroupPtr) {
accumulateLightSourcesWorldPositions(*childGroupPtr, currentPos, currentOri, positions);
}
}
// DesignElements are ignored.
});
}

std::vector<OpticalElement> Group::compileElements() const {
std::vector<OpticalElement> elements;

auto recurse = [&](auto& self, const Group& grp, const glm::dvec4& parentPos, const glm::dmat4& parentOri) -> void {
glm::dvec4 thisGroupPos = parentOri * grp.getPosition() + parentPos;
glm::dmat4 thisGroupOri = parentOri * grp.getOrientation();
for (const auto& child : grp.children) {
if (std::holds_alternative<std::shared_ptr<DesignElement>>(child)) {
const auto& de = std::get<std::shared_ptr<DesignElement>>(child);
elements.push_back(de->compile(thisGroupPos, thisGroupOri));
} else if (std::holds_alternative<std::shared_ptr<Group>>(child)) {
self(self, *std::get<std::shared_ptr<Group>>(child), thisGroupPos, thisGroupOri);
}
// Sources are ignored for optical element compilation.

for (const auto& child : grp.children) { // For each child...
if (std::holds_alternative<std::unique_ptr<DesignElement>>(child)) {
const auto& dePtr = std::get<std::unique_ptr<DesignElement>>(child);
if (dePtr) {
// Compile an OpticalElement from the DesignElement
elements.push_back(dePtr->compile(thisGroupPos, thisGroupOri));
}
} else if (std::holds_alternative<std::unique_ptr<Group>>(child)) {
const auto& groupPtr = std::get<std::unique_ptr<Group>>(child);
if (groupPtr) {
// Recurse into the child group
self(self, *groupPtr, thisGroupPos, thisGroupOri);
}
} // Ignore DesignSources
}
};

// Start recursion at this group
recurse(recurse, *this, glm::dvec4(0, 0, 0, 1), glm::dmat4(1.0));
return elements;
}

std::vector<Ray> Group::compileSources(int thread_count) const {
RAYX_PROFILE_FUNCTION_STDOUT();
std::vector<Ray> rays;

auto traverseWithTransforms = [&rays, thread_count](const Group& group, const glm::dvec4& parentPosition, const glm::dmat4& parentOrientation,
const auto& self) -> void {
// Compute the transform for this group
glm::dvec4 currentPosition = parentOrientation * group.getPosition() + parentPosition;
glm::dmat4 currentOrientation = parentOrientation * group.getOrientation();

for (const auto& child : group.children) {
if (std::holds_alternative<std::shared_ptr<DesignSource>>(child)) {
const auto& src = std::get<std::shared_ptr<DesignSource>>(child);
auto sourceRays = src->compile(thread_count, currentPosition, currentOrientation);
for (auto& ray : sourceRays) {
ray.m_sourceID = static_cast<uint32_t>(rays.size());
for (const auto& child : group.children) { // For each child...
if (std::holds_alternative<std::unique_ptr<DesignSource>>(child)) {
const auto& srcPtr = std::get<std::unique_ptr<DesignSource>>(child);
if (srcPtr) {
// Compile the rays for this source
auto sourceRays = srcPtr->compile(thread_count, currentPosition, currentOrientation);
for (auto& ray : sourceRays) {
ray.m_sourceID = static_cast<uint32_t>(rays.size());
}
rays.insert(rays.end(), sourceRays.begin(), sourceRays.end());
}
} else if (std::holds_alternative<std::unique_ptr<Group>>(child)) {
const auto& childGroupPtr = std::get<std::unique_ptr<Group>>(child);
if (childGroupPtr) {
self(*childGroupPtr, currentPosition, currentOrientation, self);
}
rays.insert(rays.end(), sourceRays.begin(), sourceRays.end());
} else if (std::holds_alternative<std::shared_ptr<Group>>(child)) {
self(*std::get<std::shared_ptr<Group>>(child), currentPosition, currentOrientation, self);
}
}
};

traverseWithTransforms(*this, glm::dvec4(0, 0, 0, 1), glm::dmat4(1), traverseWithTransforms);
return rays;
}

// Retrieve all DesignElements (deep)
std::vector<std::shared_ptr<DesignElement>> Group::getElements() const {
std::vector<std::shared_ptr<DesignElement>> elements;
traverseGroup(*this, [&elements](const BeamlineNode& node) {
if (std::holds_alternative<std::shared_ptr<DesignElement>>(node)) {
elements.push_back(std::get<std::shared_ptr<DesignElement>>(node));
std::vector<DesignElement*> Group::getElements() const {
std::vector<DesignElement*> elements;
for (const auto& node : children) {
if (std::holds_alternative<std::unique_ptr<DesignElement>>(node)) {
elements.push_back(std::get<std::unique_ptr<DesignElement>>(node).get());
}
});
}
return elements;
}

// Retrieve all DesignSources (deep)
std::vector<std::shared_ptr<DesignSource>> Group::getSources() const {
std::vector<std::shared_ptr<DesignSource>> sources;
traverseGroup(*this, [&sources](const BeamlineNode& node) {
if (std::holds_alternative<std::shared_ptr<DesignSource>>(node)) {
sources.push_back(std::get<std::shared_ptr<DesignSource>>(node));
std::vector<DesignSource*> Group::getSources() const {
std::vector<DesignSource*> sources;
for (const auto& node : children) {
if (std::holds_alternative<std::unique_ptr<DesignSource>>(node)) {
sources.push_back(std::get<std::unique_ptr<DesignSource>>(node).get());
}
});
}
return sources;
}

// Retrieve all Groups (deep)
std::vector<std::shared_ptr<Group>> Group::getGroups() const {
std::vector<std::shared_ptr<Group>> groups;
traverseGroup(*this, [&groups](const BeamlineNode& node) {
if (std::holds_alternative<std::shared_ptr<Group>>(node)) {
groups.push_back(std::get<std::shared_ptr<Group>>(node));
std::vector<Group*> Group::getGroups() const {
std::vector<Group*> groups;
for (const auto& node : children) {
if (std::holds_alternative<std::unique_ptr<Group>>(node)) {
groups.push_back(std::get<std::unique_ptr<Group>>(node).get());
}
});
}
return groups;
}

size_t Group::numElements() const {
size_t count = 0;
traverseGroup(*this, [&count](const BeamlineNode& node) {
if (std::holds_alternative<std::shared_ptr<DesignElement>>(node)) {
if (std::holds_alternative<std::unique_ptr<DesignElement>>(node)) {
++count;
}
});
Expand All @@ -173,7 +195,7 @@ size_t Group::numElements() const {
size_t Group::numSources() const {
size_t count = 0;
traverseGroup(*this, [&count](const BeamlineNode& node) {
if (std::holds_alternative<std::shared_ptr<DesignSource>>(node)) {
if (std::holds_alternative<std::unique_ptr<DesignSource>>(node)) {
++count;
}
});
Expand All @@ -187,8 +209,8 @@ void traverseGroup(Group& group, Callback&& callback) {
for (auto& child : group.getChildren()) {
callback(child);
// If the child is a Group, then recursively traverse it.
if (std::holds_alternative<std::shared_ptr<Group>>(child)) {
traverseGroup(*std::get<std::shared_ptr<Group>>(child), callback);
if (std::holds_alternative<std::unique_ptr<Group>>(child)) {
traverseGroup(*std::get<std::unique_ptr<Group>>(child), callback);
}
}
}
Expand All @@ -199,10 +221,10 @@ void traverseGroup(const Group& group, Callback&& callback) {
// Iterate over the children (as read‑only).
for (const auto& child : group.getChildren()) {
callback(child);
// Even in a const context, our variant is defined as holding std::shared_ptr<Group>.
// Even in a const context, our variant is defined as holding std::unique_ptr<Group>.
// We can still call traverseGroup on the pointed-to Group.
if (std::holds_alternative<std::shared_ptr<Group>>(child)) {
traverseGroup(*std::get<std::shared_ptr<Group>>(child), callback);
if (std::holds_alternative<std::unique_ptr<Group>>(child)) {
traverseGroup(*std::get<std::unique_ptr<Group>>(child), callback);
}
}
}
Expand Down
28 changes: 14 additions & 14 deletions Intern/rayx-core/src/Beamline/Beamline.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
namespace RAYX {

class Group;
using BeamlineNode = std::variant<std::shared_ptr<DesignElement>, std::shared_ptr<DesignSource>, std::shared_ptr<Group>>;
using BeamlineNode = std::variant<std::unique_ptr<DesignElement>, std::unique_ptr<DesignSource>, std::unique_ptr<Group>>;
enum class NodeType { OpticalElement, LightSource, Group };

class RAYX_API Group {
Expand All @@ -35,19 +35,19 @@ class RAYX_API Group {
// Add a child (by move).
void addChild(BeamlineNode&& child);

// Other member functions.
MaterialTables calcMinimalMaterialTables() const;
static void accumulateLightSourcesWorldPositions(const Group& group, const glm::dvec4& parentPos, const glm::dmat4& parentOri,
std::vector<glm::dvec4>& positions);
std::vector<OpticalElement> compileElements() const;
std::vector<Ray> compileSources(int thread_count = 1) const;
// Getters returning raw pointers. This follows more the old way of handling beamlines and is to be used with care.
std::vector<DesignElement*> getElements() const;
std::vector<DesignSource*> getSources() const;
std::vector<Group*> getGroups() const;

// Getters returning smart pointers.
std::vector<std::shared_ptr<DesignElement>> getElements() const;
std::vector<std::shared_ptr<DesignSource>> getSources() const;
std::vector<std::shared_ptr<Group>> getGroups() const;
// Helper
size_t numElements() const;
size_t numSources() const;
MaterialTables calcMinimalMaterialTables() const;
std::vector<OpticalElement> compileElements() const;
std::vector<Ray> compileSources(int thread_count = 1) const;
static void accumulateLightSourcesWorldPositions(const Group& group, const glm::dvec4& parentPos, const glm::dmat4& parentOri,
std::vector<glm::dvec4>& positions);

// Getters & setters for transforms.
const glm::dvec4& getPosition() const { return m_position; }
Expand All @@ -73,11 +73,11 @@ inline NodeType getNodeType(const BeamlineNode& node) {
return std::visit(
[](auto&& element) -> NodeType {
using T = std::decay_t<decltype(element)>;
if constexpr (std::is_same_v<T, std::shared_ptr<DesignElement>>) {
if constexpr (std::is_same_v<T, std::unique_ptr<DesignElement>>) {
return NodeType::OpticalElement;
} else if constexpr (std::is_same_v<T, std::shared_ptr<DesignSource>>) {
} else if constexpr (std::is_same_v<T, std::unique_ptr<DesignSource>>) {
return NodeType::LightSource;
} else if constexpr (std::is_same_v<T, std::shared_ptr<Group>>) {
} else if constexpr (std::is_same_v<T, std::unique_ptr<Group>>) {
return NodeType::Group;
}
},
Expand Down
8 changes: 4 additions & 4 deletions Intern/rayx-core/tests/testSources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ TEST_F(TestSuite, DipoleEnergyDistribution) {
TEST_F(TestSuite, PixelPositionTest) {
auto beamline = loadBeamline("PixelSource");
auto rays = beamline.compileSources();
const DesignSource* src = beamline.getSources()[0].get();
const DesignSource* src = beamline.getSources()[0];
auto width = src->getSourceWidth();
auto height = src->getSourceHeight();
auto hordiv = src->getHorDivergence();
Expand Down Expand Up @@ -145,7 +145,7 @@ TEST_F(TestSuite, testInterpolationFunctionDipole) {
};

auto beamline = loadBeamline("dipole_plain");
const DesignSource* src = beamline.getSources()[0].get();
const DesignSource* src = beamline.getSources()[0];
DipoleSource dipolesource(*src);

for (auto values : inouts) {
Expand All @@ -167,7 +167,7 @@ TEST_F(TestSuite, testVerDivergenceDipole) {
}};

auto beamline = loadBeamline("dipole_plain");
const DesignSource* src = beamline.getSources()[0].get();
const DesignSource* src = beamline.getSources()[0];
DipoleSource dipolesource(*src);

for (auto values : inouts) {
Expand All @@ -194,7 +194,7 @@ TEST_F(TestSuite, testLightsourceGetters) {
}};
for (auto values : rmlinputs) {
auto beamline = loadBeamline(values.rmlFile);
const DesignSource* src = beamline.getSources()[0].get();
const DesignSource* src = beamline.getSources()[0];

auto test2 = values.horDivergence;
auto test4 = values.sourceDepth;
Expand Down
30 changes: 12 additions & 18 deletions Intern/rayx-ui/src/Application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,24 +151,18 @@ void Application::run() {
m_Scene = std::make_unique<Scene>(m_Device);

// Update elements and sources for UI
m_UIParams.beamlineInfo.elements = m_Beamline->getElements();
m_UIParams.beamlineInfo.sources = m_Beamline->getSources();

// Store source positions in UI parameters
m_UIParams.beamlineInfo.rSourcePositions.clear();
RAYX::Group::accumulateLightSourcesWorldPositions(*m_Beamline, glm::dvec4(0, 0, 0, 1), glm::dmat4(1),
m_UIParams.beamlineInfo.rSourcePositions);

// Set default selection in UI based on available sources or elements
if (m_UIParams.beamlineInfo.sources.size() > 0) {
m_UIParams.beamlineInfo.selectedType = SelectedType::LightSource;
m_UIParams.beamlineInfo.selectedIndex = 0;
} else if (m_UIParams.beamlineInfo.elements.size() > 0) {
m_UIParams.beamlineInfo.selectedType = SelectedType::OpticalElement;
m_UIParams.beamlineInfo.selectedIndex = 0;
} else {
m_UIParams.beamlineInfo.selectedType = SelectedType::None;
}
m_UIParams.beamlineInfo.beamline = m_Beamline.get();

// TODO: Set default selection in UI based on available sources or elements
// if (m_UIParams.beamlineInfo.sources.size() > 0) {
// m_UIParams.beamlineInfo.selectedType = SelectedType::LightSource;
// m_UIParams.beamlineInfo.selectedIndex = 0;
// } else if (m_UIParams.beamlineInfo.elements.size() > 0) {
// m_UIParams.beamlineInfo.selectedType = SelectedType::OpticalElement;
// m_UIParams.beamlineInfo.selectedIndex = 0;
// } else {
// m_UIParams.beamlineInfo.selectedType = SelectedType::None;
// }

// Prepare for ray loading or element preparation
size_t numElements = m_Beamline->numElements();
Expand Down
40 changes: 20 additions & 20 deletions Intern/rayx-ui/src/UserInterface/BeamlineDesignHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,33 @@
#include <string>
#include <unordered_set>

#include "Beamline/Beamline.h"
#include "Beamline/StringConversion.h"
#include "Data/Strings.cpp"
#include "Debug/Instrumentor.h"

void BeamlineDesignHandler::showBeamlineDesignWindow(UIBeamlineInfo& uiBeamlineInfo) {
// RAYX_PROFILE_FUNCTION_STDOUT();
void BeamlineDesignHandler::showBeamlineDesignWindow(UIBeamlineInfo& uiInfo) {
// If no node is selected, do nothing
if (!uiInfo.selectedNode) {
ImGui::Text("No node selected");
return;
}

if (uiBeamlineInfo.selectedType == SelectedType::LightSource) { // source
if (uiBeamlineInfo.selectedIndex >= 0 && uiBeamlineInfo.selectedIndex < static_cast<int>(uiBeamlineInfo.sources.size())) {
showParameters(uiBeamlineInfo.sources[uiBeamlineInfo.selectedIndex]->m_elementParameters, //
uiBeamlineInfo.elementsChanged, //
uiBeamlineInfo.selectedType);
} else {
// Handle out-of-bounds access for sources
RAYX_EXIT << "Error: selectedIndex is out of bounds for sources.";
// The node is a std::variant<...>, so figure out which alternative we have
const auto& node = *uiInfo.selectedNode;

if (std::holds_alternative<std::unique_ptr<RAYX::DesignSource>>(node)) {
const auto& srcPtr = std::get<std::unique_ptr<RAYX::DesignSource>>(node);
if (srcPtr) {
// showParameters can now edit srcPtr->m_elementParameters
showParameters(srcPtr->m_elementParameters, uiInfo.elementsChanged, SelectedType::LightSource);
}
} else if (uiBeamlineInfo.selectedType == SelectedType::OpticalElement) { // element
if (uiBeamlineInfo.selectedIndex >= 0 && uiBeamlineInfo.selectedIndex < static_cast<int>(uiBeamlineInfo.elements.size())) {
showParameters(uiBeamlineInfo.sources[uiBeamlineInfo.selectedIndex]->m_elementParameters, //
uiBeamlineInfo.elementsChanged, //
uiBeamlineInfo.selectedType);
} else {
// Handle out-of-bounds access for elements
RAYX_EXIT << "Error: selectedIndex is out of bounds for elements.";
} else if (std::holds_alternative<std::unique_ptr<RAYX::DesignElement>>(node)) {
const auto& elemPtr = std::get<std::unique_ptr<RAYX::DesignElement>>(node);
if (elemPtr) {
showParameters(elemPtr->m_elementParameters, uiInfo.elementsChanged, SelectedType::OpticalElement);
}
} else if (uiBeamlineInfo.selectedType == SelectedType::Group) { // group
// Handle group if needed
} else if (std::holds_alternative<std::unique_ptr<RAYX::Group>>(node)) {
}
}

Expand Down
Loading

0 comments on commit 07283bd

Please sign in to comment.