import json import graphviz import trino import os import sys import logging import concurrent.futures from concurrent.futures import ThreadPoolExecutor from cachier import cachier import datetime logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) logger = logging.getLogger("trino") logger.setLevel(logging.DEBUG) def create_trino_connection(): return trino.dbapi.connect( host=os.getenv("STARBURST_HOST"), port=443, http_scheme="https", verify=True, user=os.getenv("STARBURST_DEV_USER"), auth=trino.auth.OAuth2Authentication(), ) def read_manifest(file_path): with open(file_path, 'r') as f: return json.load(f) def build_graphviz_graph(manifest, previous_stages_dict): dot = graphviz.Digraph(comment='DBT Models') dot.graph_attr['rankdir'] = 'LR' threshold = 10.0 # Set your threshold value if os.path.exists('results.tsv'): os.remove('results.tsv') with open('results.tsv', 'w') as f: f.write('node_name\tstages\n') with ThreadPoolExecutor(max_workers=4) as executor: futures = {executor.submit(get_query_stages, node_name, node_data['compiled_code']): node_name for node_name, node_data in manifest['nodes'].items() if node_data['compiled_code'] and not node_name.startswith('test.')} for future in futures: node_name, stages = future.result() f.write(f"{node_name}\t{stages}\n") # Check for high relative increase in stages dependencies = manifest['nodes'][node_name].get('depends_on', {}).get('nodes', []) input_stages = [previous_stages_dict.get(dep, 0) for dep in dependencies] if input_stages: avg_input_stages = sum(input_stages) / len(input_stages) if avg_input_stages > 0 and stages / avg_input_stages >= threshold: # Check for non-zero dot.node(node_name, f"{node_name.split('.')[-1]}\n{stages} stages", style="filled", fillcolor="red") else: dot.node(node_name, f"{node_name.split('.')[-1]}\n{stages} stages") else: dot.node(node_name, f"{node_name.split('.')[-1]}\n{stages} stages") for dep in dependencies: dot.edge(dep, node_name) previous_stages_dict[node_name] = stages return dot @cachier(stale_after=datetime.timedelta(minutes=60)) def get_plan(node_name, sql): # remove ; from any part of the statement sql = sql.replace(';', '') cur = create_trino_connection().cursor() cur.execute(f"EXPLAIN (TYPE DISTRIBUTED, FORMAT JSON) {sql}") row = cur.fetchone() explain_json = json.loads(row[0]) cur.close() return explain_json def get_query_stages(node_name, sql): stages = count_stages_in_query_plan(get_plan(node_name, sql)) return node_name, stages def is_virtual_stage(stage): # Check for 'estimates' key in the stage estimates_list = stage.get('estimates', []) for estimates in estimates_list: # Check for 'NaN' or 0.0 in relevant fields if not ( estimates.get('outputSizeInBytes', 'NaN') == 'NaN' and estimates.get('cpuCost', 'NaN') == 'NaN' and estimates.get('memoryCost', 0.0) == 0.0 and estimates.get('networkCost', 0.0) == 0.0 ): return False return True def count_stages_in_query_plan(query_plan): stage_count = 0 if isinstance(query_plan, dict): # If this stage is "virtual" (has no cost), do not count it if not is_virtual_stage(query_plan): stage_count += 1 for key, value in query_plan.items(): if isinstance(value, dict): # Count this as a stage and explore its children stage_count += count_stages_in_query_plan(value) elif isinstance(value, list): for item in value: if isinstance(item, dict): stage_count += count_stages_in_query_plan(item) else: print(f"Warning: Unexpected non-dict value encountered for stage {key}: {value}") return stage_count if __name__ == "__main__": manifest_file_path = "./target/manifest.json" manifest = read_manifest(manifest_file_path) previous_stages_dict = {} dot = build_graphviz_graph(manifest, previous_stages_dict) # save dotfile to disk dot.save('output.dot') dot.render('output', view=True, format='png')