8#ifndef Sawyer_Database_H
9#define Sawyer_Database_H
11#if __cplusplus >= 201103L
13#include <boost/iterator/iterator_facade.hpp>
14#include <boost/lexical_cast.hpp>
15#include <boost/numeric/conversion/cast.hpp>
17#include <Sawyer/Assert.h>
18#include <Sawyer/Map.h>
19#include <Sawyer/Optional.h>
177 class ConnectionBase;
181class Exception:
public std::runtime_error {
183 Exception(
const std::string &what)
184 : std::runtime_error(what) {}
186 ~Exception() noexcept {}
199 friend class ::Sawyer::Database::Statement;
200 friend class ::Sawyer::Database::Detail::ConnectionBase;
202 std::shared_ptr<Detail::ConnectionBase> pimpl_;
209 explicit Connection(
const std::shared_ptr<Detail::ConnectionBase> &pimpl);
215 ~Connection() =
default;
218 static Connection fromUri(
const std::string &uri);
223 static std::string uriDocString();
249 Statement stmt(
const std::string &sql);
252 Connection& run(
const std::string &sql);
260 Optional<T>
get(
const std::string &sql);
265 std::string driverName()
const;
272 size_t lastInsert()
const;
275 void pimpl(
const std::shared_ptr<Detail::ConnectionBase> &p) {
289 friend class ::Sawyer::Database::Detail::ConnectionBase;
291 std::shared_ptr<Detail::StatementBase> pimpl_;
308 explicit Statement(
const std::shared_ptr<Detail::StatementBase> &stmt)
313 Connection connection()
const;
325 Statement& bind(
const std::string &name,
const T &value);
332 Statement& rebind(
const std::string &name,
const T &value);
368 friend class ::Sawyer::Database::Iterator;
370 std::shared_ptr<Detail::StatementBase> stmt_;
377 explicit Row(
const std::shared_ptr<Detail::StatementBase> &stmt);
382 Optional<T>
get(
size_t columnIdx)
const;
387 size_t rowNumber()
const;
400class Iterator:
public boost::iterator_facade<Iterator, const Row, boost::forward_traversal_tag> {
401 friend class ::Sawyer::Database::Detail::StatementBase;
410 explicit Iterator(
const std::shared_ptr<Detail::StatementBase> &stmt);
419 explicit operator bool()
const {
424 friend class boost::iterator_core_access;
425 const Row& dereference()
const;
426 bool equal(
const Iterator&)
const;
451class ConnectionBase:
public std::enable_shared_from_this<ConnectionBase> {
452 friend class ::Sawyer::Database::Connection;
458 virtual ~ConnectionBase() {}
462 virtual void close() = 0;
466 virtual Statement prepareStatement(
const std::string &sql) = 0;
470 virtual size_t lastInsert()
const = 0;
472 Statement makeStatement(
const std::shared_ptr<Detail::StatementBase> &detail);
474 virtual std::string driverName()
const = 0;
481 friend class ::Sawyer::Database::Detail::StatementBase;
483 std::vector<size_t> indexes;
484 bool isBound =
false;
486 void append(
size_t idx) {
487 indexes.push_back(idx);
493 friend class ::Sawyer::Database::Detail::StatementBase;
494 Optional<T> operator()(StatementBase *stmt,
size_t idx);
505class StatementBase:
public std::enable_shared_from_this<StatementBase> {
506 friend class ::Sawyer::Database::Iterator;
507 friend class ::Sawyer::Database::Row;
508 friend class ::Sawyer::Database::Statement;
509 template<
class T>
friend class ::Sawyer::Database::Detail::ColumnReader;
511 using Parameters = Container::Map<std::string, Parameter>;
513 std::shared_ptr<ConnectionBase> connection_;
514 std::weak_ptr<ConnectionBase> weakConnection_;
516 Statement::State state_ = Statement::DEAD;
517 size_t sequence_ = 0;
518 size_t rowNumber_ = 0;
521 virtual ~StatementBase() {}
524 explicit StatementBase(
const std::shared_ptr<ConnectionBase> &connection)
525 : weakConnection_(connection) {
526 ASSERT_not_null(connection);
533 std::pair<std::string, size_t> parseParameters(
const std::string &highSql) {
536 bool inString =
false;
537 size_t nLowParams = 0;
538 state(Statement::READY);
539 for (
size_t i = 0; i < highSql.size(); ++i) {
540 if (
'\'' == highSql[i]) {
541 inString = !inString;
542 lowSql += highSql[i];
543 }
else if (
'?' == highSql[i] && !inString) {
545 std::string paramName;
546 while (i+1 < highSql.size() && (::isalnum(highSql[i+1]) ||
'_' == highSql[i+1]))
547 paramName += highSql[++i];
548 if (paramName.empty())
549 throw Exception(
"invalid parameter name at character position " + boost::lexical_cast<std::string>(i));
550 Parameter ¶m = params_.insertMaybeDefault(paramName);
551 param.append(nLowParams++);
552 state(Statement::UNBOUND);
554 lowSql += highSql[i];
558 state(Statement::DEAD);
559 throw Exception(
"mismatched quotes in SQL statement");
561 return std::make_pair(lowSql, nLowParams);
565 void invalidateIteratorsAndRows() {
570 size_t sequence()
const {
576 bool lockConnection() {
577 return (connection_ = weakConnection_.lock()) !=
nullptr;
582 void unlockConnection() {
588 bool isConnectionLocked()
const {
589 return connection_ !=
nullptr;
593 std::shared_ptr<ConnectionBase> connection()
const {
594 return weakConnection_.lock();
598 Statement::State state()
const {
605 void state(Statement::State newState) {
607 case Statement::DEAD:
608 case Statement::FINISHED:
609 case Statement::UNBOUND:
610 case Statement::READY:
611 invalidateIteratorsAndRows();
614 case Statement::EXECUTING:
615 ASSERT_require(isConnectionLocked());
622 bool hasUnboundParameters()
const {
623 ASSERT_forbid(state() == Statement::DEAD);
624 for (
const Parameter ¶m: params_.values()) {
633 virtual void unbindAllParams() {
634 ASSERT_forbid(state() == Statement::DEAD);
635 for (Parameter ¶m: params_.values())
636 param.isBound = false;
637 state(params_.isEmpty() ? Statement::READY : Statement::UNBOUND);
642 virtual void reset(
bool doUnbind) {
643 ASSERT_forbid(state() == Statement::DEAD);
644 invalidateIteratorsAndRows();
648 state(hasUnboundParameters() ? Statement::UNBOUND : Statement::READY);
655 void bind(
const std::string &name,
const T &value,
bool isRebind) {
657 throw Exception(
"connection is closed");
659 case Statement::DEAD:
660 throw Exception(
"statement is dead");
661 case Statement::FINISHED:
662 case Statement::EXECUTING:
665 case Statement::READY:
666 case Statement::UNBOUND: {
667 if (!params_.exists(name))
668 throw Exception(
"no such parameter \"" + name +
"\" in statement");
669 Parameter ¶m = params_[name];
670 bool wasUnbound = !param.isBound;
671 for (
size_t idx: param.indexes) {
673 bindLowDispatch(idx, value);
674 }
catch (
const Exception &e) {
675 if (param.indexes.size() > 1)
676 state(Statement::DEAD);
680 param.isBound =
true;
682 if (wasUnbound && !hasUnboundParameters())
683 state(Statement::READY);
693 bind(name, *value, isRebind);
695 bind(name, Nothing(), isRebind);
700 virtual void bindLow(
size_t idx,
int value) = 0;
701 virtual void bindLow(
size_t idx, int64_t value) = 0;
702 virtual void bindLow(
size_t idx,
size_t value) = 0;
703 virtual void bindLow(
size_t idx,
double value) = 0;
704 virtual void bindLow(
size_t idx,
const std::string &value) = 0;
705 virtual void bindLow(
size_t idx,
const char *cstring) = 0;
706 virtual void bindLow(
size_t idx, Nothing) = 0;
707 virtual void bindLow(
size_t idx,
const std::vector<uint8_t> &data) = 0;
711 void bindLowDispatch(
size_t idx,
int v) { bindLow(idx, v); }
712 void bindLowDispatch(
size_t idx, int64_t v) { bindLow(idx, v); }
713 void bindLowDispatch(
size_t idx,
size_t v) { bindLow(idx, v); }
714 void bindLowDispatch(
size_t idx,
double v) { bindLow(idx, v); }
715 void bindLowDispatch(
size_t idx,
const std::string &v) { bindLow(idx, v); }
716 void bindLowDispatch(
size_t idx,
const char *v) { bindLow(idx, v); }
717 void bindLowDispatch(
size_t idx, Nothing v) { bindLow(idx, v); }
718 void bindLowDispatch(
size_t idx,
const std::vector<uint8_t> &v) { bindLow(idx, v); }
721 template <bool Enable = !std::is_same<long, int64_t>::value>
722 typename std::enable_if<Enable, void>::type
723 bindLowDispatch(
size_t idx,
long v) {
724 bindLow(idx,
static_cast<int64_t
>(v));
728 template <bool Enable = !std::is_same<unsigned long, size_t>::value>
729 typename std::enable_if<Enable, void>::type
730 bindLowDispatch(
size_t idx,
unsigned long v) {
731 bindLow(idx,
static_cast<size_t>(v));
735 typename std::enable_if<
736 std::is_integral<T>::value &&
737 !std::is_same<T, int>::value &&
738 !std::is_same<T, long>::value &&
739 !std::is_same<T, unsigned long>::value &&
740 !std::is_same<T, int64_t>::value &&
741 !std::is_same<T, size_t>::value &&
742 !std::is_same<T, bool>::value,
744 bindLowDispatch(
size_t idx, T v) {
745 if (std::is_signed<T>::value) {
746 bindLow(idx,
static_cast<int64_t
>(v));
749 bindLow(idx,
static_cast<size_t>(v));
754 void bindLowDispatch(
size_t idx,
bool v) {
755 bindLow(idx, v ? 1 : 0);
758 static_assert(
sizeof(long) <=
sizeof(int64_t),
"unexpected ABI: long wider than int64_t");
761 Iterator makeIterator() {
762 return Iterator(shared_from_this());
769 throw Exception(
"connection is closed");
771 case Statement::DEAD:
772 throw Exception(
"statement is dead");
773 case Statement::UNBOUND: {
775 for (Parameters::Node ¶m: params_.nodes()) {
776 if (!param.value().isBound)
777 s += (s.empty() ?
"" :
", ") + param.key();
779 ASSERT_forbid(s.empty());
780 throw Exception(
"unbound parameters: " + s);
782 case Statement::FINISHED:
783 case Statement::EXECUTING:
786 case Statement::READY: {
787 if (!lockConnection())
788 throw Exception(
"connection has been closed");
789 state(Statement::EXECUTING);
791 Iterator iter = beginLow();
796 ASSERT_not_reachable(
"invalid state");
801 virtual Iterator beginLow() = 0;
806 throw Exception(
"connection is closed");
807 ASSERT_require(state() == Statement::EXECUTING);
808 invalidateIteratorsAndRows();
814 size_t rowNumber()
const {
820 virtual Iterator nextLow() = 0;
824 Optional<T>
get(
size_t columnIdx) {
826 throw Exception(
"connection is closed");
827 ASSERT_require(state() == Statement::EXECUTING);
828 if (columnIdx >= nColumns())
829 throw Exception(
"column index " + boost::lexical_cast<std::string>(columnIdx) +
" is out of range");
830 return ColumnReader<T>()(
this, columnIdx);
834 virtual size_t nColumns()
const = 0;
837 virtual Optional<std::string> getString(
size_t idx) = 0;
838 virtual Optional<std::vector<std::uint8_t>> getBlob(
size_t idx) = 0;
843ColumnReader<T>::operator()(StatementBase *stmt,
size_t idx) {
845 if (!stmt->getString(idx).assignTo(str))
847 return boost::lexical_cast<T>(str);
851inline Optional<std::vector<uint8_t>>
852ColumnReader<std::vector<uint8_t>>::operator()(StatementBase *stmt,
size_t idx) {
853 return stmt->getBlob(idx);
857ConnectionBase::makeStatement(
const std::shared_ptr<Detail::StatementBase> &detail) {
858 return Statement(detail);
868inline Connection::Connection(
const std::shared_ptr<Detail::ConnectionBase> &pimpl)
872Connection::isOpen()
const {
873 return pimpl_ !=
nullptr;
883Connection::driverName()
const {
885 return pimpl_->driverName();
892Connection::stmt(
const std::string &sql) {
894 return pimpl_->prepareStatement(sql);
896 throw Exception(
"no active database connection");
901Connection::run(
const std::string &sql) {
908Connection::get(
const std::string &sql) {
909 for (
auto row: stmt(sql))
910 return row.
get<T>(0);
915Connection::lastInsert()
const {
917 return pimpl_->lastInsert();
919 throw Exception(
"no active database connection");
928Statement::connection()
const {
930 return Connection(pimpl_->connection());
938Statement::bind(
const std::string &name,
const T &value) {
940 pimpl_->bind(name, value,
false);
942 throw Exception(
"no active database connection");
949Statement::rebind(
const std::string &name,
const T &value) {
951 pimpl_->bind(name, value,
true);
953 throw Exception(
"no active database connection");
961 return pimpl_->begin();
963 throw Exception(
"no active database connection");
981 Iterator row = begin();
983 throw Exception(
"query did not return a row");
984 return row->get<T>(0);
992Iterator::Iterator(
const std::shared_ptr<Detail::StatementBase> &stmt)
996Iterator::dereference()
const {
998 throw Exception(
"dereferencing the end iterator");
999 if (row_.sequence_ != row_.stmt_->sequence())
1000 throw Exception(
"iterator has been invalidated");
1005Iterator::equal(
const Iterator &other)
const {
1006 return row_.stmt_ == other.row_.stmt_ && row_.sequence_ == other.row_.sequence_;
1010Iterator::increment() {
1012 throw Exception(
"incrementing the end iterator");
1013 *
this = row_.stmt_->next();
1021Row::Row(
const std::shared_ptr<Detail::StatementBase> &stmt)
1022 : stmt_(stmt), sequence_(stmt ? stmt->sequence() : 0) {}
1026Row::get(
size_t columnIdx)
const {
1027 ASSERT_not_null(stmt_);
1028 if (sequence_ != stmt_->sequence())
1029 throw Exception(
"row has been invalidated");
1030 return stmt_->get<T>(columnIdx);
1034Row::rowNumber()
const {
1035 ASSERT_not_null(stmt_);
1036 if (sequence_ != stmt_->sequence())
1037 throw Exception(
"row has been invalidated");
1038 return stmt_->rowNumber();
Holds a value or nothing.
bool get(const Word *words, size_t idx)
Return a single bit.
bool increment(Word *vec1, const BitRange &range1)
Increment.