+1 vote
in Programming Languages by (89.5k points)
In a given CSR matrix, I want to know how many non-zero elements are there in each row and each column. Which function should I use to count it?

1 Answer

+3 votes
by (17.2k points)
selected by
 
Best answer

You can use getnnz() function of the scipy.sparse.csr_matrix module with parameter 'axis'. To get count in each column, set axis=0 and to get count in each row, set axis=1. Here is an example:

import numpy as np

from scipy.sparse import csr_matrix

X = np.array([[0, 5, 0, 0, 2], [3, 0, 0, 7, 0], [0, 0, 4, 0, 0], [0, 1, 0, 0, 6], [8, 0, 0, 0, 4]])

# Convert to CSR format

X_csr = csr_matrix(X)

count of non-zero elements per column

nonzero_counts_per_col = X_csr.getnnz(axis=0)

print(f"nonzero_counts_per_col: {nonzero_counts_per_col}")

output of this code: nonzero_counts_per_col: [2 2 1 1 3]

count of non-zero elements per row

nonzero_counts_per_row = X_csr.getnnz(axis=1)

print(f"nonzero_counts_per_row: {nonzero_counts_per_row}")

Output of this code: nonzero_counts_per_row: [2 2 1 2 2]


...