From b291997d4b7d34f8f39bca1571cd1d1fa0b14de1 Mon Sep 17 00:00:00 2001
From: Guilherme Henrique <guihss.cs@gmail.com>
Date: Thu, 21 Sep 2023 15:19:22 +0200
Subject: [PATCH] improve code readability

---
 nlp.py               |  58 ++--
 property_matching.py | 696 ++++++++++++++++++++++++++-----------------
 utils.py             |  93 ++++--
 3 files changed, 507 insertions(+), 340 deletions(-)

diff --git a/nlp.py b/nlp.py
index e118e69..087c1f4 100644
--- a/nlp.py
+++ b/nlp.py
@@ -2,44 +2,36 @@ import nltk
 
 nltk.download('averaged_perceptron_tagger')
 
-def get_core_concept(e1):
-    t1 = nltk.pos_tag(e1)
-    v1 = []
-    sn = False
-    for t in t1:
-        if 'V' in t[1] and len(t[0]) > 4:
 
-            v1.append(t[0])
+def get_core_concept(entity):
+    """
+    Get the core concept of an entity. The core concept is the first verb with length > 4 or the first noun with its
+    adjectives.
+    :param entity: RDFLib entity
+    :return: list of words
+    """
+    tags = nltk.pos_tag(entity)
+    core_concept = []
+    no_name = False
+    for (word, tag) in tags:
+        if 'V' in tag and len(word) > 4:
+            core_concept.append(word)
             break
 
-        if 'N' in t[1] or 'J' in t[1] and not sn:
-            if 'IN' in t[1]:
-                sn = True
+        if 'N' in tag or 'J' in tag and not no_name:
+            if 'IN' in tag:
+                no_name = True
             else:
-                v1.append(t[0])
+                core_concept.append(word)
 
-    return v1
+    return core_concept
 
 
-def get_core_tagged(e1):
-    t1 = nltk.pos_tag(e1)
-    v1 = []
-    sn = False
-    for t in t1:
-        if 'V' in t[1] and len(t[0]) > 4:
-
-            v1.append(t)
-            break
-
-        if 'N' in t[1] or 'J' in t[1] and not sn:
-            if 'IN' in t[1]:
-                sn = True
-            else:
-                v1.append(t)
-
-    return v1
-
-
-def filter_jj(words):
+def filter_adjectives(words):
+    """
+    Filter adjectives from a list of words.
+    :param words: list of words
+    :return: list of words without adjectives
+    """
     tags = nltk.pos_tag(words)
-    return list(map(lambda x: x[0], filter(lambda x: x[1][0] == 'N', tags)))
\ No newline at end of file
+    return list(map(lambda word: word[0], filter(lambda word: word[1][0] == 'N', tags)))
diff --git a/property_matching.py b/property_matching.py
index f3bf11f..ecda9f5 100644
--- a/property_matching.py
+++ b/property_matching.py
@@ -1,6 +1,5 @@
 import numpy as np
-from nlp import filter_jj, get_core_concept
-from transformers import AutoTokenizer
+from nlp import filter_adjectives, get_core_concept
 import torch.nn as nn
 from om.match import onts, aligns
 from om.ont import get_n, tokenize
@@ -12,290 +11,410 @@ from py_stringmatching import SoftTfIdf, JaroWinkler
 from utils import metrics
 from collections import Counter
 from tqdm.auto import tqdm
-import math
 
 
-def get_type_h(e, g, ml=1):
-    if type(e) is Literal:
-        return [e.datatype]
+def get_type_hierarchy(entity, graph, max_depth=1):
+    """
+    Get the type hierarchy of an entity in a graph up to a certain depth.
+    :param entity: RDF entity
+    :param graph: RDFLib graph
+    :param max_depth: maximum depth of the hierarchy.
+    :return: list of types
+    """
+    if type(entity) is Literal:
+        return [entity.datatype]
 
-    tp = g.value(e, DCTERMS.subject)
+    entity_parent = graph.value(entity, DCTERMS.subject)
 
-    if tp is None:
-        return [e]
+    if entity_parent is None:
+        return [entity]
 
-    h = [tp]
+    hierarchy = [entity_parent]
 
-    for _ in range(ml):
-        if g.value(tp, SKOS.broader) is None:
+    for _ in range(max_depth):
+        if graph.value(entity_parent, SKOS.broader) is None:
             break
 
-        tp = g.value(tp, SKOS.broader)
-        h.append(tp)
+        entity_parent = graph.value(entity_parent, SKOS.broader)
+        hierarchy.append(entity_parent)
 
-    return h
+    return hierarchy
 
 
-def most_common_dr_hist(g, ml=1, mh=5):
-    props = set()
-    for s, p, o in g.triples((None, RDF.type, RDF.Property)):
-        props.add(s)
+def most_common_domain_range_pair(graph, max_depth=1, most_common_count=5):
+    """
+    Get the most common domain and range for each property in a graph.
+    :param graph:
+    :param max_depth:
+    :param most_common_count:
+    :return:
+    """
+    properties = set()
+    for s, p, o in graph.triples((None, RDF.type, RDF.Property)):
+        properties.add(s)
 
-    ng = Graph()
+    new_graph = Graph()
 
-    pc = {}
-    for prop in props:
+    pair_counter = {}
+    for prop in properties:
+
+        for s, p, o in graph.triples((None, prop, None)):
+            subject_type = get_type_hierarchy(s, graph, max_depth=max_depth)
+            object_type = get_type_hierarchy(o, graph, max_depth=max_depth)
+
+            if p not in pair_counter:
+                pair_counter[p] = {'domain': Counter(), 'range': Counter()}
 
-        for s, p, o in g.triples((None, prop, None)):
-            st = get_type_h(s, g, ml=ml)
-            ot = get_type_h(o, g, ml=ml)
+            for s in subject_type:
+                pair_counter[p]['domain'][s] += 1
 
-            if p not in pc:
-                pc[p] = {'domain': Counter(), 'range': Counter()}
+            for o in object_type:
+                pair_counter[p]['range'][o] += 1
 
-            for s in st:
-                pc[p]['domain'][s] += 1
+    for pair in pair_counter:
+        count = pair_counter[pair]
+        domain = count['domain'].most_common(most_common_count)
+        rang = count['range'].most_common(most_common_count)
 
-            for o in ot:
-                pc[p]['range'][o] += 1
+        joined_domain = join_entities(domain)
+        joined_range = join_entities(rang)
 
-    for k in pc:
-        c = pc[k]
-        d = c['domain'].most_common(mh)
-        r = c['range'].most_common(mh)
+        new_graph.add((pair, RDFS.domain, URIRef(joined_domain)))
+        new_graph.add((pair, RDFS.range, URIRef(joined_range)))
 
-        jd = '_'.join([x[0].split('/')[-1].split('#')[-1].split(':')[-1] for x in d])
-        jr = '_'.join([x[0].split('/')[-1].split('#')[-1].split(':')[-1] for x in r])
+    new_graph.namespace_manager = graph.namespace_manager
+    return new_graph
 
-        ng.add((k, RDFS.domain, URIRef(jd)))
-        ng.add((k, RDFS.range, URIRef(jr)))
 
-    ng.namespace_manager = g.namespace_manager
-    return ng
+def join_entities(entities):
+    return '_'.join([x[0].split('/')[-1].split('#')[-1].split(':')[-1] for x in entities])
 
 
-def get_type(e, g):
-    if type(e) is Literal:
-        return e.datatype
+def get_type(entity, graph):
+    """
+    Get the type of entity in a graph.
+    :param entity:
+    :param graph:
+    :return:
+    """
+    if type(entity) is Literal:
+        return entity.datatype
 
-    tp = g.value(e, DCTERMS.subject)
+    parent = graph.value(entity, DCTERMS.subject)
 
-    if tp is None:
-        return e
+    if parent is None:
+        return entity
 
-    return tp
+    return parent
 
 
-def most_common_pair(g):
+def most_common_pair(graph):
+    """
+    Get the most common domain and range for each property in a graph.
+    :param graph:
+    :return:
+    """
     props = set()
-    for s, p, o in g.triples((None, RDF.type, RDF.Property)):
+    for s, p, o in graph.triples((None, RDF.type, RDF.Property)):
         props.add(s)
 
-    ng = Graph()
+    new_graph = Graph()
 
-    pc = {}
+    pair_counter = {}
     for prop in props:
 
-        for s, p, o in g.triples((None, prop, None)):
-            st = get_type(s, g)
-            ot = get_type(o, g)
+        for s, p, o in graph.triples((None, prop, None)):
+            subject_type = get_type(s, graph)
+            object_type = get_type(o, graph)
 
-            if p not in pc:
-                pc[p] = Counter()
+            if p not in pair_counter:
+                pair_counter[p] = Counter()
 
-            pc[p][(st, ot)] += 1
+            pair_counter[p][(subject_type, object_type)] += 1
 
-    for k in pc:
-        c = pc[k]
-        d, r = c.most_common()[0][0]
-        ng.add((k, RDFS.domain, d))
-        ng.add((k, RDFS.range, r))
+    for pair in pair_counter:
+        count = pair_counter[pair]
+        domain, rng = count.most_common()[0][0]
+        new_graph.add((pair, RDFS.domain, domain))
+        new_graph.add((pair, RDFS.range, rng))
 
-    ng.namespace_manager = g.namespace_manager
-    return ng
+    new_graph.namespace_manager = graph.namespace_manager
+    return new_graph
 
 
-def most_common_dr(g):
+def most_common_dr(graph):
+    """
+    Get the most common domain and range for each property in a graph.
+    :param graph:
+    :return:
+    """
     props = set()
-    for s, p, o in g.triples((None, RDF.type, RDF.Property)):
+    for s, p, o in graph.triples((None, RDF.type, RDF.Property)):
         props.add(s)
 
-    ng = Graph()
+    new_graph = Graph()
 
-    pc = {}
+    pair_counter = {}
     for prop in props:
 
-        for s, p, o in g.triples((None, prop, None)):
-            st = get_type(s, g)
-            ot = get_type(o, g)
-
-            if p not in pc:
-                pc[p] = {'domain': Counter(), 'range': Counter()}
-
-            pc[p]['domain'][st] += 1
-            pc[p]['range'][ot] += 1
-
-    for k in pc:
-        c = pc[k]
-        d = c['domain'].most_common()[0][0]
-        r = c['range'].most_common()[0][0]
-
-        ng.add((k, RDFS.domain, d))
-        ng.add((k, RDFS.range, r))
-
-    ng.namespace_manager = g.namespace_manager
-    return ng
-
-
-def is_property(e, g):
-    return (e, RDFS.domain, None) in g and (e, RDFS.range, None) in g
-
-
-def get_entity_label_docs(a_entities, g1):
-    slist = []
-    for e in a_entities:
-        slist.append(list(map(str.lower, tokenize(get_n(e, g1)))))
-
-    return slist
-
-
-def flat_fr_chain(e, g):
-    if g.value(e, RDF.rest) == RDF.nil:
-        return [g.value(e, RDF.first)]
+        for s, p, o in graph.triples((None, prop, None)):
+            subject_type = get_type(s, graph)
+            object_type = get_type(o, graph)
+
+            if p not in pair_counter:
+                pair_counter[p] = {'domain': Counter(), 'range': Counter()}
+
+            pair_counter[p]['domain'][subject_type] += 1
+            pair_counter[p]['range'][object_type] += 1
+
+    for pair in pair_counter:
+        count = pair_counter[pair]
+        domain = count['domain'].most_common()[0][0]
+        rng = count['range'].most_common()[0][0]
+
+        new_graph.add((pair, RDFS.domain, domain))
+        new_graph.add((pair, RDFS.range, rng))
+
+    new_graph.namespace_manager = graph.namespace_manager
+    return new_graph
+
+
+def is_property(entity, graph):
+    """
+    Check if an entity is a property in a graph.
+    :param entity:
+    :param graph:
+    :return:
+    """
+    return (entity, RDFS.domain, None) in graph and (entity, RDFS.range, None) in graph
+
+
+def get_entity_label_docs(entities, graph):
+    """
+    Process the labels of a set of entities in a graph. The labels are tokenized and lowercased.
+    :param entities:
+    :param graph:
+    :return:
+    """
+    result = []
+    for entity in entities:
+        result.append(list(map(str.lower, tokenize(get_n(entity, graph)))))
+
+    return result
+
+
+def flat_rdf_list_chain(entity, graph):
+    """
+    Convert and RDF first rest tree to a list.
+    :param entity:
+    :param graph:
+    :return:
+    """
+    if graph.value(entity, RDF.rest) == RDF.nil:
+        return [graph.value(entity, RDF.first)]
     else:
-        return [g.value(e, RDF.first)] + flat_fr_chain(g.value(e, RDF.rest), g)
-
-
-def get_cpe(e, g):
-    cp = list(set(g.predicates(e)).difference({RDF.type}))
-    objs = list(map(lambda x: get_n(x, g), cp + flat_fr_chain(g.value(e, cp[0]), g)))
+        return [graph.value(entity, RDF.first)] + flat_rdf_list_chain(graph.value(entity, RDF.rest), graph)
+
+
+def get_concatenated_predicate_entities(entity, graph):
+    """
+    Get the concatenation of the predicates of an entity in a graph.
+    :param entity:
+    :param graph:
+    :return:
+    """
+    not_type_predicates = list(set(graph.predicates(entity)).difference({RDF.type}))
+    tmp = not_type_predicates + flat_rdf_list_chain(graph.value(entity, not_type_predicates[0]), graph)
+    objs = list(map(lambda x: get_n(x, graph), tmp))
     return '_'.join(objs), len(objs)
 
 
-def is_joinable(e, g):
-    preds = set(g.predicates(e)).difference({RDF.type})
+def is_joinable(entity, graph):
+    """
+    Check if an entity is joinable in a graph. An entity is joinable if it has only one predicate and that predicate is
+    a unionOf predicate.
+    :param entity:
+    :param graph:
+    :return:
+    """
+    preds = set(graph.predicates(entity)).difference({RDF.type})
     return len(preds) == 1 and OWL.unionOf in preds
 
 
-def flat_restriction(e, g):
+def flat_restriction(entity, graph):
+    """
+    Flatten a restriction in a graph.
+    :param entity:
+    :param graph:
+    :return:
+    """
     nodes = []
-    for s, p, o in g.triples((e, None, None)):
+    for s, p, o in graph.triples((entity, None, None)):
         if p == RDF.type:
             continue
         if type(o) is BNode:
-            nodes.extend(flat_restriction(o, g))
+            nodes.extend(flat_restriction(o, graph))
         else:
             nodes.extend([p, o])
 
     return nodes
 
 
-def is_restriction(e, g):
-    return g.value(e, RDF.type) == OWL.Restriction
+def is_restriction(entity, graph):
+    """
+    Check if an entity is a restriction in a graph.
+    :param entity:
+    :param graph:
+    :return:
+    """
+    return graph.value(entity, RDF.type) == OWL.Restriction
+
+
+def join_nodes(nodes, graph):
+    """
+    Join a list of nodes in a graph.
+    :param nodes:
+    :param graph:
+    :return:
+    """
+    return '_'.join(list(map(lambda x: get_n(x, graph), nodes)))
+
+
+def get_gen_docs(graph):
+    """
+    Get the general documents of a graph. The general documents are the labels of the entities in the graph.
+    :param graph: 
+    :return: 
+    """
+    out = []
+    for subject in set(graph.subjects()):
 
+        if type(subject) is BNode:
+            if is_joinable(subject, graph):
+                label, _ = get_concatenated_predicate_entities(subject, graph)
+            elif is_restriction(subject, graph):
+                label = join_nodes(flat_restriction(subject, graph), graph)
+            else:
+                label = get_n(subject, graph)
 
-def join_nodes(nodes, g):
-    return '_'.join(list(map(lambda x: get_n(x, g), nodes)))
+        else:
+            label = get_n(subject, graph)
 
+        tokens = list(map(str.lower, tokenize(label)))
 
-def get_gen_docs(g1):
-    out = []
-    for e in set(g1.subjects()):
+        if is_property(subject, graph):
 
-        if type(e) is BNode:
-            if is_joinable(e, g1):
-                label, _ = get_cpe(e, g1)
-            elif is_restriction(e, g1):
-                label = join_nodes(flat_restriction(e, g1), g1)
-            else:
-                label = get_n(e, g1)
+            domain_sentence = get_predicate_sentence(subject, RDFS.domain, graph)
 
+            renge_sentence = get_predicate_sentence(subject, RDFS.range, graph)
+
+            out.append(' '.join(tokens + renge_sentence + domain_sentence))
         else:
-            label = get_n(e, g1)
-
-        ns = list(map(str.lower, tokenize(label)))
-
-        if is_property(e, g1):
-
-            ds = []
-            if (e, RDFS.domain, None) in g1:
-                domain = g1.value(e, RDFS.domain)
-                if type(domain) is BNode and is_joinable(domain, g1):
-                    dn, _ = get_cpe(domain, g1)
-                else:
-                    dn = get_n(domain, g1)
-                ds = list(map(str.lower, tokenize(dn)))
-
-            rs = []
-            if (e, RDFS.range, None) in g1:
-                rg = g1.value(e, RDFS.range)
-                if type(rg) is BNode and is_joinable(rg, g1):
-                    rn, _ = get_cpe(rg, g1)
-                else:
-                    rn = get_n(rg, g1)
-                rs = list(map(str.lower, tokenize(rn)))
-
-            out.append(' '.join(ns + rs + ds))
-        else:
-            out.append(' '.join(ns))
+            out.append(' '.join(tokens))
 
     return out
 
 
-def cosine_similarity(v1, v2):
-    if np.linalg.norm(v1) * np.linalg.norm(v2) == 0:
+def get_predicate_sentence(subject, predicate, graph):
+    """
+    Get the predicate sentence of a subject in a graph.
+    :param subject: 
+    :param predicate: 
+    :param graph: 
+    :return: 
+    """
+    domain_sentence = []
+    if (subject, predicate, None) in graph:
+        domain = graph.value(subject, predicate)
+        if type(domain) is BNode and is_joinable(domain, graph):
+            dn, _ = get_concatenated_predicate_entities(domain, graph)
+        else:
+            dn = get_n(domain, graph)
+        domain_sentence = list(map(str.lower, tokenize(dn)))
+
+    return domain_sentence
+
+
+def cosine_similarity(vector1, vector2):
+    """
+    Compute the cosine similarity between two vectors.
+    :param vector1: 
+    :param vector2: 
+    :return: 
+    """
+    if np.linalg.norm(vector1) * np.linalg.norm(vector2) == 0:
         return 0
-    return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
+    return np.dot(vector1, vector2) / (np.linalg.norm(vector1) * np.linalg.norm(vector2))
 
 
-def get_document_similarity(domain_a, domain_b, m):
+def get_document_similarity(domain_a, domain_b, models):
+    """
+    Compute the document similarity between two domains.
+    :param domain_a: 
+    :param domain_b: 
+    :param models: 
+    :return: 
+    """
     if len(domain_a) <= 0 or len(domain_b) <= 0:
 
         domain_conf_a = 0
         domain_conf_b = 0
     else:
-        v = m[1].transform([' '.join(domain_a), ' '.join(domain_b)])
-        v = v.toarray()
-        domain_conf_a = cosine_similarity(v[0], v[1])
-        v = m[1].transform([' '.join(domain_b), ' '.join(domain_a)])
-        v = v.toarray()
-        domain_conf_b = cosine_similarity(v[0], v[1])
+        vector = models[1].transform([' '.join(domain_a), ' '.join(domain_b)])
+        vector = vector.toarray()
+        domain_conf_a = cosine_similarity(vector[0], vector[1])
+        vector = models[1].transform([' '.join(domain_b), ' '.join(domain_a)])
+        vector = vector.toarray()
+        domain_conf_b = cosine_similarity(vector[0], vector[1])
 
     return domain_conf_a, domain_conf_b
 
 
-def get_prop(e, g, p):
-    s = []
-    objs = list(g.objects(e, p))
-    objc = len(objs)
-    for d in objs:
-        if type(d) is BNode:
-            if is_joinable(d, g):
-                name, oc = get_cpe(d, g)
-                objc += oc - 1
-            elif is_restriction(d, g):
-                name = join_nodes(flat_restriction(d, g), g)
+def get_property_sentence(entity, graph, predicate):
+    """
+    Get the property sentence of an entity in a graph.
+    :param entity: 
+    :param graph: 
+    :param predicate: 
+    :return: 
+    """
+    sentence = []
+    objects = list(graph.objects(entity, predicate))
+    object_count = len(objects)
+    for obj in objects:
+        if type(obj) is BNode:
+            if is_joinable(obj, graph):
+                name, oc = get_concatenated_predicate_entities(obj, graph)
+                object_count += oc - 1
+            elif is_restriction(obj, graph):
+                name = join_nodes(flat_restriction(obj, graph), graph)
             else:
-                name = get_n(d, g)
+                name = get_n(obj, graph)
         else:
-            name = get_n(d, g)
-        s.extend(map(str.lower, tokenize(name)))
+            name = get_n(obj, graph)
+        sentence.extend(map(str.lower, tokenize(name)))
 
-    return s, objc
+    return sentence, object_count
 
 
-def build_tf_models(o1, o2):
-    a_entities = set(filter(lambda x: is_property(x, o1), o1.subjects()))
-    b_entities = set(filter(lambda x: is_property(x, o2), o2.subjects()))
+def build_tf_models(ontology1, ontology2):
+    """
+    Build the tf models for the soft tf-idf and the general tf-idf.
+    :param ontology1: 
+    :param ontology2: 
+    :return: 
+    """
+    properties1 = set(filter(lambda x: is_property(x, ontology1), ontology1.subjects()))
+    properties2 = set(filter(lambda x: is_property(x, ontology2), ontology2.subjects()))
 
-    slist = get_entity_label_docs(a_entities, o1) + get_entity_label_docs(b_entities, o2)
-    soft_metric = SoftTfIdf(slist, sim_func=JaroWinkler().get_raw_score, threshold=0.8)
+    sentences_list = get_entity_label_docs(properties1, ontology1) + get_entity_label_docs(properties2, ontology2)
+    soft_metric = SoftTfIdf(sentences_list, sim_func=JaroWinkler().get_raw_score, threshold=0.8)
 
-    qlist = get_gen_docs(o1) + get_gen_docs(o2)
+    document_list = get_gen_docs(ontology1) + get_gen_docs(ontology2)
 
     general_metric = TfidfVectorizer()
 
-    general_metric.fit(qlist)
+    general_metric.fit(document_list)
 
     return soft_metric, general_metric
 
@@ -307,30 +426,43 @@ class PropertyMatcher:
         self.class_model = class_model
         self.sentence_model = sentence_model
 
-    def match_property(self, e1, e2, g1, g2, m, ds, sim_weights=None, disable_dr=False):
-
-        exact_label_a = list(map(str.lower, tokenize(get_n(e1, g1))))
-
-        domain_a, dca = get_prop(e1, g1, RDFS.domain)
-        range_a, rca = get_prop(e1, g1, RDFS.range)
+    def match_property(self, entity1, entity2, graph1, graph2, models, confidence_map, similarity_weights=None,
+                       disable_domain_range_similarity=False):
+        """
+        Match two properties in two graphs. The matching is done by comparing the labels, domains and ranges of the
+        properties.
+        :param entity1: 
+        :param entity2: 
+        :param graph1: 
+        :param graph2: 
+        :param models: 
+        :param confidence_map: 
+        :param similarity_weights: 
+        :param disable_domain_range_similarity: 
+        :return: Confidence of the match.
+        """
+        exact_label_a = list(map(str.lower, tokenize(get_n(entity1, graph1))))
+
+        domain_a, dca = get_property_sentence(entity1, graph1, RDFS.domain)
+        range_a, rca = get_property_sentence(entity1, graph1, RDFS.range)
 
         if len(range_a) == 1 and exact_label_a[-1] == range_a[0]:
             exact_label_a.pop(-1)
 
         string_a = get_core_concept(exact_label_a)
 
-        exact_label_b = list(map(str.lower, tokenize(get_n(e2, g2))))
+        exact_label_b = list(map(str.lower, tokenize(get_n(entity2, graph2))))
 
-        domain_b, dcb = get_prop(e2, g2, RDFS.domain)
-        range_b, rcb = get_prop(e2, g2, RDFS.range)
+        domain_b, dcb = get_property_sentence(entity2, graph2, RDFS.domain)
+        range_b, rcb = get_property_sentence(entity2, graph2, RDFS.range)
 
         if len(range_b) == 1 and exact_label_b[-1] == range_b[0]:
             exact_label_b.pop(-1)
 
         string_b = get_core_concept(exact_label_b)
 
-        range_a = filter_jj(range_a)
-        range_b = filter_jj(range_b)
+        range_a = filter_adjectives(range_a)
+        range_b = filter_adjectives(range_b)
 
         if exact_label_a == exact_label_b:
             label_conf_a = 1
@@ -340,11 +472,11 @@ class PropertyMatcher:
             label_conf_a = 0
             label_conf_b = 0
         else:
-            label_conf_a = m[0].get_raw_score(string_a, string_b)
-            label_conf_b = m[0].get_raw_score(string_b, string_a)
+            label_conf_a = models[0].get_raw_score(string_a, string_b)
+            label_conf_b = models[0].get_raw_score(string_b, string_a)
 
-        domain_conf_a, domain_conf_b = get_document_similarity(domain_a, domain_b, m)
-        range_conf_a, range_conf_b = get_document_similarity(range_a, range_b, m)
+        domain_conf_a, domain_conf_b = get_document_similarity(domain_a, domain_b, models)
+        range_conf_a, range_conf_b = get_document_similarity(range_a, range_b, models)
 
         label_confidence = (label_conf_a + label_conf_b) / 2
         domain_confidence = (domain_conf_a + domain_conf_b) / 2
@@ -354,51 +486,63 @@ class PropertyMatcher:
             if len(domain_a) == 1 and len(domain_b) == 1:
                 domain_confidence = self.class_model.sim(domain_a[0], domain_b[0])
 
-        dsp = (g1.value(e1, RDFS.domain), g2.value(e2, RDFS.domain))
-        if dsp in ds:
-            domain_confidence += ds[dsp]
+        dsp = (graph1.value(entity1, RDFS.domain), graph2.value(entity2, RDFS.domain))
+        if dsp in confidence_map:
+            domain_confidence += confidence_map[dsp]
 
-        rsp = (g1.value(e1, RDFS.range), g2.value(e2, RDFS.range))
-        if rsp in ds:
-            range_confidence += ds[rsp]
+        rsp = (graph1.value(entity1, RDFS.range), graph2.value(entity2, RDFS.range))
+        if rsp in confidence_map:
+            range_confidence += confidence_map[rsp]
 
-        if disable_dr:
+        if disable_domain_range_similarity:
             domain_confidence = 0
             range_confidence = 0
-            sim_weights = [1]
+            similarity_weights = [1]
 
         if domain_confidence > 0.95 and range_confidence > 0.95 and label_confidence < 0.1:
             if len(string_a) <= 1 and len(string_b) <= 1:
-                sr = [' '.join(domain_a + list(map(str.lower, tokenize(get_n(e1, g1)))) + range_a)]
-                tg = [' '.join(domain_b + list(map(str.lower, tokenize(get_n(e2, g2)))) + range_b)]
-                e1 = self.sentence_model.encode(sr, convert_to_tensor=True)
-                e2 = self.sentence_model.encode(tg, convert_to_tensor=True)
-                sim = nn.functional.cosine_similarity(e1, e2).item()
+                sr = [' '.join(domain_a + list(map(str.lower, tokenize(get_n(entity1, graph1)))) + range_a)]
+                tg = [' '.join(domain_b + list(map(str.lower, tokenize(get_n(entity2, graph2)))) + range_b)]
+                entity1 = self.sentence_model.encode(sr, convert_to_tensor=True)
+                entity2 = self.sentence_model.encode(tg, convert_to_tensor=True)
+                sim = nn.functional.cosine_similarity(entity1, entity2).item()
                 if sim < 0.8:
                     sim = 0
                 label_confidence = sim
 
-        if sim_weights:
+        if similarity_weights:
             conf = []
-            if 0 in sim_weights:
+            if 0 in similarity_weights:
                 conf.append(domain_confidence)
-            if 1 in sim_weights:
+            if 1 in similarity_weights:
                 conf.append(label_confidence)
 
-            if 2 in sim_weights:
+            if 2 in similarity_weights:
                 conf.append(range_confidence)
         else:
             conf = [label_confidence, domain_confidence, range_confidence]
         return min(conf)
 
-    def match(self, base, ref, th=0.65, process_strategy=None, sim_weights=None, steps=2, disable_dr=False, tr=None):
+    def match(self, base, ref, threshold=0.65, process_strategy=None, sim_weights=None, steps=2, disable_dr=False, start_metrics=None):
+        """
+        Match ontologies in a folder according to a reference alignment.
+        :param base: Path to the ontologies.
+        :param ref: Path to the reference alignments.
+        :param threshold: 
+        :param process_strategy: 
+        :param sim_weights: 
+        :param steps: 
+        :param disable_dr: 
+        :param start_metrics:
+        :return: 
+        """
         correct = 0
         pred = 0
         total = 0
         iterations = 0
 
-        if tr is not None:
-            trm = [[0, 0] for _ in tr]
+        if start_metrics is not None:
+            total_metrics = [[0, 0] for _ in start_metrics]
 
         for r, k1, k2 in tqdm(list(onts(base, ref))):
 
@@ -425,18 +569,11 @@ class PropertyMatcher:
                     current_total += 1
                     pa.add((a1, a2))
 
-                    # d1 = o1.value(a1, RDFS.domain)
-                    # d2 = o2.value(a2, RDFS.domain)
-                    #
-                    # r1 = o1.value(a1, RDFS.range)
-                    # r2 = o2.value(a2, RDFS.range)
-                    #
-                    # print(colored('#', 'blue'), get_n(d1, o1), get_n(a1, o1), get_n(r1, o1), colored('<>', 'green'),
-                    #       get_n(d2, o2), get_n(a2, o2), get_n(r2, o2))
+
             print(current_total)
             a_entities = set(filter(lambda x: is_property(x, o1), o1.subjects()))
             b_entities = set(filter(lambda x: is_property(x, o2), o2.subjects()))
-            p, it = self.match_ontologies(o1, o2, th, sim_weights=sim_weights, steps=steps, disable_dr=disable_dr)
+            p, it = self.match_ontologies(o1, o2, threshold, sim_weights=sim_weights, steps=steps, disable_dr=disable_dr)
             iterations += it
             oi = it
             current_pred = len(p)
@@ -444,51 +581,49 @@ class PropertyMatcher:
             pred += len(p)
             correct += len(pa.intersection(set(p.keys())))
 
-            if tr is not None:
-                for i, t in enumerate(tr):
+            if start_metrics is not None:
+                for i, t in enumerate(start_metrics):
                     cp = set()
                     for pair, sim in p.items():
                         if sim >= t:
                             cp.add(pair)
 
-                    trm[i][0] += len(pa.intersection(cp))
-                    trm[i][1] += len(cp)
+                    total_metrics[i][0] += len(pa.intersection(cp))
+                    total_metrics[i][1] += len(cp)
                     print(
                         f'ontology iterations: {oi}, {metrics(len(pa.intersection(cp)), len(cp), current_total)}, aligns: {current_total}, po1: {len(a_entities)}, po2: {len(b_entities)}')
 
-            # for a1, a2 in pa.intersection(p):
-            #     print(colored('✓', 'green'), get_n(a1, o1), get_n(a2, o2))
-            #
-            # for a1, a2 in p.difference(pa):
-            #     d1 = o1.value(a1, RDFS.domain)
-            #     d2 = o2.value(a2, RDFS.domain)
-            #
-            #     r1 = o1.value(a1, RDFS.range)
-            #     r2 = o2.value(a2, RDFS.range)
-            #     print(colored('X', 'red'), get_n(d1, o1), get_n(a1, o1), get_n(r1, o1), colored('<>', 'green'),
-            #           get_n(d2, o2), get_n(a2, o2), get_n(r2, o2))
 
             print(
                 f'ontology iterations: {oi}, {metrics(current_correct, current_pred, current_total)}, aligns: {current_total}, po1: {len(a_entities)}, po2: {len(b_entities)}')
 
         print(f'iterations: {iterations}, {metrics(correct, pred, total)}')
-        if tr is not None:
+        if start_metrics is not None:
             res = []
-            for q, w in trm:
+            for q, w in total_metrics:
                 res.append(metrics(q, w, total))
 
             return res
 
         return metrics(correct, pred, total)
 
-    def match_ontologies(self, o1, o2, th, sim_weights=None, steps=2, disable_dr=False):
-
+    def match_ontologies(self, o1, o2, threshold, sim_weights=None, steps=2, disable_dr=False):
+        """
+        Match two ontologies.
+        :param o1:
+        :param o2:
+        :param threshold:
+        :param sim_weights:
+        :param steps:
+        :param disable_dr:
+        :return:
+        """
         soft_metric, general_metric = build_tf_models(o1, o2)
-        p = {}
+        final_alignment = {}
 
-        ds = {}
+        confidence_map = {}
 
-        pm = {}
+        property_map = {}
 
         iterations = 0
         for step in range(steps):
@@ -499,43 +634,44 @@ class PropertyMatcher:
                     if not is_property(e2, o2):
                         continue
 
-                    sim = self.match_property(e1, e2, o1, o2, (soft_metric, general_metric), ds,
-                                              sim_weights=sim_weights, disable_dr=disable_dr)
+                    sim = self.match_property(e1, e2, o1, o2, (soft_metric, general_metric), confidence_map,
+                                              similarity_weights=sim_weights,
+                                              disable_domain_range_similarity=disable_dr)
 
                     iterations += 1
-                    if sim <= th:
+                    if sim <= threshold:
                         continue
 
-                    if e1 in pm:
-                        if pm[e1][1] >= sim:
+                    if e1 in property_map:
+                        if property_map[e1][1] >= sim:
                             continue
-                        elif pm[e1][1] < sim:
-                            p.pop((e1, pm[e1][0]))
-                            pm.pop(pm[e1][0])
-                            pm.pop(e1)
+                        elif property_map[e1][1] < sim:
+                            final_alignment.pop((e1, property_map[e1][0]))
+                            property_map.pop(property_map[e1][0])
+                            property_map.pop(e1)
 
-                    if e2 in pm:
-                        if pm[e2][1] >= sim:
+                    if e2 in property_map:
+                        if property_map[e2][1] >= sim:
                             continue
-                        elif pm[e2][1] < sim:
-                            p.pop((pm[e2][0], e2))
-                            pm.pop(pm[e2][0])
-                            pm.pop(e2)
+                        elif property_map[e2][1] < sim:
+                            final_alignment.pop((property_map[e2][0], e2))
+                            property_map.pop(property_map[e2][0])
+                            property_map.pop(e2)
 
                     d1 = o1.value(e1, RDFS.domain)
                     d2 = o2.value(e2, RDFS.domain)
-                    ds[(d1, d2)] = 0.66
-                    p[(e1, e2)] = sim
-                    pm[e1] = (e2, sim)
-                    pm[e2] = (e1, sim)
+                    confidence_map[(d1, d2)] = 0.66
+                    final_alignment[(e1, e2)] = sim
+                    property_map[e1] = (e2, sim)
+                    property_map[e2] = (e1, sim)
                     if (e1, OWL.inverseOf, None) in o1 and (e2, OWL.inverseOf, None) in o2:
                         d1 = o1.value(o1.value(e1, OWL.inverseOf), RDFS.domain)
                         d2 = o2.value(o2.value(e2, OWL.inverseOf), RDFS.domain)
 
-                        ds[(d1, d2)] = 0.66
+                        confidence_map[(d1, d2)] = 0.66
                         iv1, iv2 = o1.value(e1, OWL.inverseOf), o2.value(e2, OWL.inverseOf)
-                        p[(iv1, iv2)] = sim
-                        pm[iv1] = (iv2, sim)
-                        pm[iv2] = (iv1, sim)
+                        final_alignment[(iv1, iv2)] = sim
+                        property_map[iv1] = (iv2, sim)
+                        property_map[iv2] = (iv1, sim)
 
-        return p, iterations
+        return final_alignment, iterations
diff --git a/utils.py b/utils.py
index f371be5..8c4446e 100644
--- a/utils.py
+++ b/utils.py
@@ -10,42 +10,81 @@ def metrics(correct, tries, total):
     return precision, recall, fm
 
 
-def gn(e, g):
-    if type(e) is str:
-        e = Literal(e)
-    ns = get_n(e, g)
+def get_name(entity, graph):
+    if type(entity) is str:
+        entity = Literal(entity)
+    name = get_n(entity, graph)
 
-    if ns.startswith('//'):
-        ns = e.split('http://yago-knowledge.org/resource/')[-1]
+    if name.startswith('//'):
+        name = entity.split('http://yago-knowledge.org/resource/')[-1]
 
-    return ns
+    return name
 
 
+def pad_encode(sentences, word_map):
+    """
+    Encodes a list of sentences into a padded tensor of integer values using a word mapping.
 
-def pad_encode(s, wm):
-    l1 = []
+    Example:
+        >>> word_map = {
+        ...     'I': 1,
+        ...     'love': 2,
+        ...     'coding': 3,
+        ...     'Python': 4,
+        ...     'great': 5,
+        ...     'fun': 6,
+        ...     'is': 7,
+        ... }
+        >>> sentences = ["I love coding Python", "Python is great", "Coding is fun"]
+        >>> encoded_sentences = pad_encode(sentences, word_map)
+        >>> print(encoded_sentences)
+        tensor([[1, 2, 3, 4],
+                [4, 7, 5, 0],
+                [3, 7, 6, 0]])
+
+    :param sentences: A list of input sentences to be encoded into tensors.
+    :param word_map: A dictionary mapping words to their corresponding integer representations.
+    :return: A tensor containing the padded and encoded sentences, where each sentence is represented
+        as a list of integers. The tensor has dimensions (num_sentences, max_sentence_length), where
+        num_sentences is the number of input sentences, and max_sentence_length is the length of the longest
+        sentence in terms of the number of words.
+    """
+    sentence_list = []
     max_len = -1
-    for q in s:
-        w = list(map(lambda q: wm[q], q.split()))
-        if len(w) > max_len:
-            max_len = len(w)
-        l1.append(w)
+    for sentence in sentences:
+        sentence = list(map(lambda word: word_map[word], sentence.split()))
+        if len(sentence) > max_len:
+            max_len = len(sentence)
+        sentence_list.append(sentence)
+
+    padded_sentences = []
+    for sentence in sentence_list:
+        padded_sentences.append(sentence + [0] * (max_len - len(sentence)))
 
-    nl1 = []
-    for w in l1:
-        nl1.append(w + [0] * (max_len - len(w)))
+    return torch.LongTensor(padded_sentences)
 
-    return torch.LongTensor(nl1)
 
+def emb_average(sentence_ids, model):
+    """
+    Calculates the average word embedding for a list of sentences using a given model.
 
-def emb_average(ids, emb):
-    xe = torch.cat(list(map(lambda q: q.unsqueeze(0), ids)))
-    xem = emb(xe).sum(dim=1)
-    cf = torch.sum((xe != 0).float(), dim=1).unsqueeze(1)
-    cf[cf == 0] = 1
-    return xem / cf
+    :param sentence_ids: (list of torch.Tensor): A list of tensors representing sentences with word embeddings.
+    :param model: (torch.nn.Module): A neural network model that can compute embeddings for input sentences.
+    :return: A tensor representing the average word embedding for each input sentence.
+    """
+    unsqueezed_sentence = torch.cat(list(map(lambda embedding: embedding.unsqueeze(0), sentence_ids)))
+    embedding_sum = model(unsqueezed_sentence).sum(dim=1)
+    non_zero_embeddings = torch.sum((unsqueezed_sentence != 0).float(), dim=1).unsqueeze(1)
+    non_zero_embeddings[non_zero_embeddings == 0] = 1
+    return embedding_sum / non_zero_embeddings
 
 
-def calc_acc(pred, cty):
-    acc = (torch.LongTensor(pred) == cty).float().sum() / cty.shape[0]
-    return acc.item()
\ No newline at end of file
+def calc_acc(predicted, correct):
+    """
+    Calculates the accuracy of a model's predictions.
+    :param predicted: A list of predicted labels.
+    :param correct:  A list of correct labels.
+    :return: The accuracy of the model's predictions.
+    """
+    acc = (torch.LongTensor(predicted) == correct).float().sum() / correct.shape[0]
+    return acc.item()
-- 
GitLab