2023-08-21 15:57:30 +08:00
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
2023-08-25 15:44:32 +08:00
|
|
|
class VisualGLMBasePostProcessor:
|
|
|
|
"""Base post processor for VisualGLM."""
|
2023-08-21 15:57:30 +08:00
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
pass
|
|
|
|
|
2023-09-21 19:54:23 +08:00
|
|
|
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
|
|
|
|
return tokenizer.decode(output_token)
|
2023-08-25 15:44:32 +08:00
|
|
|
|
|
|
|
|
|
|
|
class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
|
|
|
|
"""VSR post processor for VisualGLM."""
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
2023-09-21 19:54:23 +08:00
|
|
|
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
|
|
|
|
output_text = tokenizer.decode(output_token)
|
2023-08-25 15:44:32 +08:00
|
|
|
if 'yes' in output_text.lower():
|
|
|
|
return 'yes'
|
|
|
|
elif 'no' in output_text.lower():
|
|
|
|
return 'no'
|
|
|
|
else:
|
|
|
|
return 'unknown'
|