日韩av黄I国产麻豆传媒I国产91av视频在线观看I日韩一区二区三区在线看I美女国产在线I麻豆视频国产在线观看I成人黄色短片

歡迎訪問 生活随笔!

生活随笔

當前位置: 首頁 >

适用于python机器学习与实践的twenty_newsgroups.py文件内容

發布時間:2023/12/20 35 豆豆
生活随笔 收集整理的這篇文章主要介紹了 适用于python机器学习与实践的twenty_newsgroups.py文件内容 小編覺得挺不錯的,現在分享給大家,幫大家做個參考.

路徑為:

D:\software\python27\Lib\site-packages\sklearn\datasets

替換twenty_newsgroups.py中的內容如下:




"""Caching loader for the 20 newsgroups text classification datasetThe description of the dataset is available on the official website at:http://people.csail.mit.edu/jrennie/20Newsgroups/Quoting the introduction:The 20 Newsgroups data set is a collection of approximately 20,000newsgroup documents, partitioned (nearly) evenly across 20 differentnewsgroups. To the best of my knowledge, it was originally collectedby Ken Lang, probably for his Newsweeder: Learning to filter netnewspaper, though he does not explicitly mention this collection. The 20newsgroups collection has become a popular data set for experimentsin text applications of machine learning techniques, such as textclassification and text clustering.This dataset loader will download the recommended "by date" variant of the dataset and which features a point in time split between the train and test sets. The compressed dataset size is around 14 Mb compressed. Once uncompressed the train set is 52 MB and the test set is 34 MB.The data is downloaded, extracted and cached in the '~/scikit_learn_data' folder.The `fetch_20newsgroups` function will not vectorize the data into numpy arrays but the dataset lists the filenames of the posts and their categories as target labels.The `fetch_20newsgroups_tfidf` function will in addition do a simple tf-idf vectorization step.""" # Copyright (c) 2011 Olivier Grisel <olivier.grisel@ensta.org> # License: BSD 3 clauseimport os import logging import tarfile import pickle import shutil import re import codecsimport numpy as np import scipy.sparse as spfrom .base import get_data_home from .base import Bunch from .base import load_files from ..utils import check_random_state from ..feature_extraction.text import CountVectorizer from ..preprocessing import normalize from ..externals import joblib, sixif six.PY3:from urllib.request import urlopen else:from urllib2 import urlopenlogger = logging.getLogger(__name__)URL = ("http://people.csail.mit.edu/jrennie/""20Newsgroups/20news-bydate.tar.gz") ARCHIVE_NAME = "20news-bydate.tar.gz" CACHE_NAME = "20news-bydate.pkz" TRAIN_FOLDER = "20news-bydate-train" TEST_FOLDER = "20news-bydate-test"def download_20newsgroups(target_dir, cache_path):"""Download the 20 newsgroups data and stored it as a zipped pickle."""archive_path = os.path.join(target_dir, ARCHIVE_NAME)train_path = os.path.join(target_dir, TRAIN_FOLDER)test_path = os.path.join(target_dir, TEST_FOLDER)# if not os.path.exists(target_dir):# os.makedirs(target_dir)## if os.path.exists(archive_path):# # Download is not complete as the .tar.gz file is removed after# # download.# logger.warn("Download was incomplete, downloading again.")# os.remove(archive_path)# logger.warn("Downloading dataset from %s (14 MB)", URL)# opener = urlopen(URL)# open(archive_path, 'wb').write(opener.read())logger.info("Decompressing %s", archive_path)tarfile.open(archive_path, "r:gz").extractall(path=target_dir)os.remove(archive_path)# Store a zipped picklecache = dict(train=load_files(train_path, encoding='latin1'),test=load_files(test_path, encoding='latin1'))compressed_content = codecs.encode(pickle.dumps(cache), 'zlib_codec')open(cache_path, 'wb').write(compressed_content)shutil.rmtree(target_dir)return cachedef strip_newsgroup_header(text):"""Given text in "news" format, strip the headers, by removing everythingbefore the first blank line."""_before, _blankline, after = text.partition('\n\n')return after_QUOTE_RE = re.compile(r'(writes in|writes:|wrote:|says:|said:'r'|^In article|^Quoted from|^\||^>)')def strip_newsgroup_quoting(text):"""Given text in "news" format, strip lines beginning with the quotecharacters > or |, plus lines that often introduce a quoted section(for example, because they contain the string 'writes:'.)"""good_lines = [line for line in text.split('\n')if not _QUOTE_RE.search(line)]return '\n'.join(good_lines)def strip_newsgroup_footer(text):"""Given text in "news" format, attempt to remove a signature block.As a rough heuristic, we assume that signatures are set apart by eithera blank line or a line made of hyphens, and that it is the last such linein the file (disregarding blank lines at the end)."""lines = text.strip().split('\n')for line_num in range(len(lines) - 1, -1, -1):line = lines[line_num]if line.strip().strip('-') == '':breakif line_num > 0:return '\n'.join(lines[:line_num])else:return textdef fetch_20newsgroups(data_home=None, subset='train', categories=None,shuffle=True, random_state=42,remove=(),download_if_missing=True):"""Load the filenames and data from the 20 newsgroups dataset.Parameters----------subset: 'train' or 'test', 'all', optionalSelect the dataset to load: 'train' for the training set, 'test'for the test set, 'all' for both, with shuffled ordering.data_home: optional, default: NoneSpecify an download and cache folder for the datasets. If None,all scikit-learn data is stored in '~/scikit_learn_data' subfolders.categories: None or collection of string or unicodeIf None (default), load all the categories.If not None, list of category names to load (other categoriesignored).shuffle: bool, optionalWhether or not to shuffle the data: might be important for models thatmake the assumption that the samples are independent and identicallydistributed (i.i.d.), such as stochastic gradient descent.random_state: numpy random number generator or seed integerUsed to shuffle the dataset.download_if_missing: optional, True by defaultIf False, raise an IOError if the data is not locally availableinstead of trying to download the data from the source site.remove: tupleMay contain any subset of ('headers', 'footers', 'quotes'). Each ofthese are kinds of text that will be detected and removed from thenewsgroup posts, preventing classifiers from overfitting onmetadata.'headers' removes newsgroup headers, 'footers' removes blocks at theends of posts that look like signatures, and 'quotes' removes linesthat appear to be quoting another post.'headers' follows an exact standard; the other filters are not alwayscorrect."""data_home = get_data_home(data_home=data_home)cache_path = os.path.join(data_home, CACHE_NAME)twenty_home = os.path.join(data_home, "20news_home")cache = Noneif os.path.exists(cache_path):try:with open(cache_path, 'rb') as f:compressed_content = f.read()uncompressed_content = codecs.decode(compressed_content, 'zlib_codec')cache = pickle.loads(uncompressed_content)except Exception as e:print(80 * '_')print('Cache loading failed')print(80 * '_')print(e)if cache is None:if download_if_missing:cache = download_20newsgroups(target_dir=twenty_home,cache_path=cache_path)else:raise IOError('20Newsgroups dataset not found')if subset in ('train', 'test'):data = cache[subset]elif subset == 'all':data_lst = list()target = list()filenames = list()for subset in ('train', 'test'):data = cache[subset]data_lst.extend(data.data)target.extend(data.target)filenames.extend(data.filenames)data.data = data_lstdata.target = np.array(target)data.filenames = np.array(filenames)data.description = 'the 20 newsgroups by date dataset'else:raise ValueError("subset can only be 'train', 'test' or 'all', got '%s'" % subset)if 'headers' in remove:data.data = [strip_newsgroup_header(text) for text in data.data]if 'footers' in remove:data.data = [strip_newsgroup_footer(text) for text in data.data]if 'quotes' in remove:data.data = [strip_newsgroup_quoting(text) for text in data.data]if categories is not None:labels = [(data.target_names.index(cat), cat) for cat in categories]# Sort the categories to have the ordering of the labelslabels.sort()labels, categories = zip(*labels)mask = np.in1d(data.target, labels)data.filenames = data.filenames[mask]data.target = data.target[mask]# searchsorted to have continuous labelsdata.target = np.searchsorted(labels, data.target)data.target_names = list(categories)# Use an object array to shuffle: avoids memory copydata_lst = np.array(data.data, dtype=object)data_lst = data_lst[mask]data.data = data_lst.tolist()if shuffle:random_state = check_random_state(random_state)indices = np.arange(data.target.shape[0])random_state.shuffle(indices)data.filenames = data.filenames[indices]data.target = data.target[indices]# Use an object array to shuffle: avoids memory copydata_lst = np.array(data.data, dtype=object)data_lst = data_lst[indices]data.data = data_lst.tolist()return datadef fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None):"""Load the 20 newsgroups dataset and transform it into tf-idf vectors.This is a convenience function; the tf-idf transformation is done using thedefault settings for `sklearn.feature_extraction.text.Vectorizer`. For moreadvanced usage (stopword filtering, n-gram extraction, etc.), combinefetch_20newsgroups with a custom `Vectorizer` or `CountVectorizer`.Parameters----------subset: 'train' or 'test', 'all', optionalSelect the dataset to load: 'train' for the training set, 'test'for the test set, 'all' for both, with shuffled ordering.data_home: optional, default: NoneSpecify an download and cache folder for the datasets. If None,all scikit-learn data is stored in '~/scikit_learn_data' subfolders.remove: tupleMay contain any subset of ('headers', 'footers', 'quotes'). Each ofthese are kinds of text that will be detected and removed from thenewsgroup posts, preventing classifiers from overfitting onmetadata.'headers' removes newsgroup headers, 'footers' removes blocks at theends of posts that look like signatures, and 'quotes' removes linesthat appear to be quoting another post.Returns-------bunch : Bunch objectbunch.data: sparse matrix, shape [n_samples, n_features]bunch.target: array, shape [n_samples]bunch.target_names: list, length [n_classes]"""data_home = get_data_home(data_home=data_home)filebase = '20newsgroup_vectorized'if remove:filebase += 'remove-' + ('-'.join(remove))target_file = os.path.join(data_home, filebase + ".pk")# we shuffle but use a fixed seed for the memoizationdata_train = fetch_20newsgroups(data_home=data_home,subset='train',categories=None,shuffle=True,random_state=12,remove=remove)data_test = fetch_20newsgroups(data_home=data_home,subset='test',categories=None,shuffle=True,random_state=12,remove=remove)if os.path.exists(target_file):X_train, X_test = joblib.load(target_file)else:vectorizer = CountVectorizer(dtype=np.int16)X_train = vectorizer.fit_transform(data_train.data).tocsr()X_test = vectorizer.transform(data_test.data).tocsr()joblib.dump((X_train, X_test), target_file, compress=9)# the data is stored as int16 for compactness# but normalize needs floatsX_train = X_train.astype(np.float64)X_test = X_test.astype(np.float64)normalize(X_train, copy=False)normalize(X_test, copy=False)target_names = data_train.target_namesif subset == "train":data = X_traintarget = data_train.targetelif subset == "test":data = X_testtarget = data_test.targetelif subset == "all":data = sp.vstack((X_train, X_test)).tocsr()target = np.concatenate((data_train.target, data_test.target))else:raise ValueError("%r is not a valid subset: should be one of ""['train', 'test', 'all']" % subset)return Bunch(data=data, target=target, target_names=target_names)

總結

以上是生活随笔為你收集整理的适用于python机器学习与实践的twenty_newsgroups.py文件内容的全部內容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。