diff --git a/lib/spoom/rbs.rb b/lib/spoom/rbs.rb index 81e40f5b..acb7a512 100644 --- a/lib/spoom/rbs.rb +++ b/lib/spoom/rbs.rb @@ -105,7 +105,7 @@ def node_rbs_comments(node) location = location.join(continuation_comment.location) end continuation_comments.clear - res.signatures << Signature.new(string, location) + res.signatures.prepend(Signature.new(string, location)) elsif string.start_with?("#|") continuation_comments << comment end diff --git a/lib/spoom/sorbet/translate.rb b/lib/spoom/sorbet/translate.rb index 74cac87a..bf815176 100644 --- a/lib/spoom/sorbet/translate.rb +++ b/lib/spoom/sorbet/translate.rb @@ -53,9 +53,14 @@ def sorbet_sigs_to_rbs_comments( # Converts all the RBS comments in the given Ruby code to `sig` nodes. # It also handles type members and class annotations. - #: (String ruby_contents, file: String, ?max_line_length: Integer?) -> String - def rbs_comments_to_sorbet_sigs(ruby_contents, file:, max_line_length: nil) - RBSCommentsToSorbetSigs.rewrite_if_needed(ruby_contents, file: file, max_line_length: max_line_length) + #: (String ruby_contents, file: String, ?max_line_length: Integer?, ?overloads_strategy: Symbol) -> String + def rbs_comments_to_sorbet_sigs(ruby_contents, file:, max_line_length: nil, overloads_strategy: :translate_all) + RBSCommentsToSorbetSigs.rewrite_if_needed( + ruby_contents, + file: file, + max_line_length: max_line_length, + overloads_strategy: overloads_strategy, + ) end # Converts all `T.let` and `T.cast` nodes to RBS comments in the given Ruby code. diff --git a/lib/spoom/sorbet/translate/rbs_comments_to_sorbet_sigs.rb b/lib/spoom/sorbet/translate/rbs_comments_to_sorbet_sigs.rb index 3929a2fb..989a9061 100644 --- a/lib/spoom/sorbet/translate/rbs_comments_to_sorbet_sigs.rb +++ b/lib/spoom/sorbet/translate/rbs_comments_to_sorbet_sigs.rb @@ -20,25 +20,33 @@ class RBSCommentsToSorbetSigs < Translator RBS_REWRITE_PATTERN = Regexp.union(["#:", "#|", *RBS_ANNOTATION_MARKERS]).freeze #: Regexp private_constant :RBS_ANNOTATION_MARKERS, :RBS_REWRITE_PATTERN + ALLOWED_OVERLOAD_STRATEGIES = [:translate_all, :translate_last, :raise].freeze #: Array[Symbol] + class << self #: (String source) -> bool def contains_rbs_syntax?(source) Sigils.contains_valid_sigil?(source) && source.match?(RBS_REWRITE_PATTERN) end - #: (String ruby_contents, file: String, ?max_line_length: Integer?) -> String - def rewrite_if_needed(ruby_contents, file:, max_line_length: nil) + #: (String ruby_contents, file: String, ?max_line_length: Integer?, ?overloads_strategy: Symbol) -> String + def rewrite_if_needed(ruby_contents, file:, max_line_length: nil, overloads_strategy: :translate_all) return ruby_contents unless contains_rbs_syntax?(ruby_contents) - new(ruby_contents, file:, max_line_length:).rewrite + new(ruby_contents, file:, max_line_length:, overloads_strategy:).rewrite end end - #: (String, file: String, ?max_line_length: Integer?) -> void - def initialize(ruby_contents, file:, max_line_length: nil) + #: (String, file: String, ?max_line_length: Integer?, ?overloads_strategy: Symbol) -> void + def initialize(ruby_contents, file:, max_line_length: nil, overloads_strategy: :translate_all) super(ruby_contents, file: file) + unless ALLOWED_OVERLOAD_STRATEGIES.include?(overloads_strategy) + raise ArgumentError, "Unknown overloads_strategy: #{overloads_strategy.inspect}. " \ + "Must be one of: #{ALLOWED_OVERLOAD_STRATEGIES.map(&:inspect).join(", ")}" + end + @max_line_length = max_line_length + @overloads_strategy = overloads_strategy end # @override @@ -107,7 +115,13 @@ def visit_attr(node) return if comments.signatures.empty? - comments.signatures.each do |signature| + signatures = apply_overloads_strategy( + comments.signatures, + method_name: node.message.to_s, + location: "#{@file}:#{node.location.start_line}", + ) + + signatures.each do |signature| attr_type = ::RBS::Parser.parse_type(signature.string) sig = RBI::Sig.new @@ -143,11 +157,17 @@ def rewrite_def(def_node, comments) return if comments.empty? return if comments.signatures.empty? + signatures = apply_overloads_strategy( + comments.signatures, + method_name: def_node.name.to_s, + location: "#{@file}:#{def_node.location.start_line}", + ) + builder = RBI::Parser::TreeBuilder.new(@ruby_contents, comments: [], file: @file) builder.visit(def_node) rbi_node = builder.tree.nodes.first #: as RBI::Method - comments.signatures.each do |signature| + signatures.each do |signature| begin method_type = ::RBS::Parser.parse_method_type(signature.string) rescue ::RBS::ParsingError @@ -180,6 +200,29 @@ def rewrite_def(def_node, comments) end end + #: (Array[RBS::Signature], method_name: String, location: String) -> Array[RBS::Signature] + def apply_overloads_strategy(signatures, method_name:, location:) + return signatures if signatures.size <= 1 + + case @overloads_strategy + when :translate_all + signatures + when :translate_last + kept = signatures.last #: as RBS::Signature + others = signatures[0...-1] #: as !nil + + # Delete all the signatures we didn't keep + others.each do |signature| + from = adjust_to_line_start(signature.location.start_offset) + to = adjust_to_line_end(signature.location.end_offset) + @rewriter << Source::Delete.new(from, to) + end + [kept] + else # :raise + raise Error, "Method `#{method_name}` at #{location} has multiple overloaded signatures" + end + end + #: (Prism::ClassNode | Prism::ModuleNode | Prism::SingletonClassNode) -> void def apply_class_annotations(node) comments = node_rbs_comments(node) diff --git a/rbi/spoom.rbi b/rbi/spoom.rbi index d8f8b9f2..8daf93c6 100644 --- a/rbi/spoom.rbi +++ b/rbi/spoom.rbi @@ -2854,8 +2854,15 @@ Spoom::Sorbet::Sigils::VALID_STRICTNESS = T.let(T.unsafe(nil), Array) module Spoom::Sorbet::Translate class << self - sig { params(ruby_contents: ::String, file: ::String, max_line_length: T.nilable(::Integer)).returns(::String) } - def rbs_comments_to_sorbet_sigs(ruby_contents, file:, max_line_length: T.unsafe(nil)); end + sig do + params( + ruby_contents: ::String, + file: ::String, + max_line_length: T.nilable(::Integer), + overloads_strategy: ::Symbol + ).returns(::String) + end + def rbs_comments_to_sorbet_sigs(ruby_contents, file:, max_line_length: T.unsafe(nil), overloads_strategy: T.unsafe(nil)); end sig do params( @@ -2893,8 +2900,15 @@ class Spoom::Sorbet::Translate::Error < ::Spoom::Error; end class Spoom::Sorbet::Translate::RBSCommentsToSorbetSigs < ::Spoom::Sorbet::Translate::Translator include ::Spoom::RBS::ExtractRBSComments - sig { params(ruby_contents: ::String, file: ::String, max_line_length: T.nilable(::Integer)).void } - def initialize(ruby_contents, file:, max_line_length: T.unsafe(nil)); end + sig do + params( + ruby_contents: ::String, + file: ::String, + max_line_length: T.nilable(::Integer), + overloads_strategy: ::Symbol + ).void + end + def initialize(ruby_contents, file:, max_line_length: T.unsafe(nil), overloads_strategy: T.unsafe(nil)); end sig { override.params(node: ::Prism::CallNode).void } def visit_call_node(node); end @@ -2930,6 +2944,15 @@ class Spoom::Sorbet::Translate::RBSCommentsToSorbetSigs < ::Spoom::Sorbet::Trans sig { params(annotations: T::Array[::Spoom::RBS::Annotation], sig: ::RBI::Sig).void } def apply_member_annotations(annotations, sig); end + sig do + params( + signatures: T::Array[::Spoom::RBS::Signature], + method_name: ::String, + location: ::String + ).returns(T::Array[::Spoom::RBS::Signature]) + end + def apply_overloads_strategy(signatures, method_name:, location:); end + sig { params(comments: T::Array[::Prism::Comment]).void } def apply_type_aliases(comments); end @@ -2946,11 +2969,19 @@ class Spoom::Sorbet::Translate::RBSCommentsToSorbetSigs < ::Spoom::Sorbet::Trans sig { params(source: ::String).returns(T::Boolean) } def contains_rbs_syntax?(source); end - sig { params(ruby_contents: ::String, file: ::String, max_line_length: T.nilable(::Integer)).returns(::String) } - def rewrite_if_needed(ruby_contents, file:, max_line_length: T.unsafe(nil)); end + sig do + params( + ruby_contents: ::String, + file: ::String, + max_line_length: T.nilable(::Integer), + overloads_strategy: ::Symbol + ).returns(::String) + end + def rewrite_if_needed(ruby_contents, file:, max_line_length: T.unsafe(nil), overloads_strategy: T.unsafe(nil)); end end end +Spoom::Sorbet::Translate::RBSCommentsToSorbetSigs::ALLOWED_OVERLOAD_STRATEGIES = T.let(T.unsafe(nil), Array) Spoom::Sorbet::Translate::RBSCommentsToSorbetSigs::RBS_ANNOTATION_MARKERS = T.let(T.unsafe(nil), Array) Spoom::Sorbet::Translate::RBSCommentsToSorbetSigs::RBS_REWRITE_PATTERN = T.let(T.unsafe(nil), Regexp) diff --git a/test/spoom/sorbet/translate/rbs_comments_to_sorbet_sigs_test.rb b/test/spoom/sorbet/translate/rbs_comments_to_sorbet_sigs_test.rb index 38b3330d..77342b9b 100644 --- a/test/spoom/sorbet/translate/rbs_comments_to_sorbet_sigs_test.rb +++ b/test/spoom/sorbet/translate/rbs_comments_to_sorbet_sigs_test.rb @@ -678,6 +678,72 @@ class Range RB end + def test_translate_overloads_translate_all_is_default + contents = <<~RB + class Foo + #: () { (Integer) -> void } -> void + #: () -> Enumerator[Integer, void] + def each(&block); end + end + RB + + assert_equal(<<~RB, rbs_comments_to_sorbet_sigs(contents)) + class Foo + sig { params(block: ::T.proc.params(arg0: Integer).void).void } + sig { returns(::T::Enumerator[Integer, void]) } + def each(&block); end + end + RB + end + + def test_translate_overloads_translate_last + contents = <<~RB + class Foo + #: () { (Integer) -> void } -> void + #: () -> Enumerator[Integer, void] + def each(&block); end + end + RB + + assert_equal(<<~RB, rbs_comments_to_sorbet_sigs(contents, overloads_strategy: :translate_last)) + class Foo + sig { returns(::T::Enumerator[Integer, void]) } + def each(&block); end + end + RB + end + + def test_translate_overloads_raise + contents = <<~RB + class Foo + #: () { (Integer) -> void } -> void + #: () -> Enumerator[Integer, void] + def each(&block); end + end + RB + + error = assert_raises(Translate::Error) do + rbs_comments_to_sorbet_sigs(contents, overloads_strategy: :raise) + end + assert_equal("Method `each` at test.rb:4 has multiple overloaded signatures", error.message) + end + + def test_translate_overloads_single_signature_unaffected + contents = <<~RB + class Foo + #: () -> void + def foo; end + end + RB + + assert_equal(<<~RB, rbs_comments_to_sorbet_sigs(contents, overloads_strategy: :translate_last)) + class Foo + sig { void } + def foo; end + end + RB + end + def test_contains_rbs_syntax_returns_true_for_supported_rbs_annotations [ "# @abstract", @@ -815,9 +881,14 @@ def foo; end private - #: (String, ?max_line_length: Integer?) -> String - def rbs_comments_to_sorbet_sigs(ruby_contents, max_line_length: nil) - RBSCommentsToSorbetSigs.new(ruby_contents, file: "test.rb", max_line_length: max_line_length).rewrite + #: (String, ?max_line_length: Integer?, ?overloads_strategy: Symbol) -> String + def rbs_comments_to_sorbet_sigs(ruby_contents, max_line_length: nil, overloads_strategy: :translate_all) + RBSCommentsToSorbetSigs.new( + ruby_contents, + file: "test.rb", + max_line_length: max_line_length, + overloads_strategy: overloads_strategy, + ).rewrite end end end