parent
29cdeec30a
commit
8a777a1f7a
@ -1,182 +0,0 @@ |
||||
package xyz.fycz.myreader.ai; |
||||
|
||||
import android.util.Log; |
||||
|
||||
import java.util.ArrayList; |
||||
import java.util.Arrays; |
||||
import java.util.List; |
||||
import java.util.Random; |
||||
|
||||
import xyz.fycz.myreader.greendao.entity.Book; |
||||
import xyz.fycz.myreader.greendao.entity.Chapter; |
||||
import xyz.fycz.myreader.greendao.service.ChapterService; |
||||
|
||||
/** |
||||
* 预测书籍字数 |
||||
* |
||||
* @author fengyue |
||||
* @date 2021/4/18 16:54 |
||||
*/ |
||||
public class BookWordCountPre { |
||||
private static final String TAG = BookWordCountPre.class.getSimpleName(); |
||||
|
||||
private double lr = 0.001; |
||||
private Book book; |
||||
private List<Chapter> chapters; |
||||
private double[][] trainData; |
||||
private double[][] target; |
||||
private double[][] weights1; |
||||
//private double[][] weights2;
|
||||
private double[][] testData; |
||||
private double median; |
||||
private static Random rand = new Random(); |
||||
|
||||
|
||||
public BookWordCountPre(Book book) { |
||||
this.book = book; |
||||
this.chapters = ChapterService.getInstance().findBookAllChapterByBookId(book.getId()); |
||||
} |
||||
|
||||
//进行训练
|
||||
public boolean train() { |
||||
if (!preData()) { |
||||
Log.i(TAG, String.format("《%s》缓存章节数量过少,无法进行训练", book.getName())); |
||||
return false; |
||||
} |
||||
Log.i(TAG, String.format("《%s》开始进行训练", book.getName())); |
||||
double loss = 0; |
||||
double eps = 0.0000000001; |
||||
double[][] gradient1; |
||||
//double[][] gradient2;
|
||||
double[][] adagrad1 = new double[trainData[0].length][1]; |
||||
//double[][] adagrad2 = new double[trainData[0].length][1];
|
||||
//double[][] dl_dw = new double[trainData[0].length][1];
|
||||
int maxEpoch; |
||||
maxEpoch = 1000 / trainData.length; |
||||
if (maxEpoch < 10) maxEpoch = 10; |
||||
for (int epoch = 0; epoch < maxEpoch; epoch++) { |
||||
shuffle(trainData, target); |
||||
for (int j = 0; j < trainData.length; j++) { |
||||
double[][] oneData = MatrixUtil.to2dMatrix(trainData[j], false); |
||||
double[][] oneTarget = MatrixUtil.to2dMatrix(target[j], true); |
||||
double[][] out = getOut(oneData); |
||||
loss = Math.sqrt(MatrixUtil.sum(MatrixUtil.pow(MatrixUtil.sub(out, oneTarget), 2)) / 2); |
||||
/*dl_dw = MatrixUtil.sub( |
||||
MatrixUtil.add( |
||||
MatrixUtil.dot(MatrixUtil.pow(oneData, 2), weights2), |
||||
MatrixUtil.dot(oneData, weights2)), |
||||
oneTarget |
||||
);*/ |
||||
|
||||
gradient1 = MatrixUtil.dot(MatrixUtil.transpose(oneData), MatrixUtil.sub(out, oneTarget)); |
||||
//gradient1 = MatrixUtil.dot(MatrixUtil.transpose(MatrixUtil.pow(oneData, 2)), dl_dw);
|
||||
//gradient2 = MatrixUtil.dot(MatrixUtil.transpose(oneData), dl_dw);
|
||||
adagrad1 = MatrixUtil.add(adagrad1, MatrixUtil.pow(gradient1, 2)); |
||||
//adagrad2 = MatrixUtil.add(adagrad1, MatrixUtil.pow(gradient2, 2));
|
||||
weights1 = MatrixUtil.sub(weights1, MatrixUtil.divide(MatrixUtil.dot(gradient1, lr), |
||||
MatrixUtil.sqrt(MatrixUtil.add(adagrad1, eps)))); |
||||
/*weights2 = MatrixUtil.sub(weights2, MatrixUtil.divide(MatrixUtil.dot(gradient2, lr), |
||||
MatrixUtil.sqrt(MatrixUtil.add(adagrad2, eps))));*/ |
||||
} |
||||
Log.i(TAG, String.format("《%s》-> epoch=%d,loss=%f", book.getName(), epoch, loss)); |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
//进行预测并获得书籍总字数
|
||||
public int predict() { |
||||
double[][] pre = getOut(testData); |
||||
double[] preVec = MatrixUtil.toVector(pre); |
||||
Arrays.sort(preVec); |
||||
int k = (int) (preVec[preVec.length / 2 + 1] / median); |
||||
//int k = (int) ((MatrixUtil.sum(pre) / pre.length) / median);
|
||||
pre = MatrixUtil.divide(pre, k); |
||||
/*for (int i = 0; i < pre.length; i++) { |
||||
pre[i][0] = median; |
||||
}*/ |
||||
Log.i(TAG, String.format("k=%d->《%s》的预测数据%s", k, book.getName(), |
||||
Arrays.toString(MatrixUtil.toVector(pre)))); |
||||
return (int) (MatrixUtil.sum(pre) + MatrixUtil.sum(target)); |
||||
} |
||||
|
||||
private double[][] getOut(double[][] data) { |
||||
/*return MatrixUtil.add(MatrixUtil.dot(MatrixUtil.pow(data, 2), weights2), |
||||
MatrixUtil.dot(data, weights1));*/ |
||||
return MatrixUtil.dot(data, weights1); |
||||
} |
||||
|
||||
//准备训练数据
|
||||
private boolean preData() { |
||||
rand.setSeed(10); |
||||
List<Chapter> catheChapters = new ArrayList<>(); |
||||
List<Chapter> unCatheChapters = new ArrayList<>(); |
||||
//章节最长标题长度
|
||||
int maxTitleLen = 0; |
||||
//获取已缓存章节
|
||||
for (Chapter chapter : chapters) { |
||||
if (ChapterService.isChapterCached(book.getId(), chapter.getTitle())) { |
||||
catheChapters.add(chapter); |
||||
} else { |
||||
unCatheChapters.add(chapter); |
||||
} |
||||
if (maxTitleLen < chapter.getTitle().length()) { |
||||
maxTitleLen = chapter.getTitle().length(); |
||||
} |
||||
} |
||||
Log.i(TAG, String.format("《%s》已缓存章节数量:%d,最大章节标题长度:%d", |
||||
book.getName(), catheChapters.size(), maxTitleLen)); |
||||
if (catheChapters.size() <= 10) return false; |
||||
//创建训练数据
|
||||
trainData = new double[catheChapters.size()][maxTitleLen + 1]; |
||||
//创建测试数据
|
||||
testData = new double[chapters.size() - catheChapters.size()][maxTitleLen + 1]; |
||||
//创建权重矩阵
|
||||
weights1 = new double[maxTitleLen + 1][1]; |
||||
//weights2 = new double[maxTitleLen + 1][1];
|
||||
//创建目标矩阵
|
||||
target = new double[catheChapters.size()][1]; |
||||
for (int i = 0; i < catheChapters.size(); i++) { |
||||
Chapter chapter = catheChapters.get(i); |
||||
char[] charArr = chapter.getTitle().replaceAll("[((【{]", "").toCharArray(); |
||||
for (int j = 0; j < charArr.length; j++) { |
||||
trainData[i][j] = charArr[j]; |
||||
} |
||||
trainData[i][maxTitleLen] = 1; |
||||
target[i][0] = ChapterService.countChar(book.getId(), chapter.getTitle()); |
||||
} |
||||
for (int i = 0; i < maxTitleLen + 1; i++) { |
||||
weights1[i][0] = rand.nextDouble(); |
||||
//weights2[i][0] = Math.random();
|
||||
} |
||||
for (int i = 0; i < unCatheChapters.size(); i++) { |
||||
Chapter chapter = unCatheChapters.get(i); |
||||
char[] charArr = chapter.getTitle().toCharArray(); |
||||
for (int j = 0; j < charArr.length; j++) { |
||||
testData[i][j] = charArr[j]; |
||||
} |
||||
testData[i][maxTitleLen] = 1; |
||||
} |
||||
/*double[] tem = MatrixUtil.toVector(target); |
||||
Arrays.sort(tem); |
||||
median = tem[tem.length / 2 + 1];*/ |
||||
median = MatrixUtil.sum(target) / target.length; |
||||
return true; |
||||
} |
||||
|
||||
|
||||
public static <T> void swap(T[] a, int i, int j) { |
||||
T temp = a[i]; |
||||
a[i] = a[j]; |
||||
a[j] = temp; |
||||
} |
||||
|
||||
public static <T> void shuffle(T[]... arr) { |
||||
int length = arr[0].length; |
||||
for (int i = length; i > 0; i--) { |
||||
int randInd = rand.nextInt(i); |
||||
for (T[] ts : arr) { |
||||
swap(ts, randInd, i - 1); |
||||
} |
||||
} |
||||
} |
||||
} |
@ -1,295 +0,0 @@ |
||||
package xyz.fycz.myreader.ai; |
||||
|
||||
/** |
||||
* @author fengyue |
||||
* @date 2021/4/7 16:10 |
||||
*/ |
||||
public class MatrixUtil { |
||||
//矩阵加法 C=A+B
|
||||
public static double[][] add(double[][] m1, double[][] m2) { |
||||
if (m1 == null || m2 == null || |
||||
m1.length != m2.length || |
||||
m1[0].length != m2[0].length) { |
||||
return null; |
||||
} |
||||
|
||||
double[][] m = new double[m1.length][m1[0].length]; |
||||
|
||||
for (int i = 0; i < m.length; ++i) { |
||||
for (int j = 0; j < m[i].length; ++j) { |
||||
m[i][j] = m1[i][j] + m2[i][j]; |
||||
} |
||||
} |
||||
|
||||
return m; |
||||
} |
||||
|
||||
public static double[][] add(double[][] m, double a) { |
||||
if (m == null) { |
||||
return null; |
||||
} |
||||
|
||||
double[][] retM = new double[m.length][m[0].length]; |
||||
|
||||
for (int i = 0; i < retM.length; ++i) { |
||||
for (int j = 0; j < retM[i].length; ++j) { |
||||
retM[i][j] = m[i][j] + a; |
||||
} |
||||
} |
||||
|
||||
return retM; |
||||
} |
||||
|
||||
public static double[][] sub(double[][] m1, double[][] m2) { |
||||
if (m1 == null || m2 == null || |
||||
m1.length != m2.length || |
||||
m1[0].length != m2[0].length) { |
||||
return null; |
||||
} |
||||
|
||||
double[][] m = new double[m1.length][m1[0].length]; |
||||
|
||||
for (int i = 0; i < m.length; ++i) { |
||||
for (int j = 0; j < m[i].length; ++j) { |
||||
m[i][j] = m1[i][j] - m2[i][j]; |
||||
} |
||||
} |
||||
|
||||
return m; |
||||
} |
||||
|
||||
//矩阵转置
|
||||
public static double[][] transpose(double[][] m) { |
||||
if (m == null) return null; |
||||
double[][] mt = new double[m[0].length][m.length]; |
||||
for (int i = 0; i < m.length; ++i) { |
||||
for (int j = 0; j < m[i].length; ++j) { |
||||
mt[j][i] = m[i][j]; |
||||
} |
||||
} |
||||
return mt; |
||||
} |
||||
|
||||
//矩阵相乘 C=A*B
|
||||
public static double[][] dot(double[][] m1, double[][] m2) { |
||||
if (m1 == null || m2 == null || m1[0].length != m2.length) |
||||
return null; |
||||
|
||||
double[][] m = new double[m1.length][m2[0].length]; |
||||
for (int i = 0; i < m1.length; ++i) { |
||||
for (int j = 0; j < m2[0].length; ++j) { |
||||
for (int k = 0; k < m1[i].length; ++k) { |
||||
m[i][j] += m1[i][k] * m2[k][j]; |
||||
} |
||||
} |
||||
} |
||||
|
||||
return m; |
||||
} |
||||
|
||||
//数乘矩阵
|
||||
public static double[][] dot(double[][] m, double k) { |
||||
if (m == null) return null; |
||||
double[][] retM = new double[m.length][m[0].length]; |
||||
for (int i = 0; i < m.length; i++) { |
||||
for (int j = 0; j < m[0].length; j++) { |
||||
retM[i][j] = m[i][j] * k; |
||||
} |
||||
} |
||||
return retM; |
||||
} |
||||
|
||||
//同型矩阵除法
|
||||
public static double[][] divide(double[][] m1, double[][] m2) { |
||||
if (m1 == null || m2 == null || |
||||
m1.length != m2.length || |
||||
m1[0].length != m2[0].length) { |
||||
return null; |
||||
} |
||||
double[][] retM = new double[m1.length][m1[0].length]; |
||||
for (int i = 0; i < retM.length; ++i) { |
||||
for (int j = 0; j < retM[i].length; ++j) { |
||||
retM[i][j] = m1[i][j] / m2[i][j]; |
||||
} |
||||
} |
||||
return retM; |
||||
} |
||||
|
||||
//矩阵除数
|
||||
public static double[][] divide(double[][] m, double k) { |
||||
if (m == null) return null; |
||||
double[][] retM = new double[m.length][m[0].length]; |
||||
for (int i = 0; i < m.length; i++) { |
||||
for (int j = 0; j < m[0].length; j++) { |
||||
retM[i][j] = m[i][j] / k; |
||||
} |
||||
} |
||||
return retM; |
||||
} |
||||
|
||||
//求矩阵行列式(需为方阵)
|
||||
public static double det(double[][] m) { |
||||
if (m == null || m.length != m[0].length) |
||||
return 0; |
||||
|
||||
if (m.length == 1) |
||||
return m[0][0]; |
||||
else if (m.length == 2) |
||||
return det2(m); |
||||
else if (m.length == 3) |
||||
return det3(m); |
||||
else { |
||||
int re = 0; |
||||
for (int i = 0; i < m.length; ++i) { |
||||
re += (((i + 1) % 2) * 2 - 1) * det(companion(m, i, 0)) * m[i][0]; |
||||
} |
||||
return re; |
||||
} |
||||
} |
||||
|
||||
//求二阶行列式
|
||||
public static double det2(double[][] m) { |
||||
if (m == null || m.length != 2 || m[0].length != 2) |
||||
return 0; |
||||
|
||||
return m[0][0] * m[1][1] - m[1][0] * m[0][1]; |
||||
} |
||||
|
||||
//求三阶行列式
|
||||
public static double det3(double[][] m) { |
||||
if (m == null || m.length != 3 || m[0].length != 3) |
||||
return 0; |
||||
|
||||
double re = 0; |
||||
for (int i = 0; i < 3; ++i) { |
||||
int temp1 = 1; |
||||
for (int j = 0, k = i; j < 3; ++j, ++k) { |
||||
temp1 *= m[j][k % 3]; |
||||
} |
||||
re += temp1; |
||||
temp1 = 1; |
||||
for (int j = 0, k = i; j < 3; ++j, --k) { |
||||
if (k < 0) k += 3; |
||||
temp1 *= m[j][k]; |
||||
} |
||||
re -= temp1; |
||||
} |
||||
|
||||
return re; |
||||
} |
||||
|
||||
//求矩阵的逆(需方阵)
|
||||
public static double[][] inv(double[][] m) { |
||||
if (m == null || m.length != m[0].length) |
||||
return null; |
||||
|
||||
double A = det(m); |
||||
double[][] mi = new double[m.length][m[0].length]; |
||||
for (int i = 0; i < m.length; ++i) { |
||||
for (int j = 0; j < m[i].length; ++j) { |
||||
double[][] temp = companion(m, i, j); |
||||
mi[j][i] = (((i + j + 1) % 2) * 2 - 1) * det(temp) / A; |
||||
} |
||||
} |
||||
|
||||
return mi; |
||||
} |
||||
|
||||
//求方阵代数余子式
|
||||
public static double[][] companion(double[][] m, int x, int y) { |
||||
if (m == null || m.length <= x || m[0].length <= y || |
||||
m.length == 1 || m[0].length == 1) |
||||
return null; |
||||
|
||||
double[][] cm = new double[m.length - 1][m[0].length - 1]; |
||||
|
||||
int dx = 0; |
||||
for (int i = 0; i < m.length; ++i) { |
||||
if (i != x) { |
||||
int dy = 0; |
||||
for (int j = 0; j < m[i].length; ++j) { |
||||
if (j != y) { |
||||
cm[dx][dy++] = m[i][j]; |
||||
} |
||||
} |
||||
++dx; |
||||
} |
||||
} |
||||
return cm; |
||||
} |
||||
|
||||
//生成全为0的矩阵
|
||||
public static double[][] zeros(int rows, int cols){ |
||||
return new double[rows][cols]; |
||||
} |
||||
|
||||
//生成全为1的矩阵
|
||||
public static double[][] ones(int rows, int cols){ |
||||
return add(zeros(rows, cols), 1); |
||||
} |
||||
|
||||
public static double sum(double[][] matrix){ |
||||
double sum = 0; |
||||
for (double[] doubles : matrix) { |
||||
for (double aDouble : doubles) { |
||||
sum += aDouble; |
||||
} |
||||
} |
||||
return sum; |
||||
} |
||||
|
||||
public static double[][] pow(double[][] matrix, int exponent){ |
||||
if (matrix == null) return null; |
||||
double[][] retM = new double[matrix.length][matrix[0].length]; |
||||
for (int i = 0; i < matrix.length; i++) { |
||||
for (int j = 0; j < matrix[i].length; j++) { |
||||
retM[i][j] = Math.pow(matrix[i][j], exponent); |
||||
} |
||||
} |
||||
return retM; |
||||
} |
||||
|
||||
public static double[][] sqrt(double[][] matrix){ |
||||
if (matrix == null) return null; |
||||
double[][] retM = new double[matrix.length][matrix[0].length]; |
||||
for (int i = 0; i < matrix.length; i++) { |
||||
for (int j = 0; j < matrix[i].length; j++) { |
||||
retM[i][j] = Math.sqrt(matrix[i][j]); |
||||
} |
||||
} |
||||
return matrix; |
||||
} |
||||
|
||||
public static double[][] to2dMatrix(double[] vector, boolean isCol){ |
||||
if (vector == null) return null; |
||||
double[][] retM; |
||||
if (isCol) { |
||||
retM = new double[vector.length][1]; |
||||
}else { |
||||
retM = new double[1][vector.length]; |
||||
} |
||||
for (int i = 0; i < vector.length; i++) { |
||||
if (isCol) { |
||||
retM[i][0] = vector[i]; |
||||
}else { |
||||
retM[0][i] = vector[i]; |
||||
} |
||||
} |
||||
return retM; |
||||
} |
||||
|
||||
public static double[] toVector(double[][] matrix){ |
||||
double[] retV = null; |
||||
if (matrix.length == 1){ |
||||
retV = new double[matrix[0].length]; |
||||
double[] doubles = matrix[0]; |
||||
System.arraycopy(doubles, 0, retV, 0, doubles.length); |
||||
}else if (matrix[0].length == 1){ |
||||
retV = new double[matrix.length]; |
||||
for (int i = 0; i < matrix.length; i++) { |
||||
retV[i] = matrix[i][0]; |
||||
} |
||||
} |
||||
return retV; |
||||
} |
||||
} |
Loading…
Reference in new issue