This commit is contained in:
yufeng zhao 2025-03-16 03:25:32 +00:00
parent a5abe18aa3
commit 78d94e7bbd
3 changed files with 35 additions and 0 deletions

View File

@ -0,0 +1,13 @@
from mmengine.config import read_base
with read_base():
from .groups.bbeh import bbeh_summary_groups
summarizer = dict(
dataset_abbrs=[
['bbeh', 'naive_average'],
['bbeh', 'harmonic_mean']
],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []),
)

View File

@ -203,6 +203,17 @@ class MultiModelSummarizer:
numerator = sum(results[k] * sg['weights'][k] for k in sg['weights'])
denominator = sum(sg['weights'].values())
metric = 'weighted_average'
elif 'harmonic_mean' in sg:
# Check for non-positive values that would cause issues in harmonic mean
if any(results[k] <= 0 for k in results):
self.logger.warning(f'Non-positive values found when calculating harmonic mean for {sg["name"]}')
# Handle non-positive values (either skip or use a small positive value)
numerator = len(results)
denominator = sum(1 / max(results[k], 1e-10) for k in results)
else:
numerator = len(results)
denominator = sum(1 / results[k] for k in results)
metric = 'harmonic_mean'
else:
numerator = sum(results[k] for k in results)
denominator = len(results)

View File

@ -115,6 +115,17 @@ class PretrainSummarizer:
numerator = sum(results[k] * sg['weights'][k] for k in sg['weights'])
denominator = sum(sg['weights'].values())
metric = 'weighted_average'
elif 'harmonic_mean' in sg:
# Check for non-positive values that would cause issues in harmonic mean
if any(results[k] <= 0 for k in results):
self.logger.warning(f'Non-positive values found when calculating harmonic mean for {sg["name"]}')
# Handle non-positive values (either skip or use a small positive value)
numerator = len(results)
denominator = sum(1 / max(results[k], 1e-10) for k in results)
else:
numerator = len(results)
denominator = sum(1 / results[k] for k in results)
metric = 'harmonic_mean'
else:
numerator = sum(results[k] for k in results)
denominator = len(results)