torch-geometric

安装量: 154
排名: #5598

安装

npx skills add https://github.com/davila7/claude-code-templates --skill torch-geometric

PyTorch Geometric (PyG) Overview

PyTorch Geometric is a library built on PyTorch for developing and training Graph Neural Networks (GNNs). Apply this skill for deep learning on graphs and irregular structures, including mini-batch processing, multi-GPU training, and geometric deep learning applications.

When to Use This Skill

This skill should be used when working with:

Graph-based machine learning: Node classification, graph classification, link prediction Molecular property prediction: Drug discovery, chemical property prediction Social network analysis: Community detection, influence prediction Citation networks: Paper classification, recommendation systems 3D geometric data: Point clouds, meshes, molecular structures Heterogeneous graphs: Multi-type nodes and edges (e.g., knowledge graphs) Large-scale graph learning: Neighbor sampling, distributed training Quick Start Installation uv pip install torch_geometric

For additional dependencies (sparse operations, clustering):

uv pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html

Basic Graph Creation import torch from torch_geometric.data import Data

Create a simple graph with 3 nodes

edge_index = torch.tensor([[0, 1, 1, 2], # source nodes [1, 0, 2, 1]], dtype=torch.long) # target nodes x = torch.tensor([[-1], [0], [1]], dtype=torch.float) # node features

data = Data(x=x, edge_index=edge_index) print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")

Loading a Benchmark Dataset from torch_geometric.datasets import Planetoid

Load Cora citation network

dataset = Planetoid(root='/tmp/Cora', name='Cora') data = dataset[0] # Get the first (and only) graph

print(f"Dataset: {dataset}") print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}") print(f"Features: {data.num_node_features}, Classes: {dataset.num_classes}")

Core Concepts Data Structure

PyG represents graphs using the torch_geometric.data.Data class with these key attributes:

data.x: Node feature matrix [num_nodes, num_node_features] data.edge_index: Graph connectivity in COO format [2, num_edges] data.edge_attr: Edge feature matrix [num_edges, num_edge_features] (optional) data.y: Target labels for nodes or graphs data.pos: Node spatial positions [num_nodes, num_dimensions] (optional) Custom attributes: Can add any attribute (e.g., data.train_mask, data.batch)

Important: These attributes are not mandatory—extend Data objects with custom attributes as needed.

Edge Index Format

Edges are stored in COO (coordinate) format as a [2, num_edges] tensor:

First row: source node indices Second row: target node indices

Edge list: (0→1), (1→0), (1→2), (2→1)

edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)

Mini-Batch Processing

PyG handles batching by creating block-diagonal adjacency matrices, concatenating multiple graphs into one large disconnected graph:

Adjacency matrices are stacked diagonally Node features are concatenated along the node dimension A batch vector maps each node to its source graph No padding needed—computationally efficient from torch_geometric.loader import DataLoader

loader = DataLoader(dataset, batch_size=32, shuffle=True) for batch in loader: print(f"Batch size: {batch.num_graphs}") print(f"Total nodes: {batch.num_nodes}") # batch.batch maps nodes to graphs

Building Graph Neural Networks Message Passing Paradigm

GNNs in PyG follow a neighborhood aggregation scheme:

Transform node features Propagate messages along edges Aggregate messages from neighbors Update node representations Using Pre-Built Layers

PyG provides 40+ convolutional layers. Common ones include:

GCNConv (Graph Convolutional Network):

from torch_geometric.nn import GCNConv import torch.nn.functional as F

class GCN(torch.nn.Module): def init(self, num_features, num_classes): super().init() self.conv1 = GCNConv(num_features, 16) self.conv2 = GCNConv(16, num_classes)

def forward(self, data):
    x, edge_index = data.x, data.edge_index
    x = self.conv1(x, edge_index)
    x = F.relu(x)
    x = F.dropout(x, training=self.training)
    x = self.conv2(x, edge_index)
    return F.log_softmax(x, dim=1)

GATConv (Graph Attention Network):

from torch_geometric.nn import GATConv

class GAT(torch.nn.Module): def init(self, num_features, num_classes): super().init() self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6) self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False, dropout=0.6)

def forward(self, data):
    x, edge_index = data.x, data.edge_index
    x = F.dropout(x, p=0.6, training=self.training)
    x = F.elu(self.conv1(x, edge_index))
    x = F.dropout(x, p=0.6, training=self.training)
    x = self.conv2(x, edge_index)
    return F.log_softmax(x, dim=1)

GraphSAGE:

from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module): def init(self, num_features, num_classes): super().init() self.conv1 = SAGEConv(num_features, 64) self.conv2 = SAGEConv(64, num_classes)

def forward(self, data):
    x, edge_index = data.x, data.edge_index
    x = self.conv1(x, edge_index)
    x = F.relu(x)
    x = F.dropout(x, training=self.training)
    x = self.conv2(x, edge_index)
    return F.log_softmax(x, dim=1)

Custom Message Passing Layers

For custom layers, inherit from MessagePassing:

from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degree

class CustomConv(MessagePassing): def init(self, in_channels, out_channels): super().init(aggr='add') # "add", "mean", or "max" self.lin = torch.nn.Linear(in_channels, out_channels)

def forward(self, x, edge_index):
    # Add self-loops to adjacency matrix
    edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

    # Transform node features
    x = self.lin(x)

    # Compute normalization
    row, col = edge_index
    deg = degree(col, x.size(0), dtype=x.dtype)
    deg_inv_sqrt = deg.pow(-0.5)
    norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

    # Propagate messages
    return self.propagate(edge_index, x=x, norm=norm)

def message(self, x_j, norm):
    # x_j: features of source nodes
    return norm.view(-1, 1) * x_j

Key methods:

forward(): Main entry point message(): Constructs messages from source to target nodes aggregate(): Aggregates messages (usually don't override—set aggr parameter) update(): Updates node embeddings after aggregation

Variable naming convention: Appending _i or _j to tensor names automatically maps them to target or source nodes.

Working with Datasets Loading Built-in Datasets

PyG provides extensive benchmark datasets:

Citation networks (node classification)

from torch_geometric.datasets import Planetoid dataset = Planetoid(root='/tmp/Cora', name='Cora') # or 'CiteSeer', 'PubMed'

Graph classification

from torch_geometric.datasets import TUDataset dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')

Molecular datasets

from torch_geometric.datasets import QM9 dataset = QM9(root='/tmp/QM9')

Large-scale datasets

from torch_geometric.datasets import Reddit dataset = Reddit(root='/tmp/Reddit')

Check references/datasets_reference.md for a comprehensive list.

Creating Custom Datasets

For datasets that fit in memory, inherit from InMemoryDataset:

from torch_geometric.data import InMemoryDataset, Data import torch

class MyOwnDataset(InMemoryDataset): def init(self, root, transform=None, pre_transform=None): super().init(root, transform, pre_transform) self.load(self.processed_paths[0])

@property
def raw_file_names(self):
    return ['my_data.csv']  # Files needed in raw_dir

@property
def processed_file_names(self):
    return ['data.pt']  # Files in processed_dir

def download(self):
    # Download raw data to self.raw_dir
    pass

def process(self):
    # Read data, create Data objects
    data_list = []

    # Example: Create a simple graph
    edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
    x = torch.randn(2, 16)
    y = torch.tensor([0], dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, y=y)
    data_list.append(data)

    # Apply pre_filter and pre_transform
    if self.pre_filter is not None:
        data_list = [d for d in data_list if self.pre_filter(d)]

    if self.pre_transform is not None:
        data_list = [self.pre_transform(d) for d in data_list]

    # Save processed data
    self.save(data_list, self.processed_paths[0])

For large datasets that don't fit in memory, inherit from Dataset and implement len() and get(idx).

Loading Graphs from CSV import pandas as pd import torch from torch_geometric.data import HeteroData

Load nodes

nodes_df = pd.read_csv('nodes.csv') x = torch.tensor(nodes_df[['feat1', 'feat2']].values, dtype=torch.float)

Load edges

edges_df = pd.read_csv('edges.csv') edge_index = torch.tensor([edges_df['source'].values, edges_df['target'].values], dtype=torch.long)

data = Data(x=x, edge_index=edge_index)

Training Workflows Node Classification (Single Graph) import torch import torch.nn.functional as F from torch_geometric.datasets import Planetoid

Load dataset

dataset = Planetoid(root='/tmp/Cora', name='Cora') data = dataset[0]

Create model

model = GCN(dataset.num_features, dataset.num_classes) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

Training

model.train() for epoch in range(200): optimizer.zero_grad() out = model(data) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step()

if epoch % 10 == 0:
    print(f'Epoch {epoch}, Loss: {loss.item():.4f}')

Evaluation

model.eval() pred = model(data).argmax(dim=1) correct = (pred[data.test_mask] == data.y[data.test_mask]).sum() acc = int(correct) / int(data.test_mask.sum()) print(f'Test Accuracy: {acc:.4f}')

Graph Classification (Multiple Graphs) from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import global_mean_pool

class GraphClassifier(torch.nn.Module): def init(self, num_features, num_classes): super().init() self.conv1 = GCNConv(num_features, 64) self.conv2 = GCNConv(64, 64) self.lin = torch.nn.Linear(64, num_classes)

def forward(self, data):
    x, edge_index, batch = data.x, data.edge_index, data.batch

    x = self.conv1(x, edge_index)
    x = F.relu(x)
    x = self.conv2(x, edge_index)
    x = F.relu(x)

    # Global pooling (aggregate node features to graph-level)
    x = global_mean_pool(x, batch)

    x = self.lin(x)
    return F.log_softmax(x, dim=1)

Load dataset

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') loader = DataLoader(dataset, batch_size=32, shuffle=True)

model = GraphClassifier(dataset.num_features, dataset.num_classes) optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

Training

model.train() for epoch in range(100): total_loss = 0 for batch in loader: optimizer.zero_grad() out = model(batch) loss = F.nll_loss(out, batch.y) loss.backward() optimizer.step() total_loss += loss.item()

if epoch % 10 == 0:
    print(f'Epoch {epoch}, Loss: {total_loss / len(loader):.4f}')

Large-Scale Graphs with Neighbor Sampling

For large graphs, use NeighborLoader to sample subgraphs:

from torch_geometric.loader import NeighborLoader

Create a neighbor sampler

train_loader = NeighborLoader( data, num_neighbors=[25, 10], # Sample 25 neighbors for 1st hop, 10 for 2nd hop batch_size=128, input_nodes=data.train_mask, )

Training

model.train() for batch in train_loader: optimizer.zero_grad() out = model(batch) # Only compute loss on seed nodes (first batch_size nodes) loss = F.nll_loss(out[:batch.batch_size], batch.y[:batch.batch_size]) loss.backward() optimizer.step()

Important:

Output subgraphs are directed Node indices are relabeled (0 to batch.num_nodes - 1) Only use seed node predictions for loss computation Sampling beyond 2-3 hops is generally not feasible Advanced Features Heterogeneous Graphs

For graphs with multiple node and edge types, use HeteroData:

from torch_geometric.data import HeteroData

data = HeteroData()

Add node features for different types

data['paper'].x = torch.randn(100, 128) # 100 papers with 128 features data['author'].x = torch.randn(200, 64) # 200 authors with 64 features

Add edges for different types (source_type, edge_type, target_type)

data['author', 'writes', 'paper'].edge_index = torch.randint(0, 200, (2, 500)) data['paper', 'cites', 'paper'].edge_index = torch.randint(0, 100, (2, 300))

print(data)

Convert homogeneous models to heterogeneous:

from torch_geometric.nn import to_hetero

Define homogeneous model

model = GNN(...)

Convert to heterogeneous

model = to_hetero(model, data.metadata(), aggr='sum')

Use as normal

out = model(data.x_dict, data.edge_index_dict)

Or use HeteroConv for custom edge-type-specific operations:

from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv

class HeteroGNN(torch.nn.Module): def init(self, metadata): super().init() self.conv1 = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(-1, 64), ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64), }, aggr='sum')

    self.conv2 = HeteroConv({
        ('paper', 'cites', 'paper'): GCNConv(64, 32),
        ('author', 'writes', 'paper'): SAGEConv((64, 64), 32),
    }, aggr='sum')

def forward(self, x_dict, edge_index_dict):
    x_dict = self.conv1(x_dict, edge_index_dict)
    x_dict = {key: F.relu(x) for key, x in x_dict.items()}
    x_dict = self.conv2(x_dict, edge_index_dict)
    return x_dict

Transforms

Apply transforms to modify graph structure or features:

from torch_geometric.transforms import NormalizeFeatures, AddSelfLoops, Compose

Single transform

transform = NormalizeFeatures() dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)

Compose multiple transforms

transform = Compose([ AddSelfLoops(), NormalizeFeatures(), ]) dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)

Common transforms:

Structure: ToUndirected, AddSelfLoops, RemoveSelfLoops, KNNGraph, RadiusGraph Features: NormalizeFeatures, NormalizeScale, Center Sampling: RandomNodeSplit, RandomLinkSplit Positional Encoding: AddLaplacianEigenvectorPE, AddRandomWalkPE

See references/transforms_reference.md for the full list.

Model Explainability

PyG provides explainability tools to understand model predictions:

from torch_geometric.explain import Explainer, GNNExplainer

Create explainer

explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=200), explanation_type='model', # or 'phenomenon' node_mask_type='attributes', edge_mask_type='object', model_config=dict( mode='multiclass_classification', task_level='node', return_type='log_probs', ), )

Generate explanation for a specific node

node_idx = 10 explanation = explainer(data.x, data.edge_index, index=node_idx)

Visualize

print(f'Node {node_idx} explanation:') print(f'Important edges: {explanation.edge_mask.topk(5).indices}') print(f'Important features: {explanation.node_mask[node_idx].topk(5).indices}')

Pooling Operations

For hierarchical graph representations:

from torch_geometric.nn import TopKPooling, global_mean_pool

class HierarchicalGNN(torch.nn.Module): def init(self, num_features, num_classes): super().init() self.conv1 = GCNConv(num_features, 64) self.pool1 = TopKPooling(64, ratio=0.8) self.conv2 = GCNConv(64, 64) self.pool2 = TopKPooling(64, ratio=0.8) self.lin = torch.nn.Linear(64, num_classes)

def forward(self, data):
    x, edge_index, batch = data.x, data.edge_index, data.batch

    x = F.relu(self.conv1(x, edge_index))
    x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)

    x = F.relu(self.conv2(x, edge_index))
    x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)

    x = global_mean_pool(x, batch)
    x = self.lin(x)
    return F.log_softmax(x, dim=1)

Common Patterns and Best Practices Check Graph Properties

Undirected check

from torch_geometric.utils import is_undirected print(f"Is undirected: {is_undirected(data.edge_index)}")

Connected components

from torch_geometric.utils import connected_components print(f"Connected components: {connected_components(data.edge_index)}")

Contains self-loops

from torch_geometric.utils import contains_self_loops print(f"Has self-loops: {contains_self_loops(data.edge_index)}")

GPU Training device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) data = data.to(device)

For DataLoader

for batch in loader: batch = batch.to(device) # Train...

Save and Load Models

Save

torch.save(model.state_dict(), 'model.pth')

Load

model = GCN(num_features, num_classes) model.load_state_dict(torch.load('model.pth')) model.eval()

Layer Capabilities

When choosing layers, consider these capabilities:

SparseTensor: Supports efficient sparse matrix operations edge_weight: Handles one-dimensional edge weights edge_attr: Processes multi-dimensional edge features Bipartite: Works with bipartite graphs (different source/target dimensions) Lazy: Enables initialization without specifying input dimensions

See the GNN cheatsheet at references/layer_capabilities.md.

Resources Bundled References

This skill includes detailed reference documentation:

references/layers_reference.md: Complete listing of all 40+ GNN layers with descriptions and capabilities references/datasets_reference.md: Comprehensive dataset catalog organized by category references/transforms_reference.md: All available transforms and their use cases references/api_patterns.md: Common API patterns and coding examples Scripts

Utility scripts are provided in scripts/:

scripts/visualize_graph.py: Visualize graph structure using networkx and matplotlib scripts/create_gnn_template.py: Generate boilerplate code for common GNN architectures scripts/benchmark_model.py: Benchmark model performance on standard datasets

Execute scripts directly or read them for implementation patterns.

Official Resources Documentation: https://pytorch-geometric.readthedocs.io/ GitHub: https://github.com/pyg-team/pytorch_geometric Tutorials: https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html Examples: https://github.com/pyg-team/pytorch_geometric/tree/master/examples

返回排行榜