[Fix] Fix Local Runner Params Save Path (#1768)

* update local runner params save dir

* fix remove

* fix directory remove

* Fix *_params.py by uuid4
This commit is contained in:
Junnan Liu 2024-12-19 16:07:34 +08:00 committed by GitHub
parent 9a5adbde6a
commit 499302857f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 8 deletions

View File

@ -147,7 +147,7 @@ class OpenAI(BaseAPIModel):
self.path = path self.path = path
self.max_completion_tokens = max_completion_tokens self.max_completion_tokens = max_completion_tokens
self.logger.warning( self.logger.warning(
f'Max Completion tokens for {path} is :{max_completion_tokens}') f'Max Completion tokens for {path} is {max_completion_tokens}')
def generate(self, def generate(self,
inputs: List[PromptType], inputs: List[PromptType],
@ -594,7 +594,6 @@ class OpenAISDK(OpenAI):
model=self.path, model=self.path,
max_completion_tokens=self.max_completion_tokens, max_completion_tokens=self.max_completion_tokens,
n=1, n=1,
temperature=self.temperature,
messages=messages, messages=messages,
extra_body=self.extra_body, extra_body=self.extra_body,
) )

View File

@ -207,9 +207,14 @@ class LocalRunner(BaseRunner):
task_name = task.name task_name = task.name
pwd = os.getcwd()
# Dump task config to file # Dump task config to file
mmengine.mkdir_or_exist('tmp/') mmengine.mkdir_or_exist('tmp/')
param_file = f'tmp/{os.getpid()}_{index}_params.py' # Using uuid to avoid filename conflict
import uuid
uuid_str = str(uuid.uuid4())
param_file = f'{pwd}/tmp/{uuid_str}_params.py'
try: try:
task.cfg.dump(param_file) task.cfg.dump(param_file)
tmpl = get_command_template(gpu_ids) tmpl = get_command_template(gpu_ids)
@ -236,5 +241,8 @@ class LocalRunner(BaseRunner):
logger.error(f'task {task_name} fail, see\n{out_path}') logger.error(f'task {task_name} fail, see\n{out_path}')
finally: finally:
# Clean up # Clean up
if not self.keep_tmp_file:
os.remove(param_file) os.remove(param_file)
else:
pass
return task_name, result.returncode return task_name, result.returncode