Skip to content

Commit

Permalink
fix: 整理和改正 attention 构造问题
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Feb 1, 2024
1 parent f6ecea9 commit 6816f00
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/05computation/include/computation/operators/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
namespace refactor::computation {

struct Attention final : public Operator {
dim_t maxSeqLen;

constexpr Attention(decltype(maxSeqLen) maxSeqLen_) noexcept
: Operator(), maxSeqLen(maxSeqLen_) {}
constexpr Attention() noexcept = default;

static size_t typeId() noexcept;
size_t opTypeId() const noexcept final;
std::string_view name() const noexcept final;
kernel::CollectorBox candidateKernels(Target) const final;
std::string serialize() const noexcept final;
};

}// namespace refactor::computation
Expand Down
8 changes: 8 additions & 0 deletions src/05computation/src/operators/attention.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "computation/operators/attention.h"
#include "kernel/collectors/attention.h"

namespace refactor::computation {
using Op = Attention;
Expand All @@ -9,5 +10,12 @@ namespace refactor::computation {
}
auto Op::opTypeId() const noexcept -> size_t { return typeId(); }
auto Op::name() const noexcept -> std::string_view { return "Attention"; }
auto Op::candidateKernels(Target target) const -> kernel::CollectorBox {
using Collector_ = kernel::AttentionCollector;
return std::make_unique<Collector_>(target);
}
auto Op::serialize() const noexcept -> std::string {
return "Attention()";
}

}// namespace refactor::computation
4 changes: 2 additions & 2 deletions src/08-01llm/src/operators/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace refactor::llm {
: Operator(), maxSeqLen(maxSeqLen_) {}

auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
auto maxSeqLen = attributes.getOrInsert("max_seq_len", {0}).float_();
auto maxSeqLen = attributes.getOrInsert("max_seq_len", {0}).int_();
return OpBox(std::make_unique<Op>(maxSeqLen));
}
auto Op::typeId() -> size_t {
Expand Down Expand Up @@ -129,7 +129,7 @@ namespace refactor::llm {

auto Op::lower(TensorRefs) const -> computation::OpBox {
using Op_ = computation::Attention;
return std::make_unique<Op_>(maxSeqLen);
return std::make_unique<Op_>();
}

}// namespace refactor::llm

0 comments on commit 6816f00

Please sign in to comment.