SCIKIT-LEARN 决策树实现csv文档简单的推论预测

news/2025/2/27 8:22:16

一.学习背景

        原文来自scikit-learn的学习拓展,根据樱花分类示例衍生而来。源文开源地址:scikit-learn: machine learning in Python — scikit-learn 0.16.1 documentation,想学机器学习和数据挖掘的可以去瞧瞧!

二.读取csv文档

df = pd.read_csv('./test_data.csv')
print("df: ", df)
# 假设目标变量是列名为 'target' 的列
#X = df.drop('target', axis=1)  # 特征axis:轴的方向,0为行,1为列,默认为0
X = df.drop(columns=['z'], inplace=False)
y = df['z']  # 目标变量

三.划分数据集

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_test2=[[2,10],[0,3],[3,3],[3,7],[0,6],[1,6],[1,3],[1,4]]
print("X_test: ", X_test)
print("y_test: ", y_test)

四. 创建决策树分类器 

clf = DecisionTreeClassifier(random_state=42)

 五.训练模型

clf.fit(X_train, y_train)

六.预测测试集

y_pred = clf.predict(X_test2)
print("y_pred: ", y_pred)

七.源码 

import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, accuracy_score
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris

# 读取CSV文件
df = pd.read_csv('./test_data.csv')
print("df: ", df)
# 假设目标变量是列名为 'target' 的列
#X = df.drop('target', axis=1)  # 特征axis:轴的方向,0为行,1为列,默认为0
X = df.drop(columns=['z'], inplace=False)
y = df['z']  # 目标变量

# iris = load_iris()
# X, y = iris.data, iris.target
# print("X: ", X)
# print("y: ", y)
# df = pd.DataFrame(X, columns=iris.feature_names)
# df['target'] = y
#print("df: ", df)

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_test2=[[2,10],[0,3],[3,3],[3,7],[0,6],[1,6],[1,3],[1,4]]
print("X_test: ", X_test)
print("y_test: ", y_test)


# 创建决策树分类器
clf = DecisionTreeClassifier(random_state=42)

# 训练模型
clf.fit(X_train, y_train)

# 预测测试集
y_pred = clf.predict(X_test2)
print("y_pred: ", y_pred)
#

八.推测演示结果

        源数据csv文档

 

根据数据x,y预测z的结果,预测[2,10],[0,3],[3,3],[3,7],[0,6],[1,6],[1,3],[1,4]结果,下面是推演过程。


http://www.niftyadmin.cn/n/5869792.html

相关文章

Java高频面试之SE-23

hello啊,各位观众姥爷们!!!本baby今天又来了!哈哈哈哈哈嗝🐶 Java 中的 Stream 是 Java 8 引入的一种全新的数据处理方式,它基于函数式编程思想,提供了一种高效、简洁且灵活的方式来…

Spark内存迭代计算

一、宽窄依赖 窄依赖:父RDD的一个分区数据全部发往子RDD的一个分区 宽依赖:父RDD的一个分区数据发往子RDD的多个分区,也称为shuffle 二、Spark是如何进行内存计算的?DAG的作用?Stage阶段划分的作用? &a…

WSL2下,向github进行push时出现timeout的问题

昨晚在完成15445 Project2.2后,笔者兴致冲冲地准备把代码提交到github上,谁知一连提交几次都出现 ssh:connect to host github.com port 22: Connection timed out 这个问题。我开始还以为是网络波动,测试了多次之后才发现应该是22端口出问题…

Linux:(3)

一:Linux和Linux互传(压缩包) scp:Linux scp 命令用于 Linux 之间复制文件和目录。 scp 是 secure copy 的缩写, scp 是 linux 系统下基于 ssh 登陆进行安全的远程文件拷贝命令。 scp 是加密的,rcp 是不加密的,scp 是…

【图形学入门笔记】线性代数的本质

【笔记未完待续】如果我的分享对你有帮助,请记得点赞关注不迷路。 视频源地址:https://www.youtube.com/watch?vfNk_zzaMoSs 作者:3Blue1Brown 此处仅做个人笔记使用。 01 - 向量究竟是什么? 线性代数中最基础、最根源的…

Tomcat 目录结构和应用实现

Tomcat 是一款开源的、轻量级的 Web 服务器,它不仅能够提供 HTTP 服务,还能够运行 Java Servlet 和 JavaServer Pages(JSP)。对于许多开发者来说,理解 Tomcat 的目录结构以及如何在该结构中组织应用,往往是…

解决Deepseek“服务器繁忙,请稍后再试”问题,基于硅基流动和chatbox的解决方案

文章目录 前言操作步骤步骤1:注册账号步骤2:在线体验步骤3:获取API密钥步骤4:安装chatbox步骤5:chatbox设置 价格方面 前言 最近在使用DeepSeek时,开启深度思考功能后,频繁遇到“服务器繁忙&am…

C++ Primer Plus第八章课后习题总结

1. 编写通常接受一个参数(字符串的地址),并打印该字符串的函数。然而,如果提供了第二个参数(int类型),且该参数不为0,则该函数打印字符串的次数将为该函数被调用的次数(注…