domainapi/main.py
2025-04-08 23:41:24 +02:00

465 lines
No EOL
17 KiB
Python

import csv
import re
import io
import datetime
from fastapi import FastAPI, Request, HTTPException, Query, UploadFile, File, Form
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import uvicorn
from pathlib import Path
from typing import List, Dict, Optional, Set
from tinydb import TinyDB, Query as TinyDBQuery
# Initialize FastAPI app
app = FastAPI(title="Domain Viewer", description="Upload and view domains and DNS records")
# Setup templates
templates = Jinja2Templates(directory="templates")
# Setup TinyDB with configurable path
import os
DB_DIR = os.environ.get('DB_DIR', '.')
DB_PATH = Path(f"{DB_DIR}/domains_db.json")
db = TinyDB(DB_PATH)
domains_table = db.table('domains')
dns_records_table = db.table('dns_records')
uploads_table = db.table('uploads')
# Process a domain entry and return its components
def process_domain_entry(domain_entry):
# Remove trailing dot if present
if domain_entry.endswith('.'):
domain_entry = domain_entry[:-1]
# Parse domain components
parts = domain_entry.split('.')
if len(parts) > 1:
# For domain.tld format
if len(parts) == 2:
sld = parts[0] # Second Level Domain
tld = parts[1] # Top Level Domain
domain_info = {
"sld": sld,
"tld": tld,
"full_domain": domain_entry
}
# For subdomain.domain.tld format
else:
sld = parts[-2] # Second Level Domain
tld = parts[-1] # Top Level Domain
subdomain = '.'.join(parts[:-2]) # Subdomains
domain_info = {
"sld": sld,
"tld": tld,
"full_domain": domain_entry,
"subdomain": subdomain
}
return domain_info
return None
# Upload CSV file and store in database
async def process_csv_upload(file_content, upload_id, description=None):
domains_to_insert = []
dns_records_to_insert = []
unique_domains = set()
timestamp = datetime.datetime.now().isoformat()
row_count = 0
try:
# Read CSV content
file_text = file_content.decode('utf-8')
csv_reader = csv.reader(io.StringIO(file_text))
# Print first few lines for debugging
preview_lines = file_text.split('\n')[:5]
print(f"CSV preview (first 5 lines):")
for i, line in enumerate(preview_lines):
print(f" Line {i+1}: {line}")
for row_num, row in enumerate(csv_reader, 1):
row_count += 1
if not row: # Skip empty rows
print(f"Skipping empty row at line {row_num}")
continue
# Extract domain from first column
domain_entry = row[0]
# Process domain for domains table
domain_info = process_domain_entry(domain_entry)
if domain_info:
# Create a unique key to avoid duplicates within this upload
unique_key = f"{domain_info['sld']}.{domain_info['tld']}"
if unique_key not in unique_domains:
unique_domains.add(unique_key)
# Add upload metadata
domain_info['upload_id'] = upload_id
domain_info['timestamp'] = timestamp
domains_to_insert.append(domain_info)
else:
print(f"Warning: Could not process domain entry at line {row_num}: {domain_entry}")
# Process DNS record if we have enough fields
if len(row) >= 5:
domain = row[0]
ttl = row[1]
record_class = row[2]
record_type = row[3]
record_data = ','.join(row[4:]) # Join remaining parts as record data
# Remove trailing dot from domain if present
if domain.endswith('.'):
domain = domain[:-1]
# Parse domain components
domain_parts = domain.split('.')
# Create entry
entry = {
"domain": domain,
"ttl": ttl,
"record_class": record_class,
"record_type": record_type,
"record_data": record_data,
"upload_id": upload_id,
"timestamp": timestamp
}
# Add domain components
if len(domain_parts) > 1:
if domain_parts[0].startswith('_'): # Service records like _dmarc
entry["service"] = domain_parts[0]
# Adjust domain parts
domain_parts = domain_parts[1:]
# For domain.tld format
if len(domain_parts) == 2:
entry["sld"] = domain_parts[0] # Second Level Domain
entry["tld"] = domain_parts[1] # Top Level Domain
# For subdomain.domain.tld format
elif len(domain_parts) > 2:
entry["sld"] = domain_parts[-2] # Second Level Domain
entry["tld"] = domain_parts[-1] # Top Level Domain
entry["subdomain"] = '.'.join(domain_parts[:-2]) # Subdomains
dns_records_to_insert.append(entry)
print(f"Processed {row_count} rows from CSV")
print(f"Records to insert: {len(domains_to_insert)} domains, {len(dns_records_to_insert)} DNS records")
# Insert data into tables
if domains_to_insert:
print(f"Inserting {len(domains_to_insert)} domains into database")
domains_table.insert_multiple(domains_to_insert)
else:
print("No domains to insert")
if dns_records_to_insert:
print(f"Inserting {len(dns_records_to_insert)} DNS records into database")
dns_records_table.insert_multiple(dns_records_to_insert)
else:
print("No DNS records to insert")
return len(domains_to_insert), len(dns_records_to_insert)
except Exception as e:
import traceback
print(f"Error processing CSV file: {e}")
print(traceback.format_exc())
return 0, 0
# Load domains from database - deduplicated by full domain name
def load_domains(specific_upload_id: str = None) -> List[Dict]:
try:
domains = domains_table.all()
# If a specific upload ID is provided, only show domains from that upload
if specific_upload_id:
domains = [d for d in domains if d.get('upload_id') == specific_upload_id]
return domains
# Sort by timestamp in descending order (newest first)
domains.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
# Create a dictionary to track unique domains by full domain name
unique_domains = {}
for domain in domains:
# Create a unique key based on the full domain name
unique_key = domain.get('full_domain', '')
# Only keep the most recent entry for each unique domain
if unique_key and unique_key not in unique_domains:
domain['is_latest'] = True
unique_domains[unique_key] = domain
# Return the deduplicated list
return list(unique_domains.values())
except Exception as e:
print(f"Error loading domains from database: {e}")
return []
# Load DNS entries from database - deduplicated by domain, class, and type (no history)
def load_dns_entries(specific_upload_id: str = None) -> List[Dict]:
try:
entries = dns_records_table.all()
# If a specific upload ID is provided, only show records from that upload
if specific_upload_id:
entries = [e for e in entries if e.get('upload_id') == specific_upload_id]
return entries
# Sort by timestamp in descending order (newest first)
entries.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
# Create a dictionary to track unique entries (most recent only)
unique_entries = {}
for entry in entries:
# Create a unique key based on domain, class, and type
unique_key = f"{entry.get('domain')}:{entry.get('record_class')}:{entry.get('record_type')}"
# Only keep the most recent entry for each unique combination
if unique_key not in unique_entries:
# Mark as most recent entry
entry['is_latest'] = True
unique_entries[unique_key] = entry
# Return the deduplicated list with only the most recent entries
return list(unique_entries.values())
except Exception as e:
print(f"Error loading DNS records from database: {e}")
return []
# Get unique values for filter dropdowns
def get_unique_values(entries: List[Dict]) -> Dict[str, Set]:
unique_values = {
"record_type": set(),
"record_class": set(),
"tld": set(),
"sld": set()
}
for entry in entries:
for key in unique_values.keys():
if key in entry and entry[key]:
unique_values[key].add(entry[key])
# Convert sets to sorted lists
return {k: sorted(list(v)) for k, v in unique_values.items()}
# Get all uploads
def get_uploads():
uploads = uploads_table.all()
# Sort uploads by timestamp (newest first)
uploads.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
return uploads
# Delete an upload and its associated data
def delete_upload(upload_id):
try:
# Remove the upload from uploads table
Upload = TinyDBQuery()
uploads_table.remove(Upload.id == upload_id)
# Remove associated domain and DNS records
Domain = TinyDBQuery()
domains_table.remove(Domain.upload_id == upload_id)
DNSRecord = TinyDBQuery()
dns_records_table.remove(DNSRecord.upload_id == upload_id)
return True
except Exception as e:
print(f"Error deleting upload {upload_id}: {e}")
return False
# Routes
@app.get("/", response_class=HTMLResponse)
async def home(request: Request, upload_id: Optional[str] = None):
"""Home page with upload form and SLD listing"""
domains = load_domains(upload_id)
uploads = get_uploads()
return templates.TemplateResponse(
"index.html",
{
"request": request,
"domains": domains,
"uploads": uploads
}
)
@app.get("/delete-upload/{upload_id}", response_class=RedirectResponse)
async def delete_upload_route(upload_id: str):
"""Delete an upload and all associated records"""
success = delete_upload(upload_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete upload")
# Redirect back to home page
return RedirectResponse(url="/", status_code=303)
@app.post("/upload", response_class=RedirectResponse)
async def upload_csv(request: Request, file: UploadFile = File(...), description: str = Form(None)):
"""Handle file upload"""
try:
# Read file content
content = await file.read()
# Ensure content is not empty
if not content or len(content) == 0:
raise ValueError("Uploaded file is empty")
# Generate a unique ID for this upload with timestamp and a random suffix for extra uniqueness
now = datetime.datetime.now()
import random
random_suffix = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=6))
upload_id = f"upload_{now.strftime('%Y%m%d_%H%M%S')}_{random_suffix}"
print(f"Processing upload: ID={upload_id}, Filename={file.filename}, Content size={len(content)} bytes")
# Process the CSV content
domains_count, records_count = await process_csv_upload(content, upload_id, description)
print(f"Upload processed: {domains_count} domains and {records_count} DNS records inserted")
if domains_count == 0 and records_count == 0:
print("WARNING: No records were inserted. Content may be invalid or empty.")
# Try to decode and print the first few lines for debugging
try:
preview = content.decode('utf-8')[:500]
print(f"File preview: {preview}")
except:
print("Could not decode file content for preview")
# Store upload information with file hash to identify content changes
import hashlib
content_hash = hashlib.md5(content).hexdigest()
upload_info = {
"id": upload_id,
"filename": file.filename,
"description": description,
"timestamp": datetime.datetime.now().isoformat(),
"domains_count": domains_count,
"records_count": records_count,
"content_hash": content_hash
}
uploads_table.insert(upload_info)
# Redirect back to home page
return RedirectResponse(url="/", status_code=303)
except Exception as e:
print(f"Error in upload_csv: {e}")
return {"error": str(e)}
@app.get("/dns-records", response_class=HTMLResponse)
async def dns_records(
request: Request,
upload_id: Optional[str] = None,
record_type: Optional[str] = None,
record_class: Optional[str] = None,
tld: Optional[str] = None,
sld: Optional[str] = None,
domain: Optional[str] = None
):
"""DNS Records page with filtering"""
# Get all entries first, based on upload_id if provided
entries = load_dns_entries(upload_id)
# Apply additional filters if provided
if record_type:
entries = [e for e in entries if e.get("record_type") == record_type]
if record_class:
entries = [e for e in entries if e.get("record_class") == record_class]
if tld:
entries = [e for e in entries if e.get("tld") == tld]
if sld:
entries = [e for e in entries if e.get("sld") == sld]
if domain:
entries = [e for e in entries if domain.lower() in e.get("domain", "").lower()]
# Get unique values for filter dropdowns from all entries (not filtered)
all_entries = load_dns_entries(upload_id)
unique_values = get_unique_values(all_entries)
uploads = get_uploads()
return templates.TemplateResponse(
"dns_records.html",
{
"request": request,
"entries": entries,
"unique_values": unique_values,
"uploads": uploads
}
)
# API Routes
@app.get("/api/uploads", response_model=List[Dict])
async def get_all_uploads():
"""API endpoint that returns all uploads"""
return get_uploads()
@app.get("/api/slds", response_model=List[Dict])
async def get_slds(upload_id: Optional[str] = None):
"""API endpoint that returns all SLDs with optional filter by upload_id"""
# The load_domains function now handles deduplication and upload_id filtering
domains = load_domains(upload_id)
return domains
@app.get("/api/slds/{sld}", response_model=List[Dict])
async def get_domains_by_sld(sld: str, upload_id: Optional[str] = None):
"""API endpoint that returns domains for a specific SLD with optional filter by upload_id"""
# Get domains, already deduplicated and optionally filtered by upload_id
all_domains = load_domains(upload_id)
# Filter by SLD
filtered = [item for item in all_domains if item["sld"].lower() == sld.lower()]
if not filtered:
raise HTTPException(status_code=404, detail=f"No domains found with SLD: {sld}")
return filtered
@app.get("/api/dns", response_model=List[Dict])
async def get_dns_entries(
record_type: Optional[str] = None,
record_class: Optional[str] = None,
tld: Optional[str] = None,
sld: Optional[str] = None,
domain: Optional[str] = None,
upload_id: Optional[str] = None
):
"""API endpoint that returns filtered DNS entries"""
# Get entries - if upload_id is specified, only those entries are returned
entries = load_dns_entries(upload_id)
# Apply additional filters if provided
if record_type:
entries = [e for e in entries if e.get("record_type") == record_type]
if record_class:
entries = [e for e in entries if e.get("record_class") == record_class]
if tld:
entries = [e for e in entries if e.get("tld") == tld]
if sld:
entries = [e for e in entries if e.get("sld") == sld]
if domain:
entries = [e for e in entries if domain.lower() in e.get("domain", "").lower()]
return entries
@app.get("/api/dns/types", response_model=Dict[str, List])
async def get_unique_filter_values(upload_id: Optional[str] = None):
"""API endpoint that returns unique values for filters"""
# Get entries - if upload_id is specified, only those entries are returned
entries = load_dns_entries(upload_id)
return get_unique_values(entries)
# Create templates directory and HTML file
Path("templates").mkdir(exist_ok=True)
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)