node_classification_with_gnn
Node classification is the problem of finding out the right label for a node based on its neighborsβ labels and structure similarities.
About the query moduleβ
This query module contains all necessary functions you need to train GNN model on Memgraph.
The node_classification
module supports as follows:
- homogeneous and heterogeneous graphs
- multiple-label and multi-edge-type graphs
- any-size datasets
- the following model architectures:
- Graph Attention with Jumping Knowledge
- multiple versions of Graph attention networks (GAT)
- GraphSAGE
- early stopping
- calculation of various metrics
- predictions for specified nodes
- model saving and loading
- recommendation system use cases
The easiest way to test node_classification is by downloading Memgraph Platform
and using some of the preloaded datasets in Memgraph Lab. If you want to explore our implementation, jump to github/memgraph/mage and find
python/node_classification.py
. Feel free to give us a β if you like the code.
Feel free to open a GitHub issue or start a discussion on Discord if you want to speed up development.
Usageβ
Load dataset in Memgraph, call set_model_parameters
, and start training your model. When training is done, query module will save models.
Afterwards, you can test modules on other data (which model has not already seen for example) and inspect the results!
The module reports the mean average precision
for every batch training
or evaluation
epoch.
To summarize basic node classification workflow is as follows:
- load data to Memgraph
- set parameters by calling
set_model_parameters()
function. Be sure that node_features property on nodes are in place. - call
train()
function - inspect training results (optional) by calling
get_training_data()
function - optionally use
save_model()
andload_model()
- predict node class by calling
predict()
procedure
This MAGE module is still in its early stage. We intend to use it only for exploring or learning about node classification. If you want it to be production-ready, make sure to either open a GitHub issue or drop us a comment on Discord.
Proceduresβ
If you want to execute this algorithm on graph projections, subgraphs or portions of the graph, be sure to check out the guide on How to run a MAGE module on subgraphs.
set_model_parameters(params)
β
The function initializes all global variables. You can change global variables via params dictionary. Procedure checks if variables in params are defined appropriately. If so, map of default global parameters is overriden with user defined dictionary params. After that procedure executes previously defined functions declare_globals and declare_model_and_data and sets each global variable to some value.
Input:β
params: (mgp.Map, optional)
: User defined parameters from query module. Defaults to {}.
Name | Type | Default | Description |
---|---|---|---|
hidden_features_size | List[Int] | [16, 16] | Embedding dimension for each node in a new layer. |
layer_type | String | GATJK | Type of layer used, supported types: GATJK , GAT , GRAPHSAGE . |
aggregator | String | mean | Type of aggregator used, supported type: mean . |
learning_rate | Float | 0.1 | Optimizer's learning rate. |
weight_decay | Float | 5e-4 | Optimizer's weight decay. |
split_ratio | Float | 0.8 | Ratio between training and validation data. |
metrics | List[String] | ["loss","accuracy","f1_score","precision","recall","num_wrong_examples"] | List of metrics to report, supports any combination of "loss","accuracy","f1_score","precision","recall","num_wrong_examples". |
node_id_property | String | id | Property name of node features. |
num_epochs | Integer | 100 | The number of epochs for model training. |
console_log_freq | Integer | 5 | Specifies how often results will be printed. |
checkpoint_freq | Integer | 5 | Specifies how often the model will be saved. The model is persisted on disc. |
device_type | String | cpu | Defines if the model will be trained using the cpu or cuda . To run on Cuda GPU , check if the system supports it with torch.cuda.is_available() , then set this flag to cuda . |
path_to_model | String | "/tmp/torch_models" | Path for loading and storing the model. |
Exceptions:β
Exception
: Exception is raised if some variable in dictionary params is not correctly defined.
Outputβ
mgp.Record( hidden_features_size=list, layer_type=str, aggregator=str, learning_rate=float, weight_decay=float, split_ratio=float, metrics=mgp.Any, node_id_property=str, num_epochs=int, console_log_freq=int, checkpoint_freq=int, device_type=str, path_to_model=str, )
β‘ Map of parameters set for training
Usage:β
CALL node_classification.set_model_parameters(
{layer_type: "GATJK", learning_rate: 0.001, hidden_features_size: [16,16], class_name: "fraud", features_name: "embedding"}
) YIELD * RETURN *;
train(num_epochs)
β
This procedure performs model training. Firstly it declares data, model, optimizer, and criterion. Afterward, it performs training.
Inputβ
num_epochs (int, optional)
β‘ Number of epochs (default:100).
Exceptionsβ
Exception
β‘ Raised if graph is empty.
Outputsβ
epoch: int
β‘ Epoch number.loss: float
β‘ Loss of model on training data.val_loss: float
β‘ Loss of model on validation data.train_log: list
β‘ List of metrics on training data.val_log: list
β‘ List of metrics on validation data.
Usageβ
CALL node_classification.train() YIELD * RETURN *;
get_training_data()
β
Use following procedure to get logged data from training.
Return valuesβ
epoch: int
β‘ Epoch number for current record's logged data.loss: float
β‘ Loss in epoch.train_log: mgp.Any
β‘ Training parameters for epoch.val_log: mgp.Any
β‘ Validation parameters for epoch.
Usageβ
CALL node_classification.get_training_data() YIELD * RETURN *;
save_model()
β
This function saves the model to a specified folder. If there are already max_models_to_keep in the folder, the oldest model is deleted.
Exceptionβ
Exception
: Raised if model is not initialized or defined.
Return valuesβ
path (str)
β‘ Path to the stored model.status (str)
β‘ Status of the stored model.
Usageβ
CALL node_classification.save_model() YIELD * RETURN *;
load_model(num)
β
This function loads the model from the specified folder.
Inputβ
num (int, optional)
: Ordinal number of model to load from the default path on the disc (default: 0, i.e., newest model).
Return valuesβ
path: str
β‘ Path of loaded model.
Usageβ
CALL node_classification.load_model() YIELD * RETURN *;
predict(vertex)
β
This function predicts metrics on one node. It is suggested to load the test data (data without labels) as well. Test data won't be a part of the training or validation process.
Inputβ
vertex: mgp.Vertex
β‘ Prediction node.
Return valuesβ
predicted_class: int
β‘ Predicted class for specified node.
Usage:β
MATCH (n {id: 1}) CALL node_classification.predict(n) YIELD * RETURN predicted_value;
reset()
β
This function resets all variables to default values.
Return valuesβ
status (str)
: Status of reset function.
Usage:β
CALL node_classification.reset() YIELD * RETURN *;
Exampleβ
- Step 1: Input graph
- Step 2: Load commands
- Step 3: Set model parameters
- Step 4: Train
- Step 5: Train results
- Step 6: Predict
- Step 7: Predict results
CREATE (v1:PAPER {id: 10, features: [1, 2, 3], label:0});
CREATE (v2:PAPER {id: 11, features: [1.54, 0.3, 1.78], label:0});
CREATE (v3:PAPER {id: 12, features: [0.5, 1, 4.5], label:0});
CREATE (v4:PAPER {id: 13, features: [0.78, 0.234, 1.2], label:0});
CREATE (v5:PAPER {id: 14, features: [3, 4, 100], label:0});
CREATE (v6:PAPER {id: 15, features: [2.1, 2.2, 2.3], label:1});
CREATE (v7:PAPER {id: 16, features: [2.2, 2.3, 2.4], label:1});
CREATE (v8:PAPER {id: 17, features: [2.3, 2.4, 2.5], label:1});
CREATE (v9:PAPER {id: 18, features: [2.4, 2.5, 2.6], label:1});
MATCH (v1:PAPER {id:10}), (v2:PAPER {id:11}) CREATE (v1)-[e:CITES {}]->(v2);
MATCH (v2:PAPER {id:11}), (v3:PAPER {id:12}) CREATE (v2)-[e:CITES {}]->(v3);
MATCH (v3:PAPER {id:12}), (v4:PAPER {id:13}) CREATE (v3)-[e:CITES {}]->(v4);
MATCH (v4:PAPER {id:13}), (v1:PAPER {id:10}) CREATE (v4)-[e:CITES {}]->(v1);
MATCH (v4:PAPER {id:13}), (v5:PAPER {id:14}) CREATE (v4)-[e:CITES {}]->(v5);
MATCH (v5:PAPER {id:14}), (v6:PAPER {id:15}) CREATE (v5)-[e:CITES {}]->(v6);
MATCH (v6:PAPER {id:15}), (v7:PAPER {id:16}) CREATE (v6)-[e:CITES {}]->(v7);
MATCH (v7:PAPER {id:16}), (v8:PAPER {id:17}) CREATE (v7)-[e:CITES {}]->(v8);
MATCH (v8:PAPER {id:17}), (v9:PAPER {id:18}) CREATE (v8)-[e:CITES {}]->(v9);
MATCH (v9:PAPER {id:18}), (v6:PAPER {id:15}) CREATE (v9)-[e:CITES {}]->(v6);
CALL node_classification.set_model_parameters({layer_type: "GAT", learning_rate: 0.001,
hidden_features_size: [2,2],
class_name: "label", features_name: "features", console_log_freq:1}) YIELD *
RETURN *;
CALL node_classification.train(5) YIELD epoch, loss RETURN *;
+----------+----------+
| epoch | loss |
+----------+----------+
| 1 | 0.788709 |
| 2 | 0.765075 |
| 3 | 0.776351 |
| 4 | 0.727615 |
| 5 | 0.727735 |
MATCH (v1:PAPER {id: 10})
CALL node_classification.predict(v1) YIELD predicted_class RETURN predicted_class, v1.label as correct_class;
+-----------------+-----------------+
| predicted_class | correct_class |
+-----------------+-----------------+
| 0 | 0 |
+-----------------+-----------------+