Skip to content
Snippets Groups Projects
Commit 2e8ddf31 authored by Julien B.'s avatar Julien B.
Browse files

fix(trainer): change trainer function

parent 7c0e1544
No related branches found
No related tags found
No related merge requests found
......@@ -43,7 +43,6 @@ def start_training(job):
'tags': tags
})
logger.debug(training_data)
with grpc.insecure_channel(job.job_data['server_url']) as channel:
stub = trainer_pb2_grpc.TrainerStub(channel)
request = trainer_pb2.TrainingInput(
......@@ -52,10 +51,9 @@ def start_training(job):
finetuned_repo_name=job.job_data['finetuned_repo_name'],
huggingface_token=job.job_data['huggingface_token'],
)
logger.debug(request)
responses = stub.StartTraining(request)
for response in responses:
logger.debug(f"gRPC message : {response.status}")
logger.debug(f"Incoming gRPC message : {response.status}")
logger.debug(f"fin de la connexion gRPC")
# passer toutes les origines des concepts en BERT
......@@ -18,6 +18,7 @@ from transformers.utils import (
)
import trainer_pb2_grpc
from microservices.trainer import trainer_pb2
is_busy = False
......@@ -43,6 +44,7 @@ class TrainerServicer(trainer_pb2_grpc.TrainerServicer):
is_busy = False
else:
print(f"gRPC server is already busy")
return trainer_pb2.TrainingEvent(status="Training ended successfully !")
return response_messages()
......@@ -57,7 +59,6 @@ def serve():
def training_process(training_data, fondation_model_id, finetuned_repo_name, huggingface_token):
training_data = json.loads(training_data)
print(training_data)
MAX_LENGTH = 256
tag2id = {'action': 1, 'actor': 2, 'artifact': 3, 'condition': 4, 'location': 5, 'modality': 6, 'reference': 7,
'time': 8}
......@@ -241,8 +242,8 @@ def training_process(training_data, fondation_model_id, finetuned_repo_name, hug
)
trainer.train()
trainer.model.push_to_hub(finetuned_repo_name, use_auth_token=huggingface_token)
tokenizer.push_to_hub(finetuned_repo_name, use_auth_token=huggingface_token)
trainer.model.push_to_hub(finetuned_repo_name, token=huggingface_token)
tokenizer.push_to_hub(finetuned_repo_name, token=huggingface_token)
if __name__ == '__main__':
......
......@@ -5,4 +5,11 @@ services:
- "80:80"
restart: unless-stopped
environment:
- PYTHONUNBUFFERED=1
\ No newline at end of file
- PYTHONUNBUFFERED=1
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [ gpu ]
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment