package io.getquill.sources.finagle.mysql import com.twitter.finagle.exp.mysql._ import com.twitter.util.{Await, Future, Local} import com.typesafe.scalalogging.Logger import io.getquill.FinagleMysqlSourceConfig import io.getquill.naming.NamingStrategy import io.getquill.sources.BindedStatementBuilder import io.getquill.sources.sql.SqlSource import io.getquill.sources.sql.idiom.MySQLDialect import org.slf4j.LoggerFactory import scala.util.Try abstract class FinagleMysqlSource[N <: NamingStrategy](config: FinagleMysqlSourceConfig[N]) extends SqlSource[MySQLDialect, N, Row, BindedStatementBuilder[List[Parameter]]] with FinagleMysqlDecoders with FinagleMysqlEncoders { protected val logger: Logger = Logger(LoggerFactory.getLogger(classOf[FinagleMysqlSource[_]])) type QueryResult[T] = Future[List[T]] type ActionResult[T] = Future[Result] type BatchedActionResult[T] = Future[List[Result]] class ActionApply[T](f: List[T] => Future[List[Result]]) extends Function1[List[T], Future[List[Result]]] { def apply(params: List[T]) = f(params) def apply(param: T) = f(List(param)).map(_.head) } private[mysql] def dateTimezone = config.dateTimezone protected val client = config.client Await.result(client.ping) override def close = Await.result(client.close()) private val currentClient = new Local[Client] def probe(sql: String) = Try(Await.result(client.query(sql))) def transaction[T](f: FinagleMysqlSource[N] => Future[T]) def execute(sql: String, bind: BindedStatementBuilder[List[Parameter]] => BindedStatementBuilder[List[Parameter]], generated: Option[String] = None): Future[Result] = { val (expanded, params) = bind(new BindedStatementBuilder).build(sql) logger.info(expanded) withClient(_.prepare(expanded)(params(List()): _*)) } def executeBatch[T](sql: String, bindParams: T => BindedStatementBuilder[List[Parameter]] => BindedStatementBuilder[List[Parameter]], generated: Option[String] = None): ActionApply[T] = { def run(values: List[T]): Future[List[Result]] = values match { case Nil => Future.value(List()) case value :: tail => val (expanded, params) = bindParams(value)(new BindedStatementBuilder).build(sql) logger.info(expanded) withClient(_.prepare(expanded)(params(List()): _*)) .flatMap(r => run(tail).map(r +: _)) } new ActionApply(run _) } def query[T](sql: String, bind: BindedStatementBuilder[List[Parameter]] => BindedStatementBuilder[List[Parameter]], extractor: Row => T): Future[List[T]] = { val (expanded, params) = bind(new BindedStatementBuilder).build(sql) logger.info(expanded) withClient(_.prepare(expanded).select(params(List()): _*)(extractor)).map(_.toList) } private def withClient[T](f: Client => T) = currentClient().map { client => f(client) }.getOrElse { f(client) } }