From c50b4759868908570f4b6e258de02c208baa2122 Mon Sep 17 00:00:00 2001
From: "Julien B." <xm9q8f80@jlnbrtn.me>
Date: Sun, 25 Aug 2024 21:34:16 +0200
Subject: [PATCH] fix(inferer+annotator): fix issues

---
 api/internal_services/annotator.py |  2 ++
 api/internal_services/database.py  |  5 +++--
 api/internal_services/gpt.py       |  8 +++-----
 api/models/InfererConfig.py        |  5 +++--
 api/routers/endpoints.py           | 14 +++++++++-----
 microservices/inferer/setup.py     |  2 +-
 6 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/api/internal_services/annotator.py b/api/internal_services/annotator.py
index 35e1eea..34e4c29 100644
--- a/api/internal_services/annotator.py
+++ b/api/internal_services/annotator.py
@@ -71,6 +71,8 @@ def annotation_process(job):
                             get_last_concept_index(),
                             word_id
                         )
+        case 'NONE':
+            logger.error("No annotator configured, please provide an annotator before trying to add sentences.")
 
 
 def separate_intervals(data):
diff --git a/api/internal_services/database.py b/api/internal_services/database.py
index b8a287b..9695e41 100644
--- a/api/internal_services/database.py
+++ b/api/internal_services/database.py
@@ -31,12 +31,13 @@ def update_last_concept_index(value):
 def get_annotator_config():
     result = db.search(where('key') == 'annotator_config')
     if not result:
-        created_object = {'key': 'annotator_config', 'value': { 'provider': 'GPT' }}
+        created_object = {'key': 'annotator_config', 'value': { 'provider': 'NONE' }}
         db.insert(created_object)
-        return 0
+        return created_object['value']
     else:
         return result[0]['value']
 
 def update_annotator_config(value):
+    get_annotator_config()
     db.update({'value': value}, where('key') == 'annotator_config')
     return value
\ No newline at end of file
diff --git a/api/internal_services/gpt.py b/api/internal_services/gpt.py
index a457c38..04f666d 100644
--- a/api/internal_services/gpt.py
+++ b/api/internal_services/gpt.py
@@ -1,14 +1,12 @@
 import json
-import os
-import traceback
-
 from openai import OpenAI
+from api.internal_services import database
 from api.internal_services.logger import logger
 
-client = OpenAI()
-
 def gpt_process(sentence):
     try:
+        annotator_config = database.get_annotator_config()
+        client = OpenAI(api_key=annotator_config['openai_key'])
         completion = client.chat.completions.create(
             model="gpt-4",
             messages=[
diff --git a/api/models/InfererConfig.py b/api/models/InfererConfig.py
index 88b8d58..ee4910c 100644
--- a/api/models/InfererConfig.py
+++ b/api/models/InfererConfig.py
@@ -2,5 +2,6 @@ from pydantic import BaseModel
 
 class InfererConfig(BaseModel):
     provider: str
-    model_id: str
-    server_url: str
\ No newline at end of file
+    model_id: str | None = None
+    server_url: str | None = None
+    openai_key: str | None = None
\ No newline at end of file
diff --git a/api/routers/endpoints.py b/api/routers/endpoints.py
index b1659d4..c609da9 100644
--- a/api/routers/endpoints.py
+++ b/api/routers/endpoints.py
@@ -40,9 +40,13 @@ def add_sentence_to_process(training_body: TrainingBody):
 
 @router.post("/actions/inferer/config")
 def add_sentence_to_process(infererConfig: InfererConfig):
-    new_config = {'model_id': infererConfig.provider}
-    if infererConfig.provider == 'BERT':
-        new_config['model_id'] = infererConfig.model_id
-        new_config['server_url'] = infererConfig.server_url
+    new_config = {'provider': infererConfig.provider}
 
-    database.update_annotator_config({'model_id': 'GPT'})
\ No newline at end of file
+    match infererConfig.provider:
+        case 'BERT':
+            new_config['model_id'] = infererConfig.model_id
+            new_config['server_url'] = infererConfig.server_url
+        case 'GPT':
+            new_config['openai_key'] = infererConfig.openai_key
+
+    database.update_annotator_config(new_config)
\ No newline at end of file
diff --git a/microservices/inferer/setup.py b/microservices/inferer/setup.py
index 35f4ac0..6659f7e 100644
--- a/microservices/inferer/setup.py
+++ b/microservices/inferer/setup.py
@@ -1,7 +1,7 @@
 from setuptools import setup, find_packages
 
 setup(
-    name='ALA-Plateform-trainer',
+    name='ALA-Plateform-inferer',
     version='0.1',
     packages=find_packages(),
     install_requires=[
-- 
GitLab