Skip to content
Snippets Groups Projects
Commit 92649128 authored by jbreton's avatar jbreton
Browse files

Change S4 workflow & add mistral7b

parent 2ab683d6
No related branches found
No related tags found
No related merge requests found
......@@ -161,4 +161,6 @@ cython_debug/
temp
models/Mixtral-8x7B-Instruct-v0.1/*
!models/Mixtral-8x7B-Instruct-v0.1/.gitkeep
\ No newline at end of file
!models/Mixtral-8x7B-Instruct-v0.1/.gitkeep
models/Mistral-7B-Instruct-v0.2/*
!models/Mistral-7B-Instruct-v0.2/.gitkeep
from huggingface_hub import snapshot_download
snapshot_download(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", cache_dir="../temp", local_dir="./Mixtral-8x7B-Instruct-v0.1")
\ No newline at end of file
#snapshot_download(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", cache_dir="../temp", local_dir="./Mixtral-8x7B-Instruct-v0.1")
snapshot_download(repo_id="mistralai/Mistral-7B-Instruct-v0.2", cache_dir="../temp", local_dir="./Mistral-7B-Instruct-v0.2")
\ No newline at end of file
import json
from alive_progress import alive_bar
from torch import bfloat16
import torch
import transformers
import bitsandbytes, flash_attn
model_id = "./models/mixtral-8x7b"
model_id = "../../models/Mixtral-8x7B-Instruct-v0.1"
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
......@@ -24,11 +23,12 @@ generate_text = transformers.pipeline(
return_full_text=False, # if using langchain set True
task="text-generation",
# we pass model parameters here too
temperature=1, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
do_sample=True,
temperature=0.5, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
top_p=0.15, # select from top tokens whose probability add up to 15%
top_k=0, # select from top 0 tokens (because zero, relies on top_p)
max_new_tokens=512, # max number of tokens to generate in the output
repetition_penalty=1.1 # if output begins repeating increase
max_new_tokens=4096, # max number of tokens to generate in the output
repetition_penalty=1.0 # if output begins repeating increase
)
inst = """
......@@ -65,42 +65,25 @@ def instruction_format(sys_message: str, query: str):
return f'<s> [INST] {sys_message} [/INST]\nUser: {query}\nAssistant: '
with open('./data/dataQS.json', 'r', encoding='utf-8') as file:
with open('../../data/evalQS.json', 'r', encoding='utf-8') as file:
loaded = json.load(file)
input = []
output = {}
with alive_bar(len(loaded)) as bar:
for sentence in loaded:
try:
input_prompt = instruction_format(inst, sentence)
res = generate_text(input_prompt)
input.append(instruction_format(inst, sentence))
bar()
print("Input creation finished")
jsonOutput = res[0]["generated_text"]
res = generate_text(input)
if "Action" in jsonOutput:
jsonOutput["action"] = jsonOutput.pop("Action")
if "Acteur" in jsonOutput:
jsonOutput["actor"] = jsonOutput.pop("Acteur")
if "Objet" in jsonOutput:
jsonOutput["artifact"] = jsonOutput.pop("Objet")
if "Condition" in jsonOutput:
jsonOutput["condition"] = jsonOutput.pop("Condition")
if "Définition" in jsonOutput:
jsonOutput["definition"] = jsonOutput.pop("Définition")
if "Lieu" in jsonOutput:
jsonOutput["location"] = jsonOutput.pop("Lieu")
if "Modalité" in jsonOutput:
jsonOutput["modality"] = jsonOutput.pop("Modalité")
if "Référence" in jsonOutput:
jsonOutput["reference"] = jsonOutput.pop("Référence")
if "Temps" in jsonOutput:
jsonOutput["time"] = jsonOutput.pop("Temps")
i = 0
for sentence in loaded:
output[sentence] = res[i][0]["generated_text"]
i += 1
output[sentence] = jsonOutput
except Exception as e:
print(f"Erreur avec la phrase : {sentence} | {e}")
bar()
with open('./results/S4/MIXTRAL_answers.json', 'w', encoding='utf-8') as file:
with open('../../results/S4/MIXTRAL_raw_answers.json', 'w', encoding='utf-8') as file:
json.dump(output, file) # in 44:36.6 (0.08/s)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment