Last active
April 16, 2025 05:22
-
-
Save joe32140/3c38f377750202d7803b8c0fa0ef1e8b to your computer and use it in GitHub Desktop.
Code Retrieval Model Evaluation On COIR with MTEB Library
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Evaluate and compare multiple code retrieval models on various MTEB code-related tasks. | |
| Uses only the recommended subset of CoIR tasks based on provided analysis. | |
| """ | |
| import os | |
| import argparse | |
| import logging | |
| import json | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| from mteb import MTEB, get_tasks | |
| from tabulate import tabulate | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| # Configure logging | |
| logging.basicConfig(format='%(asctime)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S', | |
| level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def evaluate_model(model_path_or_name, output_dir, batch_size=16, save_predictions=False, overwrite_results=False): | |
| """Evaluate a single model on the recommended MTEB code-related tasks.""" | |
| logger.info(f"Loading model: {model_path_or_name}") | |
| # Ensure CUDA availability check or handle device placement more robustly if needed | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device}") | |
| model = SentenceTransformer(model_path_or_name, trust_remote_code=True, device=device, model_kwargs={"torch_dtype": torch.bfloat16 if device == "cuda" else torch.float32}) # bfloat16 only on CUDA | |
| # Define tasks to evaluate on - using the recommended subset based on the image | |
| # Recommended: CodeSearchNet-CCR, Apps-rtl, SyntheticText2SQL, CodeTrans-Contest, CodeTrans-DL, CodeFeedBack-ST, CodeFeedBack-MT | |
| # MTEB Names: CodeSearchNetCCRetrieval, AppsRetrieval, SyntheticText2SQL, CodeTransOceanContest, CodeTransOceanDL, CodeFeedbackST, CodeFeedbackMT | |
| recommended_tasks = [ | |
| "AppsRetrieval", | |
| "CodeFeedbackMT", | |
| "CodeFeedbackST", | |
| "CodeSearchNetCCRetrieval", # Added based on recommendation | |
| "CodeTransOceanContest", | |
| "CodeTransOceanDL", | |
| "SyntheticText2SQL", | |
| # "CosQA", # Removed (Not Recommended) | |
| # "StackOverflowQA", # Removed (Not Recommended) | |
| ] | |
| logger.info(f"Evaluating on tasks: {recommended_tasks}") | |
| tasks = get_tasks(tasks=recommended_tasks) | |
| # Initialize MTEB with the specified tasks | |
| evaluation = MTEB(tasks=tasks) | |
| # Create model-specific output directory | |
| # Handle potential slashes in model names from Hugging Face Hub | |
| model_name_safe = model_path_or_name.replace('/', '_') | |
| model_output_dir = os.path.join(output_dir, model_name_safe) | |
| os.makedirs(model_output_dir, exist_ok=True) | |
| # Run evaluation | |
| logger.info(f"Running evaluation on MTEB recommended code tasks for {model_path_or_name}") | |
| results = evaluation.run( | |
| model=model, | |
| output_folder=model_output_dir, | |
| eval_splits=["test"], # Assuming 'test' split is appropriate for all these tasks | |
| batch_size=batch_size, | |
| save_predictions=save_predictions, | |
| overwrite_results=overwrite_results | |
| ) | |
| # Return results along with the original model name/path for dictionary key consistency | |
| return results, model_path_or_name | |
| def compare_models(results_dict): | |
| """Compare multiple models and create a comparison table.""" | |
| # Extract metrics for each model and task | |
| task_metrics = {} | |
| # First, identify all tasks and their primary metrics across all models | |
| primary_metric = "ndcg_at_10" # Default primary metric for retrieval tasks | |
| all_tasks_found = set() | |
| for model_name, results in results_dict.items(): | |
| if results is None: | |
| logger.warning(f"No results found for model: {model_name}") | |
| continue | |
| for task_result in results: | |
| all_tasks_found.add(task_result.task_name) | |
| # You might want to dynamically determine the main metric if it varies | |
| # main_score_key = task_result.main_score if hasattr(task_result, 'main_score') else primary_metric | |
| # For simplicity, let's focus on ndcg_at_10 if available | |
| if task_result.task_name not in task_metrics: | |
| task_metrics[task_result.task_name] = primary_metric # Store the metric we want | |
| # Create DataFrame data | |
| all_data = [] | |
| for task_name in sorted(list(all_tasks_found)): | |
| metric = task_metrics.get(task_name, primary_metric) # Use ndcg_at_10 or default | |
| for model_name, results in results_dict.items(): | |
| if results is None: | |
| value = np.nan # Represent missing data as NaN | |
| value_str = "N/A" | |
| else: | |
| matching_result = next((tr for tr in results if tr.task_name == task_name), None) | |
| if matching_result and "test" in matching_result.scores and matching_result.scores["test"]: | |
| # MTEB results structure can vary slightly; check common structures | |
| score_dict = matching_result.scores["test"][0] if isinstance(matching_result.scores["test"], list) else matching_result.scores["test"] | |
| value = score_dict.get(metric, np.nan) # Get the primary metric, default to NaN if not found | |
| if pd.isna(value): | |
| # Fallback to main_score if primary not found | |
| main_score_key = getattr(matching_result, 'main_score', None) | |
| if main_score_key: | |
| value = score_dict.get(main_score_key, np.nan) | |
| if pd.isna(value): | |
| value_str = "N/A" | |
| elif isinstance(value, (int, float)): | |
| value_str = f"{value:.4f}" | |
| else: | |
| value_str = str(value) | |
| else: | |
| value = np.nan | |
| value_str = "N/A" | |
| all_data.append({ | |
| "Task": task_name, | |
| "Metric": metric, # Use the determined primary metric | |
| "Model": model_name, | |
| "Value": value, # Store numeric value for calculations/plotting | |
| "ValueStr": value_str # Store formatted string for display | |
| }) | |
| if not all_data: | |
| logger.warning("No data collected for comparison.") | |
| return pd.DataFrame(), pd.DataFrame() | |
| # Convert to DataFrame | |
| df = pd.DataFrame(all_data) | |
| # Create a pivot table for better visualization (using ValueStr) | |
| try: | |
| pivot_df = df.pivot_table( | |
| index=["Task", "Metric"], | |
| columns="Model", | |
| values="ValueStr", | |
| aggfunc="first" # Take the first value if duplicates (shouldn't happen with this structure) | |
| ).reset_index() | |
| except Exception as e: | |
| logger.error(f"Failed to create pivot table: {e}") | |
| logger.error("DataFrame contents:\n", df) | |
| pivot_df = pd.DataFrame() # Return empty df on error | |
| return df, pivot_df | |
| def plot_comparison(df, output_dir): | |
| """Create bar charts comparing model performance for each task and metric.""" | |
| if df.empty: | |
| logger.warning("Skipping plot generation due to empty DataFrame.") | |
| return | |
| # Ensure the plot directory exists | |
| plot_dir = os.path.join(output_dir, "plots") | |
| os.makedirs(plot_dir, exist_ok=True) | |
| # Group by task and metric | |
| for (task, metric), group_df in df.groupby(["Task", "Metric"]): | |
| # Filter out non-numeric values before plotting | |
| numeric_df = group_df.dropna(subset=['Value']) | |
| if numeric_df.empty: | |
| logger.warning(f"Skipping plot for {task} - {metric}: No numeric data.") | |
| continue | |
| # Check if 'Value' column is actually numeric after filtering NaNs | |
| if not pd.api.types.is_numeric_dtype(numeric_df['Value']): | |
| logger.warning(f"Skipping plot for {task} - {metric}: 'Value' column is not numeric.") | |
| continue | |
| plt.figure(figsize=(10, 6)) | |
| # Use the numeric 'Value' column for plotting | |
| bars = plt.bar(numeric_df["Model"], numeric_df["Value"]) | |
| # Add values on top of bars | |
| for bar in bars: | |
| height = bar.get_height() | |
| plt.text(bar.get_x() + bar.get_width()/2., height + 0.002, | |
| f'{height:.4f}', ha='center', va='bottom', rotation=0, fontsize=8) | |
| plt.title(f'{task} - {metric}') | |
| plt.ylabel(metric) | |
| # Dynamically adjust ylim to prevent text overlap | |
| current_ylim = plt.ylim() | |
| plt.ylim(current_ylim[0], current_ylim[1] * 1.1) # Add 10% headroom | |
| plt.xticks(rotation=45, ha='right') | |
| plt.tight_layout() # Adjust layout to prevent labels overlapping | |
| # Save the plot | |
| plot_path = os.path.join(plot_dir, f'{task}_{metric}.png') | |
| try: | |
| plt.savefig(plot_path) | |
| logger.info(f"Saved plot: {plot_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to save plot {plot_path}: {e}") | |
| plt.close() # Close the figure to free memory | |
| # --- Summary Plot Generation (Optional - can be complex) --- | |
| # Example: Plotting average ndcg_at_10 across tasks per model | |
| primary_metric = "ndcg_at_10" | |
| summary_df = df[df['Metric'] == primary_metric].dropna(subset=['Value']) | |
| if not summary_df.empty and pd.api.types.is_numeric_dtype(summary_df['Value']): | |
| avg_scores = summary_df.groupby('Model')['Value'].mean().reset_index() | |
| plt.figure(figsize=(10, 6)) | |
| bars = plt.bar(avg_scores["Model"], avg_scores["Value"]) | |
| for bar in bars: | |
| height = bar.get_height() | |
| plt.text(bar.get_x() + bar.get_width()/2., height + 0.002, | |
| f'{height:.4f}', ha='center', va='bottom', fontsize=8) | |
| plt.title(f'Average {primary_metric} Across Recommended Tasks') | |
| plt.ylabel(f'Average {primary_metric}') | |
| plt.xticks(rotation=45, ha='right') | |
| current_ylim = plt.ylim() | |
| plt.ylim(current_ylim[0], current_ylim[1] * 1.1) | |
| plt.tight_layout() | |
| summary_plot_dir = os.path.join(plot_dir, "summary") | |
| os.makedirs(summary_plot_dir, exist_ok=True) | |
| summary_plot_path = os.path.join(summary_plot_dir, f'summary_avg_{primary_metric}.png') | |
| try: | |
| plt.savefig(summary_plot_path) | |
| logger.info(f"Saved summary plot: {summary_plot_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to save summary plot {summary_plot_path}: {e}") | |
| plt.close() | |
| def main(args): | |
| # Create output directory if it doesn't exist | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # Models to evaluate (example models) | |
| models = [ | |
| "Alibaba-NLP/gte-base-en-v1.5", # Common general purpose baseline | |
| "codesage/codesage-small", # Code specific baseline | |
| # Add other models you want to compare here | |
| # "your-org/your-fine-tuned-model", | |
| ] | |
| # Evaluate each model | |
| results_dict = {} | |
| for model_path_or_name in models: | |
| try: | |
| results, model_key = evaluate_model( | |
| model_path_or_name, | |
| args.output_dir, | |
| args.batch_size, | |
| args.save_predictions, | |
| args.overwrite_results | |
| ) | |
| # Use the original name/path as the key | |
| results_dict[model_key] = results | |
| except Exception as e: | |
| logger.error(f"Evaluation failed for model {model_path_or_name}: {e}", exc_info=True) | |
| results_dict[model_path_or_name] = None # Indicate failure | |
| # Compare models | |
| if not results_dict: | |
| logger.error("No models were successfully evaluated. Exiting comparison.") | |
| return | |
| df, pivot_df = compare_models(results_dict) | |
| if df.empty and pivot_df.empty: | |
| logger.error("Comparison resulted in empty dataframes. Check evaluation logs.") | |
| return | |
| # Save comparison results (raw data) | |
| comparison_path = os.path.join(args.output_dir, "model_comparison_raw.csv") | |
| try: | |
| df.to_csv(comparison_path, index=False) | |
| logger.info(f"Raw comparison data saved to {comparison_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to save raw comparison data: {e}") | |
| # Save pivot table results | |
| pivot_path = os.path.join(args.output_dir, "model_comparison_pivot.csv") | |
| try: | |
| pivot_df.to_csv(pivot_path, index=False) # index=False as Task/Metric are columns now | |
| logger.info(f"Pivot table comparison saved to {pivot_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to save pivot table comparison: {e}") | |
| # Calculate mean scores across tasks for each metric and model using the numeric 'Value' | |
| # Filter out non-numeric before calculating mean | |
| numeric_df = df.dropna(subset=['Value']) | |
| if not numeric_df.empty and pd.api.types.is_numeric_dtype(numeric_df['Value']): | |
| try: | |
| mean_scores = numeric_df.pivot_table( | |
| values='Value', | |
| index='Metric', | |
| columns='Model', | |
| aggfunc=np.mean # Use numpy mean which handles NaN | |
| ).round(4) # Round after aggregation | |
| # Create a pretty table for display | |
| if not pivot_df.empty: | |
| table = tabulate(pivot_df, headers='keys', tablefmt='pretty', showindex=False) # showindex=False as index is reset | |
| logger.info(f"\nModel Comparison Results (Formatted):\n{table}") | |
| else: | |
| logger.warning("Pivot table is empty, cannot display formatted results.") | |
| # Create a pretty table for mean scores | |
| mean_table = tabulate(mean_scores, headers='keys', tablefmt='pretty', showindex=True) # Show metric index | |
| logger.info(f"\nMean Scores Across Tasks:\n{mean_table}") | |
| # Save the tables as text | |
| table_path = os.path.join(args.output_dir, "model_comparison_summary.txt") | |
| with open(table_path, 'w') as f: | |
| if not pivot_df.empty: | |
| f.write("Model Comparison Results (Formatted):\n") | |
| f.write(table) | |
| else: | |
| f.write("Model Comparison Results (Formatted): Pivot table empty.\n") | |
| f.write("\n\nMean Scores Across Tasks:\n") | |
| f.write(mean_table) | |
| logger.info(f"Comparison tables saved to {table_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to calculate or display mean scores: {e}") | |
| logger.error("Mean Scores DataFrame calculation might have failed.") | |
| else: | |
| logger.warning("Could not calculate mean scores: No numeric data available.") | |
| # Save the pivot table text even if mean scores fail | |
| table_path = os.path.join(args.output_dir, "model_comparison_summary.txt") | |
| with open(table_path, 'w') as f: | |
| if not pivot_df.empty: | |
| table = tabulate(pivot_df, headers='keys', tablefmt='pretty', showindex=False) | |
| f.write("Model Comparison Results (Formatted):\n") | |
| f.write(table) | |
| f.write("\n\nMean Scores Across Tasks: Could not be calculated (no numeric data).\n") | |
| else: | |
| f.write("Model Comparison Results (Formatted): Pivot table empty.\n") | |
| f.write("\n\nMean Scores Across Tasks: Could not be calculated (no numeric data).\n") | |
| logger.info(f"Partial comparison table saved to {table_path}") | |
| # Create comparison plots using the raw numeric data | |
| plot_comparison(df, args.output_dir) # Pass the raw DataFrame df | |
| logger.info(f"Evaluation and comparison finished. Results are in {args.output_dir}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Compare multiple code retrieval models on MTEB recommended code tasks") | |
| parser.add_argument("--output_dir", type=str, default="evaluation/mteb_recommended_code_tasks", | |
| help="Directory to save evaluation results") | |
| parser.add_argument("--batch_size", type=int, default=16, # Adjusted default based on common practice | |
| help="Batch size for encoding. Adjust based on GPU memory.") | |
| parser.add_argument("--save_predictions", action="store_true", | |
| help="Whether to save raw predictions (can consume significant disk space)") | |
| parser.add_argument("--overwrite_results", action="store_true", | |
| help="Whether to overwrite existing evaluation results for a model-task pair") | |
| args = parser.parse_args() | |
| main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment