diff --git a/callbacks.py b/callbacks.py index 8213b6102a7d722b53a0de519799f34c502c1418..18d5e4911eb33f625892413868794c6c6d15d8bf 100644 --- a/callbacks.py +++ b/callbacks.py @@ -12,6 +12,7 @@ from utils import extract_data from pages.application.RandomForest.utils import xrf from pages.application.RandomForest.utils.xrf import * + sys.modules['xrf'] = xrf from sklearn.ensemble._voting import VotingClassifier @@ -255,7 +256,7 @@ def register_callbacks(app): try: if ml_type is None: return warning_selection_model, None, {}, {} - elif pretrained_model is None : + elif pretrained_model is None: return warning_selection_pretrained_model, None, {}, {} else: pretrained_model = parse_contents_graph(pretrained_model, model_filename) @@ -337,11 +338,11 @@ def register_callbacks(app): if ml_type == "RandomForest": pretrained_model = parse_contents_graph(pretrained_model, model_filename) if isinstance(pretrained_model, xrf.rndmforest.RF2001): - return int(pretrained_model.forest.n_estimators) + return int(pretrained_model.forest.n_estimators) - 1 elif isinstance(pretrained_model, RandomForestClassifier): - return pretrained_model.n_estimators + return pretrained_model.n_estimators - 1 elif isinstance(pretrained_model, VotingClassifier): - return len(pretrained_model.estimators) + return len(pretrained_model.estimators) - 1 else: return 0