Переглянути джерело

support non-Chinese. version 1.1.0

chenhaiyang 4 роки тому
батько
коміт
ebaf76ec46
3 змінених файлів з 63 додано та 27 видалено
  1. 3 2
      setup.py
  2. 53 23
      sim/text_sim.py
  3. 7 2
      tests/test_text_sim.py

+ 3 - 2
setup.py

@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 
 import pathlib
+
 from setuptools import setup
 
 HERE = pathlib.Path(__file__).parent
@@ -9,10 +10,10 @@ README = (HERE / 'README.md').read_text()
 
 setup(
     name='short-text-sim',
-    version='1.0.0',
+    version='1.1.0',
     author='Chen Haiyang',
     license='MIT',
     packages=['sim'],
     include_package_data=True,
-    install_requires=['jieba', 'numpy'],
+    install_requires=['jieba', 'numpy', 'scikit-learn'],
 )

+ 53 - 23
sim/text_sim.py

@@ -6,6 +6,7 @@ from operator import itemgetter
 
 import jieba
 import numpy as np
+from sklearn.feature_extraction.text import CountVectorizer
 
 from .words_sim import SimCilin
 
@@ -36,29 +37,42 @@ def get_similarity(s1, s2):
     :param s2:
     :return:
     """
-    all_sim_1 = list()
-    for w1 in s1:
-        if is_contains_chinese(w1):
-            sim_list = list()
-            for w2 in s2:
-                sim_list.append(ci_lin.compute_word_sim(w1, w2))
-            sim_list.sort()
-            all_sim_1.append(sim_list[-1])
-
-    all_sim_2 = list()
-    for w1 in s2:
-        if is_contains_chinese(w1):
-            sim_list = list()
-            for w2 in s1:
-                sim_list.append(ci_lin.compute_word_sim(w1, w2))
-            sim_list.sort()
-            all_sim_2.append(sim_list[-1])
-
-    if not all_sim_1:
-        all_sim_1 = [0]
-    if not all_sim_2:
-        all_sim_2 = [0]
-    return (np.mean(all_sim_1) + np.mean(all_sim_2)) / 2
+    if len(s1) == 0 or len(s2) == 0:
+        return 0.0
+
+    chinese_list1 = []
+    chinese_list2 = []
+    chinese_str1 = ''
+    chinese_str2 = ''
+    non_chinese_str1 = ''
+    non_chinese_str2 = ''
+    for word in s1:
+        if is_contains_chinese(word):
+            chinese_list1.append(word)
+            chinese_str1 += word
+        else:
+            non_chinese_str1 += word
+    for word in s2:
+        if is_contains_chinese(word):
+            chinese_list2.append(word)
+            chinese_str2 += word
+        else:
+            non_chinese_str2 += word
+
+    sim_matrix = np.zeros((len(chinese_list1), len(chinese_list2)), dtype=float)
+    for i in range(len(chinese_list1)):
+        for j in range(len(chinese_list2)):
+            sim_matrix[i, j] = ci_lin.compute_word_sim(chinese_list1[i], chinese_list2[j])
+    chinese_sim = 0.0
+    if sim_matrix.any():
+        chinese_sim = (np.max(sim_matrix, axis=0).mean() + np.max(sim_matrix, axis=1).mean()) / 2
+
+    non_chinese_sim = jaccard_similarity(non_chinese_str1, non_chinese_str2)
+
+    chinese_len = len(chinese_str1 + chinese_str2)
+    non_chinese_len = len(non_chinese_str1 + non_chinese_str2)
+
+    return (chinese_sim * chinese_len + non_chinese_sim * non_chinese_len) / (chinese_len + non_chinese_len)
 
 
 def most_similar_items(src_s, sentences, n=3):
@@ -99,3 +113,19 @@ def merge(word_list):
     for w in word_list:
         s += w.split('/')[0]
     return s
+
+
+def jaccard_similarity(s1, s2):
+    if len(s1) == 0 or len(s2) == 0:
+        return 0
+
+    def add_space(s):
+        return ' '.join(list(s))
+
+    s1, s2 = add_space(s1), add_space(s2)
+    cv = CountVectorizer(tokenizer=lambda s: s.split())
+    corpus = [s1, s2]
+    vectors = cv.fit_transform(corpus).toarray()
+    numerator = np.sum(np.min(vectors, axis=0))
+    denominator = np.sum(np.max(vectors, axis=0))
+    return 1.0 * numerator / denominator

+ 7 - 2
tests/test_text_sim.py

@@ -2,10 +2,15 @@
 
 """Test for text sim."""
 
-from sim.text_sim import most_similar_items
+from sim.text_sim import most_similar_items, get_similarity
+from sim.words_sim import SimCilin
+
+ci_lin = SimCilin()
 
 if __name__ == '__main__':
     str1 = '我喜欢吃苹果'
     str2 = ['他喜欢肯地瓜']
-    str_l = ['我喜欢吃梨', '你喜欢吃苹果', 'unbelievable', 'GT-faf-我和你']
+    str_l = ['我喜欢吃梨', '你喜欢吃苹果', 'unbelievable', 'GT-faf-我和你', '']
+
     print(most_similar_items(str1, str_l, 5))
+    print(get_similarity('', ''))