Skip to content

Commit

Permalink
feat (Repository): #db can be DB::Database or DB::Connection
Browse files Browse the repository at this point in the history
It allows to use temporary connections, e.g. in transctions
  • Loading branch information
vladfaust committed Apr 16, 2019
1 parent e1b899c commit f41f9e6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
4 changes: 4 additions & 0 deletions spec/repository_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ module Onyx::SQL
class Repository
def initialize(@db : MockDB, @logger = Onyx::SQL::Repository::Logger::Dummy.new)
end

protected def db_driver
db.driver
end
end
end

Expand Down
16 changes: 12 additions & 4 deletions src/onyx-sql/repository.cr
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ module Onyx::SQL
# # 442μs
# ```
class Repository
# A `DB::Database` instance for this repository.
# A `::DB::Database | ::DB::Connection` instance for this repository.
property db

# A `Repository::Logger` instance for this repository.
property logger

# Initialize the repository.
def initialize(@db : DB::Database, @logger : Logger = Logger::Standard.new)
def initialize(@db : ::DB::Database | ::DB::Connection, @logger : Logger = Logger::Standard.new)
end

protected def postgresql?
Expand All @@ -45,7 +45,7 @@ module Onyx::SQL
# If the `#db` driver is `PG::Driver`, replace all `?` with `$1`, `$2` etc. Otherwise return *sql_query* untouched.
def prepare_query(sql_query : String)
{% begin %}
case db.driver
case db_driver
{% if Object.all_subclasses.any? { |sc| sc.stringify == "PG::Driver" } %}
when PG::Driver
counter = 0
Expand All @@ -59,7 +59,7 @@ module Onyx::SQL
# Return `#db` driver name, e.g. `"postgresql"` for `PG::Driver`.
def driver
{% begin %}
case db.driver
case db_driver
{% if Object.all_subclasses.any? { |sc| sc.stringify == "PG::Driver" } %}
when PG::Driver then "postgresql"
{% end %}
Expand All @@ -70,5 +70,13 @@ module Onyx::SQL
end
{% end %}
end

protected def db_driver
if db.is_a?(::DB::Database)
db.as(::DB::Database).driver
else
db.as(::DB::Connection).context.as(::DB::Database).driver
end
end
end
end

0 comments on commit f41f9e6

Please sign in to comment.