module EnumerableExtensions # An exception raised when an invalid number of entries is returned class InvalidCountError < StandardError # Initialize an exception to report an invalid enumerable count # # @param expectation [String] # @param entries [Array] # # @return [undefined] # # @api public def initialize(expectation, entries) super( 'Found %{count}, expected %{expectation}' % { count: entries.count, expectation: expectation } ) end end # InvalidCountError # Object to represent undefined arguments Undefined = Object.new.freeze # Return exactly one entry from the enumerable # # @param default [Object] # # @yield [count, entries] # # @yieldparam [Integer] count # @yieldparam [Enumerable] entries # # @yieldreturn [Object] # return the default from the block, if provided # # @return [Object] # returned if exactly one entry is found # # @raise [InvalidCountError] # raised if zero or more than one entry is found and there is no default # # @api public def one(default = Undefined) block = -> (*block_args) { [yield(*block_args)] } if block_given? result = if block || default.equal?(Undefined) exactly(1, default, &block) else exactly(1, [default]) end result.fetch(0) end # Return one or more entries from the enumerable # # @return [Array] # returned if one or more entries # # @raise [InvalidCountError] # raised if zero entries are found # # @api public def min_one entries = to_a fail InvalidCountError.new('one or more', entries) if entries.none? entries end # Return zero or one entry from the enumerable # # @return [Object] # returned if zero or one entry is found # # @raise [InvalidCountError] # raised if more than one entry is ound # # @api public def max_one entries = take(2).to_a fail InvalidCountError.new('zero or one', entries) if entries.many? entries.first end # Return an exact number of entries from the enumerable # # @param count [Integer] # @param default [Object] # # @yield [count, entries] # # @yieldparam [Integer] count # @yieldparam [Enumerable] entries # # @yieldreturn [Object] # return the default from the block, if provided # # @return [Enumerable] # returned if exactly one entry is found # # @raise [InvalidCountError] # raised if an invalid number of entries is found # # @api public def exactly(count, default = Undefined, &block) assert_default_or_block(default, &block) entries = take(count.succ).to_a return entries if entries.count.equal?(count) return default unless default.equal?(Undefined) block ||= -> (*args) { fail(InvalidCountError.new(*args)) } block.call(count, self) end private # Assert that a block and default argument cannot be provided together # # @param default [Object] # # @raise [ArgumentError] # raised if a block and default value are provided # # @api private def assert_default_or_block(default) return unless block_given? && !default.equal?(Undefined) fail ArgumentError, 'Must pass in a block or a default argument, not both' end end # EnumerableExtensions ActiveRecord::Base.extend(EnumerableExtensions) Array.module_eval { include EnumerableExtensions }