Skip to content

Instantly share code, notes, and snippets.

@joe32140
Last active April 16, 2025 05:22
Show Gist options
  • Select an option

  • Save joe32140/3c38f377750202d7803b8c0fa0ef1e8b to your computer and use it in GitHub Desktop.

Select an option

Save joe32140/3c38f377750202d7803b8c0fa0ef1e8b to your computer and use it in GitHub Desktop.
Code Retrieval Model Evaluation On COIR with MTEB Library
"""
Evaluate and compare multiple code retrieval models on various MTEB code-related tasks.
Uses only the recommended subset of CoIR tasks based on provided analysis.
"""
import os
import argparse
import logging
import json
import pandas as pd
from sentence_transformers import SentenceTransformer
from mteb import MTEB, get_tasks
from tabulate import tabulate
import matplotlib.pyplot as plt
import numpy as np
import torch
# Configure logging
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
def evaluate_model(model_path_or_name, output_dir, batch_size=16, save_predictions=False, overwrite_results=False):
"""Evaluate a single model on the recommended MTEB code-related tasks."""
logger.info(f"Loading model: {model_path_or_name}")
# Ensure CUDA availability check or handle device placement more robustly if needed
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
model = SentenceTransformer(model_path_or_name, trust_remote_code=True, device=device, model_kwargs={"torch_dtype": torch.bfloat16 if device == "cuda" else torch.float32}) # bfloat16 only on CUDA
# Define tasks to evaluate on - using the recommended subset based on the image
# Recommended: CodeSearchNet-CCR, Apps-rtl, SyntheticText2SQL, CodeTrans-Contest, CodeTrans-DL, CodeFeedBack-ST, CodeFeedBack-MT
# MTEB Names: CodeSearchNetCCRetrieval, AppsRetrieval, SyntheticText2SQL, CodeTransOceanContest, CodeTransOceanDL, CodeFeedbackST, CodeFeedbackMT
recommended_tasks = [
"AppsRetrieval",
"CodeFeedbackMT",
"CodeFeedbackST",
"CodeSearchNetCCRetrieval", # Added based on recommendation
"CodeTransOceanContest",
"CodeTransOceanDL",
"SyntheticText2SQL",
# "CosQA", # Removed (Not Recommended)
# "StackOverflowQA", # Removed (Not Recommended)
]
logger.info(f"Evaluating on tasks: {recommended_tasks}")
tasks = get_tasks(tasks=recommended_tasks)
# Initialize MTEB with the specified tasks
evaluation = MTEB(tasks=tasks)
# Create model-specific output directory
# Handle potential slashes in model names from Hugging Face Hub
model_name_safe = model_path_or_name.replace('/', '_')
model_output_dir = os.path.join(output_dir, model_name_safe)
os.makedirs(model_output_dir, exist_ok=True)
# Run evaluation
logger.info(f"Running evaluation on MTEB recommended code tasks for {model_path_or_name}")
results = evaluation.run(
model=model,
output_folder=model_output_dir,
eval_splits=["test"], # Assuming 'test' split is appropriate for all these tasks
batch_size=batch_size,
save_predictions=save_predictions,
overwrite_results=overwrite_results
)
# Return results along with the original model name/path for dictionary key consistency
return results, model_path_or_name
def compare_models(results_dict):
"""Compare multiple models and create a comparison table."""
# Extract metrics for each model and task
task_metrics = {}
# First, identify all tasks and their primary metrics across all models
primary_metric = "ndcg_at_10" # Default primary metric for retrieval tasks
all_tasks_found = set()
for model_name, results in results_dict.items():
if results is None:
logger.warning(f"No results found for model: {model_name}")
continue
for task_result in results:
all_tasks_found.add(task_result.task_name)
# You might want to dynamically determine the main metric if it varies
# main_score_key = task_result.main_score if hasattr(task_result, 'main_score') else primary_metric
# For simplicity, let's focus on ndcg_at_10 if available
if task_result.task_name not in task_metrics:
task_metrics[task_result.task_name] = primary_metric # Store the metric we want
# Create DataFrame data
all_data = []
for task_name in sorted(list(all_tasks_found)):
metric = task_metrics.get(task_name, primary_metric) # Use ndcg_at_10 or default
for model_name, results in results_dict.items():
if results is None:
value = np.nan # Represent missing data as NaN
value_str = "N/A"
else:
matching_result = next((tr for tr in results if tr.task_name == task_name), None)
if matching_result and "test" in matching_result.scores and matching_result.scores["test"]:
# MTEB results structure can vary slightly; check common structures
score_dict = matching_result.scores["test"][0] if isinstance(matching_result.scores["test"], list) else matching_result.scores["test"]
value = score_dict.get(metric, np.nan) # Get the primary metric, default to NaN if not found
if pd.isna(value):
# Fallback to main_score if primary not found
main_score_key = getattr(matching_result, 'main_score', None)
if main_score_key:
value = score_dict.get(main_score_key, np.nan)
if pd.isna(value):
value_str = "N/A"
elif isinstance(value, (int, float)):
value_str = f"{value:.4f}"
else:
value_str = str(value)
else:
value = np.nan
value_str = "N/A"
all_data.append({
"Task": task_name,
"Metric": metric, # Use the determined primary metric
"Model": model_name,
"Value": value, # Store numeric value for calculations/plotting
"ValueStr": value_str # Store formatted string for display
})
if not all_data:
logger.warning("No data collected for comparison.")
return pd.DataFrame(), pd.DataFrame()
# Convert to DataFrame
df = pd.DataFrame(all_data)
# Create a pivot table for better visualization (using ValueStr)
try:
pivot_df = df.pivot_table(
index=["Task", "Metric"],
columns="Model",
values="ValueStr",
aggfunc="first" # Take the first value if duplicates (shouldn't happen with this structure)
).reset_index()
except Exception as e:
logger.error(f"Failed to create pivot table: {e}")
logger.error("DataFrame contents:\n", df)
pivot_df = pd.DataFrame() # Return empty df on error
return df, pivot_df
def plot_comparison(df, output_dir):
"""Create bar charts comparing model performance for each task and metric."""
if df.empty:
logger.warning("Skipping plot generation due to empty DataFrame.")
return
# Ensure the plot directory exists
plot_dir = os.path.join(output_dir, "plots")
os.makedirs(plot_dir, exist_ok=True)
# Group by task and metric
for (task, metric), group_df in df.groupby(["Task", "Metric"]):
# Filter out non-numeric values before plotting
numeric_df = group_df.dropna(subset=['Value'])
if numeric_df.empty:
logger.warning(f"Skipping plot for {task} - {metric}: No numeric data.")
continue
# Check if 'Value' column is actually numeric after filtering NaNs
if not pd.api.types.is_numeric_dtype(numeric_df['Value']):
logger.warning(f"Skipping plot for {task} - {metric}: 'Value' column is not numeric.")
continue
plt.figure(figsize=(10, 6))
# Use the numeric 'Value' column for plotting
bars = plt.bar(numeric_df["Model"], numeric_df["Value"])
# Add values on top of bars
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height + 0.002,
f'{height:.4f}', ha='center', va='bottom', rotation=0, fontsize=8)
plt.title(f'{task} - {metric}')
plt.ylabel(metric)
# Dynamically adjust ylim to prevent text overlap
current_ylim = plt.ylim()
plt.ylim(current_ylim[0], current_ylim[1] * 1.1) # Add 10% headroom
plt.xticks(rotation=45, ha='right')
plt.tight_layout() # Adjust layout to prevent labels overlapping
# Save the plot
plot_path = os.path.join(plot_dir, f'{task}_{metric}.png')
try:
plt.savefig(plot_path)
logger.info(f"Saved plot: {plot_path}")
except Exception as e:
logger.error(f"Failed to save plot {plot_path}: {e}")
plt.close() # Close the figure to free memory
# --- Summary Plot Generation (Optional - can be complex) ---
# Example: Plotting average ndcg_at_10 across tasks per model
primary_metric = "ndcg_at_10"
summary_df = df[df['Metric'] == primary_metric].dropna(subset=['Value'])
if not summary_df.empty and pd.api.types.is_numeric_dtype(summary_df['Value']):
avg_scores = summary_df.groupby('Model')['Value'].mean().reset_index()
plt.figure(figsize=(10, 6))
bars = plt.bar(avg_scores["Model"], avg_scores["Value"])
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height + 0.002,
f'{height:.4f}', ha='center', va='bottom', fontsize=8)
plt.title(f'Average {primary_metric} Across Recommended Tasks')
plt.ylabel(f'Average {primary_metric}')
plt.xticks(rotation=45, ha='right')
current_ylim = plt.ylim()
plt.ylim(current_ylim[0], current_ylim[1] * 1.1)
plt.tight_layout()
summary_plot_dir = os.path.join(plot_dir, "summary")
os.makedirs(summary_plot_dir, exist_ok=True)
summary_plot_path = os.path.join(summary_plot_dir, f'summary_avg_{primary_metric}.png')
try:
plt.savefig(summary_plot_path)
logger.info(f"Saved summary plot: {summary_plot_path}")
except Exception as e:
logger.error(f"Failed to save summary plot {summary_plot_path}: {e}")
plt.close()
def main(args):
# Create output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)
# Models to evaluate (example models)
models = [
"Alibaba-NLP/gte-base-en-v1.5", # Common general purpose baseline
"codesage/codesage-small", # Code specific baseline
# Add other models you want to compare here
# "your-org/your-fine-tuned-model",
]
# Evaluate each model
results_dict = {}
for model_path_or_name in models:
try:
results, model_key = evaluate_model(
model_path_or_name,
args.output_dir,
args.batch_size,
args.save_predictions,
args.overwrite_results
)
# Use the original name/path as the key
results_dict[model_key] = results
except Exception as e:
logger.error(f"Evaluation failed for model {model_path_or_name}: {e}", exc_info=True)
results_dict[model_path_or_name] = None # Indicate failure
# Compare models
if not results_dict:
logger.error("No models were successfully evaluated. Exiting comparison.")
return
df, pivot_df = compare_models(results_dict)
if df.empty and pivot_df.empty:
logger.error("Comparison resulted in empty dataframes. Check evaluation logs.")
return
# Save comparison results (raw data)
comparison_path = os.path.join(args.output_dir, "model_comparison_raw.csv")
try:
df.to_csv(comparison_path, index=False)
logger.info(f"Raw comparison data saved to {comparison_path}")
except Exception as e:
logger.error(f"Failed to save raw comparison data: {e}")
# Save pivot table results
pivot_path = os.path.join(args.output_dir, "model_comparison_pivot.csv")
try:
pivot_df.to_csv(pivot_path, index=False) # index=False as Task/Metric are columns now
logger.info(f"Pivot table comparison saved to {pivot_path}")
except Exception as e:
logger.error(f"Failed to save pivot table comparison: {e}")
# Calculate mean scores across tasks for each metric and model using the numeric 'Value'
# Filter out non-numeric before calculating mean
numeric_df = df.dropna(subset=['Value'])
if not numeric_df.empty and pd.api.types.is_numeric_dtype(numeric_df['Value']):
try:
mean_scores = numeric_df.pivot_table(
values='Value',
index='Metric',
columns='Model',
aggfunc=np.mean # Use numpy mean which handles NaN
).round(4) # Round after aggregation
# Create a pretty table for display
if not pivot_df.empty:
table = tabulate(pivot_df, headers='keys', tablefmt='pretty', showindex=False) # showindex=False as index is reset
logger.info(f"\nModel Comparison Results (Formatted):\n{table}")
else:
logger.warning("Pivot table is empty, cannot display formatted results.")
# Create a pretty table for mean scores
mean_table = tabulate(mean_scores, headers='keys', tablefmt='pretty', showindex=True) # Show metric index
logger.info(f"\nMean Scores Across Tasks:\n{mean_table}")
# Save the tables as text
table_path = os.path.join(args.output_dir, "model_comparison_summary.txt")
with open(table_path, 'w') as f:
if not pivot_df.empty:
f.write("Model Comparison Results (Formatted):\n")
f.write(table)
else:
f.write("Model Comparison Results (Formatted): Pivot table empty.\n")
f.write("\n\nMean Scores Across Tasks:\n")
f.write(mean_table)
logger.info(f"Comparison tables saved to {table_path}")
except Exception as e:
logger.error(f"Failed to calculate or display mean scores: {e}")
logger.error("Mean Scores DataFrame calculation might have failed.")
else:
logger.warning("Could not calculate mean scores: No numeric data available.")
# Save the pivot table text even if mean scores fail
table_path = os.path.join(args.output_dir, "model_comparison_summary.txt")
with open(table_path, 'w') as f:
if not pivot_df.empty:
table = tabulate(pivot_df, headers='keys', tablefmt='pretty', showindex=False)
f.write("Model Comparison Results (Formatted):\n")
f.write(table)
f.write("\n\nMean Scores Across Tasks: Could not be calculated (no numeric data).\n")
else:
f.write("Model Comparison Results (Formatted): Pivot table empty.\n")
f.write("\n\nMean Scores Across Tasks: Could not be calculated (no numeric data).\n")
logger.info(f"Partial comparison table saved to {table_path}")
# Create comparison plots using the raw numeric data
plot_comparison(df, args.output_dir) # Pass the raw DataFrame df
logger.info(f"Evaluation and comparison finished. Results are in {args.output_dir}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Compare multiple code retrieval models on MTEB recommended code tasks")
parser.add_argument("--output_dir", type=str, default="evaluation/mteb_recommended_code_tasks",
help="Directory to save evaluation results")
parser.add_argument("--batch_size", type=int, default=16, # Adjusted default based on common practice
help="Batch size for encoding. Adjust based on GPU memory.")
parser.add_argument("--save_predictions", action="store_true",
help="Whether to save raw predictions (can consume significant disk space)")
parser.add_argument("--overwrite_results", action="store_true",
help="Whether to overwrite existing evaluation results for a model-task pair")
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment