mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
30 lines
757 B
Python
30 lines
757 B
Python
from typing import Any
|
|
|
|
import torch
|
|
|
|
|
|
class VisualGLMBasePostProcessor:
|
|
"""Base post processor for VisualGLM."""
|
|
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
|
|
return tokenizer.decode(output_token)
|
|
|
|
|
|
class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
|
|
"""VSR post processor for VisualGLM."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
|
|
output_text = tokenizer.decode(output_token)
|
|
if 'yes' in output_text.lower():
|
|
return 'yes'
|
|
elif 'no' in output_text.lower():
|
|
return 'no'
|
|
else:
|
|
return 'unknown'
|