From ce7bd4f8fa2c649ffa6f2bf371d4e8acc1796fcd Mon Sep 17 00:00:00 2001
From: Julien Breton <julien.breton@moncitron.fr>
Date: Sun, 4 Feb 2024 14:42:25 +0900
Subject: [PATCH] change hyper parameters

---
 modules/llm/Mistral-7b/Mistral-7b_fine_tune.py     | 8 ++++----
 modules/llm/Mixtral-8x7b/Mixtral-8x7b_fine_tune.py | 8 ++++----
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/modules/llm/Mistral-7b/Mistral-7b_fine_tune.py b/modules/llm/Mistral-7b/Mistral-7b_fine_tune.py
index 4c9d6ec..a41fc61 100644
--- a/modules/llm/Mistral-7b/Mistral-7b_fine_tune.py
+++ b/modules/llm/Mistral-7b/Mistral-7b_fine_tune.py
@@ -56,13 +56,13 @@ def fine_tune(base_model, new_model):
     # Training Arguments
     # Hyperparameters should beadjusted based on the hardware you using
     training_arguments = TrainingArguments(
-        per_device_train_batch_size=1,
-        gradient_accumulation_steps=4,
-        num_train_epochs=6,
+        per_device_train_batch_size=2,
+        gradient_accumulation_steps=1,
+        num_train_epochs=2,
         learning_rate=1e-4,
         logging_steps=2,
         optim="adamw_torch",
-        save_strategy="epoch",
+        save_strategy="steps",
         output_dir="./results"
     )
 
diff --git a/modules/llm/Mixtral-8x7b/Mixtral-8x7b_fine_tune.py b/modules/llm/Mixtral-8x7b/Mixtral-8x7b_fine_tune.py
index 7f40b7f..d9e8126 100644
--- a/modules/llm/Mixtral-8x7b/Mixtral-8x7b_fine_tune.py
+++ b/modules/llm/Mixtral-8x7b/Mixtral-8x7b_fine_tune.py
@@ -70,13 +70,13 @@ def fine_tuned(base_model, new_model):
         model=model,
         train_dataset=train_data,
         args=TrainingArguments(
-            per_device_train_batch_size=1,
-            gradient_accumulation_steps=4,
-            num_train_epochs=6,
+            per_device_train_batch_size=2,
+            gradient_accumulation_steps=1,
+            num_train_epochs=2,
             learning_rate=1e-4,
             logging_steps=2,
             optim="adamw_torch",
-            save_strategy="epoch",
+            save_strategy="steps",
             output_dir="./results"
         ),
         data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
-- 
GitLab