Skip to content

Instantly share code, notes, and snippets.

@kalexmills
Last active December 23, 2019 17:01
Show Gist options
  • Select an option

  • Save kalexmills/4e2724ffa5ba7d1f90fdf7f0de9242e0 to your computer and use it in GitHub Desktop.

Select an option

Save kalexmills/4e2724ffa5ba7d1f90fdf7f0de9242e0 to your computer and use it in GitHub Desktop.

Revisions

  1. kalexmills revised this gist Dec 23, 2019. 1 changed file with 8 additions and 4 deletions.
    12 changes: 8 additions & 4 deletions MultiSetSpec.scala
    Original file line number Diff line number Diff line change
    @@ -84,19 +84,23 @@ class MultiSetSpec extends FlatSpec with Matchers {
    }

    it should "admit non-deterministic monadic computations" in {
    val x = MultiSet(Set(1,2,3))
    var x = MultiSet(Set(1,2,3))

    x.flatMap(x => MultiSet().addMany(x, 2)).iterator.toList.sorted shouldEqual (List(1,1,2,2,3,3))

    x = MultiSet(Set(1,3,5))

    x.flatMap(x => MultiSet(x, x+1)).iterator.toList.sorted shouldEqual (List(1,2,3,4,5,6))
    }
    it should "multiply existing elements when asked" in {
    val x = MultiSet(Set('a','b','c')).mult(3)

    x.iterator.toList.sorted shouldEqual (List('a','a','a','b','b','b','c','c','c'))
    }

    it should "admit non-deterministic computations" in {
    val x = MultiSet(Set(1,3,5))
    it should "allow unary multplication" in {
    val x = MultiSet().addMany(1, 3)

    x.flatMap(x => MultiSet(Set(x, x+1))).iterator.toList.sorted shouldEqual (List(1,2,3,4,5,6))
    x.flatMap(x => MultiSet(x,x)).iterator.toList.sorted shouldEqual(List(1,1,1,1,1,1))
    }
    }
  2. kalexmills revised this gist Dec 23, 2019. 1 changed file with 0 additions and 1 deletion.
    1 change: 0 additions & 1 deletion MultiSet.scala
    Original file line number Diff line number Diff line change
    @@ -145,7 +145,6 @@ object MultiSet {
    implicit val monadForMultiset = new Monad[MultiSet] {
    def pure[A](x: A): MultiSet[A] = MultiSet(x)

    // TODO: actually -- count needs to be multiplied for this to be lawful
    def flatMap[A, B](fa: MultiSet[A])(f: A => MultiSet[B]): MultiSet[B] =
    fa.foldLeft(MultiSet[B]())((set, a) => set.union(f(a).mult(fa.multiplicity(a))))

  3. kalexmills revised this gist Dec 23, 2019. No changes.
  4. kalexmills revised this gist Dec 23, 2019. 2 changed files with 49 additions and 0 deletions.
    31 changes: 31 additions & 0 deletions MultiSet.scala
    Original file line number Diff line number Diff line change
    @@ -2,6 +2,7 @@ package com.niftysoft.gennit.util

    import cats._
    import cats.implicits._
    import scala.annotation.tailrec

    case class MultiSet[V] private (data: Map[V,Int]) {

    @@ -11,6 +12,11 @@ case class MultiSet[V] private (data: Map[V,Int]) {

    def contains(elem: V): Boolean = data.contains(elem)

    def mult(factor: Int): MultiSet[V] =
    MultiSet(
    data.map{case (x -> count) => (x -> count * factor)}
    )

    def addMany(elem: V, num: Int): MultiSet[V] =
    MultiSet(
    data + (elem -> (multiplicity(elem) + num))
    @@ -135,4 +141,29 @@ object MultiSet {
    def foldRight[A, B](fa: MultiSet[A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] =
    fa.iterator.toList.foldRight(lb)(f)
    }

    implicit val monadForMultiset = new Monad[MultiSet] {
    def pure[A](x: A): MultiSet[A] = MultiSet(x)

    // TODO: actually -- count needs to be multiplied for this to be lawful
    def flatMap[A, B](fa: MultiSet[A])(f: A => MultiSet[B]): MultiSet[B] =
    fa.foldLeft(MultiSet[B]())((set, a) => set.union(f(a).mult(fa.multiplicity(a))))

    def tailRecM[A, B](a: A)(f: A => MultiSet[Either[A,B]]): MultiSet[B] = {
    var buf = MultiSet[B]()
    @tailrec
    def go(sets: List[MultiSet[Either[A,B]]]): Unit = sets match {
    case set :: tail => set.data.toList match {
    case (x -> count) :: rest => x match {
    case Right(b) => buf.addMany(b, count); go(MultiSet(rest.toMap) :: tail)
    case Left(a) => go(f(a) :: MultiSet(rest.toMap) :: tail)
    }
    case Nil => go(tail)
    }
    case Nil => ()
    }
    go(f(a) :: Nil)
    buf
    }
    }
    }
    18 changes: 18 additions & 0 deletions MultiSetSpec.scala
    Original file line number Diff line number Diff line change
    @@ -1,6 +1,7 @@
    package com.niftysoft.gennit.util

    import org.scalatest._
    import cats.implicits._

    class MultiSetSpec extends FlatSpec with Matchers {
    "MultiSet" should "work on empty set" in {
    @@ -81,4 +82,21 @@ class MultiSetSpec extends FlatSpec with Matchers {

    x.iterator.toList.sorted shouldEqual (List(1,2,3))
    }

    it should "admit non-deterministic monadic computations" in {
    val x = MultiSet(Set(1,2,3))

    x.flatMap(x => MultiSet().addMany(x, 2)).iterator.toList.sorted shouldEqual (List(1,1,2,2,3,3))
    }
    it should "multiply existing elements when asked" in {
    val x = MultiSet(Set('a','b','c')).mult(3)

    x.iterator.toList.sorted shouldEqual (List('a','a','a','b','b','b','c','c','c'))
    }

    it should "admit non-deterministic computations" in {
    val x = MultiSet(Set(1,3,5))

    x.flatMap(x => MultiSet(Set(x, x+1))).iterator.toList.sorted shouldEqual (List(1,2,3,4,5,6))
    }
    }
  5. kalexmills created this gist Dec 15, 2019.
    138 changes: 138 additions & 0 deletions MultiSet.scala
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,138 @@
    package com.niftysoft.gennit.util

    import cats._
    import cats.implicits._

    case class MultiSet[V] private (data: Map[V,Int]) {

    def filter(f: V => Boolean): MultiSet[V] = MultiSet(data.filter{case(v, mul) => f(v)})

    def multiplicity(elem: V): Int = data.getOrElse(elem, 0)

    def contains(elem: V): Boolean = data.contains(elem)

    def addMany(elem: V, num: Int): MultiSet[V] =
    MultiSet(
    data + (elem -> (multiplicity(elem) + num))
    )

    def excl(elem: V): MultiSet[V] =
    MultiSet(
    if(multiplicity(elem) == 0) {
    data
    } else if (multiplicity(elem) == 1) {
    data - elem
    } else {
    data + (elem -> (multiplicity(elem) - 1))
    }
    )

    def exclAll(elem: V): MultiSet[V] =
    MultiSet(
    data - elem
    )

    def incl(elem: V): MultiSet[V] = addMany(elem, 1)

    def diff(other: Set[V]): MultiSet[V] =
    diff(MultiSet(other))

    def diff(other: Seq[V]): MultiSet[V] =
    diff(MultiSet(other:_*))

    def diff(other: MultiSet[V]): MultiSet[V] =
    MultiSet(
    data.map{case (v,mul) => (v, mul - other.multiplicity(v))}
    .filter{case (v, mul) => mul > 0}
    )

    def sum(other: Set[V]): MultiSet[V] =
    sum(MultiSet(other))

    def sum(other: Seq[V]): MultiSet[V] =
    sum(MultiSet(other:_*))

    def sum(other: MultiSet[V]): MultiSet[V] =
    MultiSet(
    data.map{case (v, mul) => (v, other.multiplicity(v) + mul)} ++
    (other.data -- data.keySet)
    )

    def union(other: Seq[V]): MultiSet[V] =
    union(MultiSet(other:_*))

    def union(other: Set[V]): MultiSet[V] =
    union(MultiSet(other))

    def union(other: MultiSet[V]): MultiSet[V] =
    MultiSet(
    data.map{case (v, mul) => (v, Math.max(other.multiplicity(v), mul))} ++
    (other.data -- data.keySet))

    def intersect(other: Seq[V]): MultiSet[V] =
    intersect(MultiSet(other:_*))

    def intersect(other: MultiSet[V]): MultiSet[V] =
    MultiSet(
    data.map{case (v, mul) => (v, Math.min(other.multiplicity(v), mul))}
    .filter{case (v, mul) => mul > 0}
    )

    def toList: List[V] = iterator.toList

    def iterator: Iterator[V] = new Iterator[V] {
    private[this] val keys = data.keysIterator
    private[this] var curr: Option[V] = if (keys.hasNext) Some(keys.next()) else None
    private[this] var valLeft: Int = currMult()

    def hasNext: Boolean = keys.hasNext || valLeft > 0

    def next(): V =
    if (valLeft > 0) {
    valLeft -= 1
    curr.get
    } else {
    curr = Some(keys.next()) // throws NoSuchElementException as needed
    valLeft = currMult() - 1
    curr.get
    }
    private[this] def currMult(): Int = {
    curr.map(data(_)).getOrElse(0)
    }
    }
    override def equals(o: Any): Boolean = {
    o match {
    case ms @ MultiSet(data) => this.data.equals(data)
    case _ => false
    }
    }
    override def toString(): String = {
    data.toList.map{case (x, count) =>
    List(x.toString)
    .replicateA(count)
    .flatten
    .intercalate(", ")}
    .intercalate(", ")
    }
    }

    object MultiSet {
    def apply[A](): MultiSet[A] = new MultiSet(Map())
    def apply[A](x: A*): MultiSet[A] = new MultiSet(x.groupBy(identity).map{case (v, s) => (v, s.length)})
    def apply[A](x: Set[A]): MultiSet[A] = new MultiSet(x.map{x => (x -> 1)}.toMap)

    implicit val functorForMultiset = new Functor[MultiSet] {
    def map[A, B](fa: MultiSet[A])(f: A => B): MultiSet[B] =
    MultiSet(fa.data.map{case(a, mul) => (f(a), mul)})
    }

    implicit val foldableForMultiset = new Foldable[MultiSet] {
    import cats._
    import cats.implicits._
    def foldLeft[A, B](fa: MultiSet[A], b: B)(f: (B, A) => B): B =
    fa.iterator.foldLeft(b)(f)

    def foldRight[A, B](fa: MultiSet[A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] =
    fa.iterator.toList.foldRight(lb)(f)
    }
    }
    84 changes: 84 additions & 0 deletions MultiSetSpec.scala
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,84 @@
    package com.niftysoft.gennit.util

    import org.scalatest._

    class MultiSetSpec extends FlatSpec with Matchers {
    "MultiSet" should "work on empty set" in {
    val x: MultiSet[Int] = MultiSet()

    x.contains(1) shouldEqual (false)
    x.iterator.toList shouldEqual (List.empty)
    }

    it should "sum arguments to apply correctly" in {
    val x: MultiSet[Int] = MultiSet(1,1)

    x.multiplicity(1) shouldEqual (2)
    x.iterator.toList shouldEqual (List(1,1))
    }

    it should "remove things when excl is called" in {
    val x: MultiSet[Int] = MultiSet().addMany(1,2)

    x.excl(1).iterator.toList shouldEqual (List(1))
    }

    it should "admit multiple elements" in {
    val x: MultiSet[Int] = MultiSet().addMany(2, 6)

    x.multiplicity(2) shouldEqual (6)
    x.contains(2) shouldEqual (true)
    x.iterator.toList shouldEqual (List(2,2,2,2,2,2))
    }

    it should "implement difference" in {
    val x = MultiSet().addMany(0, 4).addMany(1,3)
    val y = MultiSet().addMany(0, 2).addMany(1,5)

    val xdiffy = x.diff(y)
    val ydiffx = y.diff(x)

    xdiffy.iterator.toList.sorted shouldEqual (List(0,0))
    ydiffx.iterator.toList.sorted shouldEqual (List(1,1))
    }

    it should "implement sums" in {
    val x = MultiSet().addMany(0, 4).addMany(1,3)
    val y = MultiSet().addMany(0, 2).addMany(1,5)

    val xplusy = x.sum(y)

    xplusy.iterator.toList.sorted shouldEqual (List(0,0,0,0,0,0,1,1,1,1,1,1,1,1))
    }

    it should "admit unions with sets" in {
    val x = MultiSet().addMany(0,4)
    val y = Set(1,2,3)

    x.union(y).iterator.toList.sorted shouldEqual (List(0,0,0,0,1,2,3))
    }

    it should "implement unions" in {
    val x = MultiSet().addMany(0, 4).addMany(1,3)
    val y = MultiSet().addMany(0, 2).addMany(1,5)

    val xuy = x.union(y)

    xuy.iterator.toList.sorted shouldEqual (List(0,0,0,0,1,1,1,1,1))
    }

    it should "implement intersections" in {
    val x = MultiSet().addMany(0, 4).addMany(1,3)
    val y = MultiSet().addMany(0, 2).addMany(1,5)

    val xny = x.intersect(y)

    xny.iterator.toList.sorted shouldEqual (List(0,0,1,1,1))
    }

    it should "be creatable from sets" in {
    val x = MultiSet(Set(1,2,3))

    x.iterator.toList.sorted shouldEqual (List(1,2,3))
    }
    }