Skip to content

Instantly share code, notes, and snippets.

@kraftaa
Last active February 10, 2025 18:06
Show Gist options
  • Select an option

  • Save kraftaa/1c60a3652d85aee34d53a4ca10f7a80c to your computer and use it in GitHub Desktop.

Select an option

Save kraftaa/1c60a3652d85aee34d53a4ca10f7a80c to your computer and use it in GitHub Desktop.
use super::prelude::*;
use std::collections::HashMap;
use project_schemas::tables::{
orders::dsl as orders_dsl, users::dsl as users_dsl
currencies::dsl as currencies_dsl, taxes::dsl as taxes,
};
#[derive(ParquetRecordWriter)]
struct CombinedOrderRecord {
order_id: i32,
user_id: Option<i32>,
country: String,
currency: String,
conversion_rate: f64,
user_email: Option<String>,
amount: f64,
amount_usd: f64;
amount_with_tax: f64,
created_at: Option<NaiveDateTime>,
}
pub fn combined_orders(pg_uri: &str) -> (String, i64) {
let conn = PgConnection::establish(pg_uri).unwrap();
let orders_load = Instant::now();
let orders = orders_dsl::orders.load::<Order>(&conn).unwrap();
trace!("load orders took: {:?}", orders_load.elapsed());
let users_load = Instant::now();
let users = users_dsl::users.load::<User>(&conn).unwrap();
trace!("load users took: {:?}", users_load.elapsed());
let users_ids: Vec<i32> = users.iter().map(|x| x.id).collect();
trace!("{:?}", users_ids.len());
let orders_ids: Vec<i32> = users.iter().filter_map(|x| x.order_id).collect();
let currencies_load = Instant::now();
let currencies: Vec<Currency> = currencies_dsl::currencies
.filter(currencies_dsl::type.eq("Order"))
.filter(currencies_dsl::type_id.eq(any(&orders_ids[..])))
.load::<Currency>(&conn)
.unwrap();
trace!(
"load currencies ({}) took: {:?}",
currencies.len(),
currencies_load.elapsed()
);
let currencies_by_order_id: HashMap<i32, &Currency> = currencies
.iter()
.map(|x| (x.type_id.unwrap(), x))
.collect();
let taxes_load = Instant::now();
let taxes: Vec<Tax> = taxes_dsl::taxes
.filter(taxes_dsl::taxable_type.eq("Order"))
.filter(taxes_dsl::type.eq("RetailTax"))
.filter(taxes_dsl::taxable_id.eq(any(&order_ids[..])))
.load::<Tax>(&conn)
.unwrap();
trace!(
"load taxes_time ({}) took: {:?}",
taxes.len(),
taxes_load.elapsed()
);
let taxes_by_order_id: HashMap<i32, &Tax> =
taxes.iter().map(|x| (x.taxable_id.unwrap(), x)).collect();
let users_by_user_id: HashMap<i32, &User> = users.iter().map(|x| (x.id, x)).collect();
let mut count = 0;
let path = "/tmp/combined_orders.parquet";
let path_meta = <&str>::clone(&path);
let parquet_records: Vec<CombinedOrderRecord> = orders
.iter()
.filter(|order| order.user_id.is_some())
.filter(|order| order.active)
.map(|o| {
let currency = currencies_by_order_id_id.get(&o.id);
let conversion_rate = currency
.map(|x| {
x.conversion_rate
.clone()
.map(|cr| cr.to_f64().expect("bigdecimal to f64"))
.expect("Unwrapping currency in Orders")
})
.unwrap_or(1.0);
let currency_name = currency
.map(|x| {
x.currency
.clone()
.expect("Unwrapping currency name in Orders")
})
.unwrap_or_else(|| "USD".to_string());
let amount =
o.amount.to_f64().expect("big decimal price");
let amount_usd = amount / conversion_rate;
let taxes = taxes_by_order_id.get(&p.id);
let tax = taxes.map(|x| x.amount.to_f64().expect("tax amount bigdecimal to f64"));
let tax_rate = taxes.map(|x| x.rate.to_f64().expect("tax rate bigdecimal to f64"));
let user_id = if let Some(o) = order {
o.user_id
} else {
None
};
let user = users_by_user_id.get(&user_id.unwrap_or(0));
let user_email = if let Some(u) = user {
u.email.clone()
} else {
None
};
let country = if let Some(u) = user {
u.country.clone()
} else {
None
};
CombinedOrderRecord {
order_id: o.order_id,
user_id: o.user_id,
country,
currency: String,
conversion_rate: f64,
user_email,
amount: o.amount,
amount_usd,
amount_with_tax: o.amount * (1.0 + tax_rate), // Calculation
created_at: o.created_at,
}
})
.collect();
let schema = parquet_records.as_slice().schema().unwrap();
println!("{:?} schema", &schema);
// let props = Arc::new(WriterProperties::builder().build());
let file = std::fs::File::create(path).unwrap();
let mut pfile = SerializedFileWriter::new(file, schema, props()).unwrap();
{
let mut row_group = pfile.next_row_group().unwrap();
(&parquet_records[..])
.write_to_row_group(&mut row_group)
.expect("can't 'write_to_row_group' ...");
pfile.close_row_group(row_group).unwrap();
count += 1;
println!("{} count", count);
}
pfile.close().unwrap();
let reader = SerializedFileReader::try_from(path_meta).unwrap();
let parquet_metadata = reader.metadata();
let file_metadata = parquet_metadata.file_metadata();
let rows_number = file_metadata.num_rows();
(path.into(), rows_number)
}
pub struct OrderTask {}
impl ProjectTask for OrderTask {
fn run(&self, postgres_uri: &str) -> (String, i64) {
orders(postgres_uri)
}
}
}
pub async fn create_crawler(
crawler_name: String,
path: String,
_greedy: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let _crawler_targets = path.clone();
let iam_role =
"arn:aws:iam::id:role/service-role/AWSGlueServiceRole-role".to_string();
let config = aws_config::from_env().region(REGION).load().await;
use aws_sdk_glue::Client;
let glue = Client::new(&config);
let get_crawler = glue
.get_crawler()
.name(crawler_name.clone())
.send()
.await
.unwrap();
let must_create = match get_crawler {
GetCrawlerOutput {
crawler: Some(Crawler { name, .. }),
..
} => match name {
Some(_crawler_name) => false,
_ => panic!("nothing here"),
},
_ => true,
};
if must_create {
let create_crawler = glue
.create_crawler()
.name(crawler_name.clone())
.database_name("database_name".to_string())
.role(iam_role)
.targets(
CrawlerTargets::builder()
.s3_targets(S3Target::builder().path(path).build())
.build(),
)
.send()
.await;
info!("create crawler success {:?}", create_crawler.unwrap())
} else {
info!("crawler already exists")
}
Ok(())
}
create-model:
echo "[print_schema]\nfile = 'project_schemas/src/tables/$(table)s_tl.rs'" > diesel.toml
echo "\nmod $(table)s_tl;\npub use self::$(table)s_tl::*;" >> project_schemas/src/tables/mod.rs
echo "\nmod $(table);\npub use self::$(table)::*;" >> project_schemas/src/models/mod.rs
echo "\nmod $(table)s;\npub use self::$(table)s::*;" >> project_tasks/src/tasks/mod.rs
diesel print-schema --database-url=postgres://postgres@localhost/database_name --only-tables -- $(table)s > project_schemas/src/tables/$(table)s_tl.rs
diesel_ext > project_schemas/src/models/$(table).rs
sed s/magic/${table}/g project_tasks/src/tasks/template_task > project_tasks/src/tasks/$(table)s.rs
src/lib.rs
// Order matters!
extern crate openssl;
extern crate diesel;
use project_sqlx::StreamTask;
use project_tasks::tasks::projectTask;
#[derive(Debug, serde::Deserialize)]
#[allow(non_snake_case)]
pub struct Args {
pub arg_POSTGRES_URI: String,
pub flag_table: String,
pub flag_limit: usize,
pub flag_upload: String,
pub flag_file: String,
pub flag_kevel: String,
pub flag_netsuite: String,
}
pub const USAGE: &str = "
project
Usage:
project (<POSTGRES-URI>) [--table=<table>] [--upload=<S3_URL>] [--file=<file>] [--extra=<file>]
Options:
-l --limit=<LIMIT> Number of documents per request [default: 1000]
-h --help Show this screen.
-t --table=<TABLE> Postgres table to process
-u --upload=<S3_URL> Target file [default: s3://bucket]
-n --no-upload Skip uploading to S3
-b --file=<file> Shows to use factors
-k --extra=<file> Shows to run extra data load
";
pub const DATABASE: &str = "database"; \\ athena database name
pub fn tasks_list() -> Vec<(&'static str, Box<dyn ProjectTask>)> {
let tasks: Vec<(&str, Box<dyn ProjectTask>)> = vec![
("combined_orders", Box::new(project_tasks::tasks::CombinedOrderTask {})),
....
];
tasks
}
#[tokio::main]
async fn main() {
// Required to make static musl builds happy
openssl_probe::init_ssl_cert_env_vars();
pretty_env_logger::init();
let project_time = Instant::now();
let mut batch_size = 0_i64;
let mut labels = HashMap::new();
// uncomment it during debugging
// #[cfg(not(debug_assertions))]
let _guard = sentry::init((
"https://..@..ingest.sentry.io/..?timeout=10,verify_ssl=0",
sentry::ClientOptions {
release: sentry::release_name!(),
environment: Some("production".into()),
..Default::default()
},
));
::rayon::ThreadPoolBuilder::new()
.num_threads(2)
.build_global()
.unwrap();
let args: Args = Docopt::new(USAGE)
.and_then(|d| d.deserialize())
.unwrap_or_else(|e| e.exit());
let _current_timestamp = chrono::offset::Utc::now();
if args.flag_table == "all" {
// use async_std::task::block_on;
let (_size, count) = all(args.flag_table);
// let (_size, count) = block_on(all(args.flag_table));
// let (_size, count) = all().await;
let s3_path = format!("{}/{}", BASE_PATH, MAIN_FOLDER);
create_crawler(CRAWLER_NAME.to_string(), s3_path, true)
.await
.expect("create crawler");
start_crawler(CRAWLER_NAME.to_string(), true)
.await
.expect("start crawler");
let project_time = project_time.elapsed().as_millis();
info!("{} project time", project_time);
batch_size += count as i64;
labels.insert("batch_name".to_string(), "all".to_string());
} else if args.flag_table == "stream_tasks" {
let (_size, count) = stream_tasks().await;
let s3_path = format!("{}/{}", BASE_PATH, MAIN_FOLDER);
create_crawler(CRAWLER_NAME.to_string(), s3_path, true)
.await
.expect("create crawler");
start_crawler(CRAWLER_NAME.to_string(), true)
.await
.expect("start crawler");
let project_time = project_time.elapsed().as_millis();
info!("{} project time", project_time);
batch_size += count;
labels.insert("batch_name".to_string(), "stream_tasks".to_string());
}
pub async fn start_crawler(
crawler_name: String,
poll_to_completion: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let config = aws_config::from_env().region(REGION).load().await;
let glue = aws_sdk_glue::Client::new(&config);
let mut attempts = 0;
loop {
let start_crawler = glue.start_crawler().name(crawler_name.clone()).send().await;
attempts += 1;
match start_crawler {
Ok(_) => {
println!("crawling away on {}", crawler_name);
break;
}
Err(crawler_error) => {
if let SdkError::ServiceError(err) = crawler_error {
match err.err() {
StartCrawlerError::CrawlerRunningException(_) => {
info!("crawler update failed due to running state. bailing out.");
if !poll_to_completion {
info!("crawler failed. bailing out.");
break;
} else {
if attempts < 20 {
info!("crawler already running, retrying in 5 seconds")
} else {
panic!("crawler has tried 20 times. dying")
}
std::thread::sleep(DELAY_TIME);
}
}
StartCrawlerError::EntityNotFoundException(_) => {
println!("not found")
}
StartCrawlerError::OperationTimeoutException(_) => {
println!("timed out")
}
StartCrawlerError::Unhandled(_) => {
panic!("unhandled StartCrawlerErrorKind")
}
_ => {
println!("no idea")
}
}
}
if poll_to_completion {
wait_for_crawler(&glue, &crawler_name).await?
}
}
}
}
Ok(())
}
#![allow(unknown_lints)]
#![allow(incomplete_features)]
#![recursion_limit = "512"]
#![allow(proc_macro_derive_resolution_fallback)]
use std::panic::RefUnwindSafe;
use std::panic::UnwindSafe;
pub mod tasks_sqlx;
use async_trait::async_trait;
use chrono::NaiveDate;
use std::fmt::Debug;
use std::path::PathBuf;
pub use tasks_sqlx::*;
#[async_trait]
pub trait ProjectStreamTask: Debug + Sync + Send + RefUnwindSafe + UnwindSafe {
async fn run(&self, postgres_uri: &str) -> (String, i64);
}
#[async_trait]
pub trait HugeStreamTask: Debug + Sync + Send + RefUnwindSafe + UnwindSafe {
async fn run(&self, postgres_uri: &str) -> Vec<(NaiveDate, PathBuf, u128, i64)>;
}
pub mod prelude_sqlx {
pub use std::time::Instant;
pub use diesel::dsl::any;
pub use diesel::pg::PgConnection;
pub use diesel::prelude::*;
pub use bigdecimal::{BigDecimal, ToPrimitive};
pub use chrono::prelude::*;
pub use diesel::result::Error;
pub use project_parquet::prelude::*;
pub use project_parquet::props;
pub use project_parquet::FileWriterRows;
pub use project_schemas::models::*;
pub use futures_util::stream::StreamExt;
pub use log::*;
pub use parquet::file::properties::WriterProperties;
pub use rayon::prelude::*;
pub use sqlx::postgres::PgPool;
pub use std::collections::HashMap;
pub use std::convert::TryFrom;
pub use std::fs;
pub use std::fs::File;
pub use std::panic::catch_unwind;
pub use std::panic::RefUnwindSafe;
pub use std::panic::UnwindSafe;
pub use std::sync::Arc;
pub use std::path::Path;
}
use super::prelude::*;
pub use futures_util::stream::StreamExt;
use parquet::record::RecordWriter;
pub use sqlx::postgres::PgPool;
#[derive(ParquetRecordWriter, Default, sqlx::FromRow, Debug)]
struct ProductRecordStream {
id: i64,
email: Option<String>,
completed_at: Option<NaiveDateTime>,
// uuid: Option<Uuid>,
uuid: Option<String>,
created_at: Option<NaiveDateTime>,
created_on: Option<NaiveDate>,
updated_at: Option<NaiveDateTime>,
name: Option<String>,
organization_id: Option<i32>,
}
pub async fn products(pg_uri: &str) -> anyhow::Result<(String, i64)> {
let pool = PgPool::connect(pg_uri).await?;
let fake_products = vec![ProductsRecordStream {
..Default::default()
}];
let schema = fake_products.as_slice().schema().unwrap();
let schema_2 = fake_products.as_slice().schema().unwrap();
let schema_vec = schema_2.get_fields();
let mut fields: Vec<&str> = vec![];
for i in schema_vec {
if i.name() == "uuid" {
fields.push("uuid::varchar")
} else if .. {
...
} else {
fields.push(i.name())
}
}
println!("{:?} fields!", fields);
let products_load = Instant::now();
let path = "/tmp/products.parquet";
// let props = Arc::new(WriterProperties::builder().build());
let file = std::fs::File::create(&path).unwrap();
let mut pfile = SerializedFileWriter::new(file, schema, props()).unwrap();
let table: &str = "products";
let mut query = "SELECT ".to_owned();
let fields: &str = &fields.join(", ");
query.push_str(fields);
query.push_str(" FROM ");
query.push_str(table);
let q = sqlx::query_as::<sqlx::Postgres, ProductRecordStream>(&query);
let products_stream = q.fetch(&pool);
println!("{} query", query);
println!(" before stream");
trace!("load requests took: {:?}", products_load.elapsed());
let mut chunk_stream = products_stream.map(|fs| fs.unwrap()).chunks(5000);
while let Some(chunks) = chunk_stream.next().await {
let mut row_group = pfile.next_row_group().unwrap();
(&chunks[..])
.write_to_row_group(&mut row_group)
.expect("can't 'write_to_row_group' ...");
pfile.close_row_group(row_group).unwrap();
}
pfile.close().unwrap();
let reader = SerializedFileReader::try_from(path).unwrap();
let parquet_metadata = reader.metadata();
println!("{:?} num row group", parquet_metadata.num_row_groups());
let file_metadata = parquet_metadata.file_metadata();
let rows_number = file_metadata.num_rows();
println!("{:?} file_metadata.num_rows()", file_metadata.num_rows());
Ok((path.into(), rows_number))
}
use async_trait::async_trait;
#[derive(Debug)]
pub struct ProductStreamingTask {}
#[async_trait]
impl ProjectStreamingTask for ProductStreamingTask {
async fn run(&self, postgres_uri: &str) -> (String, i64) {
products(postgres_uri).await.unwrap()
}
}
#!/bin/bash
query_result="
column_name | data_type | is_nullable
---------------------------------+-----------------------------+-------------
id | integer | NO
active | boolean | NO
name | character varying | YES
created_at | timestamp without time zone | NO
data | jsonb | YES
"
# Function to map PostgreSQL types to Diesel types
map_pg_to_diesel() {
local pg_type=$1
case $pg_type in
integer) echo "Int4" ;;
bigint) echo "Int8" ;;
smallint) echo "Int2" ;;
boolean) echo "Bool" ;;
"character varying") echo "Varchar" ;;
timestamp*) echo "Timestamp" ;;
jsonb) echo "Jsonb" ;;
json) echo "Json" ;;
*) echo "Unknown" ;;
esac
}
# Initialize the Diesel table definition
table_name="orders"
diesel_table="table! {
$table_name (id) {
"
# Process the query result
while IFS='|' read -r column_name data_type is_nullable; do
# Trim whitespace
column_name=$(echo "$column_name" | xargs)
data_type=$(echo "$data_type" | xargs)
is_nullable=$(echo "$is_nullable" | xargs)
# Skip header and separator lines
if [[ -z "$column_name" || "$column_name" == "column_name" || "$column_name" =~ ^[-+]+$ ]]; then
continue
fi
# Map PostgreSQL type to Diesel type
diesel_type=$(map_pg_to_diesel "$data_type")
# Add Nullable if column is nullable
if [[ "$is_nullable" == "YES" ]]; then
diesel_type="Nullable<$diesel_type>"
fi
# Append to Diesel table definition
diesel_table+=" $column_name -> $diesel_type,\n"
done <<< "$query_result"
# Close the Diesel table definition
diesel_table+=" }
}"
# Output the Diesel table definition
echo -e "$diesel_table"
// orders_tl.rs for Postgres diesel schema
table! {
orders (id) {
id -> Int4,
user_id -> Nullable<Int4>,
amount -> Nullable<BigDecimal>,
created_at -> Nullable<Timestamp>,
uuid -> Uuid,
}
}
// order.rs for the struct
use uuid::Uuid;
use bigdecimal::BigDecimal;
use chrono::NaiveDateTime;
#[derive(Queryable, Debug)]
pub struct Order {
pub id: i32,
pub user_id: Option<i32>,
pub amount: f64,
pub created_at: Option<NaiveDateTime>,
pub uuid: Uuid,
}
// users_tl.rs for Postgres diesel schema
table! {
users (id) {
id -> Int4,
name -> Nullable<Varchar>,
country -> Nullable<Varchar>,
created_at -> Nullable<Timestamp>,
email -> Nullable<Varchar>,
}
}
// user.rs for the struct
use chrono::NaiveDateTime;
#[derive(Queryable, Debug)]
pub struct User {
pub id: i32,
pub name: Option<String>,
pub country: Option<String>,
pub created_at: Option<NaiveDateTime>,
pub email: Option<String>,
}
// products_tl.rs for Postgres diesel schema
table! {
products (id) {
id -> Int4,
name -> Varchar,
quantity -> Nullable<Int4>,
created_at -> Nullable<Timestamp>,
approved -> Nullable<Bool>,
price -> Numeric,
}
}
// product.rs for the struct
use bigdecimal::BigDecimal;
use chrono::NaiveDateTime;
#[derive(Queryable, Debug)]
pub struct Product {
pub id: i32,
pub name: String,
pub quantity: Option<i32>,
pub created_at: Option<NaiveDateTime>,
pub approved: Option<boolean>,
pub price: f64,
}
pub mod prelude {
pub use std::time::Instant;
pub use diesel::dsl::any;
pub use diesel::pg::PgConnection;
pub use diesel::prelude::*;
pub use bigdecimal::{BigDecimal, ToPrimitive};
pub use chrono::prelude::*;
pub use std::collections::HashMap;
pub use rayon::prelude::*;
pub use super::page_iter::*;
pub use super::ProjectStreamingTask;
pub use super::ProjectTask;
pub use super::HugeTask;
pub use ::function_name::named;
pub use diesel::result::Error;
pub use project_parquet::prelude::*;
pub use project_parquet::props;
pub use project_parquet::FileWriterRows;
pub use projeect_pg_schemas::models::*;
pub use log::*;
pub use parquet::file::properties::WriterProperties;
pub use std::convert::TryFrom;
pub use std::fs;
pub use std::fs::File;
pub use std::sync::Arc;
pub use std::path::Path;
}
use super::prelude::*;
use project_schemas::tables::magics::dsl as magics_dsl;
pub fn magic(pg_uri: &str) -> (String, i64) {
let conn = PgConnection::establish(pg_uri).unwrap();
let magics_load = Instant::now();
let magics = magics_dsl::magics
.load::<magic>(&conn)
.unwrap();
trace!("magics: {:?}", magics_load.elapsed());
let path = "/tmp/magics.parquet";
let path_meta = <&str>::clone(&path);
let vector_for_schema = &magics;
let schema = vector_for_schema.as_slice().schema().unwrap();
let file = std::fs::File::create(path).unwrap();
let mut pfile = SerializedFileWriter::new(file, schema, props())
.unwrap();
{
let mut row_group = pfile.next_row_group().unwrap();
(&magics[..]).write_to_row_group(&mut row_group).expect("can't 'write_to_row_group' ...");
pfile.close_row_group(row_group).unwrap();
}
// let rows_number = *pfile.total_num_rows() as i64;
pfile.close().unwrap();
let reader = SerializedFileReader::try_from(path_meta).unwrap();
let parquet_metadata = reader.metadata();
let file_metadata = parquet_metadata.file_metadata();
let rows_number = file_metadata.num_rows();
(path.into(), rows_number)
}
pub struct magicTask {
}
impl ProjectTask for magicTask {
fn run(&self, postgres_uri: &str) -> (String, i64) {
magics(postgres_uri)
}
}
pub async fn upload(
path: PathBuf,
bucket_name: &str,
key: &str,
) -> Result<(), Box<dyn std::error::Error>> {
use aws_sdk_s3::primitives::ByteStream;
let body = ByteStream::from_path(Path::new(&path)).await;
let config = aws_config::from_env().region(REGION).load().await;
let client = S3Client::new(&config);
let _ = client
.put_object()
.bucket(bucket_name)
.key(key)
.body(body.unwrap())
.send()
.await;
info!("Uploaded file: {}", key);
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment