use crate::{ domain::{Error, Unit}, service::Cache as CacheConfig, }; use async_trait::async_trait; use deadpool_redis::{ redis::{ cmd, from_redis_value, AsyncCommands, FromRedisValue, JsonAsyncCommands, RedisError, RedisResult, RedisWrite, ToRedisArgs, }, Connection, Pool, PoolConfig, Runtime, }; use futures::Future; use secrecy::ExposeSecret; use serde::{Deserialize, Serialize}; use serde_json::Value; #[async_trait] pub trait CacheExt { async fn get_or_insert_with( &self, key: &str, f: F, expire: Option, ) -> Result where T: for<'de> Deserialize<'de> + Serialize + Send + Clone, F: FnOnce() -> Fut + Send, Fut: Future> + Send; async fn get(&self, key: &str) -> Result, Error> where T: for<'de> Deserialize<'de>; async fn set(&self, key: &str, value: T, expire: Option) -> Result where T: Serialize + Send; async fn remove(&self, key: &str) -> Result; async fn clear(&self) -> Result; } #[derive(Debug, Clone)] struct CacheEntry { key: String, value: Option, } impl CacheEntry { pub fn new(key: String, value: Value) -> Self { Self { key, value: Some(ValueWrapper(value)), } } fn get_as(&self) -> Option where T: for<'de> Deserialize<'de>, { match &self.value { Some(value) => serde_json::from_value(value.0.clone()).ok(), None => None, } } } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(transparent)] pub struct ValueWrapper(Value); impl ToRedisArgs for ValueWrapper { fn write_redis_args(&self, out: &mut W) where W: ?Sized + RedisWrite, { let json_str = serde_json::to_string(&self).unwrap(); out.write_arg(json_str.as_bytes()); } } impl FromRedisValue for ValueWrapper { fn from_redis_value(v: &deadpool_redis::redis::Value) -> RedisResult { let json_str: Option = from_redis_value(v)?; match json_str { Some(json_str) => { let json: Value = serde_json::from_str(json_str.as_str()).map_err(|error| { RedisError::from(std::io::Error::new( std::io::ErrorKind::InvalidData, error.to_string(), )) })?; Ok(Self(json)) } None => Ok(Self(Value::Null)), } } } impl ToRedisArgs for CacheEntry { fn write_redis_args(&self, out: &mut W) where W: ?Sized + RedisWrite, { let json = serde_json::json!( { "key": self.key, "value": self.value } ); let json_str = json.to_string(); out.write_arg(json_str.as_bytes()); } } impl FromRedisValue for CacheEntry { fn from_redis_value(v: &deadpool_redis::redis::Value) -> RedisResult { let json_str: Option = from_redis_value(v)?; match json_str { Some(json_str) => { let json: Value = serde_json::from_str(json_str.as_str()).map_err(|error| { RedisError::from(std::io::Error::new( std::io::ErrorKind::InvalidData, error.to_string(), )) })?; let key = json.get("key").ok_or(RedisError::from(std::io::Error::new( std::io::ErrorKind::InvalidData, "key not found", )))?; let key = key.as_str().ok_or(RedisError::from(std::io::Error::new( std::io::ErrorKind::InvalidData, "key not found", )))?; let key = key.to_string(); let value = json .get("value") .ok_or(RedisError::from(std::io::Error::new( std::io::ErrorKind::InvalidData, "value not found", )))?; let value = value.clone(); Ok(Self { key, value: Some(ValueWrapper(value)), }) } None => Ok(Self { key: Default::default(), value: None, }), } } } pub struct Cache { pool: Pool, } impl Cache { pub fn new(configuration: &CacheConfig) -> Result { let redis = deadpool_redis::Config { url: Some(configuration.url().expose_secret().into()), connection: None, pool: Some(PoolConfig { max_size: configuration.pool_size(), timeouts: configuration.timeouts(), ..Default::default() }), }; let pool = redis.create_pool(Some(Runtime::Tokio1)).map_err(|error| { Error::cache_generic( "There was an error with the configuration", None, None, error.into(), ) })?; Ok(Self { pool }) } pub async fn connection(&self) -> Result { let connection = self.pool.get().await.map_err(|error| { Error::cache_generic( "There was an error with the connection", None, None, error.into(), ) })?; Ok(connection) } } #[async_trait] impl CacheExt for Cache { #[tracing::instrument(name = "cache::get_or_insert_with", skip(self, key, f))] async fn get_or_insert_with( &self, key: &str, f: F, expire: Option, ) -> Result where T: for<'de> Deserialize<'de> + Serialize + Send + Clone, F: FnOnce() -> Fut + Send, Fut: Future> + Send, { match self.get(key).await? { Some(entry) => { tracing::debug!("Cache hit for key: {}", key); Ok(entry) } None => { let value = f().await?; self.set(key, value.clone(), expire).await?; Ok(value) } } } #[tracing::instrument(name = "cache::get", skip(self, key))] async fn get(&self, key: &str) -> Result, Error> where T: for<'de> Deserialize<'de>, { let entry: CacheEntry = self.connection().await?.get(key).await?; match entry.value { Some(_) => Ok(entry.get_as()), None => Ok(None), } } #[tracing::instrument(name = "cache::insert", skip(self, key, value, expire))] async fn set(&self, key: &str, value: T, expire: Option) -> Result where T: Serialize + Send, { let entry = CacheEntry::new(key.to_string(), serde_json::to_value(value)?); self.connection() .await? .set_ex::<_, CacheEntry, Option>( entry.key.clone(), entry.clone(), expire.unwrap_or(86400), ) .await .map(|_| ())?; Ok(()) } #[tracing::instrument(name = "cache::remove", skip(self, key))] async fn remove(&self, key: &str) -> Result { self.connection() .await? .del::<_, i64>(key) .await .map(|_| ())?; Ok(()) } #[tracing::instrument(name = "cache::clear", skip(self))] async fn clear(&self) -> Result { cmd("FLUSHALL") .query_async(&mut *self.connection().await?) .await .map(|_: ()| ())?; Ok(()) } }