parent
29cdeec30a
commit
8a777a1f7a
@ -1,182 +0,0 @@ |
|||||||
package xyz.fycz.myreader.ai; |
|
||||||
|
|
||||||
import android.util.Log; |
|
||||||
|
|
||||||
import java.util.ArrayList; |
|
||||||
import java.util.Arrays; |
|
||||||
import java.util.List; |
|
||||||
import java.util.Random; |
|
||||||
|
|
||||||
import xyz.fycz.myreader.greendao.entity.Book; |
|
||||||
import xyz.fycz.myreader.greendao.entity.Chapter; |
|
||||||
import xyz.fycz.myreader.greendao.service.ChapterService; |
|
||||||
|
|
||||||
/** |
|
||||||
* 预测书籍字数 |
|
||||||
* |
|
||||||
* @author fengyue |
|
||||||
* @date 2021/4/18 16:54 |
|
||||||
*/ |
|
||||||
public class BookWordCountPre { |
|
||||||
private static final String TAG = BookWordCountPre.class.getSimpleName(); |
|
||||||
|
|
||||||
private double lr = 0.001; |
|
||||||
private Book book; |
|
||||||
private List<Chapter> chapters; |
|
||||||
private double[][] trainData; |
|
||||||
private double[][] target; |
|
||||||
private double[][] weights1; |
|
||||||
//private double[][] weights2;
|
|
||||||
private double[][] testData; |
|
||||||
private double median; |
|
||||||
private static Random rand = new Random(); |
|
||||||
|
|
||||||
|
|
||||||
public BookWordCountPre(Book book) { |
|
||||||
this.book = book; |
|
||||||
this.chapters = ChapterService.getInstance().findBookAllChapterByBookId(book.getId()); |
|
||||||
} |
|
||||||
|
|
||||||
//进行训练
|
|
||||||
public boolean train() { |
|
||||||
if (!preData()) { |
|
||||||
Log.i(TAG, String.format("《%s》缓存章节数量过少,无法进行训练", book.getName())); |
|
||||||
return false; |
|
||||||
} |
|
||||||
Log.i(TAG, String.format("《%s》开始进行训练", book.getName())); |
|
||||||
double loss = 0; |
|
||||||
double eps = 0.0000000001; |
|
||||||
double[][] gradient1; |
|
||||||
//double[][] gradient2;
|
|
||||||
double[][] adagrad1 = new double[trainData[0].length][1]; |
|
||||||
//double[][] adagrad2 = new double[trainData[0].length][1];
|
|
||||||
//double[][] dl_dw = new double[trainData[0].length][1];
|
|
||||||
int maxEpoch; |
|
||||||
maxEpoch = 1000 / trainData.length; |
|
||||||
if (maxEpoch < 10) maxEpoch = 10; |
|
||||||
for (int epoch = 0; epoch < maxEpoch; epoch++) { |
|
||||||
shuffle(trainData, target); |
|
||||||
for (int j = 0; j < trainData.length; j++) { |
|
||||||
double[][] oneData = MatrixUtil.to2dMatrix(trainData[j], false); |
|
||||||
double[][] oneTarget = MatrixUtil.to2dMatrix(target[j], true); |
|
||||||
double[][] out = getOut(oneData); |
|
||||||
loss = Math.sqrt(MatrixUtil.sum(MatrixUtil.pow(MatrixUtil.sub(out, oneTarget), 2)) / 2); |
|
||||||
/*dl_dw = MatrixUtil.sub( |
|
||||||
MatrixUtil.add( |
|
||||||
MatrixUtil.dot(MatrixUtil.pow(oneData, 2), weights2), |
|
||||||
MatrixUtil.dot(oneData, weights2)), |
|
||||||
oneTarget |
|
||||||
);*/ |
|
||||||
|
|
||||||
gradient1 = MatrixUtil.dot(MatrixUtil.transpose(oneData), MatrixUtil.sub(out, oneTarget)); |
|
||||||
//gradient1 = MatrixUtil.dot(MatrixUtil.transpose(MatrixUtil.pow(oneData, 2)), dl_dw);
|
|
||||||
//gradient2 = MatrixUtil.dot(MatrixUtil.transpose(oneData), dl_dw);
|
|
||||||
adagrad1 = MatrixUtil.add(adagrad1, MatrixUtil.pow(gradient1, 2)); |
|
||||||
//adagrad2 = MatrixUtil.add(adagrad1, MatrixUtil.pow(gradient2, 2));
|
|
||||||
weights1 = MatrixUtil.sub(weights1, MatrixUtil.divide(MatrixUtil.dot(gradient1, lr), |
|
||||||
MatrixUtil.sqrt(MatrixUtil.add(adagrad1, eps)))); |
|
||||||
/*weights2 = MatrixUtil.sub(weights2, MatrixUtil.divide(MatrixUtil.dot(gradient2, lr), |
|
||||||
MatrixUtil.sqrt(MatrixUtil.add(adagrad2, eps))));*/ |
|
||||||
} |
|
||||||
Log.i(TAG, String.format("《%s》-> epoch=%d,loss=%f", book.getName(), epoch, loss)); |
|
||||||
} |
|
||||||
return true; |
|
||||||
} |
|
||||||
|
|
||||||
//进行预测并获得书籍总字数
|
|
||||||
public int predict() { |
|
||||||
double[][] pre = getOut(testData); |
|
||||||
double[] preVec = MatrixUtil.toVector(pre); |
|
||||||
Arrays.sort(preVec); |
|
||||||
int k = (int) (preVec[preVec.length / 2 + 1] / median); |
|
||||||
//int k = (int) ((MatrixUtil.sum(pre) / pre.length) / median);
|
|
||||||
pre = MatrixUtil.divide(pre, k); |
|
||||||
/*for (int i = 0; i < pre.length; i++) { |
|
||||||
pre[i][0] = median; |
|
||||||
}*/ |
|
||||||
Log.i(TAG, String.format("k=%d->《%s》的预测数据%s", k, book.getName(), |
|
||||||
Arrays.toString(MatrixUtil.toVector(pre)))); |
|
||||||
return (int) (MatrixUtil.sum(pre) + MatrixUtil.sum(target)); |
|
||||||
} |
|
||||||
|
|
||||||
private double[][] getOut(double[][] data) { |
|
||||||
/*return MatrixUtil.add(MatrixUtil.dot(MatrixUtil.pow(data, 2), weights2), |
|
||||||
MatrixUtil.dot(data, weights1));*/ |
|
||||||
return MatrixUtil.dot(data, weights1); |
|
||||||
} |
|
||||||
|
|
||||||
//准备训练数据
|
|
||||||
private boolean preData() { |
|
||||||
rand.setSeed(10); |
|
||||||
List<Chapter> catheChapters = new ArrayList<>(); |
|
||||||
List<Chapter> unCatheChapters = new ArrayList<>(); |
|
||||||
//章节最长标题长度
|
|
||||||
int maxTitleLen = 0; |
|
||||||
//获取已缓存章节
|
|
||||||
for (Chapter chapter : chapters) { |
|
||||||
if (ChapterService.isChapterCached(book.getId(), chapter.getTitle())) { |
|
||||||
catheChapters.add(chapter); |
|
||||||
} else { |
|
||||||
unCatheChapters.add(chapter); |
|
||||||
} |
|
||||||
if (maxTitleLen < chapter.getTitle().length()) { |
|
||||||
maxTitleLen = chapter.getTitle().length(); |
|
||||||
} |
|
||||||
} |
|
||||||
Log.i(TAG, String.format("《%s》已缓存章节数量:%d,最大章节标题长度:%d", |
|
||||||
book.getName(), catheChapters.size(), maxTitleLen)); |
|
||||||
if (catheChapters.size() <= 10) return false; |
|
||||||
//创建训练数据
|
|
||||||
trainData = new double[catheChapters.size()][maxTitleLen + 1]; |
|
||||||
//创建测试数据
|
|
||||||
testData = new double[chapters.size() - catheChapters.size()][maxTitleLen + 1]; |
|
||||||
//创建权重矩阵
|
|
||||||
weights1 = new double[maxTitleLen + 1][1]; |
|
||||||
//weights2 = new double[maxTitleLen + 1][1];
|
|
||||||
//创建目标矩阵
|
|
||||||
target = new double[catheChapters.size()][1]; |
|
||||||
for (int i = 0; i < catheChapters.size(); i++) { |
|
||||||
Chapter chapter = catheChapters.get(i); |
|
||||||
char[] charArr = chapter.getTitle().replaceAll("[((【{]", "").toCharArray(); |
|
||||||
for (int j = 0; j < charArr.length; j++) { |
|
||||||
trainData[i][j] = charArr[j]; |
|
||||||
} |
|
||||||
trainData[i][maxTitleLen] = 1; |
|
||||||
target[i][0] = ChapterService.countChar(book.getId(), chapter.getTitle()); |
|
||||||
} |
|
||||||
for (int i = 0; i < maxTitleLen + 1; i++) { |
|
||||||
weights1[i][0] = rand.nextDouble(); |
|
||||||
//weights2[i][0] = Math.random();
|
|
||||||
} |
|
||||||
for (int i = 0; i < unCatheChapters.size(); i++) { |
|
||||||
Chapter chapter = unCatheChapters.get(i); |
|
||||||
char[] charArr = chapter.getTitle().toCharArray(); |
|
||||||
for (int j = 0; j < charArr.length; j++) { |
|
||||||
testData[i][j] = charArr[j]; |
|
||||||
} |
|
||||||
testData[i][maxTitleLen] = 1; |
|
||||||
} |
|
||||||
/*double[] tem = MatrixUtil.toVector(target); |
|
||||||
Arrays.sort(tem); |
|
||||||
median = tem[tem.length / 2 + 1];*/ |
|
||||||
median = MatrixUtil.sum(target) / target.length; |
|
||||||
return true; |
|
||||||
} |
|
||||||
|
|
||||||
|
|
||||||
public static <T> void swap(T[] a, int i, int j) { |
|
||||||
T temp = a[i]; |
|
||||||
a[i] = a[j]; |
|
||||||
a[j] = temp; |
|
||||||
} |
|
||||||
|
|
||||||
public static <T> void shuffle(T[]... arr) { |
|
||||||
int length = arr[0].length; |
|
||||||
for (int i = length; i > 0; i--) { |
|
||||||
int randInd = rand.nextInt(i); |
|
||||||
for (T[] ts : arr) { |
|
||||||
swap(ts, randInd, i - 1); |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
@ -1,295 +0,0 @@ |
|||||||
package xyz.fycz.myreader.ai; |
|
||||||
|
|
||||||
/** |
|
||||||
* @author fengyue |
|
||||||
* @date 2021/4/7 16:10 |
|
||||||
*/ |
|
||||||
public class MatrixUtil { |
|
||||||
//矩阵加法 C=A+B
|
|
||||||
public static double[][] add(double[][] m1, double[][] m2) { |
|
||||||
if (m1 == null || m2 == null || |
|
||||||
m1.length != m2.length || |
|
||||||
m1[0].length != m2[0].length) { |
|
||||||
return null; |
|
||||||
} |
|
||||||
|
|
||||||
double[][] m = new double[m1.length][m1[0].length]; |
|
||||||
|
|
||||||
for (int i = 0; i < m.length; ++i) { |
|
||||||
for (int j = 0; j < m[i].length; ++j) { |
|
||||||
m[i][j] = m1[i][j] + m2[i][j]; |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return m; |
|
||||||
} |
|
||||||
|
|
||||||
public static double[][] add(double[][] m, double a) { |
|
||||||
if (m == null) { |
|
||||||
return null; |
|
||||||
} |
|
||||||
|
|
||||||
double[][] retM = new double[m.length][m[0].length]; |
|
||||||
|
|
||||||
for (int i = 0; i < retM.length; ++i) { |
|
||||||
for (int j = 0; j < retM[i].length; ++j) { |
|
||||||
retM[i][j] = m[i][j] + a; |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return retM; |
|
||||||
} |
|
||||||
|
|
||||||
public static double[][] sub(double[][] m1, double[][] m2) { |
|
||||||
if (m1 == null || m2 == null || |
|
||||||
m1.length != m2.length || |
|
||||||
m1[0].length != m2[0].length) { |
|
||||||
return null; |
|
||||||
} |
|
||||||
|
|
||||||
double[][] m = new double[m1.length][m1[0].length]; |
|
||||||
|
|
||||||
for (int i = 0; i < m.length; ++i) { |
|
||||||
for (int j = 0; j < m[i].length; ++j) { |
|
||||||
m[i][j] = m1[i][j] - m2[i][j]; |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return m; |
|
||||||
} |
|
||||||
|
|
||||||
//矩阵转置
|
|
||||||
public static double[][] transpose(double[][] m) { |
|
||||||
if (m == null) return null; |
|
||||||
double[][] mt = new double[m[0].length][m.length]; |
|
||||||
for (int i = 0; i < m.length; ++i) { |
|
||||||
for (int j = 0; j < m[i].length; ++j) { |
|
||||||
mt[j][i] = m[i][j]; |
|
||||||
} |
|
||||||
} |
|
||||||
return mt; |
|
||||||
} |
|
||||||
|
|
||||||
//矩阵相乘 C=A*B
|
|
||||||
public static double[][] dot(double[][] m1, double[][] m2) { |
|
||||||
if (m1 == null || m2 == null || m1[0].length != m2.length) |
|
||||||
return null; |
|
||||||
|
|
||||||
double[][] m = new double[m1.length][m2[0].length]; |
|
||||||
for (int i = 0; i < m1.length; ++i) { |
|
||||||
for (int j = 0; j < m2[0].length; ++j) { |
|
||||||
for (int k = 0; k < m1[i].length; ++k) { |
|
||||||
m[i][j] += m1[i][k] * m2[k][j]; |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return m; |
|
||||||
} |
|
||||||
|
|
||||||
//数乘矩阵
|
|
||||||
public static double[][] dot(double[][] m, double k) { |
|
||||||
if (m == null) return null; |
|
||||||
double[][] retM = new double[m.length][m[0].length]; |
|
||||||
for (int i = 0; i < m.length; i++) { |
|
||||||
for (int j = 0; j < m[0].length; j++) { |
|
||||||
retM[i][j] = m[i][j] * k; |
|
||||||
} |
|
||||||
} |
|
||||||
return retM; |
|
||||||
} |
|
||||||
|
|
||||||
//同型矩阵除法
|
|
||||||
public static double[][] divide(double[][] m1, double[][] m2) { |
|
||||||
if (m1 == null || m2 == null || |
|
||||||
m1.length != m2.length || |
|
||||||
m1[0].length != m2[0].length) { |
|
||||||
return null; |
|
||||||
} |
|
||||||
double[][] retM = new double[m1.length][m1[0].length]; |
|
||||||
for (int i = 0; i < retM.length; ++i) { |
|
||||||
for (int j = 0; j < retM[i].length; ++j) { |
|
||||||
retM[i][j] = m1[i][j] / m2[i][j]; |
|
||||||
} |
|
||||||
} |
|
||||||
return retM; |
|
||||||
} |
|
||||||
|
|
||||||
//矩阵除数
|
|
||||||
public static double[][] divide(double[][] m, double k) { |
|
||||||
if (m == null) return null; |
|
||||||
double[][] retM = new double[m.length][m[0].length]; |
|
||||||
for (int i = 0; i < m.length; i++) { |
|
||||||
for (int j = 0; j < m[0].length; j++) { |
|
||||||
retM[i][j] = m[i][j] / k; |
|
||||||
} |
|
||||||
} |
|
||||||
return retM; |
|
||||||
} |
|
||||||
|
|
||||||
//求矩阵行列式(需为方阵)
|
|
||||||
public static double det(double[][] m) { |
|
||||||
if (m == null || m.length != m[0].length) |
|
||||||
return 0; |
|
||||||
|
|
||||||
if (m.length == 1) |
|
||||||
return m[0][0]; |
|
||||||
else if (m.length == 2) |
|
||||||
return det2(m); |
|
||||||
else if (m.length == 3) |
|
||||||
return det3(m); |
|
||||||
else { |
|
||||||
int re = 0; |
|
||||||
for (int i = 0; i < m.length; ++i) { |
|
||||||
re += (((i + 1) % 2) * 2 - 1) * det(companion(m, i, 0)) * m[i][0]; |
|
||||||
} |
|
||||||
return re; |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
//求二阶行列式
|
|
||||||
public static double det2(double[][] m) { |
|
||||||
if (m == null || m.length != 2 || m[0].length != 2) |
|
||||||
return 0; |
|
||||||
|
|
||||||
return m[0][0] * m[1][1] - m[1][0] * m[0][1]; |
|
||||||
} |
|
||||||
|
|
||||||
//求三阶行列式
|
|
||||||
public static double det3(double[][] m) { |
|
||||||
if (m == null || m.length != 3 || m[0].length != 3) |
|
||||||
return 0; |
|
||||||
|
|
||||||
double re = 0; |
|
||||||
for (int i = 0; i < 3; ++i) { |
|
||||||
int temp1 = 1; |
|
||||||
for (int j = 0, k = i; j < 3; ++j, ++k) { |
|
||||||
temp1 *= m[j][k % 3]; |
|
||||||
} |
|
||||||
re += temp1; |
|
||||||
temp1 = 1; |
|
||||||
for (int j = 0, k = i; j < 3; ++j, --k) { |
|
||||||
if (k < 0) k += 3; |
|
||||||
temp1 *= m[j][k]; |
|
||||||
} |
|
||||||
re -= temp1; |
|
||||||
} |
|
||||||
|
|
||||||
return re; |
|
||||||
} |
|
||||||
|
|
||||||
//求矩阵的逆(需方阵)
|
|
||||||
public static double[][] inv(double[][] m) { |
|
||||||
if (m == null || m.length != m[0].length) |
|
||||||
return null; |
|
||||||
|
|
||||||
double A = det(m); |
|
||||||
double[][] mi = new double[m.length][m[0].length]; |
|
||||||
for (int i = 0; i < m.length; ++i) { |
|
||||||
for (int j = 0; j < m[i].length; ++j) { |
|
||||||
double[][] temp = companion(m, i, j); |
|
||||||
mi[j][i] = (((i + j + 1) % 2) * 2 - 1) * det(temp) / A; |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return mi; |
|
||||||
} |
|
||||||
|
|
||||||
//求方阵代数余子式
|
|
||||||
public static double[][] companion(double[][] m, int x, int y) { |
|
||||||
if (m == null || m.length <= x || m[0].length <= y || |
|
||||||
m.length == 1 || m[0].length == 1) |
|
||||||
return null; |
|
||||||
|
|
||||||
double[][] cm = new double[m.length - 1][m[0].length - 1]; |
|
||||||
|
|
||||||
int dx = 0; |
|
||||||
for (int i = 0; i < m.length; ++i) { |
|
||||||
if (i != x) { |
|
||||||
int dy = 0; |
|
||||||
for (int j = 0; j < m[i].length; ++j) { |
|
||||||
if (j != y) { |
|
||||||
cm[dx][dy++] = m[i][j]; |
|
||||||
} |
|
||||||
} |
|
||||||
++dx; |
|
||||||
} |
|
||||||
} |
|
||||||
return cm; |
|
||||||
} |
|
||||||
|
|
||||||
//生成全为0的矩阵
|
|
||||||
public static double[][] zeros(int rows, int cols){ |
|
||||||
return new double[rows][cols]; |
|
||||||
} |
|
||||||
|
|
||||||
//生成全为1的矩阵
|
|
||||||
public static double[][] ones(int rows, int cols){ |
|
||||||
return add(zeros(rows, cols), 1); |
|
||||||
} |
|
||||||
|
|
||||||
public static double sum(double[][] matrix){ |
|
||||||
double sum = 0; |
|
||||||
for (double[] doubles : matrix) { |
|
||||||
for (double aDouble : doubles) { |
|
||||||
sum += aDouble; |
|
||||||
} |
|
||||||
} |
|
||||||
return sum; |
|
||||||
} |
|
||||||
|
|
||||||
public static double[][] pow(double[][] matrix, int exponent){ |
|
||||||
if (matrix == null) return null; |
|
||||||
double[][] retM = new double[matrix.length][matrix[0].length]; |
|
||||||
for (int i = 0; i < matrix.length; i++) { |
|
||||||
for (int j = 0; j < matrix[i].length; j++) { |
|
||||||
retM[i][j] = Math.pow(matrix[i][j], exponent); |
|
||||||
} |
|
||||||
} |
|
||||||
return retM; |
|
||||||
} |
|
||||||
|
|
||||||
public static double[][] sqrt(double[][] matrix){ |
|
||||||
if (matrix == null) return null; |
|
||||||
double[][] retM = new double[matrix.length][matrix[0].length]; |
|
||||||
for (int i = 0; i < matrix.length; i++) { |
|
||||||
for (int j = 0; j < matrix[i].length; j++) { |
|
||||||
retM[i][j] = Math.sqrt(matrix[i][j]); |
|
||||||
} |
|
||||||
} |
|
||||||
return matrix; |
|
||||||
} |
|
||||||
|
|
||||||
public static double[][] to2dMatrix(double[] vector, boolean isCol){ |
|
||||||
if (vector == null) return null; |
|
||||||
double[][] retM; |
|
||||||
if (isCol) { |
|
||||||
retM = new double[vector.length][1]; |
|
||||||
}else { |
|
||||||
retM = new double[1][vector.length]; |
|
||||||
} |
|
||||||
for (int i = 0; i < vector.length; i++) { |
|
||||||
if (isCol) { |
|
||||||
retM[i][0] = vector[i]; |
|
||||||
}else { |
|
||||||
retM[0][i] = vector[i]; |
|
||||||
} |
|
||||||
} |
|
||||||
return retM; |
|
||||||
} |
|
||||||
|
|
||||||
public static double[] toVector(double[][] matrix){ |
|
||||||
double[] retV = null; |
|
||||||
if (matrix.length == 1){ |
|
||||||
retV = new double[matrix[0].length]; |
|
||||||
double[] doubles = matrix[0]; |
|
||||||
System.arraycopy(doubles, 0, retV, 0, doubles.length); |
|
||||||
}else if (matrix[0].length == 1){ |
|
||||||
retV = new double[matrix.length]; |
|
||||||
for (int i = 0; i < matrix.length; i++) { |
|
||||||
retV[i] = matrix[i][0]; |
|
||||||
} |
|
||||||
} |
|
||||||
return retV; |
|
||||||
} |
|
||||||
} |
|
Loading…
Reference in new issue