#include "DB.h" #include "Exception.h" using namespace std; DB::DB(const string & path) { sqlite3 * db = NULL; int ret = sqlite3_open_v2(path.c_str(), &db, SQLITE_OPEN_READWRITE, NULL); this->db = { db, [] (sqlite3 * db) { sqlite3_close_v2(db); }, }; if (ret != SQLITE_OK) throw Exception("sqlite3_open_v2: %d", ret); } DBStatement DB::prepare(const string & query) { return DBStatement(*this, query); } DBStatement::DBStatement(DB & parent, const string & query) : parent(parent) { sqlite3_stmt * stmt = NULL; int ret = sqlite3_prepare_v2(this->parent.db.get(), query.c_str(), query.size(), &stmt, NULL); this->stmt = { stmt, [] (sqlite3_stmt * stmt) { sqlite3_finalize(stmt); }, }; if (ret != SQLITE_OK) throw Exception("sqlite3_prepare_v2: %d", ret); } DBStatement & DBStatement::bind(const string & text) { int ret = sqlite3_bind_text(this->stmt.get(), this->param_index, text.data(), text.size(), NULL); if (ret != SQLITE_OK) throw Exception("sqlite3_bind_text: %d", ret); this->param_index++; return *this; } DBStatement & DBStatement::bind(const int & number) { int ret = sqlite3_bind_int(this->stmt.get(), this->param_index, number); if (ret != SQLITE_OK) throw Exception("sqlite3_bind_int: %d", ret); this->param_index++; return *this; } DBStatement & DBStatement::reset() { int ret = sqlite3_reset(this->stmt.get()); if (ret != SQLITE_OK) throw Exception("sqlite3_reset: %d", ret); this->param_index = 1; return *this; } void DBStatement::execute() { int ret = sqlite3_step(this->stmt.get()); if (ret != SQLITE_DONE) throw Exception("sqlite3_step: %d", ret); } DBQueryRow DBStatement::row() { int ret = sqlite3_step(this->stmt.get()); if (ret != SQLITE_ROW && ret != SQLITE_DONE) throw Exception("sqlite3_step: %d", ret); return { *this }; } DBQueryRow::DBQueryRow(DBStatement & parent) : parent(parent) { } template <> const char * DBQueryRow::col(int index, const char * const & default_value) { int type = sqlite3_column_type(this->parent.stmt.get(), index); if (type == SQLITE_NULL) return default_value; return reinterpret_cast(sqlite3_column_text(this->parent.stmt.get(), index)); } template <> const char * DBQueryRow::col(int index) { return this->col(index, ""); } template <> int DBQueryRow::col(int index, const int & default_value) { int type = sqlite3_column_type(this->parent.stmt.get(), index); if (type == SQLITE_NULL) return default_value; return sqlite3_column_int(this->parent.stmt.get(), index); } template <> int DBQueryRow::col(int index) { return this->col(index, 0); } DBQueryRowRange DBStatement::rows() { return { *this }; } DBQueryRowRange::DBQueryRowIterator & DBQueryRowRange::DBQueryRowIterator::operator ++ () { int ret = sqlite3_step(this->parent.parent.stmt.get()); if (ret == SQLITE_DONE) { this->end = true; return *this; } if (ret != SQLITE_ROW) throw Exception("sqlite3_step: %d", ret); return *this; } bool DBQueryRowRange::DBQueryRowIterator::operator != (const DBQueryRowIterator &) const { return !this->end; } DBQueryRow & DBQueryRowRange::DBQueryRowIterator::operator * () const { return this->parent.row; } DBQueryRowRange::DBQueryRowIterator DBQueryRowRange::begin() { DBQueryRowRange::DBQueryRowIterator it = { *this }; return ++it; } DBQueryRowRange::DBQueryRowIterator DBQueryRowRange::end() { return { *this }; } DBQueryRowRange::DBQueryRowRange(DBStatement & parent) : parent(parent), row(parent) { } DBQueryRowRange::DBQueryRowIterator::DBQueryRowIterator(DBQueryRowRange & parent) : parent(parent) { }