diff --git a/README.md b/README.md index 052885e..81ef56d 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,10 @@ A lightweight, distributed SQL database engine. Designed for cloud environments - **Analytics Performance**: - **Columnar Storage**: Binary-per-column persistence for efficient analytical scanning. - **Vectorized Execution**: Batch-at-a-time processing model for high-throughput query execution. -- **Multi-Node Transactions**: ACID guarantees across the cluster via Two-Phase Commit (2PC). +- **Multi-Node Transactions**: ACID guarantees across the cluster via Two-Phase Commit (2PC) and connection-aware execution state supporting `BEGIN`, `COMMIT`, and `ROLLBACK`. +- **Advanced Execution Engine**: + - **Full Outer Join Support**: Specialized `HashJoinOperator` implementing `LEFT`, `RIGHT`, and `FULL` outer join semantics with automatic null-padding. + - **B+ Tree Indexing**: Persistent indexing for high-speed point lookups and optimized query planning. - **Type-Safe Value System**: Robust handling of SQL data types using `std::variant`. - **Volcano & Vectorized Engine**: Flexible execution models supporting traditional row-based and high-performance columnar processing. - **PostgreSQL Wire Protocol**: Handshake and simple query protocol implementation for tool compatibility. @@ -46,17 +49,18 @@ A lightweight, distributed SQL database engine. Designed for cloud environments mkdir build cd build cmake .. -make -j$(nproc) +make -j$(nproc) # Or ../tests/run_test.sh for automated multi-OS build ``` ### Running Tests ```bash -# Run all tests +# Run the integrated test suite (Unit + E2E + Logic) +./tests/run_test.sh + +# Or run individual binaries ./build/sqlEngine_tests -# Run distributed-specific tests ./build/distributed_tests -./build/distributed_txn_tests ``` ### Starting the Cluster diff --git a/docs/phases/README.md b/docs/phases/README.md index d153481..0e44249 100644 --- a/docs/phases/README.md +++ b/docs/phases/README.md @@ -56,10 +56,10 @@ This directory contains the technical documentation for the lifecycle of the clo ### Phase 9 — Stability & Testing Refinement **Focus**: Engine Robustness & E2E Validation. -- Slotted-page layout fixes for large table support. -- Buffer Pool Manager lifecycle management (destructor flushing). -- Robust Python E2E client with partial-read handling and numeric validation. -- Standardized test orchestration via `run_test.sh`. +- **Advanced Execution**: Full support for `LEFT`, `RIGHT`, and `FULL` outer joins. +- **Transactional Integrity**: Persistent connection-based execution state and comprehensive `ROLLBACK` support for all DML operations. +- **Logic Validation**: Integration of the SqlLogicTest (SLT) suite with 80+ logic test cases covering Joins, Transactions, Aggregates, and Indexes. +- **Automation**: Standardized cross-platform test orchestration via `run_test.sh` with automatic CPU detection. --- diff --git a/include/executor/operator.hpp b/include/executor/operator.hpp index e63738b..f44820a 100644 --- a/include/executor/operator.hpp +++ b/include/executor/operator.hpp @@ -326,12 +326,12 @@ class HashJoinOperator : public Operator { class LimitOperator : public Operator { private: std::unique_ptr child_; - uint64_t limit_; - uint64_t offset_; + int64_t limit_; + int64_t offset_; uint64_t current_count_ = 0; public: - LimitOperator(std::unique_ptr child, uint64_t limit, uint64_t offset = 0); + LimitOperator(std::unique_ptr child, int64_t limit, int64_t offset = 0); bool init() override; bool open() override; diff --git a/include/executor/query_executor.hpp b/include/executor/query_executor.hpp index 2e3fc79..aa0200b 100644 --- a/include/executor/query_executor.hpp +++ b/include/executor/query_executor.hpp @@ -44,7 +44,7 @@ class QueryExecutor { transaction::TransactionManager& transaction_manager, recovery::LogManager* log_manager = nullptr, cluster::ClusterManager* cluster_manager = nullptr); - ~QueryExecutor() = default; + ~QueryExecutor(); // Disable copy/move for executor QueryExecutor(const QueryExecutor&) = delete; @@ -74,6 +74,7 @@ class QueryExecutor { QueryResult execute_select(const parser::SelectStatement& stmt, transaction::Transaction* txn); QueryResult execute_create_table(const parser::CreateTableStatement& stmt); + QueryResult execute_create_index(const parser::CreateIndexStatement& stmt); QueryResult execute_drop_table(const parser::DropTableStatement& stmt); QueryResult execute_drop_index(const parser::DropIndexStatement& stmt); QueryResult execute_insert(const parser::InsertStatement& stmt, transaction::Transaction* txn); diff --git a/include/parser/statement.hpp b/include/parser/statement.hpp index dfed208..946d86c 100644 --- a/include/parser/statement.hpp +++ b/include/parser/statement.hpp @@ -73,8 +73,8 @@ class SelectStatement : public Statement { std::vector> group_by_; std::unique_ptr having_; std::vector> order_by_; - int64_t limit_ = 0; - int64_t offset_ = 0; + int64_t limit_ = -1; + int64_t offset_ = -1; bool distinct_ = false; public: @@ -112,7 +112,7 @@ class SelectStatement : public Statement { [[nodiscard]] int64_t limit() const { return limit_; } [[nodiscard]] int64_t offset() const { return offset_; } [[nodiscard]] bool distinct() const { return distinct_; } - [[nodiscard]] bool has_limit() const { return limit_ > 0; } + [[nodiscard]] bool has_limit() const { return limit_ >= 0; } [[nodiscard]] bool has_offset() const { return offset_ > 0; } [[nodiscard]] std::string to_string() const override; diff --git a/include/parser/token.hpp b/include/parser/token.hpp index 52ab882..e73d832 100644 --- a/include/parser/token.hpp +++ b/include/parser/token.hpp @@ -49,6 +49,7 @@ enum class TokenType : uint8_t { Join, Left, Right, + Full, Inner, Outer, Order, diff --git a/include/storage/heap_table.hpp b/include/storage/heap_table.hpp index 0f34bfc..131a666 100644 --- a/include/storage/heap_table.hpp +++ b/include/storage/heap_table.hpp @@ -170,6 +170,12 @@ class HeapTable { */ bool physical_remove(const TupleId& tuple_id); + /** + * @brief Resets xmax to 0 (used for rollback of a DELETE) + * @return true on success + */ + bool undo_remove(const TupleId& tuple_id); + /** * @brief Replaces an existing record with new data * @param tuple_id The record to update diff --git a/include/transaction/transaction.hpp b/include/transaction/transaction.hpp index 870fb77..68bca41 100644 --- a/include/transaction/transaction.hpp +++ b/include/transaction/transaction.hpp @@ -7,7 +7,9 @@ #define CLOUDSQL_TRANSACTION_TRANSACTION_HPP #include +#include #include +#include #include #include @@ -55,6 +57,7 @@ struct UndoLog { Type type = Type::INSERT; std::string table_name; storage::HeapTable::TupleId rid; + std::optional old_rid; }; /** @@ -120,7 +123,17 @@ class Transaction { void add_undo_log(UndoLog::Type type, const std::string& table_name, const storage::HeapTable::TupleId& rid) { - undo_logs_.push_back({type, table_name, rid}); + /* Enforce invariant: non-UPDATE types should not provide old_rid through this overload */ + assert(type != UndoLog::Type::UPDATE); + undo_logs_.push_back({type, table_name, rid, std::nullopt}); + } + + void add_undo_log(UndoLog::Type type, const std::string& table_name, + const storage::HeapTable::TupleId& rid, + const storage::HeapTable::TupleId& old_rid) { + /* Enforce invariant: this overload is primarily for UPDATE types providing old_rid */ + assert(type == UndoLog::Type::UPDATE); + undo_logs_.push_back({type, table_name, rid, old_rid}); } [[nodiscard]] const std::vector& get_undo_logs() const { return undo_logs_; } diff --git a/include/transaction/transaction_manager.hpp b/include/transaction/transaction_manager.hpp index 65e71c6..486e260 100644 --- a/include/transaction/transaction_manager.hpp +++ b/include/transaction/transaction_manager.hpp @@ -82,7 +82,7 @@ class TransactionManager { /** * @brief Undo changes made by a transaction */ - void undo_transaction(Transaction* txn); + bool undo_transaction(Transaction* txn); }; } // namespace cloudsql::transaction diff --git a/src/executor/operator.cpp b/src/executor/operator.cpp index a09ed66..7f6cc49 100644 --- a/src/executor/operator.cpp +++ b/src/executor/operator.cpp @@ -162,8 +162,12 @@ bool IndexScanOperator::next(Tuple& out_tuple) { while (current_match_index_ < matching_ids_.size()) { const auto& tid = matching_ids_[current_match_index_++]; + storage::HeapTable::TupleId rid; + rid.page_num = tid.page_num; + rid.slot_num = tid.slot_num; + storage::HeapTable::TupleMeta meta; - if (table_->get_meta(tid, meta)) { + if (table_->get_meta(rid, meta)) { /* MVCC Visibility Check */ bool visible = true; const Transaction* const txn = get_txn(); @@ -734,7 +738,7 @@ void HashJoinOperator::add_child(std::unique_ptr child) { /* --- LimitOperator --- */ -LimitOperator::LimitOperator(std::unique_ptr child, uint64_t limit, uint64_t offset) +LimitOperator::LimitOperator(std::unique_ptr child, int64_t limit, int64_t offset) : Operator(OperatorType::Limit, child->get_txn(), child->get_lock_manager()), child_(std::move(child)), limit_(limit), @@ -750,9 +754,12 @@ bool LimitOperator::open() { } /* Skip offset rows */ + current_count_ = 0; Tuple tuple; - while (current_count_ < offset_ && child_->next(tuple)) { - current_count_++; + if (offset_ > 0) { + while (current_count_ < static_cast(offset_) && child_->next(tuple)) { + current_count_++; + } } current_count_ = 0; set_state(ExecState::Open); @@ -760,7 +767,7 @@ bool LimitOperator::open() { } bool LimitOperator::next(Tuple& out_tuple) { - if (current_count_ >= limit_) { + if (limit_ >= 0 && current_count_ >= static_cast(limit_)) { set_state(ExecState::Done); return false; } diff --git a/src/executor/query_executor.cpp b/src/executor/query_executor.cpp index ebd31b0..9c47ab7 100644 --- a/src/executor/query_executor.cpp +++ b/src/executor/query_executor.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -39,6 +40,29 @@ namespace cloudsql::executor { +namespace { +enum class IndexOp { Insert, Remove }; + +/** + * @brief Helper to perform index writes and check for success + */ +bool apply_index_write(storage::BTreeIndex& index, const common::Value& key, + const storage::HeapTable::TupleId& rid, IndexOp op, std::string& error_msg) { + bool success = false; + if (op == IndexOp::Insert) { + success = index.insert(key, rid); + } else { + success = index.remove(key, rid); + } + + if (!success) { + error_msg = "Index operation failed for key: " + key.to_string(); + return false; + } + return true; +} +} // namespace + void ShardStateMachine::apply(const raft::LogEntry& entry) { if (entry.data.empty()) return; @@ -93,6 +117,12 @@ QueryExecutor::QueryExecutor(Catalog& catalog, storage::BufferPoolManager& bpm, log_manager_(log_manager), cluster_manager_(cluster_manager) {} +QueryExecutor::~QueryExecutor() { + if (current_txn_ != nullptr) { + transaction_manager_.abort(current_txn_); + } +} + QueryResult QueryExecutor::execute(const parser::Statement& stmt) { const auto start = std::chrono::high_resolution_clock::now(); QueryResult result; @@ -123,6 +153,8 @@ QueryResult QueryExecutor::execute(const parser::Statement& stmt) { result = execute_select(dynamic_cast(stmt), txn); } else if (stmt.type() == parser::StmtType::CreateTable) { result = execute_create_table(dynamic_cast(stmt)); + } else if (stmt.type() == parser::StmtType::CreateIndex) { + result = execute_create_index(dynamic_cast(stmt)); } else if (stmt.type() == parser::StmtType::DropTable) { result = execute_drop_table(dynamic_cast(stmt)); } else if (stmt.type() == parser::StmtType::DropIndex) { @@ -268,6 +300,82 @@ QueryResult QueryExecutor::execute_create_table(const parser::CreateTableStateme return result; } +QueryResult QueryExecutor::execute_create_index(const parser::CreateIndexStatement& stmt) { + QueryResult result; + + /* Reject composite indexes */ + if (stmt.columns().size() != 1) { + result.set_error("Composite indexes not supported"); + return result; + } + + auto table_meta_opt = catalog_.get_table_by_name(stmt.table_name()); + if (!table_meta_opt.has_value()) { + result.set_error("Table not found: " + stmt.table_name()); + return result; + } + const auto* table_meta = table_meta_opt.value(); + + std::vector col_positions; + common::ValueType key_type = common::ValueType::TYPE_NULL; + + const auto& col_name = stmt.columns()[0]; + bool found = false; + for (const auto& col : table_meta->columns) { + if (col.name == col_name) { + col_positions.push_back(col.position); + key_type = col.type; + found = true; + break; + } + } + if (!found) { + result.set_error("Column not found: " + col_name); + return result; + } + + /* Update Catalog */ + const oid_t index_id = catalog_.create_index(stmt.index_name(), table_meta->table_id, + col_positions, IndexType::BTree, stmt.unique()); + if (index_id == 0) { + result.set_error("Failed to create index in catalog"); + return result; + } + + /* Create Physical Index File */ + storage::BTreeIndex index(stmt.index_name(), bpm_, key_type); + if (!index.create()) { + static_cast(catalog_.drop_index(index_id)); + result.set_error("Failed to create index file"); + return result; + } + + /* Populate Index with existing data (Backfill) */ + Schema schema; + for (const auto& col : table_meta->columns) { + schema.add_column(col.name, col.type); + } + storage::HeapTable table(stmt.table_name(), bpm_, schema); + auto iter = table.scan(); + storage::HeapTable::TupleMeta meta; + std::string err; + while (iter.next_meta(meta)) { + if (meta.xmax == 0) { + /* Extract key from tuple */ + const common::Value& key = meta.tuple.get(col_positions[0]); + if (!apply_index_write(index, key, iter.current_id(), IndexOp::Insert, err)) { + static_cast(index.drop()); + static_cast(catalog_.drop_index(index_id)); + result.set_error(err); + return result; + } + } + } + + result.set_rows_affected(1); + return result; +} + QueryResult QueryExecutor::execute_insert(const parser::InsertStatement& stmt, transaction::Transaction* txn) { QueryResult result; @@ -328,6 +436,19 @@ QueryResult QueryExecutor::execute_insert(const parser::InsertStatement& stmt, const auto tid = table.insert(tuple, xmin); + /* Update Indexes */ + std::string err; + for (const auto& idx_info : table_meta->indexes) { + if (!idx_info.column_positions.empty()) { + uint16_t pos = idx_info.column_positions[0]; + common::ValueType ktype = table_meta->columns[pos].type; + storage::BTreeIndex index(idx_info.name, bpm_, ktype); + if (!apply_index_write(index, tuple.get(pos), tid, IndexOp::Insert, err)) { + throw std::runtime_error(err); + } + } + } + /* Log INSERT */ if (log_manager_ != nullptr && txn != nullptr) { recovery::LogRecord log(txn->get_id(), txn->get_prev_lsn(), @@ -410,13 +531,31 @@ QueryResult QueryExecutor::execute_delete(const parser::DeleteStatement& stmt, } } - /* Retrieve old tuple for logging */ + /* Retrieve old tuple for logging and index maintenance (unconditional) */ Tuple old_tuple; - if (log_manager_ != nullptr && txn != nullptr) { - static_cast(table.get(rid, old_tuple)); + if (!table.get(rid, old_tuple)) { + result.set_error("Failed to retrieve tuple for deletion maintenance: " + + rid.to_string()); + return result; } if (table.remove(rid, xmax)) { + /* Update Indexes */ + std::string err; + if (!old_tuple.empty()) { + for (const auto& idx_info : table_meta->indexes) { + if (!idx_info.column_positions.empty()) { + uint16_t pos = idx_info.column_positions[0]; + common::ValueType ktype = table_meta->columns[pos].type; + storage::BTreeIndex index(idx_info.name, bpm_, ktype); + if (!apply_index_write(index, old_tuple.get(pos), rid, IndexOp::Remove, + err)) { + throw std::runtime_error(err); + } + } + } + } + /* Log DELETE */ if (log_manager_ != nullptr && txn != nullptr) { recovery::LogRecord log(txn->get_id(), txn->get_prev_lsn(), @@ -460,6 +599,7 @@ QueryResult QueryExecutor::execute_update(const parser::UpdateStatement& stmt, /* Phase 1: Collect RIDs and compute new values to avoid Halloween Problem */ struct UpdateOp { storage::HeapTable::TupleId rid; + Tuple old_tuple; Tuple new_tuple; }; std::vector updates; @@ -482,30 +622,51 @@ QueryResult QueryExecutor::execute_update(const parser::UpdateStatement& stmt, new_tuple.set(idx, val_expr->evaluate(&meta.tuple, &schema)); } } - updates.push_back({iter.current_id(), std::move(new_tuple)}); + updates.push_back({iter.current_id(), meta.tuple, std::move(new_tuple)}); } } /* Phase 2: Apply Updates */ for (const auto& op : updates) { - /* Retrieve old tuple for logging */ - Tuple old_tuple; - if (log_manager_ != nullptr && txn != nullptr) { - static_cast(table.get(op.rid, old_tuple)); - } - if (table.remove(op.rid, txn_id)) { + /* Update Indexes - Remove old, Insert new */ + std::string err; + for (const auto& idx_info : table_meta->indexes) { + if (!idx_info.column_positions.empty()) { + uint16_t pos = idx_info.column_positions[0]; + common::ValueType ktype = table_meta->columns[pos].type; + storage::BTreeIndex index(idx_info.name, bpm_, ktype); + if (!apply_index_write(index, op.old_tuple.get(pos), op.rid, IndexOp::Remove, + err)) { + throw std::runtime_error(err); + } + } + } + /* Log DELETE part of update */ if (log_manager_ != nullptr && txn != nullptr) { recovery::LogRecord log(txn->get_id(), txn->get_prev_lsn(), recovery::LogRecordType::MARK_DELETE, table_name, op.rid, - old_tuple); + op.old_tuple); const auto lsn = log_manager_->append_log_record(log); txn->set_prev_lsn(lsn); } const auto new_tid = table.insert(op.new_tuple, txn_id); + /* Update Indexes - Insert new */ + for (const auto& idx_info : table_meta->indexes) { + if (!idx_info.column_positions.empty()) { + uint16_t pos = idx_info.column_positions[0]; + common::ValueType ktype = table_meta->columns[pos].type; + storage::BTreeIndex index(idx_info.name, bpm_, ktype); + if (!apply_index_write(index, op.new_tuple.get(pos), new_tid, IndexOp::Insert, + err)) { + throw std::runtime_error(err); + } + } + } + /* Log INSERT part of update */ if (log_manager_ != nullptr && txn != nullptr) { recovery::LogRecord log(txn->get_id(), txn->get_prev_lsn(), @@ -516,8 +677,7 @@ QueryResult QueryExecutor::execute_update(const parser::UpdateStatement& stmt, } if (txn != nullptr) { - txn->add_undo_log(transaction::UndoLog::Type::UPDATE, table_name, op.rid); - txn->add_undo_log(transaction::UndoLog::Type::INSERT, table_name, new_tid); + txn->add_undo_log(transaction::UndoLog::Type::UPDATE, table_name, new_tid, op.rid); } rows_updated++; } @@ -529,7 +689,7 @@ QueryResult QueryExecutor::execute_update(const parser::UpdateStatement& stmt, std::unique_ptr QueryExecutor::build_plan(const parser::SelectStatement& stmt, transaction::Transaction* txn) { - /* 1. Base: SeqScan of the initial table */ + /* 1. Base: Initial table access (Sequential Scan or Index Scan) */ if (!stmt.from()) { return nullptr; } @@ -540,9 +700,7 @@ std::unique_ptr QueryExecutor::build_plan(const parser::SelectStatemen if (cluster_manager_ != nullptr && cluster_manager_->has_shuffle_data(context_id_, base_table_name)) { auto data = cluster_manager_->fetch_shuffle_data(context_id_, base_table_name); - /* We need a schema for the buffered data. For simplicity, we assume - * the first table in the FROM clause has a catalog entry we can use. - */ + /* We need a schema for the buffered data. */ auto meta_opt = catalog_.get_table_by_name(base_table_name); Schema buffer_schema; if (meta_opt.has_value()) { @@ -565,9 +723,60 @@ std::unique_ptr QueryExecutor::build_plan(const parser::SelectStatemen base_schema.add_column(col.name, col.type); } - std::unique_ptr current_root = std::make_unique( - std::make_unique(base_table_name, bpm_, base_schema), txn, - &lock_manager_); + /* Index Selection Optimization: + * If there's a simple equality filter on an indexed column, use IndexScanOperator. + */ + std::unique_ptr current_root = nullptr; + bool index_used = false; + + if (stmt.where() && stmt.where()->type() == parser::ExprType::Binary && stmt.joins().empty()) { + const auto* bin_expr = dynamic_cast(stmt.where()); + if (bin_expr->op() == parser::TokenType::Eq) { + std::string col_name; + common::Value const_val; + bool eligible = false; + + if (bin_expr->left().type() == parser::ExprType::Column && + bin_expr->right().type() == parser::ExprType::Constant) { + col_name = bin_expr->left().to_string(); + const_val = bin_expr->right().evaluate(); + eligible = true; + } else if (bin_expr->right().type() == parser::ExprType::Column && + bin_expr->left().type() == parser::ExprType::Constant) { + col_name = bin_expr->right().to_string(); + const_val = bin_expr->left().evaluate(); + eligible = true; + } + + if (eligible) { + /* Check if col_name is indexed */ + for (const auto& idx_info : base_table_meta->indexes) { + if (!idx_info.column_positions.empty()) { + uint16_t pos = idx_info.column_positions[0]; + /* Handle both qualified and unqualified names */ + if (base_table_meta->columns[pos].name == col_name || + (base_table_name + "." + base_table_meta->columns[pos].name) == + col_name) { + common::ValueType ktype = base_table_meta->columns[pos].type; + current_root = std::make_unique( + std::make_unique(base_table_name, bpm_, + base_schema), + std::make_unique(idx_info.name, bpm_, ktype), + std::move(const_val), txn, &lock_manager_); + index_used = true; + break; + } + } + } + } + } + } + + if (!index_used) { + current_root = std::make_unique( + std::make_unique(base_table_name, bpm_, base_schema), txn, + &lock_manager_); + } /* 2. Add JOINs */ for (const auto& join : stmt.joins()) { @@ -605,11 +814,6 @@ std::unique_ptr QueryExecutor::build_plan(const parser::SelectStatemen &lock_manager_); } - /* For now, we use HashJoin if a condition exists, otherwise NestedLoop would be needed. - * Note: HashJoin requires equality condition. We'll assume equality for now or default to - * NLJ. Currently cloudSQL only has HashJoin implemented in operator.cpp. - */ - bool use_hash_join = false; std::unique_ptr left_key = nullptr; std::unique_ptr right_key = nullptr; @@ -667,8 +871,8 @@ std::unique_ptr QueryExecutor::build_plan(const parser::SelectStatemen } } - /* 3. Filter (WHERE) */ - if (stmt.where()) { + /* 3. Filter (WHERE) - Only if not already handled by IndexScan */ + if (stmt.where() && !index_used) { current_root = std::make_unique(std::move(current_root), stmt.where()->clone()); } diff --git a/src/network/server.cpp b/src/network/server.cpp index 3cc8f3b..0cc4259 100644 --- a/src/network/server.cpp +++ b/src/network/server.cpp @@ -335,6 +335,8 @@ void Server::handle_connection(int client_fd) { static_cast(send(client_fd, ready.data(), ready.size(), 0)); // 2. Query Loop + executor::QueryExecutor exec(catalog_, bpm_, lock_manager_, transaction_manager_); + while (true) { char type = 0; n = recv(client_fd, &type, 1, 0); @@ -365,8 +367,6 @@ void Server::handle_connection(int client_fd) { executor::DistributedExecutor dist_exec(catalog_, *cluster_manager_); res = dist_exec.execute(*stmt, sql); } else { - executor::QueryExecutor exec(catalog_, bpm_, lock_manager_, - transaction_manager_); res = exec.execute(*stmt); } diff --git a/src/parser/lexer.cpp b/src/parser/lexer.cpp index e086fdc..45a4e29 100644 --- a/src/parser/lexer.cpp +++ b/src/parser/lexer.cpp @@ -58,6 +58,7 @@ std::map Lexer::init_keywords() { {"RIGHT", TokenType::Right}, {"INNER", TokenType::Inner}, {"OUTER", TokenType::Outer}, + {"FULL", TokenType::Full}, {"GROUP", TokenType::Group}, {"BY", TokenType::By}, {"ORDER", TokenType::Order}, diff --git a/src/parser/parser.cpp b/src/parser/parser.cpp index fd1729e..5fa4b54 100644 --- a/src/parser/parser.cpp +++ b/src/parser/parser.cpp @@ -129,10 +129,23 @@ std::unique_ptr Parser::parse_select() { if (consume(TokenType::Join)) { join_type = SelectStatement::JoinType::Inner; } else if (consume(TokenType::Left)) { + static_cast(consume(TokenType::Outer)); if (!consume(TokenType::Join)) { return nullptr; } join_type = SelectStatement::JoinType::Left; + } else if (consume(TokenType::Right)) { + static_cast(consume(TokenType::Outer)); + if (!consume(TokenType::Join)) { + return nullptr; + } + join_type = SelectStatement::JoinType::Right; + } else if (consume(TokenType::Full)) { + static_cast(consume(TokenType::Outer)); + if (!consume(TokenType::Join)) { + return nullptr; + } + join_type = SelectStatement::JoinType::Full; } else { break; } @@ -234,23 +247,30 @@ std::unique_ptr Parser::parse_select() { } } - /* LIMIT */ - if (consume(TokenType::Limit)) { - const Token val = next_token(); - if (val.type() == TokenType::Number) { - stmt->set_limit(val.as_int64()); - } else { - return nullptr; - } - } - - /* OFFSET */ - if (consume(TokenType::Offset)) { - const Token val = next_token(); - if (val.type() == TokenType::Number) { - stmt->set_offset(val.as_int64()); + /* LIMIT and OFFSET */ + bool limit_set = false; + bool offset_set = false; + while (true) { + if (consume(TokenType::Limit)) { + if (limit_set) return nullptr; + const Token val = next_token(); + if (val.type() == TokenType::Number) { + stmt->set_limit(val.as_int64()); + limit_set = true; + } else { + return nullptr; + } + } else if (consume(TokenType::Offset)) { + if (offset_set) return nullptr; + const Token val = next_token(); + if (val.type() == TokenType::Number) { + stmt->set_offset(val.as_int64()); + offset_set = true; + } else { + return nullptr; + } } else { - return nullptr; + break; } } diff --git a/src/storage/heap_table.cpp b/src/storage/heap_table.cpp index 52dceab..73f6ebb 100644 --- a/src/storage/heap_table.cpp +++ b/src/storage/heap_table.cpp @@ -306,6 +306,13 @@ bool HeapTable::physical_remove(const TupleId& tuple_id) { return write_page(tuple_id.page_num, buffer.data()); } +/** + * @brief Reset xmax to 0 (used for rollback of a DELETE) + */ +bool HeapTable::undo_remove(const TupleId& tuple_id) { + return remove(tuple_id, 0); +} + bool HeapTable::update(const TupleId& tuple_id, const executor::Tuple& tuple, uint64_t txn_id) { if (!remove(tuple_id, txn_id)) { return false; diff --git a/src/transaction/transaction_manager.cpp b/src/transaction/transaction_manager.cpp index 2d79491..c000f87 100644 --- a/src/transaction/transaction_manager.cpp +++ b/src/transaction/transaction_manager.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -15,6 +16,7 @@ #include "executor/types.hpp" #include "recovery/log_manager.hpp" #include "recovery/log_record.hpp" +#include "storage/btree_index.hpp" #include "storage/buffer_pool_manager.hpp" #include "storage/heap_table.hpp" #include "transaction/lock_manager.hpp" @@ -155,38 +157,134 @@ void TransactionManager::abort(Transaction* txn) { } } -void TransactionManager::undo_transaction(Transaction* txn) { +bool TransactionManager::undo_transaction(Transaction* txn) { const auto& logs = txn->get_undo_logs(); + bool success = true; /* Undo in reverse order */ for (auto it = logs.rbegin(); it != logs.rend(); ++it) { const auto& log = *it; - auto table_meta = catalog_.get_table_by_name(log.table_name); - if (!table_meta) { + auto table_meta_opt = catalog_.get_table_by_name(log.table_name); + if (!table_meta_opt) { + std::cerr << "Rollback ERROR: Table metadata not found for '" << log.table_name + << "' during undo. Transaction: " << txn->get_id() << "\n"; + success = false; continue; } + const auto* table_meta = table_meta_opt.value(); /* Reconstruct schema for HeapTable */ executor::Schema schema; - for (const auto& col : (*table_meta)->columns) { + for (const auto& col : table_meta->columns) { schema.add_column(col.name, col.type); } storage::HeapTable table(log.table_name, bpm_, schema); switch (log.type) { - case UndoLog::Type::INSERT: - static_cast(table.physical_remove(log.rid)); + case UndoLog::Type::INSERT: { + /* For INSERT undo, remove from indexes and then physical remove from heap */ + executor::Tuple tuple; + if (table.get(log.rid, tuple)) { + for (const auto& idx_info : table_meta->indexes) { + if (!idx_info.column_positions.empty()) { + uint16_t pos = idx_info.column_positions[0]; + common::ValueType ktype = table_meta->columns[pos].type; + storage::BTreeIndex index(idx_info.name, bpm_, ktype); + if (!index.remove(tuple.get(pos), log.rid)) { + std::cerr << "Rollback ERROR: Index remove failed for table '" + << log.table_name << "', index '" << idx_info.name + << "'\n"; + success = false; + } + } + } + } + if (!table.physical_remove(log.rid)) { + std::cerr << "Rollback ERROR: physical_remove failed for INSERT undo\n"; + success = false; + } break; - case UndoLog::Type::DELETE: - /* TODO: Implement DELETE undo */ - static_cast(0); + } + case UndoLog::Type::DELETE: { + /* For DELETE undo, reset xmax and re-insert into indexes */ + if (!table.undo_remove(log.rid)) { + std::cerr << "Rollback ERROR: undo_remove failed for DELETE undo\n"; + success = false; + } else { + executor::Tuple tuple; + if (table.get(log.rid, tuple)) { + for (const auto& idx_info : table_meta->indexes) { + if (!idx_info.column_positions.empty()) { + uint16_t pos = idx_info.column_positions[0]; + common::ValueType ktype = table_meta->columns[pos].type; + storage::BTreeIndex index(idx_info.name, bpm_, ktype); + if (!index.insert(tuple.get(pos), log.rid)) { + std::cerr << "Rollback ERROR: Index insert failed for table '" + << log.table_name << "', index '" << idx_info.name + << "'\n"; + success = false; + } + } + } + } + } break; - case UndoLog::Type::UPDATE: - /* TODO: Implement UPDATE undo */ - static_cast(1); + } + case UndoLog::Type::UPDATE: { + /* For UPDATE undo, remove new version from indexes/heap and restore old version's + * xmax/indexes */ + executor::Tuple new_tuple; + if (table.get(log.rid, new_tuple)) { + for (const auto& idx_info : table_meta->indexes) { + if (!idx_info.column_positions.empty()) { + uint16_t pos = idx_info.column_positions[0]; + common::ValueType ktype = table_meta->columns[pos].type; + storage::BTreeIndex index(idx_info.name, bpm_, ktype); + if (!index.remove(new_tuple.get(pos), log.rid)) { + std::cerr << "Rollback ERROR: Index remove failed for table '" + << log.table_name << "', index '" << idx_info.name + << "'\n"; + success = false; + } + } + } + } + if (!table.physical_remove(log.rid)) { + std::cerr << "Rollback ERROR: physical_remove failed for new version in UPDATE " + "undo\n"; + success = false; + } + + if (log.old_rid.has_value()) { + if (!table.undo_remove(log.old_rid.value())) { + std::cerr << "Rollback ERROR: undo_remove failed for old version in UPDATE " + "undo\n"; + success = false; + } else { + executor::Tuple old_tuple; + if (table.get(log.old_rid.value(), old_tuple)) { + for (const auto& idx_info : table_meta->indexes) { + if (!idx_info.column_positions.empty()) { + uint16_t pos = idx_info.column_positions[0]; + common::ValueType ktype = table_meta->columns[pos].type; + storage::BTreeIndex index(idx_info.name, bpm_, ktype); + if (!index.insert(old_tuple.get(pos), log.old_rid.value())) { + std::cerr + << "Rollback ERROR: Index insert failed for table '" + << log.table_name << "', index '" << idx_info.name + << "'\n"; + success = false; + } + } + } + } + } + } break; + } } } + return success; } Transaction* TransactionManager::get_transaction(txn_id_t txn_id) { diff --git a/test_data/idx_id.idx b/test_data/idx_id.idx new file mode 100644 index 0000000..e69de29 diff --git a/tests/logic/indexes.slt b/tests/logic/indexes.slt new file mode 100644 index 0000000..980e1f1 --- /dev/null +++ b/tests/logic/indexes.slt @@ -0,0 +1,48 @@ +# Index Tests + +statement ok +CREATE TABLE idx_test (id INT, val TEXT); + +statement ok +INSERT INTO idx_test VALUES (1, 'one'), (2, 'two'), (3, 'three'), (4, 'four'), (5, 'five'); + +# 1. Create a BTree index on 'id' +statement ok +CREATE INDEX idx_id ON idx_test (id); + +# 2. Point lookup using indexed column +query IT +SELECT id, val FROM idx_test WHERE id = 3; +---- +3 three + +# 3. Point lookup with no match +query I +SELECT id FROM idx_test WHERE id = 10; +---- + +# 4. Range-like query (if supported, but usually SeqScan fallback) +query I +SELECT id FROM idx_test WHERE id > 2 ORDER BY id; +---- +3 +4 +5 + +# 5. Drop index and verify it can be recreated (proving it was actually dropped) +statement ok +DROP INDEX idx_id; + +statement ok +CREATE INDEX idx_id ON idx_test (id); + +statement ok +DROP INDEX idx_id; + +query IT +SELECT id, val FROM idx_test WHERE id = 1; +---- +1 one + +statement ok +DROP TABLE idx_test; diff --git a/tests/logic/joins.slt b/tests/logic/joins.slt index eee732d..eec8c8c 100644 --- a/tests/logic/joins.slt +++ b/tests/logic/joins.slt @@ -12,7 +12,7 @@ INSERT INTO users_j VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); statement ok INSERT INTO orders_j VALUES (101, 1, 50.0), (102, 1, 25.0), (103, 2, 100.0); -# Inner Join +# 1. Inner Join query TR SELECT users_j.name, orders_j.amount FROM users_j JOIN orders_j ON users_j.id = orders_j.user_id ORDER BY orders_j.amount; ---- @@ -20,13 +20,13 @@ Alice 25.0 Alice 50.0 Bob 100.0 -# Join with where +# 2. Join with where query TI SELECT users_j.name, orders_j.id FROM users_j JOIN orders_j ON users_j.id = orders_j.user_id WHERE orders_j.amount > 60; ---- Bob 103 -# Left Join (Charlie has no orders) +# 3. Left Join (Charlie has no orders) query TR SELECT users_j.name, orders_j.amount FROM users_j LEFT JOIN orders_j ON users_j.id = orders_j.user_id ORDER BY users_j.name, orders_j.amount; ---- @@ -35,6 +35,29 @@ Alice 50.0 Bob 100.0 Charlie NULL +# 4. Add an order with no user +statement ok +INSERT INTO orders_j VALUES (104, 99, 75.0); + +# 5. Right Join (User 99 doesn't exist) +query TRR +SELECT users_j.name, orders_j.amount, orders_j.id FROM users_j RIGHT JOIN orders_j ON users_j.id = orders_j.user_id ORDER BY orders_j.amount; +---- +Alice 25.0 102 +Alice 50.0 101 +NULL 75.0 104 +Bob 100.0 103 + +# 6. Full Join (Charlie has no orders, User 99 doesn't exist) +query TRR +SELECT users_j.name, orders_j.amount, orders_j.id FROM users_j FULL JOIN orders_j ON users_j.id = orders_j.user_id ORDER BY users_j.name, orders_j.amount; +---- +Alice 25.0 102 +Alice 50.0 101 +Bob 100.0 103 +Charlie NULL NULL +NULL 75.0 104 + statement ok DROP TABLE users_j; diff --git a/tests/logic/limit_offset.slt b/tests/logic/limit_offset.slt new file mode 100644 index 0000000..ff86705 --- /dev/null +++ b/tests/logic/limit_offset.slt @@ -0,0 +1,61 @@ +# Limit and Offset Tests + +statement ok +CREATE TABLE lim_off_test (id INT, val TEXT); + +statement ok +INSERT INTO lim_off_test VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'), (5, 'e'); + +# 1. Simple Limit +query I +SELECT id FROM lim_off_test ORDER BY id LIMIT 3; +---- +1 +2 +3 + +# 2. Simple Offset +query I +SELECT id FROM lim_off_test ORDER BY id OFFSET 2; +---- +3 +4 +5 + +# 3. Limit and Offset combined +query I +SELECT id FROM lim_off_test ORDER BY id LIMIT 2 OFFSET 1; +---- +2 +3 + +# 4. Limit 0 (should return empty) +query I +SELECT id FROM lim_off_test ORDER BY id LIMIT 0; +---- + +# 5. Offset equal to row count +query I +SELECT id FROM lim_off_test ORDER BY id OFFSET 5; +---- + +# 6. Offset greater than row count +query I +SELECT id FROM lim_off_test ORDER BY id OFFSET 10; +---- + +# 7. Limit greater than available rows +query I +SELECT id FROM lim_off_test ORDER BY id OFFSET 3 LIMIT 10; +---- +4 +5 + +# 8. Limit and Offset with Filter (OFFSET 3 LIMIT 1) +query I +SELECT id FROM lim_off_test WHERE id > 0 ORDER BY id OFFSET 3 LIMIT 1; +---- +4 + +statement ok +DROP TABLE lim_off_test; diff --git a/tests/logic/slt_runner.py b/tests/logic/slt_runner.py index 7c19f61..9fd88ac 100644 --- a/tests/logic/slt_runner.py +++ b/tests/logic/slt_runner.py @@ -1,7 +1,7 @@ import socket import struct import sys -import time +import os import math PROTOCOL_VERSION_3 = 196608 @@ -17,6 +17,7 @@ def connect(self): self.sock.settimeout(5.0) self.sock.connect((self.host, self.port)) + # PostgreSQL Startup Message length = 8 packet = struct.pack('!II', length, PROTOCOL_VERSION_3) self.sock.sendall(packet) @@ -79,7 +80,7 @@ def query(self, sql): idx += col_len rows.append(row_data) elif type_char == 'C': - pass # CommandComplete + status = body.decode('utf-8').strip('\0') elif type_char == 'E': status = "ERROR" elif type_char == 'Z': @@ -87,9 +88,36 @@ def query(self, sql): return rows, status +def normalize_value(val): + if val is None: + return "NULL" + return str(val) + +def compare_values(actual, expected, col_type): + if expected == "NULL": + return actual is None + if actual is None: + return expected == "NULL" + + if col_type == 'R': # Float/Real + try: + return math.isclose(float(actual), float(expected), rel_tol=1e-6) + except (ValueError, TypeError): + return str(actual) == str(expected) + + return str(actual) == str(expected) + def run_slt(file_path, port): client = CloudSQLClient(port=port) - client.connect() + try: + client.connect() + except Exception as e: + print(f"ERROR: Connection failed: {e}") + return False + + if not os.path.exists(file_path): + print(f"ERROR: File not found: {file_path}") + return False with open(file_path, 'r') as f: lines = f.readlines() @@ -104,8 +132,12 @@ def run_slt(file_path, port): line_idx += 1 continue + start_line = line_idx + 1 + if line.startswith('statement'): - expected_status = line.split()[1] # ok or error + parts = line.split() + expected_status = parts[1] # ok or error + sql_lines = [] line_idx += 1 while line_idx < len(lines) and lines[line_idx].strip(): @@ -116,8 +148,11 @@ def run_slt(file_path, port): total_tests += 1 _, actual_status = client.query(sql) - if actual_status.lower() != expected_status.lower(): - print(f"FAILURE at {file_path}:{line_idx}") + is_error = actual_status == "ERROR" + matches = (expected_status == "error" and is_error) or (expected_status == "ok" and not is_error) + + if not matches: + print(f"FAILURE at {file_path}:{start_line}") print(f" SQL: {sql}") print(f" Expected status: {expected_status}, got: {actual_status}") failed_tests += 1 @@ -128,6 +163,10 @@ def run_slt(file_path, port): types = parts[1] sort_mode = parts[2] if len(parts) > 2 else None + if sort_mode and sort_mode not in ['rowsort', 'valuesort']: + print(f"ERROR at {file_path}:{start_line}: Unsupported sort mode '{sort_mode}'") + sys.exit(1) + sql_lines = [] line_idx += 1 while line_idx < len(lines) and lines[line_idx].strip() != '----': @@ -146,36 +185,38 @@ def run_slt(file_path, port): actual_rows, status = client.query(sql) if status == "ERROR": - print(f"FAILURE at {file_path}:{line_idx}") + print(f"FAILURE at {file_path}:{start_line}") print(f" SQL: {sql}") - print(f" Query failed with ERROR status") + print(" Query failed with ERROR status") failed_tests += 1 continue - # Apply sort mode + # Apply SLT sort modes if sort_mode == 'rowsort': actual_rows.sort() expected_rows.sort() elif sort_mode == 'valuesort': - actual_values = sorted([str(val) if val is not None else "NULL" for row in actual_rows for val in row]) - expected_values = sorted([val for row in expected_rows for val in row]) - actual_rows = [[v] for v in actual_values] - expected_rows = [[v] for v in expected_values] - elif sort_mode: - print(f"ERROR: Unsupported sort mode: {sort_mode}") - sys.exit(1) + # Valuesort sorts every individual value in the result set + actual_vals = sorted([normalize_value(v) for row in actual_rows for v in row]) + expected_vals = sorted([v for row in expected_rows for v in row]) + actual_rows = [[v] for v in actual_vals] + expected_rows = [[v] for v in expected_vals] + # Update types to all be 'T' since we flattened everything to strings for valuesort + types = 'T' * len(actual_vals) - # Compare results + # Compare row counts if len(actual_rows) != len(expected_rows): - print(f"FAILURE at {file_path}:{line_idx}") + print(f"FAILURE at {file_path}:{start_line}") print(f" SQL: {sql}") print(f" Expected {len(expected_rows)} rows, got {len(actual_rows)}") + print(f" Actual rows: {actual_rows}") failed_tests += 1 continue + # Compare cell by cell for i in range(len(actual_rows)): if len(actual_rows[i]) != len(expected_rows[i]): - print(f"FAILURE at {file_path}:{line_idx}, row {i}") + print(f"FAILURE at {file_path}:{start_line}, row {i}") print(f" Expected {len(expected_rows[i])} columns, got {len(actual_rows[i])}") failed_tests += 1 break @@ -184,28 +225,18 @@ def run_slt(file_path, port): for j in range(len(actual_rows[i])): act = actual_rows[i][j] exp = expected_rows[i][j] + col_type = types[j] if j < len(types) else 'T' - if exp == "NULL" and act is None: - continue - - # Basic numeric normalization for float comparison - if types[j] == 'R' and sort_mode != 'valuesort': - try: - if not math.isclose(float(act), float(exp), rel_tol=1e-6): - match = False - except: - match = False - else: - if str(act) != str(exp): - match = False - - if not match: - print(f"FAILURE at {file_path}:{line_idx}, row {i} col {j}") - print(f" Expected '{exp}', got '{act}'") - failed_tests += 1 + if not compare_values(act, exp, col_type): + print(f"FAILURE at {file_path}:{start_line}, row {i} col {j}") + print(f" SQL: {sql}") + print(f" Expected '{exp}', got '{normalize_value(act)}'") + print(f" Full row: {[normalize_value(v) for v in actual_rows[i]]}") + match = False break - if not match: break - + if not match: + failed_tests += 1 + break else: line_idx += 1 diff --git a/tests/logic/transactions.slt b/tests/logic/transactions.slt new file mode 100644 index 0000000..7f595ac --- /dev/null +++ b/tests/logic/transactions.slt @@ -0,0 +1,85 @@ +# Transaction Tests + +statement ok +CREATE TABLE txn_test (id INT, val TEXT); + +# 1. Basic Commit +statement ok +BEGIN; + +statement ok +INSERT INTO txn_test VALUES (1, 'commit_me'); + +query IT +SELECT id, val FROM txn_test WHERE id = 1; +---- +1 commit_me + +statement ok +COMMIT; + +query IT +SELECT id, val FROM txn_test WHERE id = 1; +---- +1 commit_me + +# 2. Basic Rollback +statement ok +BEGIN; + +statement ok +INSERT INTO txn_test VALUES (2, 'rollback_me'); + +query IT +SELECT id, val FROM txn_test WHERE id = 2; +---- +2 rollback_me + +statement ok +ROLLBACK; + +query IT +SELECT id, val FROM txn_test WHERE id = 2; +---- + +# 3. Visibility of Updates within Transaction +statement ok +BEGIN; + +statement ok +UPDATE txn_test SET val = 'updated' WHERE id = 1; + +query IT +SELECT id, val FROM txn_test WHERE id = 1; +---- +1 updated + +statement ok +ROLLBACK; + +query IT +SELECT id, val FROM txn_test WHERE id = 1; +---- +1 commit_me + +# 4. Delete within transaction and rollback +statement ok +BEGIN; + +statement ok +DELETE FROM txn_test WHERE id = 1; + +query IT +SELECT * FROM txn_test; +---- + +statement ok +ROLLBACK; + +query IT +SELECT id, val FROM txn_test; +---- +1 commit_me + +statement ok +DROP TABLE txn_test;