Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.List;

import org.apache.sysds.runtime.matrix.data.IJV;
import org.apache.sysds.runtime.util.UtilFunctions;

/**
* This SparseBlock is an abstraction for different sparse matrix formats.
Expand Down Expand Up @@ -94,10 +95,23 @@ public enum Type {
* @param r row index
*/
public abstract void compact(int r);

/**
* In-place compaction of non-zero-entries; removes zero entries
* and shifts non-zero entries to the left if necessary.
*/
public abstract void compact();

////////////////////////
//obtain basic meta data


/**
* Get the type of the sparse block.
*
* @return sparse block type
*/
public abstract SparseBlock.Type getSparseBlockType();

/**
* Get the number of rows in the sparse block.
*
Expand Down Expand Up @@ -501,18 +515,26 @@ public boolean contains(double pattern, int rl, int ru) {
}

public List<Integer> contains(double[] pattern, boolean earlyAbort) {
int pNnz = UtilFunctions.computeNnz(pattern, 0, pattern.length);
List<Integer> ret = new ArrayList<>();
int rlen = numRows();

for( int i=0; i<rlen; i++ ) {
int apos = pos(i);
int alen = size(i);
if(pNnz > alen) continue;

int[] aix = indexes(i);
double[] avals = values(i);
boolean lret = true;
int rNnz = 0;

//safe comparison on long representations, incl NaN
for(int k=apos; k<apos+alen & !lret; k++)
for(int k=apos; k<apos+alen && lret; k++) {
lret &= Double.compare(avals[k], pattern[aix[k]]) == 0;
if( lret )
if(avals[k] != 0) rNnz++;
}
if(lret && rNnz == pNnz)
ret.add(i);
if(earlyAbort && ret.size()>0)
return ret;
Expand Down Expand Up @@ -764,17 +786,31 @@ public void remove() {
* values are available.
*/
private void findNextNonZeroRow(int cl) {
while( _curRow<_rlen && (isEmpty(_curRow)
|| (cl>0 && posFIndexGTE(_curRow, cl) < 0)) )
while(_curRow < _rlen){
if(isEmpty(_curRow)){
_curRow++;
continue;
}

int pos = (cl == 0)? 0 : posFIndexGTE(_curRow, cl);
if(pos < 0){
_curRow++;
continue;
}

int sizeRow = size(_curRow);
int endPos = (_cu == Integer.MAX_VALUE)? sizeRow : posFIndexGTE(_curRow, _cu);
if(endPos < 0) endPos = sizeRow;

if(pos < endPos){
_curColIx = pos(_curRow)+pos;
_curIndexes = indexes(_curRow);
_curValues = values(_curRow);
return;
}
_curRow++;
if(_curRow >= _rlen)
_noNext = true;
else {
_curColIx = (cl==0) ?
pos(_curRow) : posFIndexGTE(_curRow, cl);
_curIndexes = indexes(_curRow);
_curValues = values(_curRow);
}
_noNext = true;
}
}

Expand Down
35 changes: 26 additions & 9 deletions src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,25 @@ public void compact(int r) {
//do nothing everything preallocated
}

@Override
public void compact() {
int pos = 0;
for(int i=0; i< _values.length; i++) {
if(_values[i] != 0){
_values[pos] = _values[i];
_rindexes[pos] = _rindexes[i];
_cindexes[pos] = _cindexes[i];
pos++;
}
}
_size = pos;
}

@Override
public SparseBlock.Type getSparseBlockType() {
return Type.COO;
}

@Override
public int numRows() {
return _rlen;
Expand Down Expand Up @@ -221,12 +240,12 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) {
}

//2. correct array lengths
if(_size != nnz && _cindexes.length < nnz && _rindexes.length < nnz && _values.length < nnz) {
if(_size != nnz || _cindexes.length < nnz || _rindexes.length < nnz || _values.length < nnz) {
throw new RuntimeException("Incorrect array lengths.");
}

//3.1. sort order of row indices
for( int i=1; i<=nnz; i++ ) {
for( int i=1; i<nnz; i++ ) {
if(_rindexes[i] < _rindexes[i-1])
throw new RuntimeException("Wrong sorted order of row indices");
}
Expand All @@ -235,26 +254,24 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) {
for( int i=0; i<rlen; i++ ) {
int apos = pos(i);
int alen = size(i);
for(int k=apos+i; k<apos+alen; k++)
if( _cindexes[k+1] >= _cindexes[k] )
for(int k=apos+1; k<apos+alen; k++)
if(_cindexes[k-1] > _cindexes[k])
throw new RuntimeException("Wrong sparse row ordering: "
+ k + " "+_cindexes[k-1]+" "+_cindexes[k]);
for( int k=apos; k<apos+alen; k++ )
if(_values[k] == 0)
throw new RuntimeException("Wrong sparse row: zero at "
+ k + " at col index " + _cindexes[k]);
}

//4. non-existing zero values
for( int i=0; i<_size; i++ ) {
if( _values[i] == 0)
throw new RuntimeException("The values array should not contain zeros."
+ " The " + i + "th value is "+_values[i]);
if(_cindexes[i] < 0 || _rindexes[i] < 0)
throw new RuntimeException("Invalid index at pos=" + i);
}

//5. a capacity that is no larger than nnz times the resize factor
int capacity = _values.length;
if( capacity > nnz*RESIZE_FACTOR1 ) {
if( capacity > INIT_CAPACITY && capacity > nnz*RESIZE_FACTOR1 ) {
throw new RuntimeException("Capacity is larger than the nnz times a resize factor."
+ " Current size: "+capacity+ ", while Expected size:"+nnz*RESIZE_FACTOR1);
}
Expand Down
83 changes: 44 additions & 39 deletions src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,16 @@ public SparseBlockCSC(int rlen, int clen) {
_size = 0;
}

public SparseBlockCSC(int clen, int capacity, int size) {
public SparseBlockCSC(int rlen, int clen, int capacity) {
_rlen = rlen;
_ptr = new int[clen + 1]; //ix0=0
_indexes = new int[capacity];
_values = new double[capacity];
_size = size;
_size = 0;
}

public SparseBlockCSC(int[] rowPtr, int[] rowInd, double[] values, int nnz) {
_ptr = rowPtr;
public SparseBlockCSC(int[] colPtr, int[] rowInd, double[] values, int nnz) {
_ptr = colPtr;
_indexes = rowInd;
_values = values;
_size = nnz;
Expand All @@ -94,8 +95,9 @@ public SparseBlockCSC(SparseBlock sblock) {

private void initialize(SparseBlock sblock) {

if(_size > Integer.MAX_VALUE)
throw new RuntimeException("SparseBlockCSC supports nnz<=Integer.MAX_VALUE but got " + _size);
long size = sblock.size();
if(size > Integer.MAX_VALUE)
throw new RuntimeException("SparseBlockCSC supports nnz<=Integer.MAX_VALUE but got " + size);

//special case SparseBlockCSC
if(sblock instanceof SparseBlockCSC) {
Expand Down Expand Up @@ -223,27 +225,6 @@ public SparseBlockCSC(int cols, int[] rowInd, int[] colInd, double[] values) {

}

public SparseBlockCSC(int cols, int nnz, int[] rowInd) {

_clenInferred = cols;
_ptr = new int[cols + 1];
_indexes = Arrays.copyOf(rowInd, nnz);
_values = new double[nnz];
Arrays.fill(_values, 1);
_size = nnz;

//single-pass construction of col pointers
//and copy of row indexes if necessary
for(int i = 0, pos = 0; i < cols; i++) {
if(rowInd[i] >= 0) {
if(cols > nnz)
_indexes[pos] = rowInd[i];
pos++;
}
_ptr[i + 1] = pos;
}
}

/**
* Initializes the CSC sparse block from an ordered input stream of ultra-sparse ijv triples.
*
Expand Down Expand Up @@ -288,7 +269,6 @@ public void initSparse(int clen, int nnz, DataInput in) throws IOException {
// Allocate space if necessary
if(_values.length < nnz) {
resize(newCapacity(nnz));
System.out.println("hallo");
}
// Read sparse columns, append and update pointers
_ptr[0] = 0;
Expand Down Expand Up @@ -377,12 +357,36 @@ public void compact(int r) {
//do nothing everything preallocated
}

@Override
public void compact() {
int pos = 0;
for(int i=0; i<numCols(); i++) {
int apos = posCol(i);
int alen = sizeCol(i);
_ptr[i] = pos;
for(int j=apos; j<apos+alen; j++) {
if(_values[j] != 0){
_values[pos] = _values[j];
_indexes[pos] = _indexes[j];
pos++;
}
}
}
_ptr[numCols()] = pos;
_size = pos;
}

@Override
public SparseBlock.Type getSparseBlockType() {
return Type.CSC;
}

@Override
public int numRows() {
if(_rlen > -1)
return _rlen;
else {
int rlen = Arrays.stream(_indexes).max().getAsInt();
int rlen = Arrays.stream(_indexes).max().getAsInt()+1;
_rlen = rlen;
return rlen;
}
Expand Down Expand Up @@ -550,12 +554,12 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) {
}

//2. correct array lengths
if(_size != nnz && _ptr.length < clen + 1 && _values.length < nnz && _indexes.length < nnz) {
if(_size != nnz || _ptr.length < clen + 1 || _values.length < nnz || _indexes.length < nnz) {
throw new RuntimeException("Incorrect array lengths.");
}

//3. non-decreasing row pointers
for(int i = 1; i < clen; i++) {
//3. non-decreasing col pointers
for(int i = 1; i <= clen; i++) {
if(_ptr[i - 1] > _ptr[i] && strict)
throw new RuntimeException(
"Column pointers are decreasing at column: " + i + ", with pointers " + _ptr[i - 1] + " > " +
Expand All @@ -569,10 +573,9 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) {
for(int k = apos + 1; k < apos + alen; k++)
if(_indexes[k - 1] >= _indexes[k])
throw new RuntimeException(
"Wrong sparse column ordering: " + k + " " + _indexes[k - 1] + " " + _indexes[k]);
for(int k = apos; k < apos + alen; k++)
if(_values[k] == 0)
throw new RuntimeException("Wrong sparse column: zero at " + k + " at row index " + _indexes[k]);
"Wrong sparse column ordering, at column=" + i + ", pos=" + k + " with row indexes " +
_indexes[k - 1] + ">=" + _indexes[k]
);
}

//5. non-existing zero values
Expand All @@ -581,11 +584,13 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) {
throw new RuntimeException(
"The values array should not contain zeros." + " The " + i + "th value is " + _values[i]);
}
if(_indexes[i] < 0)
throw new RuntimeException("Invalid index at pos=" + i);
}

//6. a capacity that is no larger than nnz times resize factor.
int capacity = _values.length;
if(capacity > nnz * RESIZE_FACTOR1) {
if(capacity > INIT_CAPACITY && capacity > nnz * RESIZE_FACTOR1) {
throw new RuntimeException(
"Capacity is larger than the nnz times a resize factor." + " Current size: " + capacity +
", while Expected size:" + nnz * RESIZE_FACTOR1);
Expand Down Expand Up @@ -938,7 +943,7 @@ public void deleteIndexRangeCol(int c, int rl, int ru) {
int len = sizeCol(c);
int end = internPosFIndexGTECol(ru, c);
if(end < 0) //delete all remaining
end = start + len;
end = posCol(c) + len;

//overlapping array copy (shift rhs values left)
System.arraycopy(_indexes, end, _indexes, start, _size - end);
Expand Down Expand Up @@ -1059,7 +1064,7 @@ public int posFIndexGTCol(int r, int c) {
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("SparseBlockCSR: clen=");
sb.append("SparseBlockCSC: clen=");
sb.append(numCols());
sb.append(", nnz=");
sb.append(size());
Expand Down
Loading
Loading