Files
the-trial/scripts/cpu_training.py

279 lines
10 KiB
Python

#!/usr/bin/env python3
"""
CPU-Compatible Training Script for The Trial SLM
Simplified approach that works without GPU
"""
import json
import logging
import os
from pathlib import Path
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SimpleMonteCristoTrainer:
"""Simplified trainer that creates knowledge base and prompts"""
def __init__(self, data_dir: str = "data", output_dir: str = "models"):
self.data_dir = Path(data_dir)
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
def load_training_data(self):
"""Load training datasets"""
datasets = {}
# Load factual Q&A
qa_file = self.data_dir / "training" / "factual_qa.json"
if qa_file.exists():
with open(qa_file, "r", encoding="utf-8") as f:
datasets["factual"] = json.load(f)
# Load analysis data
analysis_file = self.data_dir / "training" / "literary_analysis.json"
if analysis_file.exists():
with open(analysis_file, "r", encoding="utf-8") as f:
datasets["analysis"] = json.load(f)
# Load creative writing
creative_file = self.data_dir / "training" / "creative_writing.json"
if creative_file.exists():
with open(creative_file, "r", encoding="utf-8") as f:
datasets["creative"] = json.load(f)
return datasets
def create_knowledge_base(self, datasets):
"""Create structured knowledge base"""
knowledge_base = {
"characters": {},
"themes": {},
"plot_points": {},
"symbols": {},
"style_elements": {},
}
# Process factual data for characters and plot
for item in datasets.get("factual", []):
if "character" in item:
char = item["character"]
if char not in knowledge_base["characters"]:
knowledge_base["characters"][char] = {
"questions": [],
"answers": [],
}
knowledge_base["characters"][char]["questions"].append(
item["instruction"]
)
knowledge_base["characters"][char]["answers"].append(item["output"])
if "topic" in item and item["topic"] == "plot":
knowledge_base["plot_points"][item["instruction"]] = item["output"]
# Process analysis data for themes and symbols
for item in datasets.get("analysis", []):
if "theme" in item:
theme = item["theme"]
if theme not in knowledge_base["themes"]:
knowledge_base["themes"][theme] = []
knowledge_base["themes"][theme].append(item["output"])
if "symbol" in item:
symbol = item["symbol"]
if symbol not in knowledge_base["symbols"]:
knowledge_base["symbols"][symbol] = []
knowledge_base["symbols"][symbol].append(item["output"])
# Process creative data for style
for item in datasets.get("creative", []):
if "style" in item:
style = item["style"]
if style not in knowledge_base["style_elements"]:
knowledge_base["style_elements"][style] = []
knowledge_base["style_elements"][style].append(item["output"])
return knowledge_base
def create_system_prompts(self):
"""Create system prompts for different contexts"""
system_prompts = {
"default": 'You are a specialized AI assistant expert on "The Trial" by Alexandre Dumas. You have deep knowledge of the novel\'s plot, characters, themes, historical context, and literary significance. Provide accurate, insightful, and engaging responses about all aspects of this classic work of literature.',
"factual": 'You provide factual information about "The Trial". Focus on accurate details about plot events, character descriptions, historical context, and verifiable information from the novel. Be precise and cite specific chapters or events when possible.',
"analysis": 'You provide literary analysis of "The Trial". Focus on themes, symbolism, narrative techniques, character development, and the work\'s place in literary history. Offer insightful interpretations supported by textual evidence.',
"creative": 'You write in the style of Alexandre Dumas and "The Trial". Use dramatic language, romantic adventure elements, rich descriptions, and the narrative voice characteristic of 19th-century French literature.',
}
return system_prompts
def create_ollama_modelfile(self, knowledge_base, system_prompts):
"""Create Ollama Modelfile"""
# Start building the Modelfile content
lines = [
"# The Trial Literary Analysis SLM",
"# Based on llama3.2:3b with specialized knowledge",
"",
"FROM llama3.2:3b",
"",
f"# System prompt",
f"SYSTEM {system_prompts['default']}",
"",
"# Parameters for better literary analysis",
"PARAMETER temperature 0.7",
"PARAMETER top_p 0.9",
"PARAMETER top_k 40",
"PARAMETER repeat_penalty 1.1",
"",
"# Context window for longer passages",
"PARAMETER num_ctx 4096",
"",
"# The Trial Knowledge Base",
]
# Add knowledge sections as comments
lines.extend(
[
"# Character Information:",
f"# {json.dumps(knowledge_base['characters'], indent=2)}",
"",
"# Theme Analysis:",
f"# {json.dumps(knowledge_base['themes'], indent=2)}",
"",
"# Plot Points:",
f"# {json.dumps(knowledge_base['plot_points'], indent=2)}",
"",
"# Symbolism:",
f"# {json.dumps(knowledge_base['symbols'], indent=2)}",
"",
"# Style Elements:",
f"# {json.dumps(knowledge_base['style_elements'], indent=2)}",
]
)
modelfile_content = "\n".join(lines)
# Save Modelfile
modelfile_path = self.output_dir / "Modelfile"
with open(modelfile_path, "w", encoding="utf-8") as f:
f.write(modelfile_content)
logger.info(f"Created Modelfile: {modelfile_path}")
return modelfile_path
def create_test_prompts(self):
"""Create test prompts for validation"""
test_prompts = [
{
"category": "factual",
"prompt": "Who is Edmond Dantès and what happens to him at the beginning of the novel?",
"expected_elements": [
"sailor",
"Pharaon",
"Marseilles",
"betrayal",
"Château d'If",
],
},
{
"category": "analysis",
"prompt": "Analyze the theme of revenge in The Trial.",
"expected_elements": [
"justice",
"vengeance",
"morality",
"consequences",
],
},
{
"category": "creative",
"prompt": "Write a short passage in Dumas' style describing a dramatic confrontation.",
"expected_elements": ["dramatic", "romantic", "adventure", "emotional"],
},
]
test_file = self.output_dir / "test_prompts.json"
with open(test_file, "w", encoding="utf-8") as f:
json.dump(test_prompts, f, indent=2, ensure_ascii=False)
logger.info(f"Created test prompts: {test_file}")
return test_file
def train_model(self):
"""Execute simplified training process"""
logger.info("Starting simplified The Trial SLM training...")
# Load data
datasets = self.load_training_data()
logger.info(f"Loaded datasets: {list(datasets.keys())}")
# Create knowledge base
knowledge_base = self.create_knowledge_base(datasets)
logger.info("Created structured knowledge base")
# Create system prompts
system_prompts = self.create_system_prompts()
logger.info("Created system prompts")
# Create Ollama Modelfile
modelfile_path = self.create_ollama_modelfile(knowledge_base, system_prompts)
# Create test prompts
test_file = self.create_test_prompts()
# Create training summary
summary = {
"training_method": "cpu_knowledge_injection",
"datasets_used": list(datasets.keys()),
"total_examples": sum(len(items) for items in datasets.values()),
"knowledge_base_size": {
"characters": len(knowledge_base["characters"]),
"themes": len(knowledge_base["themes"]),
"plot_points": len(knowledge_base["plot_points"]),
"symbols": len(knowledge_base["symbols"]),
"style_elements": len(knowledge_base["style_elements"]),
},
"output_files": {
"modelfile": str(modelfile_path),
"test_prompts": str(test_file),
},
}
summary_file = self.output_dir / "training_summary.json"
with open(summary_file, "w", encoding="utf-8") as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
logger.info("CPU-based training completed successfully!")
logger.info(f"Training summary: {summary_file}")
return summary
def main():
"""Main training function"""
logger.info("The Trial SLM - CPU-Compatible Training")
logger.info("=" * 60)
trainer = SimpleMonteCristoTrainer()
try:
summary = trainer.train_model()
logger.info("=" * 60)
logger.info("TRAINING COMPLETED SUCCESSFULLY!")
logger.info("=" * 60)
logger.info("Next steps:")
logger.info("1. Test the model: ollama create the-trial -f models/Modelfile")
logger.info("2. Run the model: ollama run the-trial")
logger.info("3. Test with provided prompts")
logger.info("=" * 60)
except Exception as e:
logger.error(f"Training failed: {e}")
raise
if __name__ == "__main__":
main()